diff --git a/samples/SocketsSample/Bus.cs b/samples/SocketsSample/Bus.cs index f583f96f63..95fb8f2fed 100644 --- a/samples/SocketsSample/Bus.cs +++ b/samples/SocketsSample/Bus.cs @@ -20,11 +20,17 @@ namespace Microsoft.AspNetCore.Sockets public IDisposable Subscribe(string key, Func observer) { var connections = _subscriptions.GetOrAdd(key, _ => new List>()); - connections.Add(observer); + lock (connections) + { + connections.Add(observer); + } return new DisposableAction(() => { - connections.Remove(observer); + lock (connections) + { + connections.Remove(observer); + } }); } @@ -33,10 +39,17 @@ namespace Microsoft.AspNetCore.Sockets List> connections; if (_subscriptions.TryGetValue(key, out connections)) { - foreach (var c in connections) + Task[] tasks = null; + lock (connections) { - await c(message); + tasks = new Task[connections.Count]; + for (int i = 0; i < connections.Count; i++) + { + tasks[i] = connections[i](message); + } } + + await Task.WhenAll(tasks); } } diff --git a/samples/SocketsSample/EndPoints/ChatEndPoint.cs b/samples/SocketsSample/EndPoints/ChatEndPoint.cs index 6e5174a06f..6478e5d7c1 100644 --- a/samples/SocketsSample/EndPoints/ChatEndPoint.cs +++ b/samples/SocketsSample/EndPoints/ChatEndPoint.cs @@ -47,8 +47,6 @@ namespace SocketsSample { Payload = Encoding.UTF8.GetBytes($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})") }); - - connection.Channel.Input.Complete(); } private async Task OnMessage(Message message, Connection connection) diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index 182e330f0d..e4d696a962 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -54,16 +54,6 @@ namespace Microsoft.AspNetCore.Sockets return state; } - public void MarkConnectionInactive(string id) - { - ConnectionState state; - if (_connections.TryGetValue(id, out state)) - { - // Mark the connection as active so the background thread can look at it - state.Active = false; - } - } - public void RemoveConnection(string id) { ConnectionState state; @@ -88,7 +78,7 @@ namespace Microsoft.AspNetCore.Sockets // Scan the registered connections looking for ones that have timed out foreach (var c in _connections) { - if (!c.Value.Active && (DateTimeOffset.UtcNow - c.Value.LastSeen).TotalSeconds > 30) + if (!c.Value.Active && (DateTimeOffset.UtcNow - c.Value.LastSeen).TotalSeconds > 5) { ConnectionState s; if (_connections.TryRemove(c.Key, out s)) diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index 3170af0be8..f93504bb17 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -56,7 +56,7 @@ namespace Microsoft.AspNetCore.Sockets var sse = new ServerSentEvents(connectionState.Connection); // Register this transport for disconnect - RegisterDisconnect(context, sse); + RegisterDisconnect(context, connectionState.Connection); // Call into the end point passing the connection var endpointTask = endpoint.OnConnected(connectionState.Connection); @@ -80,7 +80,7 @@ namespace Microsoft.AspNetCore.Sockets var ws = new WebSockets(connectionState.Connection); // Register this transport for disconnect - RegisterDisconnect(context, ws); + RegisterDisconnect(context, connectionState.Connection); // Call into the end point passing the connection var endpointTask = endpoint.OnConnected(connectionState.Connection); @@ -108,8 +108,6 @@ namespace Microsoft.AspNetCore.Sockets if (connectionState.Connection.Channel == null) { // REVIEW: The connection manager should encapsulate this... - connectionState.Active = true; - connectionState.LastSeen = DateTimeOffset.UtcNow; connectionState.Connection.Channel = new HttpChannel(_channelFactory); isNewConnection = true; } @@ -121,6 +119,9 @@ namespace Microsoft.AspNetCore.Sockets isNewConnection = true; } + // Mark the connection as active + connectionState.Active = true; + // Raise OnConnected for new connections only since polls happen all the time if (isNewConnection) { @@ -130,22 +131,30 @@ namespace Microsoft.AspNetCore.Sockets var ignore = endpoint.OnConnected(connectionState.Connection); } - var longPolling = new LongPolling(connectionState.Connection); + RegisterLongPollingDisconnect(context, connectionState.Connection); - // Register this transport for disconnect - RegisterDisconnect(context, longPolling); + var longPolling = new LongPolling(connectionState.Connection); // Start the transport await longPolling.ProcessRequest(context); - _manager.MarkConnectionInactive(connectionState.Connection.ConnectionId); + // Mark the connection as inactive + connectionState.LastSeen = DateTimeOffset.UtcNow; + connectionState.Active = false; } } } - private static void RegisterDisconnect(HttpContext context, IHttpTransport transport) + private static void RegisterLongPollingDisconnect(HttpContext context, Connection connection) { - context.RequestAborted.Register(state => ((IHttpTransport)state).CloseAsync(), transport); + // For long polling, we need to end the transport but not the overall connection so we write 0 bytes + context.RequestAborted.Register(state => ((HttpChannel)state).Output.WriteAsync(Span.Empty), connection.Channel); + } + + private static void RegisterDisconnect(HttpContext context, Connection connection) + { + // We just kill the output writing as a signal to the transport that it is done + context.RequestAborted.Register(state => ((HttpChannel)state).Output.CompleteWriter(), connection.Channel); } private Task ProcessGetId(HttpContext context) diff --git a/src/Microsoft.AspNetCore.Sockets/IHttpTransport.cs b/src/Microsoft.AspNetCore.Sockets/IHttpTransport.cs index b8aa396fc9..ffe8fa2da7 100644 --- a/src/Microsoft.AspNetCore.Sockets/IHttpTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/IHttpTransport.cs @@ -14,11 +14,5 @@ namespace Microsoft.AspNetCore.Sockets /// /// A that completes when the transport has finished processing Task ProcessRequest(HttpContext context); - - /// - /// Completes the Task returned from ProcessRequest if not already complete - /// - /// - Task CloseAsync(); } } diff --git a/src/Microsoft.AspNetCore.Sockets/LongPolling.cs b/src/Microsoft.AspNetCore.Sockets/LongPolling.cs index bbcc701251..a61f11a90b 100644 --- a/src/Microsoft.AspNetCore.Sockets/LongPolling.cs +++ b/src/Microsoft.AspNetCore.Sockets/LongPolling.cs @@ -7,88 +7,37 @@ namespace Microsoft.AspNetCore.Sockets { public class LongPolling : IHttpTransport { - private readonly TaskCompletionSource _initTcs = new TaskCompletionSource(); - private readonly TaskCompletionSource _lifetime = new TaskCompletionSource(); private readonly HttpChannel _channel; private readonly Connection _connection; - private readonly TaskQueue _queue; - - private HttpContext _context; public LongPolling(Connection connection) { - _queue = new TaskQueue(_initTcs.Task); _connection = connection; _channel = (HttpChannel)connection.Channel; } public async Task ProcessRequest(HttpContext context) - { - _context = context; - - _initTcs.TrySetResult(null); - - // Send queue messages to the connection - var ignore = ProcessMessages(context); - - await _lifetime.Task; - } - - private async Task ProcessMessages(HttpContext context) { var buffer = await _channel.Output.ReadAsync(); if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted) { - await CloseAsync(); + // REVIEW: Set the status code here so the client doesn't reconnect return; } - try + if (!buffer.IsEmpty) { - await Send(buffer); + try + { + context.Response.ContentLength = buffer.Length; + await buffer.CopyToAsync(context.Response.Body); + } + finally + { + _channel.Output.Advance(buffer.End); + } } - finally - { - _channel.Output.Advance(buffer.End); - } - - await EndRequest(); - } - - public async Task CloseAsync() - { - await _queue.Enqueue(state => - { - var context = (HttpContext)state; - // REVIEW: What happens if header was already? - context.Response.Headers["X-ASPNET-SOCKET-DISCONNECT"] = "1"; - return Task.CompletedTask; - }, - _context); - - await EndRequest(); - } - - private async Task EndRequest() - { - // Drain the queue and don't let any new work enter - await _queue.Drain(); - - // Complete the lifetime task - _lifetime.TrySetResult(null); - } - - private Task Send(ReadableBuffer value) - { - // REVIEW: Can we avoid the closure here? - return _queue.Enqueue(state => - { - var data = (ReadableBuffer)state; - _context.Response.ContentLength = data.Length; - return data.CopyToAsync(_context.Response.Body); - }, - value); } } } diff --git a/src/Microsoft.AspNetCore.Sockets/ServerSentEvents.cs b/src/Microsoft.AspNetCore.Sockets/ServerSentEvents.cs index 8f83eee91c..e753503c5d 100644 --- a/src/Microsoft.AspNetCore.Sockets/ServerSentEvents.cs +++ b/src/Microsoft.AspNetCore.Sockets/ServerSentEvents.cs @@ -7,20 +7,13 @@ namespace Microsoft.AspNetCore.Sockets { public class ServerSentEvents : IHttpTransport { - private readonly TaskCompletionSource _initTcs = new TaskCompletionSource(); - private readonly TaskCompletionSource _lifetime = new TaskCompletionSource(); private readonly HttpChannel _channel; private readonly Connection _connection; - private readonly TaskQueue _queue; - - private HttpContext _context; public ServerSentEvents(Connection connection) { - _queue = new TaskQueue(_initTcs.Task); _connection = connection; _channel = (HttpChannel)connection.Channel; - var ignore = StartSending(); } public async Task ProcessRequest(HttpContext context) @@ -28,27 +21,6 @@ namespace Microsoft.AspNetCore.Sockets context.Response.ContentType = "text/event-stream"; context.Response.Headers["Cache-Control"] = "no-cache"; - _context = context; - - // Set the initial TCS when everything is setup - _initTcs.TrySetResult(null); - - await _lifetime.Task; - } - - public async Task CloseAsync() - { - // Drain the queue so no new work can enter - await _queue.Drain(); - - // Complete the lifetime task - _lifetime.TrySetResult(null); - } - - private async Task StartSending() - { - await _initTcs.Task; - while (true) { var buffer = await _channel.Output.ReadAsync(); @@ -58,7 +30,7 @@ namespace Microsoft.AspNetCore.Sockets break; } - await Send(buffer); + await Send(context, buffer); _channel.Output.Advance(buffer.End); } @@ -66,29 +38,24 @@ namespace Microsoft.AspNetCore.Sockets _channel.Output.CompleteReader(); } - private Task Send(ReadableBuffer value) + private async Task Send(HttpContext context, ReadableBuffer data) { - return _queue.Enqueue(async state => - { - var data = (ReadableBuffer)state; - // TODO: Pooled buffers - // 8 = 6(data: ) + 2 (\n\n) - var buffer = new byte[8 + data.Length]; - var at = 0; - buffer[at++] = (byte)'d'; - buffer[at++] = (byte)'a'; - buffer[at++] = (byte)'t'; - buffer[at++] = (byte)'a'; - buffer[at++] = (byte)':'; - buffer[at++] = (byte)' '; - data.CopyTo(new Span(buffer, at, data.Length)); - at += data.Length; - buffer[at++] = (byte)'\n'; - buffer[at++] = (byte)'\n'; - await _context.Response.Body.WriteAsync(buffer, 0, at); - await _context.Response.Body.FlushAsync(); - }, - value); + // TODO: Pooled buffers + // 8 = 6(data: ) + 2 (\n\n) + var buffer = new byte[8 + data.Length]; + var at = 0; + buffer[at++] = (byte)'d'; + buffer[at++] = (byte)'a'; + buffer[at++] = (byte)'t'; + buffer[at++] = (byte)'a'; + buffer[at++] = (byte)':'; + buffer[at++] = (byte)' '; + data.CopyTo(new Span(buffer, at, data.Length)); + at += data.Length; + buffer[at++] = (byte)'\n'; + buffer[at++] = (byte)'\n'; + await context.Response.Body.WriteAsync(buffer, 0, at); + await context.Response.Body.FlushAsync(); } } } diff --git a/src/Microsoft.AspNetCore.Sockets/TaskQueue.cs b/src/Microsoft.AspNetCore.Sockets/TaskQueue.cs deleted file mode 100644 index 09423769eb..0000000000 --- a/src/Microsoft.AspNetCore.Sockets/TaskQueue.cs +++ /dev/null @@ -1,68 +0,0 @@ -using System; -using System.Threading.Tasks; - -namespace Microsoft.AspNetCore.Sockets -{ - // Allows serial queuing of Task instances - // The tasks are not called on the current synchronization context - - internal sealed class TaskQueue - { - private readonly object _lockObj = new object(); - private Task _lastQueuedTask; - private volatile bool _drained; - private long _size; - - public TaskQueue() - : this(Task.CompletedTask) - { - } - - public TaskQueue(Task initialTask) - { - _lastQueuedTask = initialTask; - } - - public bool IsDrained - { - get - { - return _drained; - } - } - - public Task Enqueue(Func taskFunc, object state) - { - // Lock the object for as short amount of time as possible - lock (_lockObj) - { - if (_drained) - { - return _lastQueuedTask; - } - - var newTask = _lastQueuedTask.ContinueWith((t, s1) => - { - if (t.IsFaulted || t.IsCanceled) - { - return t; - } - return taskFunc(s1); - }, - state).Unwrap(); - _lastQueuedTask = newTask; - return newTask; - } - } - - public Task Drain() - { - lock (_lockObj) - { - _drained = true; - - return _lastQueuedTask; - } - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets/WebSockets.cs b/src/Microsoft.AspNetCore.Sockets/WebSockets.cs index 0ff34a5fc7..906c36fced 100644 --- a/src/Microsoft.AspNetCore.Sockets/WebSockets.cs +++ b/src/Microsoft.AspNetCore.Sockets/WebSockets.cs @@ -11,14 +11,11 @@ namespace Microsoft.AspNetCore.Sockets { private readonly HttpChannel _channel; private readonly Connection _connection; - private readonly TaskCompletionSource _tcs = new TaskCompletionSource(); - private WebSocket _ws; public WebSockets(Connection connection) { _connection = connection; _channel = (HttpChannel)connection.Channel; - var ignore = StartSending(); } public async Task ProcessRequest(HttpContext context) @@ -31,9 +28,11 @@ namespace Microsoft.AspNetCore.Sockets var ws = await context.WebSockets.AcceptWebSocketAsync(); - _ws = ws; - - _tcs.TrySetResult(null); + // REVIEW: Should we track this task? Leaving things like this alive usually causes memory leaks :) + // The reason we don't await this is because the channel is disposed after this loop returns + // and the sending loop is waiting for the channel to end before doing anything + // We could do a 2 stage shutdown but that could complicate the code... + var sending = StartSending(ws); var outputBuffer = _channel.Input.Alloc(); @@ -66,46 +65,81 @@ namespace Microsoft.AspNetCore.Sockets } else { - await ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); break; } } + + await ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); } - public async Task CloseAsync() + private async Task StartSending(WebSocket ws) { - await _tcs.Task; - - // REVIEW: Close output vs Close? - await _ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); - } - - private async Task StartSending() - { - await _tcs.Task; - while (true) { var buffer = await _channel.Output.ReadAsync(); - if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted) + try { + if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted) + { + break; + } + + foreach (var memory in buffer) + { + ArraySegment data; + if (memory.TryGetArray(out data)) + { + if (IsClosedOrClosedSent(ws)) + { + break; + } + + await ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None); + } + } + + } + catch (Exception) + { + // Error writing, probably closed break; } - - foreach (var memory in buffer) + finally { - ArraySegment data; - if (memory.TryGetArray(out data)) - { - await _ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None); - } + _channel.Output.Advance(buffer.End); } - - _channel.Output.Advance(buffer.End); } _channel.Output.CompleteReader(); + + // REVIEW: Should this ever happen? + if (!IsClosedOrClosedSent(ws)) + { + // Close the output + await ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + } + } + + private static bool IsClosedOrClosedSent(WebSocket webSocket) + { + var webSocketState = GetWebSocketState(webSocket); + + return webSocketState == WebSocketState.Closed || + webSocketState == WebSocketState.CloseSent || + webSocketState == WebSocketState.Aborted; + } + + private static WebSocketState GetWebSocketState(WebSocket webSocket) + { + try + { + return webSocket.State; + } + catch (ObjectDisposedException) + { + return WebSocketState.Closed; + } } } }