Remove streaming transport as a top level API (#110)

- Remove Streaming* classes from Sockets. The main
API will be channels based and streaming transports
will use the PipelineChannel (formerly FramingChannel) to
access messages.
- Added WriteAsync and ReadAsync to Connection and hid
the IChannelConnection from public API.
- Also fixed the fact that unknown methods caused server side
exceptions.
- Changed the consumption pattern to WaitToReadAsync/TryRead to avoid
exceptions.
- React to API changes
This commit is contained in:
David Fowler 2017-01-11 04:01:49 -08:00 committed by GitHub
parent 9dbb3742c8
commit cd9ed9228a
45 changed files with 544 additions and 1003 deletions

View File

@ -16,7 +16,7 @@ namespace ChatSample.Hubs
{
if (!Context.User.Identity.IsAuthenticated)
{
Context.Connection.Transport.Dispose();
Context.Connection.Dispose();
}
return Task.CompletedTask;

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Threading.Tasks;
@ -13,19 +14,19 @@ namespace SocialWeather
public class PersistentConnectionLifeTimeManager
{
private readonly FormatterResolver _formatterResolver;
private readonly ConnectionList<StreamingConnection> _connectionList = new ConnectionList<StreamingConnection>();
private readonly ConnectionList _connectionList = new ConnectionList();
public PersistentConnectionLifeTimeManager(FormatterResolver formatterResolver)
{
_formatterResolver = formatterResolver;
}
public void OnConnectedAsync(StreamingConnection connection)
public void OnConnectedAsync(Connection connection)
{
_connectionList.Add(connection);
}
public void OnDisconnectedAsync(StreamingConnection connection)
public void OnDisconnectedAsync(Connection connection)
{
_connectionList.Remove(connection);
}
@ -35,7 +36,10 @@ namespace SocialWeather
foreach (var connection in _connectionList)
{
var formatter = _formatterResolver.GetFormatter<T>(connection.Metadata.Get<string>("formatType"));
await formatter.WriteAsync(data, connection.Transport.GetStream());
var ms = new MemoryStream();
await formatter.WriteAsync(data, ms);
var buffer = ReadableBuffer.Create(ms.ToArray()).Preserve();
await connection.Transport.Output.WriteAsync(new Message(buffer, Format.Binary, endOfMessage: true));
}
}
@ -54,7 +58,7 @@ namespace SocialWeather
throw new NotImplementedException();
}
public void AddGroupAsync(StreamingConnection connection, string groupName)
public void AddGroupAsync(Connection connection, string groupName)
{
var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>());
lock (groups)
@ -63,7 +67,7 @@ namespace SocialWeather
}
}
public void RemoveGroupAsync(StreamingConnection connection, string groupName)
public void RemoveGroupAsync(Connection connection, string groupName)
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
if (groups != null)

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
@ -8,7 +9,7 @@ using Microsoft.Extensions.Logging;
namespace SocialWeather
{
public class SocialWeatherEndPoint : StreamingEndPoint
public class SocialWeatherEndPoint : EndPoint
{
private readonly PersistentConnectionLifeTimeManager _lifetimeManager;
private readonly FormatterResolver _formatterResolver;
@ -22,22 +23,24 @@ namespace SocialWeather
_logger = logger;
}
public async override Task OnConnectedAsync(StreamingConnection connection)
public async override Task OnConnectedAsync(Connection connection)
{
_lifetimeManager.OnConnectedAsync(connection);
await ProcessRequests(connection);
_lifetimeManager.OnDisconnectedAsync(connection);
}
public async Task ProcessRequests(StreamingConnection connection)
public async Task ProcessRequests(Connection connection)
{
var stream = connection.Transport.GetStream();
var formatter = _formatterResolver.GetFormatter<WeatherReport>(
connection.Metadata.Get<string>("formatType"));
WeatherReport weatherReport;
while ((weatherReport = await formatter.ReadAsync(stream)) != null)
while (true)
{
Message message = await connection.Transport.Input.ReadAsync();
var stream = new MemoryStream();
await message.Payload.Buffer.CopyToAsync(stream);
WeatherReport weatherReport = await formatter.ReadAsync(stream);
await _lifetimeManager.SendToAllAsync(weatherReport);
}
}

View File

@ -1,67 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
namespace SocketsSample
{
public class ChatEndPoint : StreamingEndPoint
{
public ConnectionList<StreamingConnection> Connections { get; } = new ConnectionList<StreamingConnection>();
public override async Task OnConnectedAsync(StreamingConnection connection)
{
Connections.Add(connection);
await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})");
while (true)
{
var result = await connection.Transport.Input.ReadAsync();
var input = result.Buffer;
try
{
if (input.IsEmpty && result.IsCompleted)
{
break;
}
// We can avoid the copy here but we'll deal with that later
await Broadcast(input.ToArray());
}
finally
{
connection.Transport.Input.Advance(input.End);
}
}
Connections.Remove(connection);
await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})");
}
private Task Broadcast(string text)
{
return Broadcast(Encoding.UTF8.GetBytes(text));
}
private Task Broadcast(byte[] payload)
{
var tasks = new List<Task>(Connections.Count);
foreach (var c in Connections)
{
tasks.Add(c.Transport.Output.WriteAsync(payload));
}
return Task.WhenAll(tasks);
}
}
}

View File

@ -12,11 +12,11 @@ using Microsoft.AspNetCore.Sockets;
namespace SocketsSample.EndPoints
{
public class MessagesEndPoint : MessagingEndPoint
public class MessagesEndPoint : EndPoint
{
public ConnectionList<MessagingConnection> Connections { get; } = new ConnectionList<MessagingConnection>();
public ConnectionList Connections { get; } = new ConnectionList();
public override async Task OnConnectedAsync(MessagingConnection connection)
public override async Task OnConnectedAsync(Connection connection)
{
Connections.Add(connection);
@ -24,19 +24,19 @@ namespace SocketsSample.EndPoints
try
{
while (true)
while (await connection.Transport.Input.WaitToReadAsync())
{
using (var message = await connection.Transport.Input.ReadAsync())
Message message;
if (connection.Transport.Input.TryRead(out message))
{
// We can avoid the copy here but we'll deal with that later
await Broadcast(message.Payload.Buffer, message.MessageFormat, message.EndOfMessage);
using (message)
{
// We can avoid the copy here but we'll deal with that later
await Broadcast(message.Payload.Buffer, message.MessageFormat, message.EndOfMessage);
}
}
}
}
catch (Exception ex) when (ex.GetType().IsNested && ex.GetType().DeclaringType == typeof(Channel))
{
// Gross that we have to catch this this way. See https://github.com/dotnet/corefxlab/issues/1068
}
finally
{
Connections.Remove(connection);

View File

@ -29,7 +29,6 @@ namespace SocketsSample
});
// .AddRedis();
services.AddSingleton<ChatEndPoint>();
services.AddSingleton<MessagesEndPoint>();
services.AddSingleton<ProtobufSerializer>();
}
@ -53,8 +52,7 @@ namespace SocketsSample
app.UseSockets(routes =>
{
routes.MapEndpoint<ChatEndPoint>("/chat");
routes.MapEndpoint<MessagesEndPoint>("/msgs");
routes.MapEndpoint<MessagesEndPoint>("/chat");
});
}
}

View File

@ -6,18 +6,12 @@
</head>
<body>
<h1>ASP.NET Sockets</h1>
<h2>Streaming</h2>
<h2>Messaging</h2>
<ul>
<li><a href="sse.html#/chat">Server Sent Events</a></li>
<li><a href="polling.html#/chat">Long polling</a></li>
<li><a href="ws.html#/chat">Web Sockets</a></li>
</ul>
<h2>Messaging</h2>
<ul>
<li><a href="sse.html#/msgs">Server Sent Events</a></li>
<li><a href="polling.html#/msgs">Long polling</a></li>
<li><a href="ws.html#/msgs">Web Sockets</a></li>
</ul>
<h1>ASP.NET SignalR</h1>
<ul>
<li><a href="hubs.html">Hubs</a></li>

View File

@ -20,7 +20,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable
{
private readonly ConnectionList<StreamingConnection> _connections = new ConnectionList<StreamingConnection>();
private readonly ConnectionList _connections = new ConnectionList();
// TODO: Investigate "memory leak" entries never get removed
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>();
private readonly InvocationAdapterRegistry _registry;
@ -51,7 +51,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
foreach (var connection in _connections)
{
tasks.Add(connection.Transport.Output.WriteAsync((byte[])data));
tasks.Add(WriteAsync(connection, data));
}
previousBroadcastTask = Task.WhenAll(tasks);
@ -116,7 +116,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
public override Task OnConnectedAsync(StreamingConnection connection)
public override Task OnConnectedAsync(Connection connection)
{
var redisSubscriptions = connection.Metadata.GetOrAdd("redis_subscriptions", _ => new HashSet<string>());
var connectionTask = TaskCache.CompletedTask;
@ -133,7 +133,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
await previousConnectionTask;
previousConnectionTask = connection.Transport.Output.WriteAsync((byte[])data);
previousConnectionTask = WriteAsync(connection, data);
});
@ -149,14 +149,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
await previousUserTask;
previousUserTask = connection.Transport.Output.WriteAsync((byte[])data);
previousUserTask = WriteAsync(connection, data);
});
}
return Task.WhenAll(connectionTask, userTask);
}
public override Task OnDisconnectedAsync(StreamingConnection connection)
public override Task OnDisconnectedAsync(Connection connection)
{
_connections.Remove(connection);
@ -186,7 +186,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return Task.WhenAll(tasks);
}
public override async Task AddGroupAsync(StreamingConnection connection, string groupName)
public override async Task AddGroupAsync(Connection connection, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
@ -220,9 +220,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis
await previousTask;
var tasks = new List<Task>(group.Connections.Count);
foreach (var groupConnection in group.Connections.Cast<StreamingConnection>())
foreach (var groupConnection in group.Connections.Cast<Connection>())
{
tasks.Add(groupConnection.Transport.Output.WriteAsync((byte[])data));
tasks.Add(WriteAsync(groupConnection, data));
}
previousTask = Task.WhenAll(tasks);
@ -234,7 +234,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
public override async Task RemoveGroupAsync(StreamingConnection connection, string groupName)
public override async Task RemoveGroupAsync(Connection connection, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
@ -275,6 +275,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis
_redisServerConnection.Dispose();
}
private Task WriteAsync(Connection connection, byte[] data)
{
var buffer = ReadableBuffer.Create(data).Preserve();
return connection.Transport.Output.WriteAsync(new Message(buffer, Format.Binary, endOfMessage: true));
}
private class LoggerTextWriter : TextWriter
{
private readonly ILogger _logger;
@ -300,7 +306,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
private class GroupData
{
public SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
public ConnectionList<StreamingConnection> Connections = new ConnectionList<StreamingConnection>();
public ConnectionList Connections = new ConnectionList();
}
}
}

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
@ -12,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR
{
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub>
{
private readonly ConnectionList<StreamingConnection> _connections = new ConnectionList<StreamingConnection>();
private readonly ConnectionList _connections = new ConnectionList();
private readonly InvocationAdapterRegistry _registry;
public DefaultHubLifetimeManager(InvocationAdapterRegistry registry)
@ -20,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR
_registry = registry;
}
public override Task AddGroupAsync(StreamingConnection connection, string groupName)
public override Task AddGroupAsync(Connection connection, string groupName)
{
var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>());
@ -32,7 +33,7 @@ namespace Microsoft.AspNetCore.SignalR
return TaskCache.CompletedTask;
}
public override Task RemoveGroupAsync(StreamingConnection connection, string groupName)
public override Task RemoveGroupAsync(Connection connection, string groupName)
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
@ -54,7 +55,7 @@ namespace Microsoft.AspNetCore.SignalR
return InvokeAllWhere(methodName, args, c => true);
}
private Task InvokeAllWhere(string methodName, object[] args, Func<StreamingConnection, bool> include)
private Task InvokeAllWhere(string methodName, object[] args, Func<Connection, bool> include)
{
var tasks = new List<Task>(_connections.Count);
var message = new InvocationDescriptor
@ -73,7 +74,7 @@ namespace Microsoft.AspNetCore.SignalR
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
tasks.Add(invocationAdapter.WriteMessageAsync(message, connection.Transport.GetStream()));
tasks.Add(WriteAsync(connection, invocationAdapter, message));
}
return Task.WhenAll(tasks);
@ -91,7 +92,7 @@ namespace Microsoft.AspNetCore.SignalR
Arguments = args
};
return invocationAdapter.WriteMessageAsync(message, connection.Transport.GetStream());
return WriteAsync(connection, invocationAdapter, message);
}
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
@ -111,17 +112,24 @@ namespace Microsoft.AspNetCore.SignalR
});
}
public override Task OnConnectedAsync(StreamingConnection connection)
public override Task OnConnectedAsync(Connection connection)
{
_connections.Add(connection);
return TaskCache.CompletedTask;
}
public override Task OnDisconnectedAsync(StreamingConnection connection)
public override Task OnDisconnectedAsync(Connection connection)
{
_connections.Remove(connection);
return TaskCache.CompletedTask;
}
}
private static Task WriteAsync(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor message)
{
var stream = new MemoryStream();
invocationAdapter.WriteMessageAsync(message, stream);
var buffer = ReadableBuffer.Create(stream.ToArray()).Preserve();
return connection.Transport.Output.WriteAsync(new Message(buffer, Format.Binary, endOfMessage: true));
}
}
}

View File

