Handle misbehaving user code (#159)

* Handle misbehaving user code
- Execute EndPoint logic on a threadpool thread
- Turn synchronous exceptions into async ones to unify the
error handling
- Added tests
This commit is contained in:
David Fowler 2017-01-25 23:45:43 +00:00 committed by GitHub
parent 934f6a70d1
commit 162cd1fc06
4 changed files with 155 additions and 5 deletions

View File

@ -56,9 +56,6 @@ namespace Microsoft.AspNetCore.SignalR
public override async Task OnConnectedAsync(Connection connection)
{
// TODO: Dispatch from the caller
await Task.Yield();
try
{
await _lifetimeManager.OnConnectedAsync(connection);

View File

@ -157,7 +157,7 @@ namespace Microsoft.AspNetCore.Sockets
state.Connection.Metadata["transport"] = LongPollingTransport.Name;
state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection);
state.ApplicationTask = ExecuteApplication(endpoint, state.Connection);
}
else
{
@ -268,7 +268,7 @@ namespace Microsoft.AspNetCore.Sockets
state.RequestId = context.TraceIdentifier;
// Call into the end point passing the connection
state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection);
state.ApplicationTask = ExecuteApplication(endpoint, state.Connection);
// Start the transport
state.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted);
@ -284,6 +284,16 @@ namespace Microsoft.AspNetCore.Sockets
await _manager.DisposeAndRemoveAsync(state);
}
private async Task ExecuteApplication(EndPoint endpoint, Connection connection)
{
// Jump onto the thread pool thread so blocking user code doesn't block the setup of the
// connection and transport
await AwaitableThreadPool.Yield();
// Running this in an async method turns sync exceptions into async ones
await endpoint.OnConnectedAsync(connection);
}
private Task ProcessNegotiate(HttpContext context)
{
// Establish the connection

View File

@ -0,0 +1,38 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public static class AwaitableThreadPool
{
public static Awaitable Yield()
{
return new Awaitable();
}
public struct Awaitable : ICriticalNotifyCompletion
{
public void GetResult()
{
}
public Awaitable GetAwaiter() => this;
public bool IsCompleted => false;
public void OnCompleted(Action continuation)
{
Task.Run(continuation);
}
public void UnsafeOnCompleted(Action continuation)
{
OnCompleted(continuation);
}
}
}
}

View File

@ -116,6 +116,44 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.False(exists);
}
[Fact]
public async Task SynchronusExceptionEndsConnection()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<SynchronusExceptionEndPoint>("/sse", state);
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>("", context);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
ConnectionState removed;
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
Assert.False(exists);
}
[Fact]
public async Task SynchronusExceptionEndsLongPollingConnection()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<SynchronusExceptionEndPoint>("/poll", state);
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>("", context);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
ConnectionState removed;
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
Assert.False(exists);
}
[Fact]
public async Task CompletedEndPointEndsLongPollingConnection()
{
@ -226,6 +264,56 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
}
[Fact]
public async Task BlockingConnectionWorksWithStreamingConnections()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<BlockingEndPoint>("/sse", state);
var task = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
var buffer = ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve();
// Write to the application
await state.Application.Output.WriteAsync(new Message(buffer, Format.Text, endOfMessage: true));
await task;
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
ConnectionState removed;
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
Assert.False(exists);
}
[Fact]
public async Task BlockingConnectionWorksWithLongPollingConnection()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<BlockingEndPoint>("/poll", state);
var task = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
var buffer = ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve();
// Write to the application
await state.Application.Output.WriteAsync(new Message(buffer, Format.Text, endOfMessage: true));
await task;
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
ConnectionState removed;
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
Assert.False(exists);
}
private static DefaultHttpContext MakeRequest<TEndPoint>(string path, ConnectionState state) where TEndPoint : EndPoint
{
var context = new DefaultHttpContext();
@ -246,6 +334,23 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
}
public class BlockingEndPoint : EndPoint
{
public override Task OnConnectedAsync(Connection connection)
{
connection.Transport.Input.WaitToReadAsync().Wait();
return Task.CompletedTask;
}
}
public class SynchronusExceptionEndPoint : EndPoint
{
public override Task OnConnectedAsync(Connection connection)
{
throw new InvalidOperationException();
}
}
public class ImmediatelyCompleteEndPoint : EndPoint
{
public override Task OnConnectedAsync(Connection connection)