Code cleanup (somewhat)

- Transports have be drastically simplified due to channels
- Make sure tasks don't leak
This commit is contained in:
David Fowler 2016-10-03 01:52:18 -07:00
parent e018fe70f7
commit 8e66d63577
9 changed files with 128 additions and 242 deletions

View File

@ -20,11 +20,17 @@ namespace Microsoft.AspNetCore.Sockets
public IDisposable Subscribe(string key, Func<Message, Task> observer)
{
var connections = _subscriptions.GetOrAdd(key, _ => new List<Func<Message, Task>>());
connections.Add(observer);
lock (connections)
{
connections.Add(observer);
}
return new DisposableAction(() =>
{
connections.Remove(observer);
lock (connections)
{
connections.Remove(observer);
}
});
}
@ -33,10 +39,17 @@ namespace Microsoft.AspNetCore.Sockets
List<Func<Message, Task>> connections;
if (_subscriptions.TryGetValue(key, out connections))
{
foreach (var c in connections)
Task[] tasks = null;
lock (connections)
{
await c(message);
tasks = new Task[connections.Count];
for (int i = 0; i < connections.Count; i++)
{
tasks[i] = connections[i](message);
}
}
await Task.WhenAll(tasks);
}
}

View File

@ -47,8 +47,6 @@ namespace SocketsSample
{
Payload = Encoding.UTF8.GetBytes($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})")
});
connection.Channel.Input.Complete();
}
private async Task OnMessage(Message message, Connection connection)

View File

@ -54,16 +54,6 @@ namespace Microsoft.AspNetCore.Sockets
return state;
}
public void MarkConnectionInactive(string id)
{
ConnectionState state;
if (_connections.TryGetValue(id, out state))
{
// Mark the connection as active so the background thread can look at it
state.Active = false;
}
}
public void RemoveConnection(string id)
{
ConnectionState state;
@ -88,7 +78,7 @@ namespace Microsoft.AspNetCore.Sockets
// Scan the registered connections looking for ones that have timed out
foreach (var c in _connections)
{
if (!c.Value.Active && (DateTimeOffset.UtcNow - c.Value.LastSeen).TotalSeconds > 30)
if (!c.Value.Active && (DateTimeOffset.UtcNow - c.Value.LastSeen).TotalSeconds > 5)
{
ConnectionState s;
if (_connections.TryRemove(c.Key, out s))

View File

@ -56,7 +56,7 @@ namespace Microsoft.AspNetCore.Sockets
var sse = new ServerSentEvents(connectionState.Connection);
// Register this transport for disconnect
RegisterDisconnect(context, sse);
RegisterDisconnect(context, connectionState.Connection);
// Call into the end point passing the connection
var endpointTask = endpoint.OnConnected(connectionState.Connection);
@ -80,7 +80,7 @@ namespace Microsoft.AspNetCore.Sockets
var ws = new WebSockets(connectionState.Connection);
// Register this transport for disconnect
RegisterDisconnect(context, ws);
RegisterDisconnect(context, connectionState.Connection);
// Call into the end point passing the connection
var endpointTask = endpoint.OnConnected(connectionState.Connection);
@ -108,8 +108,6 @@ namespace Microsoft.AspNetCore.Sockets
if (connectionState.Connection.Channel == null)
{
// REVIEW: The connection manager should encapsulate this...
connectionState.Active = true;
connectionState.LastSeen = DateTimeOffset.UtcNow;
connectionState.Connection.Channel = new HttpChannel(_channelFactory);
isNewConnection = true;
}
@ -121,6 +119,9 @@ namespace Microsoft.AspNetCore.Sockets
isNewConnection = true;
}
// Mark the connection as active
connectionState.Active = true;
// Raise OnConnected for new connections only since polls happen all the time
if (isNewConnection)
{
@ -130,22 +131,30 @@ namespace Microsoft.AspNetCore.Sockets
var ignore = endpoint.OnConnected(connectionState.Connection);
}
var longPolling = new LongPolling(connectionState.Connection);
RegisterLongPollingDisconnect(context, connectionState.Connection);
// Register this transport for disconnect
RegisterDisconnect(context, longPolling);
var longPolling = new LongPolling(connectionState.Connection);
// Start the transport
await longPolling.ProcessRequest(context);
_manager.MarkConnectionInactive(connectionState.Connection.ConnectionId);
// Mark the connection as inactive
connectionState.LastSeen = DateTimeOffset.UtcNow;
connectionState.Active = false;
}
}
}
private static void RegisterDisconnect(HttpContext context, IHttpTransport transport)
private static void RegisterLongPollingDisconnect(HttpContext context, Connection connection)
{
context.RequestAborted.Register(state => ((IHttpTransport)state).CloseAsync(), transport);
// For long polling, we need to end the transport but not the overall connection so we write 0 bytes
context.RequestAborted.Register(state => ((HttpChannel)state).Output.WriteAsync(Span<byte>.Empty), connection.Channel);
}
private static void RegisterDisconnect(HttpContext context, Connection connection)
{
// We just kill the output writing as a signal to the transport that it is done
context.RequestAborted.Register(state => ((HttpChannel)state).Output.CompleteWriter(), connection.Channel);
}
private Task ProcessGetId(HttpContext context)