@ -8,12 +8,12 @@ namespace Microsoft.AspNetCore.SignalR
{
public class HubCallerContext
{
public HubCallerContext(StreamingConnection connection)
public HubCallerContext(Connection connection)
{
Connection = connection;
}
public StreamingConnection Connection { get; }
public Connection Connection { get; }
public ClaimsPrincipal User => Connection.User;

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Reflection;
@ -25,10 +26,10 @@ namespace Microsoft.AspNetCore.SignalR
}
}
public class HubEndPoint<THub, TClient> : StreamingEndPoint, IInvocationBinder where THub : Hub<TClient>
public class HubEndPoint<THub, TClient> : EndPoint, IInvocationBinder where THub : Hub<TClient>
{
private readonly Dictionary<string, Func<StreamingConnection, InvocationDescriptor, Task<InvocationResultDescriptor>>> _callbacks
= new Dictionary<string, Func<StreamingConnection, InvocationDescriptor, Task<InvocationResultDescriptor>>>(StringComparer.OrdinalIgnoreCase);
private readonly Dictionary<string, Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>>> _callbacks
= new Dictionary<string, Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>>>(StringComparer.OrdinalIgnoreCase);
private readonly Dictionary<string, Type[]> _paramTypes = new Dictionary<string, Type[]>();
private readonly HubLifetimeManager<THub> _lifetimeManager;
@ -52,7 +53,7 @@ namespace Microsoft.AspNetCore.SignalR
DiscoverHubMethods();
}
public override async Task OnConnectedAsync(StreamingConnection connection)
public override async Task OnConnectedAsync(Connection connection)
{
// TODO: Dispatch from the caller
await Task.Yield();
@ -68,7 +69,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task RunHubAsync(StreamingConnection connection)
private async Task RunHubAsync(Connection connection)
{
await HubOnConnectedAsync(connection);
@ -86,7 +87,7 @@ namespace Microsoft.AspNetCore.SignalR
await HubOnDisconnectedAsync(connection, null);
}
private async Task HubOnConnectedAsync(StreamingConnection connection)
private async Task HubOnConnectedAsync(Connection connection)
{
try
{
@ -112,7 +113,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task HubOnDisconnectedAsync(StreamingConnection connection, Exception exception)
private async Task HubOnDisconnectedAsync(Connection connection, Exception exception)
{
try
{
@ -138,15 +139,26 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task DispatchMessagesAsync(StreamingConnection connection)
private async Task DispatchMessagesAsync(Connection connection)
{
var stream = connection.Transport.GetStream();
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
while (true)
while (await connection.Transport.Input.WaitToReadAsync())
{
// TODO: Handle receiving InvocationResultDescriptor
var invocationDescriptor = await invocationAdapter.ReadMessageAsync(stream, this) as InvocationDescriptor;
Message message;
if (!connection.Transport.Input.TryRead(out message))
{
continue;
}
InvocationDescriptor invocationDescriptor;
using (message)
{
var inputStream = new MemoryStream(message.Payload.Buffer.ToArray());
// TODO: Handle receiving InvocationResultDescriptor
invocationDescriptor = await invocationAdapter.ReadMessageAsync(inputStream, this) as InvocationDescriptor;
}
// Is there a better way of detecting that a connection was closed?
if (invocationDescriptor == null)
@ -160,7 +172,7 @@ namespace Microsoft.AspNetCore.SignalR
}
InvocationResultDescriptor result;
Func<StreamingConnection, InvocationDescriptor, Task<InvocationResultDescriptor>> callback;
Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>> callback;
if (_callbacks.TryGetValue(invocationDescriptor.Method, out callback))
{
result = await callback(connection, invocationDescriptor);
@ -177,11 +189,19 @@ namespace Microsoft.AspNetCore.SignalR
_logger.LogError("Unknown hub method '{method}'", invocationDescriptor.Method);
}
await invocationAdapter.WriteMessageAsync(result, stream);
// TODO: Pool memory
var outStream = new MemoryStream();
await invocationAdapter.WriteMessageAsync(result, outStream);
var buffer = ReadableBuffer.Create(outStream.ToArray()).Preserve();
if (await connection.Transport.Output.WaitToWriteAsync())
{
connection.Transport.Output.TryWrite(new Message(buffer, Format.Binary, endOfMessage: true));
}
}
}
private void InitializeHub(THub hub, StreamingConnection connection)
private void InitializeHub(THub hub, Connection connection)
{
hub.Clients = _hubContext.Clients;
hub.Context = new HubCallerContext(connection);
@ -290,7 +310,7 @@ namespace Microsoft.AspNetCore.SignalR
Type[] types;
if (!_paramTypes.TryGetValue(methodName, out types))
{
throw new InvalidOperationException($"The hub method '{methodName}' could not be resolved.");
return Type.EmptyTypes;
}
return types;
}

View File

@ -8,9 +8,9 @@ namespace Microsoft.AspNetCore.SignalR
{
public abstract class HubLifetimeManager<THub>
{
public abstract Task OnConnectedAsync(StreamingConnection connection);
public abstract Task OnConnectedAsync(Connection connection);
public abstract Task OnDisconnectedAsync(StreamingConnection connection);
public abstract Task OnDisconnectedAsync(Connection connection);
public abstract Task InvokeAllAsync(string methodName, object[] args);
@ -20,9 +20,9 @@ namespace Microsoft.AspNetCore.SignalR
public abstract Task InvokeUserAsync(string userId, string methodName, object[] args);
public abstract Task AddGroupAsync(StreamingConnection connection, string groupName);
public abstract Task AddGroupAsync(Connection connection, string groupName);
public abstract Task RemoveGroupAsync(StreamingConnection connection, string groupName);
public abstract Task RemoveGroupAsync(Connection connection, string groupName);
}
}

View File

@ -75,10 +75,10 @@ namespace Microsoft.AspNetCore.SignalR
public class GroupManager<THub> : IGroupManager
{
private readonly StreamingConnection _connection;
private readonly Connection _connection;
private readonly HubLifetimeManager<THub> _lifetimeManager;
public GroupManager(StreamingConnection connection, HubLifetimeManager<THub> lifetimeManager)
public GroupManager(Connection connection, HubLifetimeManager<THub> lifetimeManager)
{
_connection = connection;
_lifetimeManager = lifetimeManager;

View File

@ -3,24 +3,28 @@
using System;
using System.Security.Claims;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets
{
public abstract class Connection : IDisposable
public class Connection : IDisposable
{
public abstract ConnectionMode Mode { get; }
public string ConnectionId { get; }
public ClaimsPrincipal User { get; set; }
public ConnectionMetadata Metadata { get; } = new ConnectionMetadata();
protected Connection(string id)
public IChannelConnection<Message> Transport { get; }
public Connection(string id, IChannelConnection<Message> transport)
{
Transport = transport;
ConnectionId = id;
}
public virtual void Dispose()
public void Dispose()
{
Transport.Dispose();
}
}
}

View File

@ -8,15 +8,15 @@ using System.Collections.Generic;
namespace Microsoft.AspNetCore.Sockets
{
public class ConnectionList<T> : IReadOnlyCollection<T> where T: Connection
public class ConnectionList : IReadOnlyCollection<Connection>
{
private readonly ConcurrentDictionary<string, T> _connections = new ConcurrentDictionary<string, T>();
private readonly ConcurrentDictionary<string, Connection> _connections = new ConcurrentDictionary<string, Connection>();
public T this[string connectionId]
public Connection this[string connectionId]
{
get
{
T connection;
Connection connection;
if (_connections.TryGetValue(connectionId, out connection))
{
return connection;
@ -27,18 +27,18 @@ namespace Microsoft.AspNetCore.Sockets
public int Count => _connections.Count;
public void Add(T connection)
public void Add(Connection connection)
{
_connections.TryAdd(connection.ConnectionId, connection);
}
public void Remove(T connection)
public void Remove(Connection connection)
{
T dummy;
Connection dummy;
_connections.TryRemove(connection.ConnectionId, out dummy);
}
public IEnumerator<T> GetEnumerator()
public IEnumerator<Connection> GetEnumerator()
{
foreach (var item in _connections)
{

View File

@ -3,8 +3,6 @@
using System;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Sockets.Internal;
@ -15,11 +13,9 @@ namespace Microsoft.AspNetCore.Sockets
{
private readonly ConcurrentDictionary<string, ConnectionState> _connections = new ConcurrentDictionary<string, ConnectionState>();
private readonly Timer _timer;
private readonly PipelineFactory _pipelineFactory;
public ConnectionManager(PipelineFactory pipelineFactory)
public ConnectionManager()
{
_pipelineFactory = pipelineFactory;
_timer = new Timer(Scan, this, 0, 1000);
}
@ -28,8 +24,23 @@ namespace Microsoft.AspNetCore.Sockets
return _connections.TryGetValue(id, out state);
}
public ConnectionState CreateConnection(ConnectionMode mode) =>
mode == ConnectionMode.Streaming ? CreateStreamingConnection() : CreateMessagingConnection();
public ConnectionState CreateConnection()
{
var id = MakeNewConnectionId();
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
var state = new ConnectionState(
new Connection(id, applicationSide),
transportSide);
_connections.TryAdd(id, state);
return state;
}
public void RemoveConnection(string id)
{
@ -92,41 +103,5 @@ namespace Microsoft.AspNetCore.Sockets
}
}
}
private ConnectionState CreateMessagingConnection()
{
var id = MakeNewConnectionId();
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
var state = new MessagingConnectionState(
new MessagingConnection(id, applicationSide),
transportSide);
_connections.TryAdd(id, state);
return state;
}
private ConnectionState CreateStreamingConnection()
{
var id = MakeNewConnectionId();
var transportToApplication = _pipelineFactory.Create();
var applicationToTransport = _pipelineFactory.Create();
var transportSide = new PipelineConnection(applicationToTransport, transportToApplication);
var applicationSide = new PipelineConnection(transportToApplication, applicationToTransport);
var state = new StreamingConnectionState(
new StreamingConnection(id, applicationSide),
transportSide);
_connections.TryAdd(id, state);
return state;
}
}
}

View File

@ -1,11 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.Sockets
{
public enum ConnectionMode
{
Streaming,
Messaging
}
}

View File

@ -11,14 +11,6 @@ namespace Microsoft.AspNetCore.Sockets
// REVIEW: This doesn't have any members any more... marker interface? Still even necessary?
public abstract class EndPoint
{
/// <summary>
/// Gets the connection mode supported by this endpoint.
/// </summary>
/// <remarks>
/// This maps directly to whichever of <see cref="MessagingEndPoint"/> or <see cref="StreamingEndPoint"/> the end point subclasses.
/// </remarks>
public abstract ConnectionMode Mode { get; }
/// <summary>
/// Called when a new connection is accepted to the endpoint
/// </summary>

View File

@ -18,14 +18,12 @@ namespace Microsoft.AspNetCore.Sockets
public class HttpConnectionDispatcher
{
private readonly ConnectionManager _manager;
private readonly PipelineFactory _pipelineFactory;
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
public HttpConnectionDispatcher(ConnectionManager manager, PipelineFactory factory, ILoggerFactory loggerFactory)
public HttpConnectionDispatcher(ConnectionManager manager, ILoggerFactory loggerFactory)
{
_manager = manager;
_pipelineFactory = factory;
_loggerFactory = loggerFactory;
_logger = _loggerFactory.CreateLogger<HttpConnectionDispatcher>();
}
@ -37,7 +35,7 @@ namespace Microsoft.AspNetCore.Sockets
if (context.Request.Path.StartsWithSegments(path + "/getid"))
{
await ProcessGetId(context, endpoint.Mode);
await ProcessGetId(context);
}
else if (context.Request.Path.StartsWithSegments(path + "/send"))
{
@ -56,10 +54,10 @@ namespace Microsoft.AspNetCore.Sockets
? Format.Binary
: Format.Text;
var state = GetOrCreateConnection(context, endpoint.Mode);
var state = GetOrCreateConnection(context);
// Adapt the connection to a message-based transport if necessary, since all the HTTP transports are message-based.
var application = GetMessagingChannel(state, format);
var application = state.Application;
// Server sent events transport
if (context.Request.Path.StartsWithSegments(path + "/sse"))
@ -137,7 +135,7 @@ namespace Microsoft.AspNetCore.Sockets
// Notify the long polling transport to end
if (endpointTask.IsFaulted)
{
state.TerminateTransport(endpointTask.Exception.InnerException);
state.Connection.Transport.Output.TryComplete(endpointTask.Exception.InnerException);
}
state.Connection.Dispose();
@ -151,19 +149,6 @@ namespace Microsoft.AspNetCore.Sockets
}
}
private static IChannelConnection<Message> GetMessagingChannel(ConnectionState state, Format format)
{
if (state.Connection.Mode == ConnectionMode.Messaging)
{
return ((MessagingConnectionState)state).Application;
}
else
{
// We need to build an adapter
return new FramingChannel(((StreamingConnectionState)state).Application, format);
}
}
private ConnectionState InitializePersistentConnection(ConnectionState state, string transport, HttpContext context, EndPoint endpoint, Format format)
{
state.Connection.User = context.User;
@ -197,10 +182,10 @@ namespace Microsoft.AspNetCore.Sockets
await Task.WhenAll(endpointTask, transportTask);
}
private Task ProcessGetId(HttpContext context, ConnectionMode mode)
private Task ProcessGetId(HttpContext context)
{
// Establish the connection
var state = _manager.CreateConnection(mode);
var state = _manager.CreateConnection();
// Get the bytes for the connection id
var connectionIdBuffer = Encoding.UTF8.GetBytes(state.Connection.ConnectionId);
@ -221,34 +206,27 @@ namespace Microsoft.AspNetCore.Sockets
ConnectionState state;
if (_manager.TryGetConnection(connectionId, out state))
{
if (state.Connection.Mode == ConnectionMode.Streaming)
// Collect the message and write it to the channel
// TODO: Need to use some kind of pooled memory here.
byte[] buffer;
using (var stream = new MemoryStream())
{
var streamingState = (StreamingConnectionState)state;
await context.Request.Body.CopyToAsync(streamingState.Application.Output);
await context.Request.Body.CopyToAsync(stream);
buffer = stream.ToArray();
}
else
{
// Collect the message and write it to the channel
// TODO: Need to use some kind of pooled memory here.
byte[] buffer;
using (var strm = new MemoryStream())
{
await context.Request.Body.CopyToAsync(strm);
await strm.FlushAsync();
buffer = strm.ToArray();
}
var format =
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
? Format.Binary
: Format.Text;
var message = new Message(
ReadableBuffer.Create(buffer).Preserve(),
format,
endOfMessage: true);
await ((MessagingConnectionState)state).Application.Output.WriteAsync(message);
}
var format =
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
? Format.Binary
: Format.Text;
var message = new Message(
ReadableBuffer.Create(buffer).Preserve(),
format,
endOfMessage: true);
await state.Application.Output.WriteAsync(message);
}
else
{
@ -256,7 +234,7 @@ namespace Microsoft.AspNetCore.Sockets
}
}
private ConnectionState GetOrCreateConnection(HttpContext context, ConnectionMode mode)
private ConnectionState GetOrCreateConnection(HttpContext context)
{
var connectionId = context.Request.Query["id"];
ConnectionState connectionState;
@ -264,7 +242,7 @@ namespace Microsoft.AspNetCore.Sockets
// There's no connection id so this is a brand new connection
if (StringValues.IsNullOrEmpty(connectionId))
{
connectionState = _manager.CreateConnection(mode);
connectionState = _manager.CreateConnection();
}
else if (!_manager.TryGetConnection(connectionId, out connectionState))
{

View File

@ -2,9 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
namespace Microsoft.AspNetCore.Sockets.Internal
@ -15,7 +12,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal
public IChannel<T> Output { get; }
IReadableChannel<T> IChannelConnection<T>.Input => Input;
IWritableChannel<T> IChannelConnection<T>.Output => Output;
public ChannelConnection(IChannel<T> input, IChannel<T> output)

View File

@ -5,24 +5,27 @@ using System;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public abstract class ConnectionState : IDisposable
public class ConnectionState : IDisposable
{
public Connection Connection { get; set; }
public ConnectionMode Mode => Connection.Mode;
public IChannelConnection<Message> Application { get; }
// These are used for long polling mostly
public Action Close { get; set; }
public DateTime LastSeenUtc { get; set; }
public bool Active { get; set; } = true;
protected ConnectionState(Connection connection)
public ConnectionState(Connection connection, IChannelConnection<Message> application)
{
Connection = connection;
Application = application;
LastSeenUtc = DateTime.UtcNow;
}
public abstract void Dispose();
public abstract void TerminateTransport(Exception innerException);
public void Dispose()
{
Connection.Dispose();
Application.Dispose();
}
}
}

View File

@ -1,142 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
namespace Microsoft.AspNetCore.Sockets.Internal
{
/// <summary>
/// Creates a <see cref="IChannelConnection{Message}"/> out of a <see cref="IPipelineConnection"/> by framing data
/// read out of the Pipeline and flattening out frames to write them to the Pipeline when received.
/// </summary>
public class FramingChannel : IChannelConnection<Message>, IReadableChannel<Message>, IWritableChannel<Message>
{
private readonly IPipelineConnection _connection;
private readonly TaskCompletionSource<object> _tcs = new TaskCompletionSource<object>();
private readonly Format _format;
Task IReadableChannel<Message>.Completion => _tcs.Task;
public IReadableChannel<Message> Input => this;
public IWritableChannel<Message> Output => this;
public FramingChannel(IPipelineConnection connection, Format format)
{
_connection = connection;
_format = format;
}
ValueTask<Message> IReadableChannel<Message>.ReadAsync(CancellationToken cancellationToken)
{
var awaiter = _connection.Input.ReadAsync();
if (awaiter.IsCompleted)
{
return new ValueTask<Message>(ReadSync(awaiter.GetResult(), cancellationToken));
}
else
{
return new ValueTask<Message>(AwaitReadAsync(awaiter, cancellationToken));
}
}
private void CancelRead()
{
// We need to fake cancellation support until we get a newer build of pipelines that has CancelPendingRead()
// HACK: from hell, we attempt to cast the input to a pipeline writer and write 0 bytes so it so that we can
// force yielding the awaiter, this is buggy because overlapping writes can be a problem.
(_connection.Input as IPipelineWriter)?.WriteAsync(Span<byte>.Empty);
}
bool IReadableChannel<Message>.TryRead(out Message item)
{
// We need to think about how we do this. There's no way to check if there is data available in a Pipeline... though maybe there should be
// We could ReadAsync and check IsCompleted, but then we'd also need to stash that Awaitable for later since we can't call ReadAsync a second time...
// CancelPendingReads could help here.
item = default(Message);
return false;
}
Task<bool> IReadableChannel<Message>.WaitToReadAsync(CancellationToken cancellationToken)
{
// See above for TryRead. Same problems here.
throw new NotSupportedException();
}
Task IWritableChannel<Message>.WriteAsync(Message item, CancellationToken cancellationToken)
{
// Just dump the message on to the pipeline
var buffer = _connection.Output.Alloc();
buffer.Append(item.Payload.Buffer);
return buffer.FlushAsync();
}
Task<bool> IWritableChannel<Message>.WaitToWriteAsync(CancellationToken cancellationToken)
{
// We need to think about how we do this. We don't have a wait to synchronously check for back-pressure in the Pipeline.
throw new NotSupportedException();
}
bool IWritableChannel<Message>.TryWrite(Message item)
{
// We need to think about how we do this. We don't have a wait to synchronously check for back-pressure in the Pipeline.
return false;
}
bool IWritableChannel<Message>.TryComplete(Exception error)
{
_connection.Output.Complete(error);
_connection.Input.Complete(error);
return true;
}
private async Task<Message> AwaitReadAsync(ReadableBufferAwaitable awaiter, CancellationToken cancellationToken)
{
using (cancellationToken.Register(state => ((FramingChannel)state).CancelRead(), this))
{
// Just await and then call ReadSync
var result = await awaiter;
return ReadSync(result, cancellationToken);
}
}
private Message ReadSync(ReadResult result, CancellationToken cancellationToken)
{
var buffer = result.Buffer;
// Preserve the buffer and advance the pipeline past it
var preserved = buffer.Preserve();
_connection.Input.Advance(buffer.End);
var msg = new Message(preserved, _format, endOfMessage: true);
if (result.IsCompleted)
{
// Complete the task
_tcs.TrySetResult(null);
}
if (cancellationToken.IsCancellationRequested)
{
_tcs.TrySetCanceled();
msg.Dispose();
// In order to keep the behavior consistent between the transports, we throw if the token was cancelled
throw new OperationCanceledException();
}
return msg;
}
public void Dispose()
{
_tcs.TrySetResult(null);
_connection.Dispose();
}
}
}

View File

@ -1,29 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public class MessagingConnectionState : ConnectionState
{
public new MessagingConnection Connection => (MessagingConnection)base.Connection;
public IChannelConnection<Message> Application { get; }
public MessagingConnectionState(MessagingConnection connection, IChannelConnection<Message> application) : base(connection)
{
Application = application;
}
public override void Dispose()
{
Connection.Dispose();
Application.Dispose();
}
public override void TerminateTransport(Exception innerException)
{
Connection.Transport.Output.TryComplete(innerException);
}
}
}

View File

@ -1,31 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public class PipelineConnection : IPipelineConnection
{
public PipelineReaderWriter Input { get; }
public PipelineReaderWriter Output { get; }
IPipelineReader IPipelineConnection.Input => Input;
IPipelineWriter IPipelineConnection.Output => Output;
public PipelineConnection(PipelineReaderWriter input, PipelineReaderWriter output)
{
Input = input;
Output = output;
}
public void Dispose()
{
Input.CompleteReader();
Input.CompleteWriter();
Output.CompleteReader();
Output.CompleteWriter();
}
}
}

View File

@ -1,31 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public class StreamingConnectionState : ConnectionState
{
public new StreamingConnection Connection => (StreamingConnection)base.Connection;
public IPipelineConnection Application { get; }
public StreamingConnectionState(StreamingConnection connection, IPipelineConnection application) : base(connection)
{
Application = application;
}
public override void Dispose()
{
Connection.Dispose();
Application.Dispose();
}
public override void TerminateTransport(Exception innerException)
{
Connection.Transport.Output.Complete(innerException);
Connection.Transport.Input.Complete(innerException);
}
}
}

View File

@ -1,23 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.Sockets
{
public class MessagingConnection : Connection
{
public override ConnectionMode Mode => ConnectionMode.Messaging;
public IChannelConnection<Message> Transport { get; }
public MessagingConnection(string id, IChannelConnection<Message> transport) : base(id)
{
Transport = transport;
}
public override void Dispose()
{
Transport.Dispose();
}
}
}

View File

@ -1,29 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets
{
public abstract class MessagingEndPoint : EndPoint
{
public override ConnectionMode Mode => ConnectionMode.Messaging;
public override Task OnConnectedAsync(Connection connection)
{
if (connection.Mode != Mode)
{
throw new InvalidOperationException($"Connection mode does not match endpoint mode. Connection mode is '{connection.Mode}', endpoint mode is '{Mode}'");
}
return OnConnectedAsync((MessagingConnection)connection);
}
/// <summary>
/// Called when a new connection is accepted to the endpoint
/// </summary>
/// <param name="connection">The new <see cref="MessagingConnection"/></param>
/// <returns>A <see cref="Task"/> that represents the connection lifetime. When the task completes, the connection is complete.</returns>
public abstract Task OnConnectedAsync(MessagingConnection connection);
}
}

View File

@ -1,24 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Sockets
{
public class StreamingConnection : Connection
{
public override ConnectionMode Mode => ConnectionMode.Streaming;
public IPipelineConnection Transport { get; set; }
public StreamingConnection(string id, IPipelineConnection transport) : base(id)
{
Transport = transport;
}
public override void Dispose()
{
Transport.Dispose();
}
}
}

View File

@ -1,29 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets
{
public abstract class StreamingEndPoint : EndPoint
{
public override ConnectionMode Mode => ConnectionMode.Streaming;
public override Task OnConnectedAsync(Connection connection)
{
if(connection.Mode != Mode)
{
throw new InvalidOperationException($"Connection mode does not match endpoint mode. Connection mode is '{connection.Mode}', endpoint mode is '{Mode}'");
}
return OnConnectedAsync((StreamingConnection)connection);
}
/// <summary>
/// Called when a new connection is accepted to the endpoint
/// </summary>
/// <param name="connection">The new <see cref="StreamingConnection"/></param>
/// <returns>A <see cref="Task"/> that represents the connection lifetime. When the task completes, the connection is complete.</returns>
public abstract Task OnConnectedAsync(StreamingConnection connection);
}
}

View File

@ -13,18 +13,18 @@ namespace Microsoft.AspNetCore.Sockets.Transports
{
public class LongPollingTransport : IHttpTransport
{
private readonly IReadableChannel<Message> _connection;
private readonly IReadableChannel<Message> _application;
private readonly ILogger _logger;
public LongPollingTransport(IReadableChannel<Message> connection, ILoggerFactory loggerFactory)
public LongPollingTransport(IReadableChannel<Message> application, ILoggerFactory loggerFactory)
{
_connection = connection;
_application = application;
_logger = loggerFactory.CreateLogger<LongPollingTransport>();
}
public async Task ProcessRequestAsync(HttpContext context)
{
if (_connection.Completion.IsCompleted)
if (_application.Completion.IsCompleted)
{
// Client should stop if it receives a 204
_logger.LogInformation("Terminating Long Polling connection by sending 204 response.");
@ -37,18 +37,19 @@ namespace Microsoft.AspNetCore.Sockets.Transports
// TODO: We need the ability to yield the connection without completing the channel.
// This is to force ReadAsync to yield without data to end to poll but not the entire connection.
// This is for cases when the client reconnects see issue #27
using (var message = await _connection.ReadAsync(context.RequestAborted))
await _application.WaitToReadAsync(context.RequestAborted);
Message message;
if (_application.TryRead(out message))
{
_logger.LogDebug("Writing {0} byte message to response", message.Payload.Buffer.Length);
context.Response.ContentLength = message.Payload.Buffer.Length;
await message.Payload.Buffer.CopyToAsync(context.Response.Body);
using (message)
{
_logger.LogDebug("Writing {0} byte message to response", message.Payload.Buffer.Length);
context.Response.ContentLength = message.Payload.Buffer.Length;
await message.Payload.Buffer.CopyToAsync(context.Response.Body);
}
}
}
catch (Exception ex) when (ex.GetType().IsNested && ex.GetType().DeclaringType == typeof(Channel))
{
// The Channel was closed, while we were waiting to read. That's fine, just means we're done.
// Gross that we have to catch this this way. See https://github.com/dotnet/corefxlab/issues/1068
}
catch (OperationCanceledException)
{
// Suppress the exception

View File

@ -30,18 +30,18 @@ namespace Microsoft.AspNetCore.Sockets.Transports
try
{
while (true)
while (await _application.WaitToReadAsync(context.RequestAborted))
{
using (var message = await _application.ReadAsync(context.RequestAborted))
Message message;
if (_application.TryRead(out message))
{
await Send(context, message);
using (message)
{
await Send(context, message);
}
}
}
}
catch (Exception ex) when (ex.GetType().IsNested && ex.GetType().DeclaringType == typeof(Channel))
{
// Gross that we have to catch this this way. See https://github.com/dotnet/corefxlab/issues/1068
}
catch (OperationCanceledException)
{
// Closed connection

View File

@ -22,20 +22,20 @@ namespace Microsoft.AspNetCore.Sockets.Transports
private bool _lastFrameIncomplete = false;
private readonly ILogger _logger;
private readonly IChannelConnection<Message> _connection;
private readonly IChannelConnection<Message> _application;
public WebSocketsTransport(IChannelConnection<Message> connection, ILoggerFactory loggerFactory)
public WebSocketsTransport(IChannelConnection<Message> application, ILoggerFactory loggerFactory)
{
if (connection == null)
if (application == null)
{
throw new ArgumentNullException(nameof(connection));
throw new ArgumentNullException(nameof(application));
}
if (loggerFactory == null)
{
throw new ArgumentNullException(nameof(loggerFactory));
}
_connection = connection;
_application = application;
_logger = loggerFactory.CreateLogger<WebSocketsTransport>();
}
@ -84,7 +84,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
// Shutting down because we received a close frame from the client.
// Complete the input writer so that the application knows there won't be any more input.
_logger.LogDebug("Client closed connection with status code '{0}' ({1}). Signaling end-of-input to application", receiving.Result.Status, receiving.Result.Description);
_connection.Output.TryComplete();
_application.Output.TryComplete();
// Wait for the application to finish sending.
_logger.LogDebug("Waiting for the application to finish sending data");
@ -95,7 +95,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
}
else
{
var failed = sending.IsFaulted || sending.IsCompleted;
var failed = sending.IsFaulted || _application.Input.Completion.IsFaulted;
// The application finished sending. Close our end of the connection
_logger.LogDebug(!failed ? "Application finished sending. Sending close frame." : "Application failed during sending. Sending InternalServerError close frame");
@ -109,7 +109,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
// Wait for the client to close.
// TODO: Consider timing out here and cancelling the receive loop.
await receiving;
_connection.Output.TryComplete();
_application.Output.TryComplete();
}
}
@ -138,7 +138,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
var message = new Message(frame.Payload.Preserve(), effectiveOpcode == WebSocketOpcode.Binary ? Format.Binary : Format.Text, frame.EndOfMessage);
// Write the message to the channel
return _connection.Output.WriteAsync(message);
return _application.Output.WriteAsync(message);
}
private void LogFrame(string action, WebSocketFrame frame)
@ -152,12 +152,13 @@ namespace Microsoft.AspNetCore.Sockets.Transports
private async Task StartSending(IWebSocketConnection ws)
{
while (true)
while (await _application.Input.WaitToReadAsync())
{
// Get a frame from the application
try
Message message;
if (_application.Input.TryRead(out message))
{
using (var message = await _connection.Input.ReadAsync())
using (message)
{
if (message.Payload.Buffer.Length > 0)
{
@ -185,11 +186,6 @@ namespace Microsoft.AspNetCore.Sockets.Transports
}
}
}
catch (Exception ex) when (ex.GetType().IsNested && ex.GetType().DeclaringType == typeof(Channel))
{
// Gross that we have to catch this this way. See https://github.com/dotnet/corefxlab/issues/1068
break;
}
}
}
}

View File

@ -20,7 +20,6 @@
"xmlDoc": true
},
"dependencies": {
"System.IO.Pipelines": "0.1.0-*",
"System.Threading.Tasks.Channels": "0.1.0-*",
"System.Security.Claims": "4.4.0-*",

View File

@ -71,7 +71,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
buffer.WriteBigEndian((ushort)Status);
if (!string.IsNullOrEmpty(Description))
{
buffer.Append(Description, EncodingData.InvariantUtf8.TextEncoding);
buffer.Append(Description, EncodingData.InvariantUtf8);
}
}
}

View File

@ -246,7 +246,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
_options.MaskingKeyGenerator.GetBytes(_maskingKeyBuffer);
}
buffer.Set(_maskingKeyBuffer);
_maskingKeyBuffer.CopyTo(buffer);
}
private async Task<WebSocketCloseResult> ReceiveLoop(Func<WebSocketFrame, object, Task> messageHandler, object state, CancellationToken cancellationToken)
@ -550,7 +550,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
{
// TODO: Could use TryGetPointer, GetBytes does take a byte*, but it seems like just waiting until we have a version that uses Span is best.
// Slow path - Allocate a heap buffer for the encoded bytes before writing them out.
payload.Span.Set(Encoding.UTF8.GetBytes(str));
Encoding.UTF8.GetBytes(str).CopyTo(payload.Span);
}
if (maskingKey.Length > 0)

View File

@ -134,10 +134,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
EnsureConnectionEstablished(connection);
var ex = await Assert.ThrowsAnyAsync<InvalidOperationException>(
async () => await connection.Invoke<Task>("!@#$%"));
var ex = await Assert.ThrowsAnyAsync<Exception>(
async () => await connection.Invoke<object>("!@#$%"));
Assert.Equal(ex.Message, "The hub method '!@#$%' could not be resolved.");
Assert.Equal(ex.Message, "Unknown hub method '!@#$%'");
}
}
}

View File

@ -8,11 +8,11 @@ using Microsoft.AspNetCore.Sockets;
namespace Microsoft.AspNetCore.SignalR.Test.Server
{
public class EchoEndPoint : StreamingEndPoint
public class EchoEndPoint : EndPoint
{
public async override Task OnConnectedAsync(StreamingConnection connection)
public async override Task OnConnectedAsync(Connection connection)
{
await connection.Transport.Input.CopyToAsync(connection.Transport.Output);
await connection.Transport.Output.WriteAsync(await connection.Transport.Input.ReadAsync());
}
}
}

View File

@ -4,9 +4,11 @@
using System;
using System.IO;
using System.IO.Pipelines;
using System.Runtime.CompilerServices;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.DependencyInjection;
@ -33,7 +35,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await connectionWrapper.ApplicationStartedReading;
// kill the connection
connectionWrapper.ConnectionState.Dispose();
connectionWrapper.Dispose();
await endPointTask;
@ -41,44 +43,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task OnDisconnectedCalledWithExceptionIfHubMethodNotFound()
{
var hub = Mock.Of<Hub>();
var endPointType = GetEndPointType(hub.GetType());
var serviceProvider = CreateServiceProvider(s =>
{
s.AddSingleton(endPointType);
s.AddTransient(hub.GetType(), sp => hub);
});
dynamic endPoint = serviceProvider.GetService(endPointType);
using (var connectionWrapper = new ConnectionWrapper())
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
await connectionWrapper.ApplicationStartedReading;
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var adapter = invocationAdapter.GetInvocationAdapter("json");
await SendRequest(connectionWrapper.Connection.Transport, adapter, "0xdeadbeef");
connectionWrapper.Dispose();
await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
Mock.Get(hub).Verify(h => h.OnDisconnectedAsync(It.IsNotNull<InvalidOperationException>()), Times.Once());
}
}
[Fact]
public async Task LifetimeManagerOnDisconnectedAsyncCalledIfLifetimeManagerOnConnectedAsyncThrows()
{
var mockLifetimeManager = new Mock<HubLifetimeManager<Hub>>();
mockLifetimeManager
.Setup(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()))
.Setup(m => m.OnConnectedAsync(It.IsAny<Connection>()))
.Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed."));
var mockHubActivator = new Mock<IHubActivator<Hub, IClientProxy>>();
@ -97,10 +67,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
async () => await endPoint.OnConnectedAsync(connectionWrapper.Connection));
Assert.Equal("Lifetime manager OnConnectedAsync failed.", exception.Message);
connectionWrapper.ConnectionState.Dispose();
connectionWrapper.Dispose();
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
// No hubs should be created since the connection is terminated
mockHubActivator.Verify(m => m.Create(), Times.Never);
mockHubActivator.Verify(m => m.Release(It.IsAny<Hub>()), Times.Never);
@ -121,13 +91,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests
using (var connectionWrapper = new ConnectionWrapper())
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
connectionWrapper.ConnectionState.Dispose();
connectionWrapper.Dispose();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
Assert.Equal("Hub OnConnected failed.", exception.Message);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
}
}
@ -145,44 +115,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests
using (var connectionWrapper = new ConnectionWrapper())
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
connectionWrapper.ConnectionState.Dispose();
connectionWrapper.Dispose();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
Assert.Equal("Hub OnDisconnected failed.", exception.Message);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
}
}
private static Type GetEndPointType(Type hubType)
{
var endPointType = typeof(HubEndPoint<>);
return endPointType.MakeGenericType(hubType);
}
private static Type GetGenericType(Type genericType, Type hubType)
{
return genericType.MakeGenericType(hubType);
}
public class OnConnectedThrowsHub : Hub
{
public override Task OnConnectedAsync()
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(new InvalidOperationException("Hub OnConnected failed."));
return tcs.Task;
}
}
public class OnDisconnectedThrowsHub : Hub
{
public override Task OnDisconnectedAsync(Exception exception)
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(new InvalidOperationException("Hub OnDisconnected failed."));
return tcs.Task;
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
}
}
@ -202,10 +141,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var adapter = invocationAdapter.GetInvocationAdapter("json");
await SendRequest(connectionWrapper.Connection.Transport, adapter, "TaskValueMethod");
var res = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper.Connection.Transport);
await SendRequest(connectionWrapper, adapter, nameof(MethodHub.TaskValueMethod));
var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper);
// json serializer makes this a long
Assert.Equal(42L, res.Result);
Assert.Equal(42L, result.Result);
// kill the connection
connectionWrapper.Connection.Dispose();
@ -230,10 +170,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var adapter = invocationAdapter.GetInvocationAdapter("json");
await SendRequest(connectionWrapper.Connection.Transport, adapter, "ValueMethod");
var res = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper.Connection.Transport);
await SendRequest(connectionWrapper, adapter, "ValueMethod");
var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper);
// json serializer makes this a long
Assert.Equal(43L, res.Result);
Assert.Equal(43L, result.Result);
// kill the connection
connectionWrapper.Connection.Dispose();
@ -258,9 +199,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var adapter = invocationAdapter.GetInvocationAdapter("json");
await SendRequest(connectionWrapper.Connection.Transport, adapter, "StaticMethod");
var res = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper.Connection.Transport);
Assert.Equal("fromStatic", res.Result);
await SendRequest(connectionWrapper, adapter, "StaticMethod");
var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper);
Assert.Equal("fromStatic", result.Result);
// kill the connection
connectionWrapper.Connection.Dispose();
@ -285,9 +227,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var adapter = invocationAdapter.GetInvocationAdapter("json");
await SendRequest(connectionWrapper.Connection.Transport, adapter, "VoidMethod");
var res = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper.Connection.Transport);
Assert.Equal(null, res.Result);
await SendRequest(connectionWrapper, adapter, "VoidMethod");
var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper);
Assert.Null(result.Result);
// kill the connection
connectionWrapper.Connection.Dispose();
@ -312,9 +255,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var adapter = invocationAdapter.GetInvocationAdapter("json");
await SendRequest(connectionWrapper.Connection.Transport, adapter, "ConcatString", (byte)32, 42, 'm', "string");
var res = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper.Connection.Transport);
Assert.Equal("32, 42, m, string", res.Result);
await SendRequest(connectionWrapper, adapter, "ConcatString", (byte)32, 42, 'm', "string");
var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper);
Assert.Equal("32, 42, m, string", result.Result);
// kill the connection
connectionWrapper.Connection.Dispose();
@ -339,17 +282,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var adapter = invocationAdapter.GetInvocationAdapter("json");
await SendRequest(connectionWrapper.Connection.Transport, adapter, "OnDisconnectedAsync");
await SendRequest(connectionWrapper, adapter, "OnDisconnectedAsync");
var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper);
try
{
await endPointTask;
Assert.True(false);
}
catch (InvalidOperationException ex)
{
Assert.Equal("The hub method 'OnDisconnectedAsync' could not be resolved.", ex.Message);
}
Assert.Equal("Unknown hub method 'OnDisconnectedAsync'", result.Error);
}
}
@ -371,21 +307,21 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var adapter = invocationAdapter.GetInvocationAdapter("json");
await SendRequest(firstConnection.Connection.Transport, adapter, "BroadcastMethod", "test");
await SendRequest(firstConnection, adapter, "BroadcastMethod", "test");
foreach (var res in await Task.WhenAll(
ReadConnectionOutputAsync<InvocationDescriptor>(firstConnection.Connection.Transport),
ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection.Connection.Transport)))
foreach (var result in await Task.WhenAll(
ReadConnectionOutputAsync<InvocationDescriptor>(firstConnection),
ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection)))
{
Assert.Equal("Broadcast", res.Method);
Assert.Equal(1, res.Arguments.Length);
Assert.Equal("test", res.Arguments[0]);
Assert.Equal("Broadcast", result.Method);
Assert.Equal(1, result.Arguments.Length);
Assert.Equal("test", result.Arguments[0]);
}
// kill the connections
firstConnection.Connection.Dispose();
secondConnection.Connection.Dispose();
await Task.WhenAll(firstEndPointTask, secondEndPointTask);
}
}
@ -408,18 +344,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var adapter = invocationAdapter.GetInvocationAdapter("json");
await SendRequest_IgnoreReceive(firstConnection.Connection.Transport, adapter, "GroupSendMethod", "testGroup", "test");
await SendRequest_IgnoreReceive(firstConnection, adapter, "GroupSendMethod", "testGroup", "test");
// check that 'secondConnection' hasn't received the group send
Assert.False(((PipelineReaderWriter)secondConnection.Connection.Transport.Output).ReadAsync().IsCompleted);
Message message;
Assert.False(secondConnection.Transport.Output.TryRead(out message));
await SendRequest_IgnoreReceive(secondConnection.Connection.Transport, adapter, "GroupAddMethod", "testGroup");
await SendRequest_IgnoreReceive(secondConnection, adapter, "GroupAddMethod", "testGroup");
await SendRequest(firstConnection, adapter, "GroupSendMethod", "testGroup", "test");
await SendRequest(firstConnection.Connection.Transport, adapter, "GroupSendMethod", "testGroup", "test");
// check that 'firstConnection' hasn't received the group send
Assert.False(((PipelineReaderWriter)firstConnection.Connection.Transport.Output).ReadAsync().IsCompleted);
Assert.False(firstConnection.Transport.Output.TryRead(out message));
// check that 'secondConnection' has received the group send
var res = await ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection.Connection.Transport);
var res = await ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection);
Assert.Equal("Send", res.Method);
Assert.Equal(1, res.Arguments.Length);
Assert.Equal("test", res.Arguments[0]);
@ -448,7 +386,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var writer = invocationAdapter.GetInvocationAdapter("json");
await SendRequest_IgnoreReceive(connection.Connection.Transport, writer, "GroupRemoveMethod", "testGroup");
await SendRequest_IgnoreReceive(connection, writer, "GroupRemoveMethod", "testGroup");
// kill the connection
connection.Connection.Dispose();
@ -475,10 +413,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var adapter = invocationAdapter.GetInvocationAdapter("json");
await SendRequest_IgnoreReceive(firstConnection.Connection.Transport, adapter, "ClientSendMethod", secondConnection.Connection.User.Identity.Name, "test");
await SendRequest_IgnoreReceive(firstConnection, adapter, "ClientSendMethod", secondConnection.Connection.User.Identity.Name, "test");
// check that 'secondConnection' has received the group send
var res = await ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection.Connection.Transport);
var res = await ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection);
Assert.Equal("Send", res.Method);
Assert.Equal(1, res.Arguments.Length);
Assert.Equal("test", res.Arguments[0]);
@ -509,13 +447,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var adapter = invocationAdapter.GetInvocationAdapter("json");
await SendRequest_IgnoreReceive(firstConnection.Connection.Transport, adapter, "ConnectionSendMethod", secondConnection.Connection.ConnectionId, "test");
await SendRequest_IgnoreReceive(firstConnection, adapter, "ConnectionSendMethod", secondConnection.Connection.ConnectionId, "test");
// check that 'secondConnection' has received the group send
var res = await ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection.Connection.Transport);
Assert.Equal("Send", res.Method);
Assert.Equal(1, res.Arguments.Length);
Assert.Equal("test", res.Arguments[0]);
var result = await ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection);
Assert.Equal("Send", result.Method);
Assert.Equal(1, result.Arguments.Length);
Assert.Equal("test", result.Arguments[0]);
// kill the connections
firstConnection.Connection.Dispose();
@ -525,6 +463,84 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
private static Type GetEndPointType(Type hubType)
{
var endPointType = typeof(HubEndPoint<>);
return endPointType.MakeGenericType(hubType);
}
private static Type GetGenericType(Type genericType, Type hubType)
{
return genericType.MakeGenericType(hubType);
}
public async Task SendRequest(ConnectionWrapper connection, IInvocationAdapter writer, string method, params object[] args)
{
if (connection == null)
{
throw new ArgumentNullException();
}
var stream = new MemoryStream();
await writer.WriteMessageAsync(new InvocationDescriptor
{
Arguments = args,
Method = method
},
stream);
var buffer = ReadableBuffer.Create(stream.ToArray()).Preserve();
await connection.Transport.Input.WriteAsync(new Message(buffer, Format.Binary, endOfMessage: true));
}
public async Task SendRequest_IgnoreReceive(ConnectionWrapper connection, IInvocationAdapter writer, string method, params object[] args)
{
await SendRequest(connection, writer, method, args);
// Consume the result
await connection.Transport.Output.ReadAsync();
}
private async Task<T> ReadConnectionOutputAsync<T>(ConnectionWrapper connection)
{
// TODO: other formats?
var message = await connection.Transport.Output.ReadAsync();
var serializer = new JsonSerializer();
return serializer.Deserialize<T>(new JsonTextReader(new StreamReader(new MemoryStream(message.Payload.Buffer.ToArray()))));
}
private IServiceProvider CreateServiceProvider(Action<ServiceCollection> addServices = null)
{
var services = new ServiceCollection();
services.AddOptions()
.AddLogging()
.AddSignalR();
addServices?.Invoke(services);
return services.BuildServiceProvider();
}
public class OnConnectedThrowsHub : Hub
{
public override Task OnConnectedAsync()
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(new InvalidOperationException("Hub OnConnected failed."));
return tcs.Task;
}
}
public class OnDisconnectedThrowsHub : Hub
{
public override Task OnDisconnectedAsync(Exception exception)
{
var tcs = new TaskCompletionSource<object>();
tcs.SetException(new InvalidOperationException("Hub OnDisconnected failed."));
return tcs.Task;
}
}
private class MethodHub : Hub
{
public Task GroupRemoveMethod(string groupName)
@ -610,83 +626,91 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public int DisposeCount = 0;
}
public async Task SendRequest(IPipelineConnection connection, IInvocationAdapter writer, string method, params object[] args)
public class ConnectionWrapper : IDisposable
{
if (connection == null)
{
throw new ArgumentNullException();
}
private static int _id;
private readonly TestChannel<Message> _input;
var stream = new MemoryStream();
await writer.WriteMessageAsync(new InvocationDescriptor
{
Arguments = args,
Method = method
}, stream);
public Connection Connection { get; }
var buffer = ((PipelineReaderWriter)connection.Input).Alloc();
buffer.Write(stream.ToArray());
await buffer.FlushAsync();
}
public ChannelConnection<Message> Transport { get; }
public async Task SendRequest_IgnoreReceive(IPipelineConnection connection, IInvocationAdapter writer, string method, params object[] args)
{
await SendRequest(connection, writer, method, args);
var methodResult = await ((PipelineReaderWriter)connection.Output).ReadAsync();
((PipelineReaderWriter)connection.Output).AdvanceReader(methodResult.Buffer.End, methodResult.Buffer.End);
}
private async Task<T> ReadConnectionOutputAsync<T>(IPipelineConnection connection)
{
// TODO: other formats?
var methodResult = await ((PipelineReaderWriter)connection.Output).ReadAsync();
var serializer = new JsonSerializer();
var res = serializer.Deserialize<T>(new JsonTextReader(new StreamReader(new MemoryStream(methodResult.Buffer.ToArray()))));
((PipelineReaderWriter)connection.Output).AdvanceReader(methodResult.Buffer.End, methodResult.Buffer.End);
return res;
}
private IServiceProvider CreateServiceProvider(Action<ServiceCollection> addServices = null)
{
var services = new ServiceCollection();
services.AddOptions()
.AddLogging()
.AddSignalR();
addServices?.Invoke(services);
return services.BuildServiceProvider();
}
private class ConnectionWrapper : IDisposable
{
private static int ID;
private PipelineFactory _factory;
public StreamingConnectionState ConnectionState;
public StreamingConnection Connection => ConnectionState.Connection;
// Still kinda gross...
public Task ApplicationStartedReading => ((PipelineReaderWriter)Connection.Transport.Input).ReadingStarted;
public Task ApplicationStartedReading => _input.ReadingStarted;
public ConnectionWrapper(string format = "json")
{
_factory = new PipelineFactory();
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
var connectionManager = new ConnectionManager(_factory);
_input = new TestChannel<Message>(transportToApplication);
ConnectionState = (StreamingConnectionState)connectionManager.CreateConnection(ConnectionMode.Streaming);
ConnectionState.Connection.Metadata["formatType"] = format;
ConnectionState.Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref ID).ToString()) }));
Transport = new ChannelConnection<Message>(_input, applicationToTransport);
Connection = new Connection(Guid.NewGuid().ToString(), Transport);
Connection.Metadata["formatType"] = format;
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) }));
}
public void Dispose()
{
ConnectionState.Dispose();
_factory.Dispose();
Connection.Dispose();
}
private class TestChannel<T> : IChannel<T>
{
private IChannel<T> _channel;
private TaskCompletionSource<object> _tcs = new TaskCompletionSource<object>();
public TestChannel(IChannel<T> channel)
{
_channel = channel;
}
public Task Completion => _channel.Completion;
public Task ReadingStarted => _tcs.Task;
public ValueAwaiter<T> GetAwaiter()
{
return _channel.GetAwaiter();
}
public ValueTask<T> ReadAsync(CancellationToken cancellationToken = default(CancellationToken))
{
_tcs.TrySetResult(null);
return _channel.ReadAsync(cancellationToken);
}
public bool TryComplete(Exception error = null)
{
return _channel.TryComplete(error);
}
public bool TryRead(out T item)
{
return _channel.TryRead(out item);
}
public bool TryWrite(T item)
{
return _channel.TryWrite(item);
}
public Task<bool> WaitToReadAsync(CancellationToken cancellationToken = default(CancellationToken))
{
_tcs.TrySetResult(null);
return _channel.WaitToReadAsync(cancellationToken);
}
public Task<bool> WaitToWriteAsync(CancellationToken cancellationToken = default(CancellationToken))
{
return _channel.WaitToWriteAsync(cancellationToken);
}
public Task WriteAsync(T item, CancellationToken cancellationToken = default(CancellationToken))
{
return _channel.WriteAsync(item, cancellationToken);
}
}
}
}

