Handle misbehaving user code (#159)
* Handle misbehaving user code - Execute EndPoint logic on a threadpool thread - Turn synchronous exceptions into async ones to unify the error handling - Added tests
This commit is contained in:
parent
934f6a70d1
commit
162cd1fc06
|
|
@ -56,9 +56,6 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
public override async Task OnConnectedAsync(Connection connection)
|
||||
{
|
||||
// TODO: Dispatch from the caller
|
||||
await Task.Yield();
|
||||
|
||||
try
|
||||
{
|
||||
await _lifetimeManager.OnConnectedAsync(connection);
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
state.Connection.Metadata["transport"] = LongPollingTransport.Name;
|
||||
|
||||
state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection);
|
||||
state.ApplicationTask = ExecuteApplication(endpoint, state.Connection);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
@ -268,7 +268,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
state.RequestId = context.TraceIdentifier;
|
||||
|
||||
// Call into the end point passing the connection
|
||||
state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection);
|
||||
state.ApplicationTask = ExecuteApplication(endpoint, state.Connection);
|
||||
|
||||
// Start the transport
|
||||
state.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted);
|
||||
|
|
@ -284,6 +284,16 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
await _manager.DisposeAndRemoveAsync(state);
|
||||
}
|
||||
|
||||
private async Task ExecuteApplication(EndPoint endpoint, Connection connection)
|
||||
{
|
||||
// Jump onto the thread pool thread so blocking user code doesn't block the setup of the
|
||||
// connection and transport
|
||||
await AwaitableThreadPool.Yield();
|
||||
|
||||
// Running this in an async method turns sync exceptions into async ones
|
||||
await endpoint.OnConnectedAsync(connection);
|
||||
}
|
||||
|
||||
private Task ProcessNegotiate(HttpContext context)
|
||||
{
|
||||
// Establish the connection
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets.Internal
|
||||
{
|
||||
public static class AwaitableThreadPool
|
||||
{
|
||||
public static Awaitable Yield()
|
||||
{
|
||||
return new Awaitable();
|
||||
}
|
||||
|
||||
public struct Awaitable : ICriticalNotifyCompletion
|
||||
{
|
||||
public void GetResult()
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
public Awaitable GetAwaiter() => this;
|
||||
|
||||
public bool IsCompleted => false;
|
||||
|
||||
public void OnCompleted(Action continuation)
|
||||
{
|
||||
Task.Run(continuation);
|
||||
}
|
||||
|
||||
public void UnsafeOnCompleted(Action continuation)
|
||||
{
|
||||
OnCompleted(continuation);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -116,6 +116,44 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
Assert.False(exists);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SynchronusExceptionEndsConnection()
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<SynchronusExceptionEndPoint>("/sse", state);
|
||||
|
||||
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>("", context);
|
||||
|
||||
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
|
||||
|
||||
ConnectionState removed;
|
||||
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
|
||||
Assert.False(exists);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SynchronusExceptionEndsLongPollingConnection()
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<SynchronusExceptionEndPoint>("/poll", state);
|
||||
|
||||
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>("", context);
|
||||
|
||||
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
|
||||
|
||||
ConnectionState removed;
|
||||
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
|
||||
Assert.False(exists);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompletedEndPointEndsLongPollingConnection()
|
||||
{
|
||||
|
|
@ -226,6 +264,56 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task BlockingConnectionWorksWithStreamingConnections()
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<BlockingEndPoint>("/sse", state);
|
||||
|
||||
var task = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
|
||||
|
||||
var buffer = ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve();
|
||||
|
||||
// Write to the application
|
||||
await state.Application.Output.WriteAsync(new Message(buffer, Format.Text, endOfMessage: true));
|
||||
|
||||
await task;
|
||||
|
||||
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
|
||||
ConnectionState removed;
|
||||
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
|
||||
Assert.False(exists);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task BlockingConnectionWorksWithLongPollingConnection()
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<BlockingEndPoint>("/poll", state);
|
||||
|
||||
var task = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
|
||||
|
||||
var buffer = ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve();
|
||||
|
||||
// Write to the application
|
||||
await state.Application.Output.WriteAsync(new Message(buffer, Format.Text, endOfMessage: true));
|
||||
|
||||
await task;
|
||||
|
||||
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
|
||||
ConnectionState removed;
|
||||
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
|
||||
Assert.False(exists);
|
||||
}
|
||||
|
||||
private static DefaultHttpContext MakeRequest<TEndPoint>(string path, ConnectionState state) where TEndPoint : EndPoint
|
||||
{
|
||||
var context = new DefaultHttpContext();
|
||||
|
|
@ -246,6 +334,23 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
}
|
||||
}
|
||||
|
||||
public class BlockingEndPoint : EndPoint
|
||||
{
|
||||
public override Task OnConnectedAsync(Connection connection)
|
||||
{
|
||||
connection.Transport.Input.WaitToReadAsync().Wait();
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
}
|
||||
|
||||
public class SynchronusExceptionEndPoint : EndPoint
|
||||
{
|
||||
public override Task OnConnectedAsync(Connection connection)
|
||||
{
|
||||
throw new InvalidOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
public class ImmediatelyCompleteEndPoint : EndPoint
|
||||
{
|
||||
public override Task OnConnectedAsync(Connection connection)
|
||||
|
|
|
|||
Loading…
Reference in New Issue