View File

@ -14,11 +14,5 @@ namespace Microsoft.AspNetCore.Sockets
/// <param name="context"></param>
/// <returns>A <see cref="Task"/> that completes when the transport has finished processing</returns>
Task ProcessRequest(HttpContext context);
/// <summary>
/// Completes the Task returned from ProcessRequest if not already complete
/// </summary>
/// <returns></returns>
Task CloseAsync();
}
}

View File

@ -7,88 +7,37 @@ namespace Microsoft.AspNetCore.Sockets
{
public class LongPolling : IHttpTransport
{
private readonly TaskCompletionSource<object> _initTcs = new TaskCompletionSource<object>();
private readonly TaskCompletionSource<object> _lifetime = new TaskCompletionSource<object>();
private readonly HttpChannel _channel;
private readonly Connection _connection;
private readonly TaskQueue _queue;
private HttpContext _context;
public LongPolling(Connection connection)
{
_queue = new TaskQueue(_initTcs.Task);
_connection = connection;
_channel = (HttpChannel)connection.Channel;
}
public async Task ProcessRequest(HttpContext context)
{
_context = context;
_initTcs.TrySetResult(null);
// Send queue messages to the connection
var ignore = ProcessMessages(context);
await _lifetime.Task;
}
private async Task ProcessMessages(HttpContext context)
{
var buffer = await _channel.Output.ReadAsync();
if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted)
{
await CloseAsync();
// REVIEW: Set the status code here so the client doesn't reconnect
return;
}
try
if (!buffer.IsEmpty)
{
await Send(buffer);
try
{
context.Response.ContentLength = buffer.Length;
await buffer.CopyToAsync(context.Response.Body);
}
finally
{
_channel.Output.Advance(buffer.End);
}
}
finally
{
_channel.Output.Advance(buffer.End);
}
await EndRequest();
}
public async Task CloseAsync()
{
await _queue.Enqueue(state =>
{
var context = (HttpContext)state;
// REVIEW: What happens if header was already?
context.Response.Headers["X-ASPNET-SOCKET-DISCONNECT"] = "1";
return Task.CompletedTask;
},
_context);
await EndRequest();
}
private async Task EndRequest()
{
// Drain the queue and don't let any new work enter
await _queue.Drain();
// Complete the lifetime task
_lifetime.TrySetResult(null);
}
private Task Send(ReadableBuffer value)
{
// REVIEW: Can we avoid the closure here?
return _queue.Enqueue(state =>
{
var data = (ReadableBuffer)state;
_context.Response.ContentLength = data.Length;
return data.CopyToAsync(_context.Response.Body);
},
value);
}
}
}

View File

@ -7,20 +7,13 @@ namespace Microsoft.AspNetCore.Sockets
{
public class ServerSentEvents : IHttpTransport
{
private readonly TaskCompletionSource<object> _initTcs = new TaskCompletionSource<object>();
private readonly TaskCompletionSource<object> _lifetime = new TaskCompletionSource<object>();
private readonly HttpChannel _channel;
private readonly Connection _connection;
private readonly TaskQueue _queue;
private HttpContext _context;
public ServerSentEvents(Connection connection)
{
_queue = new TaskQueue(_initTcs.Task);
_connection = connection;
_channel = (HttpChannel)connection.Channel;
var ignore = StartSending();
}
public async Task ProcessRequest(HttpContext context)
@ -28,27 +21,6 @@ namespace Microsoft.AspNetCore.Sockets
context.Response.ContentType = "text/event-stream";
context.Response.Headers["Cache-Control"] = "no-cache";
_context = context;
// Set the initial TCS when everything is setup
_initTcs.TrySetResult(null);
await _lifetime.Task;
}
public async Task CloseAsync()
{
// Drain the queue so no new work can enter
await _queue.Drain();
// Complete the lifetime task
_lifetime.TrySetResult(null);
}
private async Task StartSending()
{
await _initTcs.Task;
while (true)
{
var buffer = await _channel.Output.ReadAsync();
@ -58,7 +30,7 @@ namespace Microsoft.AspNetCore.Sockets
break;
}
await Send(buffer);
await Send(context, buffer);
_channel.Output.Advance(buffer.End);
}
@ -66,29 +38,24 @@ namespace Microsoft.AspNetCore.Sockets
_channel.Output.CompleteReader();
}
private Task Send(ReadableBuffer value)
private async Task Send(HttpContext context, ReadableBuffer data)
{
return _queue.Enqueue(async state =>
{
var data = (ReadableBuffer)state;
// TODO: Pooled buffers
// 8 = 6(data: ) + 2 (\n\n)
var buffer = new byte[8 + data.Length];
var at = 0;
buffer[at++] = (byte)'d';
buffer[at++] = (byte)'a';
buffer[at++] = (byte)'t';
buffer[at++] = (byte)'a';
buffer[at++] = (byte)':';
buffer[at++] = (byte)' ';
data.CopyTo(new Span<byte>(buffer, at, data.Length));
at += data.Length;
buffer[at++] = (byte)'\n';
buffer[at++] = (byte)'\n';
await _context.Response.Body.WriteAsync(buffer, 0, at);
await _context.Response.Body.FlushAsync();
},
value);
// TODO: Pooled buffers
// 8 = 6(data: ) + 2 (\n\n)
var buffer = new byte[8 + data.Length];
var at = 0;
buffer[at++] = (byte)'d';
buffer[at++] = (byte)'a';
buffer[at++] = (byte)'t';
buffer[at++] = (byte)'a';
buffer[at++] = (byte)':';
buffer[at++] = (byte)' ';
data.CopyTo(new Span<byte>(buffer, at, data.Length));
at += data.Length;
buffer[at++] = (byte)'\n';
buffer[at++] = (byte)'\n';
await context.Response.Body.WriteAsync(buffer, 0, at);
await context.Response.Body.FlushAsync();
}
}
}