View File

@ -1,7 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets.Internal;
using Xunit;
@ -13,100 +13,87 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public void NewConnectionsHaveConnectionId()
{
using (var factory = new PipelineFactory())
{
var connectionManager = new ConnectionManager(factory);
var state = connectionManager.CreateConnection(ConnectionMode.Streaming);
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.True(state.Active);
Assert.Null(state.Close);
Assert.NotNull(((StreamingConnectionState)state).Connection.Transport);
}
var connectionManager = new ConnectionManager();
var state = connectionManager.CreateConnection();
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.True(state.Active);
Assert.Null(state.Close);
Assert.NotNull(state.Connection.Transport);
}
[Fact]
public void NewConnectionsCanBeRetrieved()
{
using (var factory = new PipelineFactory())
{
var connectionManager = new ConnectionManager(factory);
var state = connectionManager.CreateConnection(ConnectionMode.Streaming);
var connectionManager = new ConnectionManager();
var state = connectionManager.CreateConnection();
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
ConnectionState newState;
Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
Assert.Same(newState, state);
}
ConnectionState newState;
Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
Assert.Same(newState, state);
}
[Fact]
public void AddNewConnection()
{
using (var factory = new PipelineFactory())
{
var connectionManager = new ConnectionManager(factory);
var state = connectionManager.CreateConnection(ConnectionMode.Streaming);
var connectionManager = new ConnectionManager();
var state = connectionManager.CreateConnection();
var transport = ((StreamingConnectionState)state).Connection.Transport;
var transport = state.Connection.Transport;
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.NotNull(transport);
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.NotNull(transport);
ConnectionState newState;
Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
Assert.Same(newState, state);
Assert.Same(transport, ((StreamingConnectionState)newState).Connection.Transport);
}
ConnectionState newState;
Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
Assert.Same(newState, state);
Assert.Same(transport, newState.Connection.Transport);
}
[Fact]
public void RemoveConnection()
{
using (var factory = new PipelineFactory())
{
var connectionManager = new ConnectionManager(factory);
var state = connectionManager.CreateConnection(ConnectionMode.Streaming);
var connectionManager = new ConnectionManager();
var state = connectionManager.CreateConnection();
var transport = ((StreamingConnectionState)state).Connection.Transport;
var transport = state.Connection.Transport;
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.NotNull(transport);
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.NotNull(transport);
ConnectionState newState;
Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
Assert.Same(newState, state);
Assert.Same(transport, ((StreamingConnectionState)newState).Connection.Transport);
ConnectionState newState;
Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
Assert.Same(newState, state);
Assert.Same(transport, newState.Connection.Transport);
connectionManager.RemoveConnection(state.Connection.ConnectionId);
Assert.False(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
}
connectionManager.RemoveConnection(state.Connection.ConnectionId);
Assert.False(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
}
[Fact]
public async Task CloseConnectionsEndsAllPendingConnections()
{
using (var factory = new PipelineFactory())
var connectionManager = new ConnectionManager();
var state = connectionManager.CreateConnection();
var task = Task.Run(async () =>
{
var connectionManager = new ConnectionManager(factory);
var state = (StreamingConnectionState)connectionManager.CreateConnection(ConnectionMode.Streaming);
var connection = state.Connection;
var task = Task.Run(async () =>
{
var result = await state.Connection.Transport.Input.ReadAsync();
Assert.False(await connection.Transport.Input.WaitToReadAsync());
Assert.True(connection.Transport.Input.Completion.IsCompleted);
});
Assert.True(result.IsCompleted);
});
connectionManager.CloseConnections();
connectionManager.CloseConnections();
await task;
}
await task;
}
}
}

