More cleanup + TaskQueue

- Introduced the TaskQueue
- Added IHttpTransport so abort callbacks can happen outside of the
transport implementation
This commit is contained in:
David Fowler 2016-10-01 03:03:20 -07:00
parent ad2724b22c
commit 813222b406
6 changed files with 153 additions and 115 deletions

View File

@ -51,6 +51,8 @@ namespace Microsoft.AspNetCore.Sockets
var connectionState = GetOrCreateConnection(context);
var sse = new ServerSentEvents((HttpChannel)connectionState.Connection.Channel);
RegisterDisconnect(context, sse);
var ignore = endpoint.OnConnected(connectionState.Connection);
await sse.ProcessRequest(context);
@ -64,6 +66,8 @@ namespace Microsoft.AspNetCore.Sockets
var connectionState = GetOrCreateConnection(context);
var ws = new WebSockets((HttpChannel)connectionState.Connection.Channel);
RegisterDisconnect(context, ws);
var ignore = endpoint.OnConnected(connectionState.Connection);
await ws.ProcessRequest(context);
@ -107,6 +111,8 @@ namespace Microsoft.AspNetCore.Sockets
var longPolling = new LongPolling((HttpChannel)connectionState.Connection.Channel);
RegisterDisconnect(context, longPolling);
await longPolling.ProcessRequest(context);
_manager.MarkConnectionInactive(connectionState.Connection.ConnectionId);
@ -114,6 +120,11 @@ namespace Microsoft.AspNetCore.Sockets
}
}
private static void RegisterDisconnect(HttpContext context, IHttpTransport transport)
{
context.RequestAborted.Register(state => ((IHttpTransport)state).Abort(), transport);
}
private Task ProcessGetId(HttpContext context)
{
// Reserve an id for this connection

View File

@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
namespace Microsoft.AspNetCore.Sockets
{
public interface IHttpTransport
{
Task ProcessRequest(HttpContext context);
void Abort();
}
}

View File

@ -1,47 +1,27 @@
using System;
using System.Text;
using System.Threading.Tasks;
using Channels;
using Microsoft.AspNetCore.Http;
namespace Microsoft.AspNetCore.Sockets
{
public class LongPolling
public class LongPolling : IHttpTransport
{
private Task _lastTask;
private object _lockObj = new object();
private bool _completed;
private TaskCompletionSource<object> _initTcs = new TaskCompletionSource<object>();
private TaskCompletionSource<object> _lifetime = new TaskCompletionSource<object>();
private HttpContext _context;
private readonly TaskCompletionSource<object> _initTcs = new TaskCompletionSource<object>();
private readonly TaskCompletionSource<object> _lifetime = new TaskCompletionSource<object>();
private readonly HttpChannel _channel;
private readonly TaskQueue _queue;
private HttpContext _context;
public LongPolling(HttpChannel channel)
{
_lastTask = _initTcs.Task;
_queue = new TaskQueue(_initTcs.Task);
_channel = channel;
}
private Task Post(Func<object, Task> work, object state)
{
if (_completed)
{
return _lastTask;
}
lock (_lockObj)
{
_lastTask = _lastTask.ContinueWith((t, s1) => work(s1), state).Unwrap();
}
return _lastTask;
}
public async Task ProcessRequest(HttpContext context)
{
// End the connection if the client goes away
context.RequestAborted.Register(state => OnConnectionAborted(state), this);
_context = context;
_initTcs.TrySetResult(null);
@ -50,8 +30,6 @@ namespace Microsoft.AspNetCore.Sockets
var ignore = ProcessMessages(context);
await _lifetime.Task;
_completed = true;
}
private async Task ProcessMessages(HttpContext context)
@ -60,7 +38,7 @@ namespace Microsoft.AspNetCore.Sockets
if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted)
{
CompleteRequest();
Abort();
return;
}
@ -74,31 +52,25 @@ namespace Microsoft.AspNetCore.Sockets
}
CompleteRequest();
Abort();
}
private static void OnConnectionAborted(object state)
public async void Abort()
{
((LongPolling)state).CompleteRequest();
}
// Drain the queue and don't let any new work enter
await _queue.Drain();
private void CompleteRequest()
{
Post(state =>
{
((TaskCompletionSource<object>)state).TrySetResult(null);
return Task.CompletedTask;
},
_lifetime);
// Complete the lifetime task
_lifetime.TrySetResult(null);
}
public Task Send(ReadableBuffer value)
{
return Post(async state =>
return _queue.Enqueue(state =>
{
var data = (ReadableBuffer)state;
_context.Response.ContentLength = data.Length;
await data.CopyToAsync(_context.Response.Body);
return data.CopyToAsync(_context.Response.Body);
},
value);
}

View File

@ -5,68 +5,42 @@ using Microsoft.AspNetCore.Http;
namespace Microsoft.AspNetCore.Sockets
{
public class ServerSentEvents
public class ServerSentEvents : IHttpTransport
{
private Task _lastTask;
private object _lockObj = new object();
private bool _completed;
private TaskCompletionSource<object> _initTcs = new TaskCompletionSource<object>();
private TaskCompletionSource<object> _lifetime = new TaskCompletionSource<object>();
private HttpContext _context;
private readonly TaskQueue _queue;
private readonly TaskCompletionSource<object> _initTcs = new TaskCompletionSource<object>();
private readonly TaskCompletionSource<object> _lifetime = new TaskCompletionSource<object>();
private readonly HttpChannel _channel;
private HttpContext _context;
public ServerSentEvents(HttpChannel channel)
{
_queue = new TaskQueue(_initTcs.Task);
_channel = channel;
_lastTask = _initTcs.Task;
var ignore = StartSending();
}
private Task Post(Func<object, Task> work, object state)
{
if (_completed)
{
return _lastTask;
}
lock (_lockObj)
{
_lastTask = _lastTask.ContinueWith((t, s1) => work(s1), state).Unwrap();
}
return _lastTask;
}
public async Task ProcessRequest(HttpContext context)
{
context.Response.ContentType = "text/event-stream";
context.Response.Headers["Cache-Control"] = "no-cache";
// End the connection if the client goes away
context.RequestAborted.Register(state => OnConnectionAborted(state), this);
_context = context;
// Set the initial TCS when everything is setup
_initTcs.TrySetResult(null);
await _lifetime.Task;
_completed = true;
}
private static void OnConnectionAborted(object state)
public async void Abort()
{
((ServerSentEvents)state).OnConnectedAborted();
}
// Drain the queue so no new work can enter
await _queue.Drain();
private void OnConnectedAborted()
{
Post(state =>
{
((TaskCompletionSource<object>)state).TrySetResult(null);
return Task.CompletedTask;
},
_lifetime);
// Complete the lifetime task
_lifetime.TrySetResult(null);
}
private async Task StartSending()
@ -92,7 +66,7 @@ namespace Microsoft.AspNetCore.Sockets
private Task Send(ReadableBuffer value)
{
return Post(async state =>
return _queue.Enqueue(state =>
{
var data = (ReadableBuffer)state;
// TODO: Pooled buffers
@ -109,7 +83,7 @@ namespace Microsoft.AspNetCore.Sockets
at += data.Length;
buffer[at++] = (byte)'\n';
buffer[at++] = (byte)'\n';
await _context.Response.Body.WriteAsync(buffer, 0, at);
return _context.Response.Body.WriteAsync(buffer, 0, at);
},
value);
}

View File

@ -0,0 +1,60 @@
using System;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets
{
// Allows serial queuing of Task instances
// The tasks are not called on the current synchronization context
internal sealed class TaskQueue
{
private readonly object _lockObj = new object();
private Task _lastQueuedTask;
private volatile bool _drained;
private long _size;
public TaskQueue()
: this(Task.CompletedTask)
{
}
public TaskQueue(Task initialTask)
{
_lastQueuedTask = initialTask;
}
public bool IsDrained
{
get
{
return _drained;
}
}
public Task Enqueue(Func<object, Task> taskFunc, object state)
{
// Lock the object for as short amount of time as possible
lock (_lockObj)
{
if (_drained)
{
return _lastQueuedTask;
}
var newTask = _lastQueuedTask.ContinueWith((t, s1) => taskFunc(s1), state).Unwrap();
_lastQueuedTask = newTask;
return newTask;
}
}
public Task Drain()
{
lock (_lockObj)
{
_drained = true;
return _lastQueuedTask;
}
}
}
}

View File

@ -7,11 +7,11 @@ using Microsoft.AspNetCore.Http;
namespace Microsoft.AspNetCore.Sockets
{
public class WebSockets
public class WebSockets : IHttpTransport
{
private readonly HttpChannel _channel;
private readonly TaskCompletionSource<object> _tcs = new TaskCompletionSource<object>();
private WebSocket _ws;
private HttpChannel _channel;
private TaskCompletionSource<object> _tcs = new TaskCompletionSource<object>();
public WebSockets(HttpChannel channel)
{
@ -19,34 +19,6 @@ namespace Microsoft.AspNetCore.Sockets
var ignore = StartSending();
}
private async Task StartSending()
{
await _tcs.Task;
while (true)
{
var buffer = await _channel.Output.ReadAsync();
if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted)
{
break;
}
foreach (var memory in buffer)
{
ArraySegment<byte> data;
if (memory.TryGetArray(out data))
{
await _ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None);
}
}
_channel.Output.Advance(buffer.End);
}
_channel.Output.CompleteReader();
}
public async Task ProcessRequest(HttpContext context)
{
if (!context.WebSockets.IsWebSocketRequest)
@ -83,5 +55,40 @@ namespace Microsoft.AspNetCore.Sockets
}
}
}
public async void Abort()
{
await _tcs.Task;
await _ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
}
private async Task StartSending()
{
await _tcs.Task;
while (true)
{
var buffer = await _channel.Output.ReadAsync();
if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted)
{
break;
}
foreach (var memory in buffer)
{
ArraySegment<byte> data;
if (memory.TryGetArray(out data))
{
await _ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None);
}
}
_channel.Output.Advance(buffer.End);
}
_channel.Output.CompleteReader();
}
}
}