From 162cd1fc0646624d9b9df583209ee27eb5bf1f85 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Wed, 25 Jan 2017 23:45:43 +0000 Subject: [PATCH] Handle misbehaving user code (#159) * Handle misbehaving user code - Execute EndPoint logic on a threadpool thread - Turn synchronous exceptions into async ones to unify the error handling - Added tests --- .../HubEndPoint.cs | 3 - .../HttpConnectionDispatcher.cs | 14 ++- .../Internal/AwaitableThreadPool.cs | 38 +++++++ .../HttpConnectionDispatcherTests.cs | 105 ++++++++++++++++++ 4 files changed, 155 insertions(+), 5 deletions(-) create mode 100644 src/Microsoft.AspNetCore.Sockets/Internal/AwaitableThreadPool.cs diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index aeed94c55d..ae4a357290 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -56,9 +56,6 @@ namespace Microsoft.AspNetCore.SignalR public override async Task OnConnectedAsync(Connection connection) { - // TODO: Dispatch from the caller - await Task.Yield(); - try { await _lifetimeManager.OnConnectedAsync(connection); diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index 163c4d2584..ea2da48e0d 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -157,7 +157,7 @@ namespace Microsoft.AspNetCore.Sockets state.Connection.Metadata["transport"] = LongPollingTransport.Name; - state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection); + state.ApplicationTask = ExecuteApplication(endpoint, state.Connection); } else { @@ -268,7 +268,7 @@ namespace Microsoft.AspNetCore.Sockets state.RequestId = context.TraceIdentifier; // Call into the end point passing the connection - state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection); + state.ApplicationTask = ExecuteApplication(endpoint, state.Connection); // Start the transport state.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted); @@ -284,6 +284,16 @@ namespace Microsoft.AspNetCore.Sockets await _manager.DisposeAndRemoveAsync(state); } + private async Task ExecuteApplication(EndPoint endpoint, Connection connection) + { + // Jump onto the thread pool thread so blocking user code doesn't block the setup of the + // connection and transport + await AwaitableThreadPool.Yield(); + + // Running this in an async method turns sync exceptions into async ones + await endpoint.OnConnectedAsync(connection); + } + private Task ProcessNegotiate(HttpContext context) { // Establish the connection diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/AwaitableThreadPool.cs b/src/Microsoft.AspNetCore.Sockets/Internal/AwaitableThreadPool.cs new file mode 100644 index 0000000000..9cae8fff28 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/Internal/AwaitableThreadPool.cs @@ -0,0 +1,38 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Sockets.Internal +{ + public static class AwaitableThreadPool + { + public static Awaitable Yield() + { + return new Awaitable(); + } + + public struct Awaitable : ICriticalNotifyCompletion + { + public void GetResult() + { + + } + + public Awaitable GetAwaiter() => this; + + public bool IsCompleted => false; + + public void OnCompleted(Action continuation) + { + Task.Run(continuation); + } + + public void UnsafeOnCompleted(Action continuation) + { + OnCompleted(continuation); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index a503ac053d..059e03d304 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -116,6 +116,44 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.False(exists); } + [Fact] + public async Task SynchronusExceptionEndsConnection() + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context = MakeRequest("/sse", state); + + await dispatcher.ExecuteAsync("", context); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + ConnectionState removed; + bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed); + Assert.False(exists); + } + + [Fact] + public async Task SynchronusExceptionEndsLongPollingConnection() + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context = MakeRequest("/poll", state); + + await dispatcher.ExecuteAsync("", context); + + Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); + + ConnectionState removed; + bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed); + Assert.False(exists); + } + [Fact] public async Task CompletedEndPointEndsLongPollingConnection() { @@ -226,6 +264,56 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); } + [Fact] + public async Task BlockingConnectionWorksWithStreamingConnections() + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context = MakeRequest("/sse", state); + + var task = dispatcher.ExecuteAsync("", context); + + var buffer = ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve(); + + // Write to the application + await state.Application.Output.WriteAsync(new Message(buffer, Format.Text, endOfMessage: true)); + + await task; + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + ConnectionState removed; + bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed); + Assert.False(exists); + } + + [Fact] + public async Task BlockingConnectionWorksWithLongPollingConnection() + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context = MakeRequest("/poll", state); + + var task = dispatcher.ExecuteAsync("", context); + + var buffer = ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve(); + + // Write to the application + await state.Application.Output.WriteAsync(new Message(buffer, Format.Text, endOfMessage: true)); + + await task; + + Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); + ConnectionState removed; + bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed); + Assert.False(exists); + } + private static DefaultHttpContext MakeRequest(string path, ConnectionState state) where TEndPoint : EndPoint { var context = new DefaultHttpContext(); @@ -246,6 +334,23 @@ namespace Microsoft.AspNetCore.Sockets.Tests } } + public class BlockingEndPoint : EndPoint + { + public override Task OnConnectedAsync(Connection connection) + { + connection.Transport.Input.WaitToReadAsync().Wait(); + return Task.CompletedTask; + } + } + + public class SynchronusExceptionEndPoint : EndPoint + { + public override Task OnConnectedAsync(Connection connection) + { + throw new InvalidOperationException(); + } + } + public class ImmediatelyCompleteEndPoint : EndPoint { public override Task OnConnectedAsync(Connection connection)