View File

@ -22,96 +22,65 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task GetIdReservesConnectionIdAndReturnsIt()
{
using (var factory = new PipelineFactory())
{
var manager = new ConnectionManager(factory);
var dispatcher = new HttpConnectionDispatcher(manager, factory, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
var ms = new MemoryStream();
context.Request.Path = "/getid";
context.Response.Body = ms;
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
var manager = new ConnectionManager();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
var ms = new MemoryStream();
context.Request.Path = "/getid";
context.Response.Body = ms;
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
var id = Encoding.UTF8.GetString(ms.ToArray());
var id = Encoding.UTF8.GetString(ms.ToArray());
ConnectionState state;
Assert.True(manager.TryGetConnection(id, out state));
Assert.Equal(id, state.Connection.ConnectionId);
}
ConnectionState state;
Assert.True(manager.TryGetConnection(id, out state));
Assert.Equal(id, state.Connection.ConnectionId);
}
// REVIEW: No longer relevant since we establish the connection right away.
//[Fact]
//public async Task SendingToReservedConnectionsThatHaveNotConnectedThrows()
//{
// using (var factory = new PipelineFactory())
// {
// var manager = new ConnectionManager(factory);
// var state = manager.ReserveConnection();
// var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory: null);
// var context = new DefaultHttpContext();
// context.Request.Path = "/send";
// var values = new Dictionary<string, StringValues>();
// values["id"] = state.Connection.ConnectionId;
// var qs = new QueryCollection(values);
// context.Request.Query = qs;
// await Assert.ThrowsAsync<InvalidOperationException>(async () =>
// {
// await dispatcher.ExecuteAsync<TestEndPoint>("", context);
// });
// }
//}
[Fact]
public async Task SendingToUnknownConnectionIdThrows()
{
using (var factory = new PipelineFactory())
var manager = new ConnectionManager();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/send";
var values = new Dictionary<string, StringValues>();
values["id"] = "unknown";
var qs = new QueryCollection(values);
context.Request.Query = qs;
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
{
var manager = new ConnectionManager(factory);
var dispatcher = new HttpConnectionDispatcher(manager, factory, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/send";
var values = new Dictionary<string, StringValues>();
values["id"] = "unknown";
var qs = new QueryCollection(values);
context.Request.Query = qs;
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
{
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
});
}
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
});
}
[Fact]
public async Task SendingWithoutConnectionIdThrows()
{
using (var factory = new PipelineFactory())
var manager = new ConnectionManager();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/send";
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
{
var manager = new ConnectionManager(factory);
var dispatcher = new HttpConnectionDispatcher(manager, factory, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/send";
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
{
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
});
}
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
});
}
}
public class TestEndPoint : StreamingEndPoint
public class TestEndPoint : EndPoint
{
public override Task OnConnectedAsync(StreamingConnection connection)
public override Task OnConnectedAsync(Connection connection)
{
throw new NotImplementedException();
}

View File

@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task Set204StatusCodeWhenChannelComplete()
{
var channel = Channel.Create<Message>();
var channel = Channel.CreateUnbounded<Message>();
var context = new DefaultHttpContext();
var poll = new LongPollingTransport(channel, new LoggerFactory());
@ -35,7 +35,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task FrameSentAsSingleResponse()
{
var channel = Channel.Create<Message>();
var channel = Channel.CreateUnbounded<Message>();
var context = new DefaultHttpContext();
var poll = new LongPollingTransport(channel, new LoggerFactory());
var ms = new MemoryStream();

View File

@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task SSESetsContentType()
{
var channel = Channel.Create<Message>();
var channel = Channel.CreateUnbounded<Message>();
var context = new DefaultHttpContext();
var sse = new ServerSentEventsTransport(channel, new LoggerFactory());
@ -36,7 +36,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task SSEAddsAppropriateFraming()
{
var channel = Channel.Create<Message>();
var channel = Channel.CreateUnbounded<Message>();
var context = new DefaultHttpContext();
var sse = new ServerSentEventsTransport(channel, new LoggerFactory());
var ms = new MemoryStream();

View File

@ -22,8 +22,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[InlineData(Format.Binary, WebSocketOpcode.Binary)]
public async Task ReceivedFramesAreWrittenToChannel(Format format, WebSocketOpcode opcode)
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
@ -70,8 +70,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[InlineData(Format.Binary, WebSocketOpcode.Binary)]
public async Task MultiFrameMessagesArePropagatedToTheChannel(Format format, WebSocketOpcode opcode)
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
@ -129,8 +129,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[InlineData(Format.Binary, WebSocketOpcode.Binary)]
public async Task IncompleteMessagesAreWrittenAsMultiFrameWebSocketMessages(Format format, WebSocketOpcode opcode)
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
@ -177,8 +177,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[InlineData(Format.Binary, WebSocketOpcode.Binary)]
public async Task DataWrittenToOutputPipelineAreSentAsFrames(Format format, WebSocketOpcode opcode)
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
@ -218,8 +218,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[InlineData(Format.Binary, WebSocketOpcode.Binary)]
public async Task FrameReceivedAfterServerCloseSent(Format format, WebSocketOpcode opcode)
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
@ -261,8 +261,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task TransportFailsWhenClientDisconnectsAbnormally()
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
@ -289,8 +289,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task ClientReceivesInternalServerErrorWhenTheApplicationFails()
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);

View File

@ -197,7 +197,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
}
}
private static void CompleteChannels(params PipelineReaderWriter[] readerWriters)
private static void CompleteChannels(params Pipe[] readerWriters)
{
foreach (var readerWriter in readerWriters)
{

View File

@ -14,13 +14,13 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
private PipelineFactory _factory;
private readonly bool _ownFactory;
public PipelineReaderWriter ServerToClient { get; }
public PipelineReaderWriter ClientToServer { get; }
public Pipe ServerToClient { get; }
public Pipe ClientToServer { get; }
public IWebSocketConnection ClientSocket { get; }
public IWebSocketConnection ServerSocket { get; }
public WebSocketPair(bool ownFactory, PipelineFactory factory, PipelineReaderWriter serverToClient, PipelineReaderWriter clientToServer, IWebSocketConnection clientSocket, IWebSocketConnection serverSocket)
public WebSocketPair(bool ownFactory, PipelineFactory factory, Pipe serverToClient, Pipe clientToServer, IWebSocketConnection clientSocket, IWebSocketConnection serverSocket)
{
_ownFactory = ownFactory;
_factory = factory;