Code cleanup (somewhat)
- Transports have be drastically simplified due to channels - Make sure tasks don't leak
This commit is contained in:
parent
e018fe70f7
commit
8e66d63577
|
|
@ -20,11 +20,17 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
public IDisposable Subscribe(string key, Func<Message, Task> observer)
|
||||
{
|
||||
var connections = _subscriptions.GetOrAdd(key, _ => new List<Func<Message, Task>>());
|
||||
connections.Add(observer);
|
||||
lock (connections)
|
||||
{
|
||||
connections.Add(observer);
|
||||
}
|
||||
|
||||
return new DisposableAction(() =>
|
||||
{
|
||||
connections.Remove(observer);
|
||||
lock (connections)
|
||||
{
|
||||
connections.Remove(observer);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -33,10 +39,17 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
List<Func<Message, Task>> connections;
|
||||
if (_subscriptions.TryGetValue(key, out connections))
|
||||
{
|
||||
foreach (var c in connections)
|
||||
Task[] tasks = null;
|
||||
lock (connections)
|
||||
{
|
||||
await c(message);
|
||||
tasks = new Task[connections.Count];
|
||||
for (int i = 0; i < connections.Count; i++)
|
||||
{
|
||||
tasks[i] = connections[i](message);
|
||||
}
|
||||
}
|
||||
|
||||
await Task.WhenAll(tasks);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -47,8 +47,6 @@ namespace SocketsSample
|
|||
{
|
||||
Payload = Encoding.UTF8.GetBytes($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})")
|
||||
});
|
||||
|
||||
connection.Channel.Input.Complete();
|
||||
}
|
||||
|
||||
private async Task OnMessage(Message message, Connection connection)
|
||||
|
|
|
|||
|
|
@ -54,16 +54,6 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
return state;
|
||||
}
|
||||
|
||||
public void MarkConnectionInactive(string id)
|
||||
{
|
||||
ConnectionState state;
|
||||
if (_connections.TryGetValue(id, out state))
|
||||
{
|
||||
// Mark the connection as active so the background thread can look at it
|
||||
state.Active = false;
|
||||
}
|
||||
}
|
||||
|
||||
public void RemoveConnection(string id)
|
||||
{
|
||||
ConnectionState state;
|
||||
|
|
@ -88,7 +78,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
// Scan the registered connections looking for ones that have timed out
|
||||
foreach (var c in _connections)
|
||||
{
|
||||
if (!c.Value.Active && (DateTimeOffset.UtcNow - c.Value.LastSeen).TotalSeconds > 30)
|
||||
if (!c.Value.Active && (DateTimeOffset.UtcNow - c.Value.LastSeen).TotalSeconds > 5)
|
||||
{
|
||||
ConnectionState s;
|
||||
if (_connections.TryRemove(c.Key, out s))
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
var sse = new ServerSentEvents(connectionState.Connection);
|
||||
|
||||
// Register this transport for disconnect
|
||||
RegisterDisconnect(context, sse);
|
||||
RegisterDisconnect(context, connectionState.Connection);
|
||||
|
||||
// Call into the end point passing the connection
|
||||
var endpointTask = endpoint.OnConnected(connectionState.Connection);
|
||||
|
|
@ -80,7 +80,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
var ws = new WebSockets(connectionState.Connection);
|
||||
|
||||
// Register this transport for disconnect
|
||||
RegisterDisconnect(context, ws);
|
||||
RegisterDisconnect(context, connectionState.Connection);
|
||||
|
||||
// Call into the end point passing the connection
|
||||
var endpointTask = endpoint.OnConnected(connectionState.Connection);
|
||||
|
|
@ -108,8 +108,6 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
if (connectionState.Connection.Channel == null)
|
||||
{
|
||||
// REVIEW: The connection manager should encapsulate this...
|
||||
connectionState.Active = true;
|
||||
connectionState.LastSeen = DateTimeOffset.UtcNow;
|
||||
connectionState.Connection.Channel = new HttpChannel(_channelFactory);
|
||||
isNewConnection = true;
|
||||
}
|
||||
|
|
@ -121,6 +119,9 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
isNewConnection = true;
|
||||
}
|
||||
|
||||
// Mark the connection as active
|
||||
connectionState.Active = true;
|
||||
|
||||
// Raise OnConnected for new connections only since polls happen all the time
|
||||
if (isNewConnection)
|
||||
{
|
||||
|
|
@ -130,22 +131,30 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
var ignore = endpoint.OnConnected(connectionState.Connection);
|
||||
}
|
||||
|
||||
var longPolling = new LongPolling(connectionState.Connection);
|
||||
RegisterLongPollingDisconnect(context, connectionState.Connection);
|
||||
|
||||
// Register this transport for disconnect
|
||||
RegisterDisconnect(context, longPolling);
|
||||
var longPolling = new LongPolling(connectionState.Connection);
|
||||
|
||||
// Start the transport
|
||||
await longPolling.ProcessRequest(context);
|
||||
|
||||
_manager.MarkConnectionInactive(connectionState.Connection.ConnectionId);
|
||||
// Mark the connection as inactive
|
||||
connectionState.LastSeen = DateTimeOffset.UtcNow;
|
||||
connectionState.Active = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void RegisterDisconnect(HttpContext context, IHttpTransport transport)
|
||||
private static void RegisterLongPollingDisconnect(HttpContext context, Connection connection)
|
||||
{
|
||||
context.RequestAborted.Register(state => ((IHttpTransport)state).CloseAsync(), transport);
|
||||
// For long polling, we need to end the transport but not the overall connection so we write 0 bytes
|
||||
context.RequestAborted.Register(state => ((HttpChannel)state).Output.WriteAsync(Span<byte>.Empty), connection.Channel);
|
||||
}
|
||||
|
||||
private static void RegisterDisconnect(HttpContext context, Connection connection)
|
||||
{
|
||||
// We just kill the output writing as a signal to the transport that it is done
|
||||
context.RequestAborted.Register(state => ((HttpChannel)state).Output.CompleteWriter(), connection.Channel);
|
||||
}
|
||||
|
||||
private Task ProcessGetId(HttpContext context)
|
||||
|
|
|
|||
|
|
@ -14,11 +14,5 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
/// <param name="context"></param>
|
||||
/// <returns>A <see cref="Task"/> that completes when the transport has finished processing</returns>
|
||||
Task ProcessRequest(HttpContext context);
|
||||
|
||||
/// <summary>
|
||||
/// Completes the Task returned from ProcessRequest if not already complete
|
||||
/// </summary>
|
||||
/// <returns></returns>
|
||||
Task CloseAsync();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,88 +7,37 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
{
|
||||
public class LongPolling : IHttpTransport
|
||||
{
|
||||
private readonly TaskCompletionSource<object> _initTcs = new TaskCompletionSource<object>();
|
||||
private readonly TaskCompletionSource<object> _lifetime = new TaskCompletionSource<object>();
|
||||
private readonly HttpChannel _channel;
|
||||
private readonly Connection _connection;
|
||||
private readonly TaskQueue _queue;
|
||||
|
||||
private HttpContext _context;
|
||||
|
||||
public LongPolling(Connection connection)
|
||||
{
|
||||
_queue = new TaskQueue(_initTcs.Task);
|
||||
_connection = connection;
|
||||
_channel = (HttpChannel)connection.Channel;
|
||||
}
|
||||
|
||||
public async Task ProcessRequest(HttpContext context)
|
||||
{
|
||||
_context = context;
|
||||
|
||||
_initTcs.TrySetResult(null);
|
||||
|
||||
// Send queue messages to the connection
|
||||
var ignore = ProcessMessages(context);
|
||||
|
||||
await _lifetime.Task;
|
||||
}
|
||||
|
||||
private async Task ProcessMessages(HttpContext context)
|
||||
{
|
||||
var buffer = await _channel.Output.ReadAsync();
|
||||
|
||||
if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted)
|
||||
{
|
||||
await CloseAsync();
|
||||
// REVIEW: Set the status code here so the client doesn't reconnect
|
||||
return;
|
||||
}
|
||||
|
||||
try
|
||||
if (!buffer.IsEmpty)
|
||||
{
|
||||
await Send(buffer);
|
||||
try
|
||||
{
|
||||
context.Response.ContentLength = buffer.Length;
|
||||
await buffer.CopyToAsync(context.Response.Body);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_channel.Output.Advance(buffer.End);
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
_channel.Output.Advance(buffer.End);
|
||||
}
|
||||
|
||||
await EndRequest();
|
||||
}
|
||||
|
||||
public async Task CloseAsync()
|
||||
{
|
||||
await _queue.Enqueue(state =>
|
||||
{
|
||||
var context = (HttpContext)state;
|
||||
// REVIEW: What happens if header was already?
|
||||
context.Response.Headers["X-ASPNET-SOCKET-DISCONNECT"] = "1";
|
||||
return Task.CompletedTask;
|
||||
},
|
||||
_context);
|
||||
|
||||
await EndRequest();
|
||||
}
|
||||
|
||||
private async Task EndRequest()
|
||||
{
|
||||
// Drain the queue and don't let any new work enter
|
||||
await _queue.Drain();
|
||||
|
||||
// Complete the lifetime task
|
||||
_lifetime.TrySetResult(null);
|
||||
}
|
||||
|
||||
private Task Send(ReadableBuffer value)
|
||||
{
|
||||
// REVIEW: Can we avoid the closure here?
|
||||
return _queue.Enqueue(state =>
|
||||
{
|
||||
var data = (ReadableBuffer)state;
|
||||
_context.Response.ContentLength = data.Length;
|
||||
return data.CopyToAsync(_context.Response.Body);
|
||||
},
|
||||
value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,20 +7,13 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
{
|
||||
public class ServerSentEvents : IHttpTransport
|
||||
{
|
||||
private readonly TaskCompletionSource<object> _initTcs = new TaskCompletionSource<object>();
|
||||
private readonly TaskCompletionSource<object> _lifetime = new TaskCompletionSource<object>();
|
||||
private readonly HttpChannel _channel;
|
||||
private readonly Connection _connection;
|
||||
private readonly TaskQueue _queue;
|
||||
|
||||
private HttpContext _context;
|
||||
|
||||
public ServerSentEvents(Connection connection)
|
||||
{
|
||||
_queue = new TaskQueue(_initTcs.Task);
|
||||
_connection = connection;
|
||||
_channel = (HttpChannel)connection.Channel;
|
||||
var ignore = StartSending();
|
||||
}
|
||||
|
||||
public async Task ProcessRequest(HttpContext context)
|
||||
|
|
@ -28,27 +21,6 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
context.Response.ContentType = "text/event-stream";
|
||||
context.Response.Headers["Cache-Control"] = "no-cache";
|
||||
|
||||
_context = context;
|
||||
|
||||
// Set the initial TCS when everything is setup
|
||||
_initTcs.TrySetResult(null);
|
||||
|
||||
await _lifetime.Task;
|
||||
}
|
||||
|
||||
public async Task CloseAsync()
|
||||
{
|
||||
// Drain the queue so no new work can enter
|
||||
await _queue.Drain();
|
||||
|
||||
// Complete the lifetime task
|
||||
_lifetime.TrySetResult(null);
|
||||
}
|
||||
|
||||
private async Task StartSending()
|
||||
{
|
||||
await _initTcs.Task;
|
||||
|
||||
while (true)
|
||||
{
|
||||
var buffer = await _channel.Output.ReadAsync();
|
||||
|
|
@ -58,7 +30,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
break;
|
||||
}
|
||||
|
||||
await Send(buffer);
|
||||
await Send(context, buffer);
|
||||
|
||||
_channel.Output.Advance(buffer.End);
|
||||
}
|
||||
|
|
@ -66,29 +38,24 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
_channel.Output.CompleteReader();
|
||||
}
|
||||
|
||||
private Task Send(ReadableBuffer value)
|
||||
private async Task Send(HttpContext context, ReadableBuffer data)
|
||||
{
|
||||
return _queue.Enqueue(async state =>
|
||||
{
|
||||
var data = (ReadableBuffer)state;
|
||||
// TODO: Pooled buffers
|
||||
// 8 = 6(data: ) + 2 (\n\n)
|
||||
var buffer = new byte[8 + data.Length];
|
||||
var at = 0;
|
||||
buffer[at++] = (byte)'d';
|
||||
buffer[at++] = (byte)'a';
|
||||
buffer[at++] = (byte)'t';
|
||||
buffer[at++] = (byte)'a';
|
||||
buffer[at++] = (byte)':';
|
||||
buffer[at++] = (byte)' ';
|
||||
data.CopyTo(new Span<byte>(buffer, at, data.Length));
|
||||
at += data.Length;
|
||||
buffer[at++] = (byte)'\n';
|
||||
buffer[at++] = (byte)'\n';
|
||||
await _context.Response.Body.WriteAsync(buffer, 0, at);
|
||||
await _context.Response.Body.FlushAsync();
|
||||
},
|
||||
value);
|
||||
// TODO: Pooled buffers
|
||||
// 8 = 6(data: ) + 2 (\n\n)
|
||||
var buffer = new byte[8 + data.Length];
|
||||
var at = 0;
|
||||
buffer[at++] = (byte)'d';
|
||||
buffer[at++] = (byte)'a';
|
||||
buffer[at++] = (byte)'t';
|
||||
buffer[at++] = (byte)'a';
|
||||
buffer[at++] = (byte)':';
|
||||
buffer[at++] = (byte)' ';
|
||||
data.CopyTo(new Span<byte>(buffer, at, data.Length));
|
||||
at += data.Length;
|
||||
buffer[at++] = (byte)'\n';
|
||||
buffer[at++] = (byte)'\n';
|
||||
await context.Response.Body.WriteAsync(buffer, 0, at);
|
||||
await context.Response.Body.FlushAsync();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,68 +0,0 @@
|
|||
using System;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
// Allows serial queuing of Task instances
|
||||
// The tasks are not called on the current synchronization context
|
||||
|
||||
internal sealed class TaskQueue
|
||||
{
|
||||
private readonly object _lockObj = new object();
|
||||
private Task _lastQueuedTask;
|
||||
private volatile bool _drained;
|
||||
private long _size;
|
||||
|
||||
public TaskQueue()
|
||||
: this(Task.CompletedTask)
|
||||
{
|
||||
}
|
||||
|
||||
public TaskQueue(Task initialTask)
|
||||
{
|
||||
_lastQueuedTask = initialTask;
|
||||
}
|
||||
|
||||
public bool IsDrained
|
||||
{
|
||||
get
|
||||
{
|
||||
return _drained;
|
||||
}
|
||||
}
|
||||
|
||||
public Task Enqueue(Func<object, Task> taskFunc, object state)
|
||||
{
|
||||
// Lock the object for as short amount of time as possible
|
||||
lock (_lockObj)
|
||||
{
|
||||
if (_drained)
|
||||
{
|
||||
return _lastQueuedTask;
|
||||
}
|
||||
|
||||
var newTask = _lastQueuedTask.ContinueWith((t, s1) =>
|
||||
{
|
||||
if (t.IsFaulted || t.IsCanceled)
|
||||
{
|
||||
return t;
|
||||
}
|
||||
return taskFunc(s1);
|
||||
},
|
||||
state).Unwrap();
|
||||
_lastQueuedTask = newTask;
|
||||
return newTask;
|
||||
}
|
||||
}
|
||||
|
||||
public Task Drain()
|
||||
{
|
||||
lock (_lockObj)
|
||||
{
|
||||
_drained = true;
|
||||
|
||||
return _lastQueuedTask;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -11,14 +11,11 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
{
|
||||
private readonly HttpChannel _channel;
|
||||
private readonly Connection _connection;
|
||||
private readonly TaskCompletionSource<object> _tcs = new TaskCompletionSource<object>();
|
||||
private WebSocket _ws;
|
||||
|
||||
public WebSockets(Connection connection)
|
||||
{
|
||||
_connection = connection;
|
||||
_channel = (HttpChannel)connection.Channel;
|
||||
var ignore = StartSending();
|
||||
}
|
||||
|
||||
public async Task ProcessRequest(HttpContext context)
|
||||
|
|
@ -31,9 +28,11 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
var ws = await context.WebSockets.AcceptWebSocketAsync();
|
||||
|
||||
_ws = ws;
|
||||
|
||||
_tcs.TrySetResult(null);
|
||||
// REVIEW: Should we track this task? Leaving things like this alive usually causes memory leaks :)
|
||||
// The reason we don't await this is because the channel is disposed after this loop returns
|
||||
// and the sending loop is waiting for the channel to end before doing anything
|
||||
// We could do a 2 stage shutdown but that could complicate the code...
|
||||
var sending = StartSending(ws);
|
||||
|
||||
var outputBuffer = _channel.Input.Alloc();
|
||||
|
||||
|
|
@ -66,46 +65,81 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
}
|
||||
else
|
||||
{
|
||||
await ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
await ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
|
||||
}
|
||||
|
||||
public async Task CloseAsync()
|
||||
private async Task StartSending(WebSocket ws)
|
||||
{
|
||||
await _tcs.Task;
|
||||
|
||||
// REVIEW: Close output vs Close?
|
||||
await _ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
|
||||
}
|
||||
|
||||
private async Task StartSending()
|
||||
{
|
||||
await _tcs.Task;
|
||||
|
||||
while (true)
|
||||
{
|
||||
var buffer = await _channel.Output.ReadAsync();
|
||||
|
||||
if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted)
|
||||
try
|
||||
{
|
||||
if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
foreach (var memory in buffer)
|
||||
{
|
||||
ArraySegment<byte> data;
|
||||
if (memory.TryGetArray(out data))
|
||||
{
|
||||
if (IsClosedOrClosedSent(ws))
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
await ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
catch (Exception)
|
||||
{
|
||||
// Error writing, probably closed
|
||||
break;
|
||||
}
|
||||
|
||||
foreach (var memory in buffer)
|
||||
finally
|
||||
{
|
||||
ArraySegment<byte> data;
|
||||
if (memory.TryGetArray(out data))
|
||||
{
|
||||
await _ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None);
|
||||
}
|
||||
_channel.Output.Advance(buffer.End);
|
||||
}
|
||||
|
||||
_channel.Output.Advance(buffer.End);
|
||||
}
|
||||
|
||||
_channel.Output.CompleteReader();
|
||||
|
||||
// REVIEW: Should this ever happen?
|
||||
if (!IsClosedOrClosedSent(ws))
|
||||
{
|
||||
// Close the output
|
||||
await ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
|
||||
}
|
||||
}
|
||||
|
||||
private static bool IsClosedOrClosedSent(WebSocket webSocket)
|
||||
{
|
||||
var webSocketState = GetWebSocketState(webSocket);
|
||||
|
||||
return webSocketState == WebSocketState.Closed ||
|
||||
webSocketState == WebSocketState.CloseSent ||
|
||||
webSocketState == WebSocketState.Aborted;
|
||||
}
|
||||
|
||||
private static WebSocketState GetWebSocketState(WebSocket webSocket)
|
||||
{
|
||||
try
|
||||
{
|
||||
return webSocket.State;
|
||||
}
|
||||
catch (ObjectDisposedException)
|
||||
{
|
||||
return WebSocketState.Closed;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue