More cleanup + TaskQueue
- Introduced the TaskQueue - Added IHttpTransport so abort callbacks can happen outside of the transport implementation
This commit is contained in:
parent
ad2724b22c
commit
813222b406
|
|
@ -51,6 +51,8 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
var connectionState = GetOrCreateConnection(context);
|
||||
var sse = new ServerSentEvents((HttpChannel)connectionState.Connection.Channel);
|
||||
|
||||
RegisterDisconnect(context, sse);
|
||||
|
||||
var ignore = endpoint.OnConnected(connectionState.Connection);
|
||||
|
||||
await sse.ProcessRequest(context);
|
||||
|
|
@ -64,6 +66,8 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
var connectionState = GetOrCreateConnection(context);
|
||||
var ws = new WebSockets((HttpChannel)connectionState.Connection.Channel);
|
||||
|
||||
RegisterDisconnect(context, ws);
|
||||
|
||||
var ignore = endpoint.OnConnected(connectionState.Connection);
|
||||
|
||||
await ws.ProcessRequest(context);
|
||||
|
|
@ -107,6 +111,8 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
var longPolling = new LongPolling((HttpChannel)connectionState.Connection.Channel);
|
||||
|
||||
RegisterDisconnect(context, longPolling);
|
||||
|
||||
await longPolling.ProcessRequest(context);
|
||||
|
||||
_manager.MarkConnectionInactive(connectionState.Connection.ConnectionId);
|
||||
|
|
@ -114,6 +120,11 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
}
|
||||
}
|
||||
|
||||
private static void RegisterDisconnect(HttpContext context, IHttpTransport transport)
|
||||
{
|
||||
context.RequestAborted.Register(state => ((IHttpTransport)state).Abort(), transport);
|
||||
}
|
||||
|
||||
private Task ProcessGetId(HttpContext context)
|
||||
{
|
||||
// Reserve an id for this connection
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public interface IHttpTransport
|
||||
{
|
||||
Task ProcessRequest(HttpContext context);
|
||||
void Abort();
|
||||
}
|
||||
}
|
||||
|
|
@ -1,47 +1,27 @@
|
|||
using System;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Channels;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public class LongPolling
|
||||
public class LongPolling : IHttpTransport
|
||||
{
|
||||
private Task _lastTask;
|
||||
private object _lockObj = new object();
|
||||
private bool _completed;
|
||||
private TaskCompletionSource<object> _initTcs = new TaskCompletionSource<object>();
|
||||
private TaskCompletionSource<object> _lifetime = new TaskCompletionSource<object>();
|
||||
private HttpContext _context;
|
||||
private readonly TaskCompletionSource<object> _initTcs = new TaskCompletionSource<object>();
|
||||
private readonly TaskCompletionSource<object> _lifetime = new TaskCompletionSource<object>();
|
||||
private readonly HttpChannel _channel;
|
||||
private readonly TaskQueue _queue;
|
||||
|
||||
private HttpContext _context;
|
||||
|
||||
public LongPolling(HttpChannel channel)
|
||||
{
|
||||
_lastTask = _initTcs.Task;
|
||||
_queue = new TaskQueue(_initTcs.Task);
|
||||
_channel = channel;
|
||||
}
|
||||
|
||||
private Task Post(Func<object, Task> work, object state)
|
||||
{
|
||||
if (_completed)
|
||||
{
|
||||
return _lastTask;
|
||||
}
|
||||
|
||||
lock (_lockObj)
|
||||
{
|
||||
_lastTask = _lastTask.ContinueWith((t, s1) => work(s1), state).Unwrap();
|
||||
}
|
||||
|
||||
return _lastTask;
|
||||
}
|
||||
|
||||
public async Task ProcessRequest(HttpContext context)
|
||||
{
|
||||
// End the connection if the client goes away
|
||||
context.RequestAborted.Register(state => OnConnectionAborted(state), this);
|
||||
|
||||
_context = context;
|
||||
|
||||
_initTcs.TrySetResult(null);
|
||||
|
|
@ -50,8 +30,6 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
var ignore = ProcessMessages(context);
|
||||
|
||||
await _lifetime.Task;
|
||||
|
||||
_completed = true;
|
||||
}
|
||||
|
||||
private async Task ProcessMessages(HttpContext context)
|
||||
|
|
@ -60,7 +38,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted)
|
||||
{
|
||||
CompleteRequest();
|
||||
Abort();
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -74,31 +52,25 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
}
|
||||
|
||||
|
||||
CompleteRequest();
|
||||
Abort();
|
||||
}
|
||||
|
||||
private static void OnConnectionAborted(object state)
|
||||
public async void Abort()
|
||||
{
|
||||
((LongPolling)state).CompleteRequest();
|
||||
}
|
||||
// Drain the queue and don't let any new work enter
|
||||
await _queue.Drain();
|
||||
|
||||
private void CompleteRequest()
|
||||
{
|
||||
Post(state =>
|
||||
{
|
||||
((TaskCompletionSource<object>)state).TrySetResult(null);
|
||||
return Task.CompletedTask;
|
||||
},
|
||||
_lifetime);
|
||||
// Complete the lifetime task
|
||||
_lifetime.TrySetResult(null);
|
||||
}
|
||||
|
||||
public Task Send(ReadableBuffer value)
|
||||
{
|
||||
return Post(async state =>
|
||||
return _queue.Enqueue(state =>
|
||||
{
|
||||
var data = (ReadableBuffer)state;
|
||||
_context.Response.ContentLength = data.Length;
|
||||
await data.CopyToAsync(_context.Response.Body);
|
||||
return data.CopyToAsync(_context.Response.Body);
|
||||
},
|
||||
value);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,68 +5,42 @@ using Microsoft.AspNetCore.Http;
|
|||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public class ServerSentEvents
|
||||
public class ServerSentEvents : IHttpTransport
|
||||
{
|
||||
private Task _lastTask;
|
||||
private object _lockObj = new object();
|
||||
private bool _completed;
|
||||
private TaskCompletionSource<object> _initTcs = new TaskCompletionSource<object>();
|
||||
private TaskCompletionSource<object> _lifetime = new TaskCompletionSource<object>();
|
||||
private HttpContext _context;
|
||||
private readonly TaskQueue _queue;
|
||||
private readonly TaskCompletionSource<object> _initTcs = new TaskCompletionSource<object>();
|
||||
private readonly TaskCompletionSource<object> _lifetime = new TaskCompletionSource<object>();
|
||||
private readonly HttpChannel _channel;
|
||||
|
||||
private HttpContext _context;
|
||||
|
||||
public ServerSentEvents(HttpChannel channel)
|
||||
{
|
||||
_queue = new TaskQueue(_initTcs.Task);
|
||||
_channel = channel;
|
||||
_lastTask = _initTcs.Task;
|
||||
var ignore = StartSending();
|
||||
}
|
||||
|
||||
private Task Post(Func<object, Task> work, object state)
|
||||
{
|
||||
if (_completed)
|
||||
{
|
||||
return _lastTask;
|
||||
}
|
||||
|
||||
lock (_lockObj)
|
||||
{
|
||||
_lastTask = _lastTask.ContinueWith((t, s1) => work(s1), state).Unwrap();
|
||||
}
|
||||
|
||||
return _lastTask;
|
||||
}
|
||||
|
||||
public async Task ProcessRequest(HttpContext context)
|
||||
{
|
||||
context.Response.ContentType = "text/event-stream";
|
||||
context.Response.Headers["Cache-Control"] = "no-cache";
|
||||
|
||||
// End the connection if the client goes away
|
||||
context.RequestAborted.Register(state => OnConnectionAborted(state), this);
|
||||
_context = context;
|
||||
|
||||
// Set the initial TCS when everything is setup
|
||||
_initTcs.TrySetResult(null);
|
||||
|
||||
await _lifetime.Task;
|
||||
|
||||
_completed = true;
|
||||
}
|
||||
|
||||
private static void OnConnectionAborted(object state)
|
||||
public async void Abort()
|
||||
{
|
||||
((ServerSentEvents)state).OnConnectedAborted();
|
||||
}
|
||||
// Drain the queue so no new work can enter
|
||||
await _queue.Drain();
|
||||
|
||||
private void OnConnectedAborted()
|
||||
{
|
||||
Post(state =>
|
||||
{
|
||||
((TaskCompletionSource<object>)state).TrySetResult(null);
|
||||
return Task.CompletedTask;
|
||||
},
|
||||
_lifetime);
|
||||
// Complete the lifetime task
|
||||
_lifetime.TrySetResult(null);
|
||||
}
|
||||
|
||||
private async Task StartSending()
|
||||
|
|
@ -92,7 +66,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
private Task Send(ReadableBuffer value)
|
||||
{
|
||||
return Post(async state =>
|
||||
return _queue.Enqueue(state =>
|
||||
{
|
||||
var data = (ReadableBuffer)state;
|
||||
// TODO: Pooled buffers
|
||||
|
|
@ -109,7 +83,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
at += data.Length;
|
||||
buffer[at++] = (byte)'\n';
|
||||
buffer[at++] = (byte)'\n';
|
||||
await _context.Response.Body.WriteAsync(buffer, 0, at);
|
||||
return _context.Response.Body.WriteAsync(buffer, 0, at);
|
||||
},
|
||||
value);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
using System;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
// Allows serial queuing of Task instances
|
||||
// The tasks are not called on the current synchronization context
|
||||
|
||||
internal sealed class TaskQueue
|
||||
{
|
||||
private readonly object _lockObj = new object();
|
||||
private Task _lastQueuedTask;
|
||||
private volatile bool _drained;
|
||||
private long _size;
|
||||
|
||||
public TaskQueue()
|
||||
: this(Task.CompletedTask)
|
||||
{
|
||||
}
|
||||
|
||||
public TaskQueue(Task initialTask)
|
||||
{
|
||||
_lastQueuedTask = initialTask;
|
||||
}
|
||||
|
||||
public bool IsDrained
|
||||
{
|
||||
get
|
||||
{
|
||||
return _drained;
|
||||
}
|
||||
}
|
||||
|
||||
public Task Enqueue(Func<object, Task> taskFunc, object state)
|
||||
{
|
||||
// Lock the object for as short amount of time as possible
|
||||
lock (_lockObj)
|
||||
{
|
||||
if (_drained)
|
||||
{
|
||||
return _lastQueuedTask;
|
||||
}
|
||||
|
||||
var newTask = _lastQueuedTask.ContinueWith((t, s1) => taskFunc(s1), state).Unwrap();
|
||||
_lastQueuedTask = newTask;
|
||||
return newTask;
|
||||
}
|
||||
}
|
||||
|
||||
public Task Drain()
|
||||
{
|
||||
lock (_lockObj)
|
||||
{
|
||||
_drained = true;
|
||||
|
||||
return _lastQueuedTask;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -7,11 +7,11 @@ using Microsoft.AspNetCore.Http;
|
|||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public class WebSockets
|
||||
public class WebSockets : IHttpTransport
|
||||
{
|
||||
private readonly HttpChannel _channel;
|
||||
private readonly TaskCompletionSource<object> _tcs = new TaskCompletionSource<object>();
|
||||
private WebSocket _ws;
|
||||
private HttpChannel _channel;
|
||||
private TaskCompletionSource<object> _tcs = new TaskCompletionSource<object>();
|
||||
|
||||
public WebSockets(HttpChannel channel)
|
||||
{
|
||||
|
|
@ -19,34 +19,6 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
var ignore = StartSending();
|
||||
}
|
||||
|
||||
private async Task StartSending()
|
||||
{
|
||||
await _tcs.Task;
|
||||
|
||||
while (true)
|
||||
{
|
||||
var buffer = await _channel.Output.ReadAsync();
|
||||
|
||||
if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
foreach (var memory in buffer)
|
||||
{
|
||||
ArraySegment<byte> data;
|
||||
if (memory.TryGetArray(out data))
|
||||
{
|
||||
await _ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None);
|
||||
}
|
||||
}
|
||||
|
||||
_channel.Output.Advance(buffer.End);
|
||||
}
|
||||
|
||||
_channel.Output.CompleteReader();
|
||||
}
|
||||
|
||||
public async Task ProcessRequest(HttpContext context)
|
||||
{
|
||||
if (!context.WebSockets.IsWebSocketRequest)
|
||||
|
|
@ -83,5 +55,40 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async void Abort()
|
||||
{
|
||||
await _tcs.Task;
|
||||
|
||||
await _ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
|
||||
}
|
||||
|
||||
private async Task StartSending()
|
||||
{
|
||||
await _tcs.Task;
|
||||
|
||||
while (true)
|
||||
{
|
||||
var buffer = await _channel.Output.ReadAsync();
|
||||
|
||||
if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
foreach (var memory in buffer)
|
||||
{
|
||||
ArraySegment<byte> data;
|
||||
if (memory.TryGetArray(out data))
|
||||
{
|
||||
await _ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None);
|
||||
}
|
||||
}
|
||||
|
||||
_channel.Output.Advance(buffer.End);
|
||||
}
|
||||
|
||||
_channel.Output.CompleteReader();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue