From 813222b406af752ae4bdd3f3980d1f30bae9ffb3 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Sat, 1 Oct 2016 03:03:20 -0700 Subject: [PATCH] More cleanup + TaskQueue - Introduced the TaskQueue - Added IHttpTransport so abort callbacks can happen outside of the transport implementation --- .../HttpConnectionDispatcher.cs | 11 +++ .../IHttpTransport.cs | 14 ++++ .../LongPolling.cs | 60 +++++----------- .../ServerSentEvents.cs | 54 ++++----------- src/Microsoft.AspNetCore.Sockets/TaskQueue.cs | 60 ++++++++++++++++ .../WebSockets.cs | 69 ++++++++++--------- 6 files changed, 153 insertions(+), 115 deletions(-) create mode 100644 src/Microsoft.AspNetCore.Sockets/IHttpTransport.cs create mode 100644 src/Microsoft.AspNetCore.Sockets/TaskQueue.cs diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index 2d3bf15e40..2b1b30ca92 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -51,6 +51,8 @@ namespace Microsoft.AspNetCore.Sockets var connectionState = GetOrCreateConnection(context); var sse = new ServerSentEvents((HttpChannel)connectionState.Connection.Channel); + RegisterDisconnect(context, sse); + var ignore = endpoint.OnConnected(connectionState.Connection); await sse.ProcessRequest(context); @@ -64,6 +66,8 @@ namespace Microsoft.AspNetCore.Sockets var connectionState = GetOrCreateConnection(context); var ws = new WebSockets((HttpChannel)connectionState.Connection.Channel); + RegisterDisconnect(context, ws); + var ignore = endpoint.OnConnected(connectionState.Connection); await ws.ProcessRequest(context); @@ -107,6 +111,8 @@ namespace Microsoft.AspNetCore.Sockets var longPolling = new LongPolling((HttpChannel)connectionState.Connection.Channel); + RegisterDisconnect(context, longPolling); + await longPolling.ProcessRequest(context); _manager.MarkConnectionInactive(connectionState.Connection.ConnectionId); @@ -114,6 +120,11 @@ namespace Microsoft.AspNetCore.Sockets } } + private static void RegisterDisconnect(HttpContext context, IHttpTransport transport) + { + context.RequestAborted.Register(state => ((IHttpTransport)state).Abort(), transport); + } + private Task ProcessGetId(HttpContext context) { // Reserve an id for this connection diff --git a/src/Microsoft.AspNetCore.Sockets/IHttpTransport.cs b/src/Microsoft.AspNetCore.Sockets/IHttpTransport.cs new file mode 100644 index 0000000000..dcc7e79765 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/IHttpTransport.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Sockets +{ + public interface IHttpTransport + { + Task ProcessRequest(HttpContext context); + void Abort(); + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/LongPolling.cs b/src/Microsoft.AspNetCore.Sockets/LongPolling.cs index f33ae036a9..c09c43c0c0 100644 --- a/src/Microsoft.AspNetCore.Sockets/LongPolling.cs +++ b/src/Microsoft.AspNetCore.Sockets/LongPolling.cs @@ -1,47 +1,27 @@ using System; -using System.Text; using System.Threading.Tasks; using Channels; using Microsoft.AspNetCore.Http; namespace Microsoft.AspNetCore.Sockets { - public class LongPolling + public class LongPolling : IHttpTransport { - private Task _lastTask; - private object _lockObj = new object(); - private bool _completed; - private TaskCompletionSource _initTcs = new TaskCompletionSource(); - private TaskCompletionSource _lifetime = new TaskCompletionSource(); - private HttpContext _context; + private readonly TaskCompletionSource _initTcs = new TaskCompletionSource(); + private readonly TaskCompletionSource _lifetime = new TaskCompletionSource(); private readonly HttpChannel _channel; + private readonly TaskQueue _queue; + + private HttpContext _context; public LongPolling(HttpChannel channel) { - _lastTask = _initTcs.Task; + _queue = new TaskQueue(_initTcs.Task); _channel = channel; } - private Task Post(Func work, object state) - { - if (_completed) - { - return _lastTask; - } - - lock (_lockObj) - { - _lastTask = _lastTask.ContinueWith((t, s1) => work(s1), state).Unwrap(); - } - - return _lastTask; - } - public async Task ProcessRequest(HttpContext context) { - // End the connection if the client goes away - context.RequestAborted.Register(state => OnConnectionAborted(state), this); - _context = context; _initTcs.TrySetResult(null); @@ -50,8 +30,6 @@ namespace Microsoft.AspNetCore.Sockets var ignore = ProcessMessages(context); await _lifetime.Task; - - _completed = true; } private async Task ProcessMessages(HttpContext context) @@ -60,7 +38,7 @@ namespace Microsoft.AspNetCore.Sockets if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted) { - CompleteRequest(); + Abort(); return; } @@ -74,31 +52,25 @@ namespace Microsoft.AspNetCore.Sockets } - CompleteRequest(); + Abort(); } - private static void OnConnectionAborted(object state) + public async void Abort() { - ((LongPolling)state).CompleteRequest(); - } + // Drain the queue and don't let any new work enter + await _queue.Drain(); - private void CompleteRequest() - { - Post(state => - { - ((TaskCompletionSource)state).TrySetResult(null); - return Task.CompletedTask; - }, - _lifetime); + // Complete the lifetime task + _lifetime.TrySetResult(null); } public Task Send(ReadableBuffer value) { - return Post(async state => + return _queue.Enqueue(state => { var data = (ReadableBuffer)state; _context.Response.ContentLength = data.Length; - await data.CopyToAsync(_context.Response.Body); + return data.CopyToAsync(_context.Response.Body); }, value); } diff --git a/src/Microsoft.AspNetCore.Sockets/ServerSentEvents.cs b/src/Microsoft.AspNetCore.Sockets/ServerSentEvents.cs index 3d5740a5b4..883da25c76 100644 --- a/src/Microsoft.AspNetCore.Sockets/ServerSentEvents.cs +++ b/src/Microsoft.AspNetCore.Sockets/ServerSentEvents.cs @@ -5,68 +5,42 @@ using Microsoft.AspNetCore.Http; namespace Microsoft.AspNetCore.Sockets { - public class ServerSentEvents + public class ServerSentEvents : IHttpTransport { - private Task _lastTask; - private object _lockObj = new object(); - private bool _completed; - private TaskCompletionSource _initTcs = new TaskCompletionSource(); - private TaskCompletionSource _lifetime = new TaskCompletionSource(); - private HttpContext _context; + private readonly TaskQueue _queue; + private readonly TaskCompletionSource _initTcs = new TaskCompletionSource(); + private readonly TaskCompletionSource _lifetime = new TaskCompletionSource(); private readonly HttpChannel _channel; + private HttpContext _context; + public ServerSentEvents(HttpChannel channel) { + _queue = new TaskQueue(_initTcs.Task); _channel = channel; - _lastTask = _initTcs.Task; var ignore = StartSending(); } - private Task Post(Func work, object state) - { - if (_completed) - { - return _lastTask; - } - - lock (_lockObj) - { - _lastTask = _lastTask.ContinueWith((t, s1) => work(s1), state).Unwrap(); - } - - return _lastTask; - } - public async Task ProcessRequest(HttpContext context) { context.Response.ContentType = "text/event-stream"; context.Response.Headers["Cache-Control"] = "no-cache"; - // End the connection if the client goes away - context.RequestAborted.Register(state => OnConnectionAborted(state), this); _context = context; // Set the initial TCS when everything is setup _initTcs.TrySetResult(null); await _lifetime.Task; - - _completed = true; } - private static void OnConnectionAborted(object state) + public async void Abort() { - ((ServerSentEvents)state).OnConnectedAborted(); - } + // Drain the queue so no new work can enter + await _queue.Drain(); - private void OnConnectedAborted() - { - Post(state => - { - ((TaskCompletionSource)state).TrySetResult(null); - return Task.CompletedTask; - }, - _lifetime); + // Complete the lifetime task + _lifetime.TrySetResult(null); } private async Task StartSending() @@ -92,7 +66,7 @@ namespace Microsoft.AspNetCore.Sockets private Task Send(ReadableBuffer value) { - return Post(async state => + return _queue.Enqueue(state => { var data = (ReadableBuffer)state; // TODO: Pooled buffers @@ -109,7 +83,7 @@ namespace Microsoft.AspNetCore.Sockets at += data.Length; buffer[at++] = (byte)'\n'; buffer[at++] = (byte)'\n'; - await _context.Response.Body.WriteAsync(buffer, 0, at); + return _context.Response.Body.WriteAsync(buffer, 0, at); }, value); } diff --git a/src/Microsoft.AspNetCore.Sockets/TaskQueue.cs b/src/Microsoft.AspNetCore.Sockets/TaskQueue.cs new file mode 100644 index 0000000000..5fbbb704cb --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/TaskQueue.cs @@ -0,0 +1,60 @@ +using System; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Sockets +{ + // Allows serial queuing of Task instances + // The tasks are not called on the current synchronization context + + internal sealed class TaskQueue + { + private readonly object _lockObj = new object(); + private Task _lastQueuedTask; + private volatile bool _drained; + private long _size; + + public TaskQueue() + : this(Task.CompletedTask) + { + } + + public TaskQueue(Task initialTask) + { + _lastQueuedTask = initialTask; + } + + public bool IsDrained + { + get + { + return _drained; + } + } + + public Task Enqueue(Func taskFunc, object state) + { + // Lock the object for as short amount of time as possible + lock (_lockObj) + { + if (_drained) + { + return _lastQueuedTask; + } + + var newTask = _lastQueuedTask.ContinueWith((t, s1) => taskFunc(s1), state).Unwrap(); + _lastQueuedTask = newTask; + return newTask; + } + } + + public Task Drain() + { + lock (_lockObj) + { + _drained = true; + + return _lastQueuedTask; + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/WebSockets.cs b/src/Microsoft.AspNetCore.Sockets/WebSockets.cs index 6b24c3ccbc..cbba8774c5 100644 --- a/src/Microsoft.AspNetCore.Sockets/WebSockets.cs +++ b/src/Microsoft.AspNetCore.Sockets/WebSockets.cs @@ -7,11 +7,11 @@ using Microsoft.AspNetCore.Http; namespace Microsoft.AspNetCore.Sockets { - public class WebSockets + public class WebSockets : IHttpTransport { + private readonly HttpChannel _channel; + private readonly TaskCompletionSource _tcs = new TaskCompletionSource(); private WebSocket _ws; - private HttpChannel _channel; - private TaskCompletionSource _tcs = new TaskCompletionSource(); public WebSockets(HttpChannel channel) { @@ -19,34 +19,6 @@ namespace Microsoft.AspNetCore.Sockets var ignore = StartSending(); } - private async Task StartSending() - { - await _tcs.Task; - - while (true) - { - var buffer = await _channel.Output.ReadAsync(); - - if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted) - { - break; - } - - foreach (var memory in buffer) - { - ArraySegment data; - if (memory.TryGetArray(out data)) - { - await _ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None); - } - } - - _channel.Output.Advance(buffer.End); - } - - _channel.Output.CompleteReader(); - } - public async Task ProcessRequest(HttpContext context) { if (!context.WebSockets.IsWebSocketRequest) @@ -83,5 +55,40 @@ namespace Microsoft.AspNetCore.Sockets } } } + + public async void Abort() + { + await _tcs.Task; + + await _ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + } + + private async Task StartSending() + { + await _tcs.Task; + + while (true) + { + var buffer = await _channel.Output.ReadAsync(); + + if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted) + { + break; + } + + foreach (var memory in buffer) + { + ArraySegment data; + if (memory.TryGetArray(out data)) + { + await _ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None); + } + } + + _channel.Output.Advance(buffer.End); + } + + _channel.Output.CompleteReader(); + } } }