View File

@ -1,68 +0,0 @@
using System;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets
{
// Allows serial queuing of Task instances
// The tasks are not called on the current synchronization context
internal sealed class TaskQueue
{
private readonly object _lockObj = new object();
private Task _lastQueuedTask;
private volatile bool _drained;
private long _size;
public TaskQueue()
: this(Task.CompletedTask)
{
}
public TaskQueue(Task initialTask)
{
_lastQueuedTask = initialTask;
}
public bool IsDrained
{
get
{
return _drained;
}
}
public Task Enqueue(Func<object, Task> taskFunc, object state)
{
// Lock the object for as short amount of time as possible
lock (_lockObj)
{
if (_drained)
{
return _lastQueuedTask;
}
var newTask = _lastQueuedTask.ContinueWith((t, s1) =>
{
if (t.IsFaulted || t.IsCanceled)
{
return t;
}
return taskFunc(s1);
},
state).Unwrap();
_lastQueuedTask = newTask;
return newTask;
}
}
public Task Drain()
{
lock (_lockObj)
{
_drained = true;
return _lastQueuedTask;
}
}
}
}

View File

@ -11,14 +11,11 @@ namespace Microsoft.AspNetCore.Sockets
{
private readonly HttpChannel _channel;
private readonly Connection _connection;
private readonly TaskCompletionSource<object> _tcs = new TaskCompletionSource<object>();
private WebSocket _ws;
public WebSockets(Connection connection)
{
_connection = connection;
_channel = (HttpChannel)connection.Channel;
var ignore = StartSending();
}
public async Task ProcessRequest(HttpContext context)
@ -31,9 +28,11 @@ namespace Microsoft.AspNetCore.Sockets
var ws = await context.WebSockets.AcceptWebSocketAsync();
_ws = ws;
_tcs.TrySetResult(null);
// REVIEW: Should we track this task? Leaving things like this alive usually causes memory leaks :)
// The reason we don't await this is because the channel is disposed after this loop returns
// and the sending loop is waiting for the channel to end before doing anything
// We could do a 2 stage shutdown but that could complicate the code...
var sending = StartSending(ws);
var outputBuffer = _channel.Input.Alloc();
@ -66,46 +65,81 @@ namespace Microsoft.AspNetCore.Sockets
}
else
{
await ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
break;
}
}
await ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
}
public async Task CloseAsync()
private async Task StartSending(WebSocket ws)
{
await _tcs.Task;
// REVIEW: Close output vs Close?
await _ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
}
private async Task StartSending()
{
await _tcs.Task;
while (true)
{
var buffer = await _channel.Output.ReadAsync();
if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted)
try
{
if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted)
{
break;
}
foreach (var memory in buffer)
{
ArraySegment<byte> data;
if (memory.TryGetArray(out data))
{
if (IsClosedOrClosedSent(ws))
{
break;
}
await ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None);
}
}
}
catch (Exception)
{
// Error writing, probably closed
break;
}
foreach (var memory in buffer)
finally
{
ArraySegment<byte> data;
if (memory.TryGetArray(out data))
{
await _ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None);
}
_channel.Output.Advance(buffer.End);
}
_channel.Output.Advance(buffer.End);
}
_channel.Output.CompleteReader();
// REVIEW: Should this ever happen?
if (!IsClosedOrClosedSent(ws))
{
// Close the output
await ws.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
}
}
private static bool IsClosedOrClosedSent(WebSocket webSocket)
{
var webSocketState = GetWebSocketState(webSocket);
return webSocketState == WebSocketState.Closed ||
webSocketState == WebSocketState.CloseSent ||
webSocketState == WebSocketState.Aborted;
}
private static WebSocketState GetWebSocketState(WebSocket webSocket)
{
try
{
return webSocket.State;
}
catch (ObjectDisposedException)
{
return WebSocketState.Closed;
}
}
}
}