Various fixes in HttpConnectionDispatcher (#151)

- The connection state object is manipulated by multiple parties in a non thread safe way. This change introduces a semaphore that should be used by anyone updating or reading the connection state. 
- Handle cases where there's an active request for a connection id and another incoming request for the same connection id, sse and websockets 409 and long polling kicks out the previous connection (https://github.com/aspnet/SignalR/issues/27 and https://github.com/aspnet/SignalR/issues/4)
- Handle requests being processed for disposed connections. There was a race where the background thread could remove and clean up the connection while it was about to be processed.
- Synchronize between the background scanning thread and the request threads when updating the connection state.
- Added `DisposeAndRemoveAsync` to the connection manager that handles`DisposeAsync` throwing and properly removes connections from connection tracking.
- Added Start to ConnectionManager so that testing is easier (background timer doesn't kick in unless start is called).
- Added RequestId to connection state for easier debugging and correlation (can easily see which request is currently processing the logical connection).
- Added tests
This commit is contained in:
David Fowler 2017-01-25 22:27:55 +00:00 committed by GitHub
parent acd1dc5e24
commit 934f6a70d1
12 changed files with 462 additions and 102 deletions

View File

@ -8,17 +8,27 @@ using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Sockets
{
public class ConnectionManager
{
private readonly ConcurrentDictionary<string, ConnectionState> _connections = new ConcurrentDictionary<string, ConnectionState>();
private readonly Timer _timer;
private Timer _timer;
private readonly ILogger<ConnectionManager> _logger;
public ConnectionManager()
public ConnectionManager(ILogger<ConnectionManager> logger)
{
_timer = new Timer(Scan, this, 0, 1000);
_logger = logger;
}
public void Start()
{
if (_timer == null)
{
_timer = new Timer(Scan, this, TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1));
}
}
public bool TryGetConnection(string id, out ConnectionState state)
@ -47,9 +57,11 @@ namespace Microsoft.AspNetCore.Sockets
public void RemoveConnection(string id)
{
ConnectionState state;
_connections.TryRemove(id, out state);
// Remove the connection completely
if (_connections.TryRemove(id, out state))
{
// Remove the connection completely
_logger.LogDebug("Removing {connectionId} from the list of connections", id);
}
}
private static string MakeNewConnectionId()
@ -65,38 +77,76 @@ namespace Microsoft.AspNetCore.Sockets
private void Scan()
{
// Scan the registered connections looking for ones that have timed out
foreach (var c in _connections)
// Pause the timer while we're running
_timer.Change(Timeout.Infinite, Timeout.Infinite);
try
{
if (!c.Value.Active && (DateTimeOffset.UtcNow - c.Value.LastSeenUtc).TotalSeconds > 5)
// Scan the registered connections looking for ones that have timed out
foreach (var c in _connections)
{
ConnectionState s;
if (_connections.TryRemove(c.Key, out s))
var status = ConnectionState.ConnectionStatus.Inactive;
var lastSeenUtc = DateTimeOffset.UtcNow;
try
{
// REVIEW: Should we keep firing and forgetting this?
var ignore = s.DisposeAsync();
c.Value.Lock.Wait();
// Capture the connection state
status = c.Value.Status;
lastSeenUtc = c.Value.LastSeenUtc;
}
finally
{
c.Value.Lock.Release();
}
// Once the decision has been made to to dispose we don't check the status again
if (status == ConnectionState.ConnectionStatus.Inactive && (DateTimeOffset.UtcNow - lastSeenUtc).TotalSeconds > 5)
{
var ignore = DisposeAndRemoveAsync(c.Value);
}
}
}
finally
{
// Resume once we finished processing all connections
_timer.Change(TimeSpan.FromSeconds(1), TimeSpan.FromSeconds(1));
}
}
public void CloseConnections()
{
// Stop firing the timer
_timer.Dispose();
_timer?.Dispose();
var tasks = new List<Task>();
foreach (var c in _connections)
{
ConnectionState s;
if (_connections.TryRemove(c.Key, out s))
{
tasks.Add(s.DisposeAsync());
}
tasks.Add(DisposeAndRemoveAsync(c.Value));
}
Task.WaitAll(tasks.ToArray(), TimeSpan.FromSeconds(5));
}
public async Task DisposeAndRemoveAsync(ConnectionState state)
{
try
{
await state.DisposeAsync();
}
catch (Exception ex)
{
_logger.LogError(0, ex, "Failed disposing connection {connectionId}", state.Connection.ConnectionId);
}
finally
{
// Remove it from the list after disposal so that's it's easy to see
// connections that might be in a hung state via the connections list
RemoveConnection(state.Connection.ConnectionId);
}
}
}
}

View File

@ -5,6 +5,7 @@ using System;
using System.IO;
using System.IO.Pipelines;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Sockets.Internal;
@ -70,8 +71,6 @@ namespace Microsoft.AspNetCore.Sockets
var sse = new ServerSentEventsTransport(state.Application.Input, _loggerFactory);
await DoPersistentConnection(endpoint, sse, context, state);
_manager.RemoveConnection(state.Connection.ConnectionId);
}
else if (context.Request.Path.StartsWithSegments(path + "/ws"))
{
@ -92,8 +91,6 @@ namespace Microsoft.AspNetCore.Sockets
var ws = new WebSocketsTransport(state.Application, _loggerFactory);
await DoPersistentConnection(endpoint, ws, context, state);
_manager.RemoveConnection(state.Connection.ConnectionId);
}
else if (context.Request.Path.StartsWithSegments(path + "/poll"))
{
@ -111,39 +108,112 @@ namespace Microsoft.AspNetCore.Sockets
return;
}
// Mark the connection as active
state.Active = true;
// Raise OnConnected for new connections only since polls happen all the time
if (state.ApplicationTask == null)
try
{
_logger.LogDebug("Establishing new Long Polling connection: {0}", state.Connection.ConnectionId);
await state.Lock.WaitAsync();
// This will re-initialize formatType metadata, but meh...
state.Connection.Metadata["transport"] = LongPollingTransport.Name;
if (state.Status == ConnectionState.ConnectionStatus.Disposed)
{
_logger.LogDebug("Connection {connectionId} was disposed,", state.Connection.ConnectionId);
state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection);
// The connection was disposed
context.Response.StatusCode = StatusCodes.Status404NotFound;
return;
}
if (state.Status == ConnectionState.ConnectionStatus.Active)
{
_logger.LogDebug("Connection {connectionId} is already active via {requestId}. Cancelling previous request.", state.Connection.ConnectionId, state.RequestId);
using (state.Cancellation)
{
// Cancel the previous request
state.Cancellation.Cancel();
try
{
// Wait for the previous request to drain
await state.TransportTask;
}
catch (OperationCanceledException)
{
// Should be a cancelled task
}
_logger.LogDebug("Previous poll cancelled for {connectionId} on {requestId}.", state.Connection.ConnectionId, state.RequestId);
}
}
// Mark the request identifier
state.RequestId = context.TraceIdentifier;
// Mark the connection as active
state.Status = ConnectionState.ConnectionStatus.Active;
// Raise OnConnected for new connections only since polls happen all the time
if (state.ApplicationTask == null)
{
_logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId);
state.Connection.Metadata["transport"] = LongPollingTransport.Name;
state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection);
}
else
{
_logger.LogDebug("Resuming existing connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId);
}
var longPolling = new LongPollingTransport(state.Application.Input, _loggerFactory);
state.Cancellation = new CancellationTokenSource();
// REVIEW: Performance of this isn't great as this does a bunch of per request allocations
var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(state.Cancellation.Token, context.RequestAborted);
// Start the transport
state.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token);
}
else
finally
{
_logger.LogDebug("Resuming existing Long Polling connection: {0}", state.Connection.ConnectionId);
state.Lock.Release();
}
var longPolling = new LongPollingTransport(state.Application.Input, _loggerFactory);
// Start the transport
state.TransportTask = longPolling.ProcessRequestAsync(context);
var resultTask = await Task.WhenAny(state.ApplicationTask, state.TransportTask);
// If the application ended before the transport task then we need to end the connection completely
// so there is no future polling
if (resultTask == state.ApplicationTask)
{
await state.DisposeAsync();
await _manager.DisposeAndRemoveAsync(state);
}
else if (!resultTask.IsCanceled)
{
// Otherwise, we update the state to inactive again and wait for the next poll
try
{
await state.Lock.WaitAsync();
// Mark the connection as inactive
state.LastSeenUtc = DateTime.UtcNow;
state.Active = false;
if (state.Status == ConnectionState.ConnectionStatus.Active)
{
// Mark the connection as inactive
state.LastSeenUtc = DateTime.UtcNow;
state.Status = ConnectionState.ConnectionStatus.Inactive;
state.RequestId = null;
// Dispose the cancellation token
state.Cancellation.Dispose();
state.Cancellation = null;
}
}
finally
{
state.Lock.Release();
}
}
}
}
@ -163,22 +233,55 @@ namespace Microsoft.AspNetCore.Sockets
return state;
}
private static async Task DoPersistentConnection(EndPoint endpoint,
IHttpTransport transport,
HttpContext context,
ConnectionState state)
private async Task DoPersistentConnection(EndPoint endpoint,
IHttpTransport transport,
HttpContext context,
ConnectionState state)
{
// Call into the end point passing the connection
state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection);
try
{
await state.Lock.WaitAsync();
// Start the transport
state.TransportTask = transport.ProcessRequestAsync(context);
if (state.Status == ConnectionState.ConnectionStatus.Disposed)
{
_logger.LogDebug("Connection {connectionId} was disposed,", state.Connection.ConnectionId);
// Connection was disposed
context.Response.StatusCode = StatusCodes.Status404NotFound;
return;
}
// There's already an active request
if (state.Status == ConnectionState.ConnectionStatus.Active)
{
_logger.LogDebug("Connection {connectionId} is already active via {requestId}.", state.Connection.ConnectionId, state.RequestId);
// Reject the request with a 409 conflict
context.Response.StatusCode = StatusCodes.Status409Conflict;
return;
}
// Mark the connection as active
state.Status = ConnectionState.ConnectionStatus.Active;
// Store the request identifier
state.RequestId = context.TraceIdentifier;
// Call into the end point passing the connection
state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection);
// Start the transport
state.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted);
}
finally
{
state.Lock.Release();
}
// Wait for any of them to end
await Task.WhenAny(state.ApplicationTask, state.TransportTask);
// Kill the channel
await state.DisposeAsync();
await _manager.DisposeAndRemoveAsync(state);
}
private Task ProcessNegotiate(HttpContext context)
@ -243,7 +346,7 @@ namespace Microsoft.AspNetCore.Sockets
}
else if (!string.Equals(transport, transportName, StringComparison.Ordinal))
{
context.Response.StatusCode = 400;
context.Response.StatusCode = StatusCodes.Status400BadRequest;
await context.Response.WriteAsync("Cannot change transports mid-connection");
return false;
}

View File

@ -2,7 +2,9 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Internal;
namespace Microsoft.AspNetCore.Sockets.Internal
{
@ -11,11 +13,17 @@ namespace Microsoft.AspNetCore.Sockets.Internal
public Connection Connection { get; set; }
public IChannelConnection<Message> Application { get; }
public CancellationTokenSource Cancellation { get; set; }
public SemaphoreSlim Lock { get; } = new SemaphoreSlim(1, 1);
public string RequestId { get; set; }
public Task TransportTask { get; set; }
public Task ApplicationTask { get; set; }
public DateTime LastSeenUtc { get; set; }
public bool Active { get; set; } = true;
public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive;
public ConnectionState(Connection connection, IChannelConnection<Message> application)
{
@ -26,23 +34,54 @@ namespace Microsoft.AspNetCore.Sockets.Internal
public async Task DisposeAsync()
{
// If the application task is faulted, propagate the error to the transport
if (ApplicationTask.IsFaulted)
{
Connection.Transport.Output.TryComplete(ApplicationTask.Exception.InnerException);
}
Task applicationTask = TaskCache.CompletedTask;
Task transportTask = TaskCache.CompletedTask;
// If the transport task is faulted, propagate the error to the application
if (TransportTask.IsFaulted)
try
{
Application.Output.TryComplete(TransportTask.Exception.InnerException);
}
await Lock.WaitAsync();
Connection.Dispose();
Application.Dispose();
if (Status == ConnectionStatus.Disposed)
{
return;
}
Status = ConnectionStatus.Disposed;
RequestId = null;
// If the application task is faulted, propagate the error to the transport
if (ApplicationTask.IsFaulted)
{
Connection.Transport.Output.TryComplete(ApplicationTask.Exception.InnerException);
}
// If the transport task is faulted, propagate the error to the application
if (TransportTask.IsFaulted)
{
Application.Output.TryComplete(TransportTask.Exception.InnerException);
}
Connection.Dispose();
Application.Dispose();
applicationTask = ApplicationTask;
transportTask = TransportTask;
}
finally
{
Lock.Release();
}
// REVIEW: Add a timeout so we don't wait forever
await Task.WhenAll(ApplicationTask, TransportTask);
await Task.WhenAll(applicationTask, transportTask);
}
public enum ConnectionStatus
{
Inactive,
Active,
Disposed
}
}
}

View File

@ -1,10 +1,6 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Hosting;
namespace Microsoft.AspNetCore.Sockets
@ -20,6 +16,7 @@ namespace Microsoft.AspNetCore.Sockets
public void Start()
{
_connectionManager.Start();
}
public void Stop()

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
@ -12,7 +13,8 @@ namespace Microsoft.AspNetCore.Sockets.Transports
/// Executes the transport
/// </summary>
/// <param name="context"></param>
/// <param name="token"></param>
/// <returns>A <see cref="Task"/> that completes when the transport has finished processing</returns>
Task ProcessRequestAsync(HttpContext context);
Task ProcessRequestAsync(HttpContext context, CancellationToken token);
}
}

View File

@ -23,17 +23,14 @@ namespace Microsoft.AspNetCore.Sockets.Transports
_logger = loggerFactory.CreateLogger<LongPollingTransport>();
}
public async Task ProcessRequestAsync(HttpContext context)
public async Task ProcessRequestAsync(HttpContext context, CancellationToken token)
{
try
{
// TODO: We need the ability to yield the connection without completing the channel.
// This is to force ReadAsync to yield without data to end to poll but not the entire connection.
// This is for cases when the client reconnects see issue #27
if (!await _application.WaitToReadAsync(context.RequestAborted))
if (!await _application.WaitToReadAsync(token))
{
_logger.LogInformation("Terminating Long Polling connection by sending 204 response.");
context.Response.StatusCode = 204;
context.Response.StatusCode = StatusCodes.Status204NoContent;
return;
}
@ -50,14 +47,17 @@ namespace Microsoft.AspNetCore.Sockets.Transports
}
catch (OperationCanceledException)
{
// Suppress the exception
if (!context.RequestAborted.IsCancellationRequested)
{
_logger.LogInformation("Terminating Long Polling connection by sending 204 response.");
context.Response.StatusCode = StatusCodes.Status204NoContent;
throw;
}
// Don't count this as cancellation, this is normal as the poll can end due to the browesr closing.
// The background thread will eventually dispose this connection if it's inactive
_logger.LogDebug("Client disconnected from Long Polling endpoint.");
}
catch (Exception ex)
{
_logger.LogError("Error reading next message from Application: {0}", ex);
throw;
}
}
}
}

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http;
@ -21,7 +22,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
_logger = loggerFactory.CreateLogger<ServerSentEventsTransport>();
}
public async Task ProcessRequestAsync(HttpContext context)
public async Task ProcessRequestAsync(HttpContext context, CancellationToken token)
{
context.Response.ContentType = "text/event-stream";
context.Response.Headers["Cache-Control"] = "no-cache";
@ -30,7 +31,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
try
{
while (await _application.WaitToReadAsync(context.RequestAborted))
while (await _application.WaitToReadAsync(token))
{
Message message;
while (_application.TryRead(out message))

View File

@ -3,6 +3,7 @@
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.WebSockets.Internal;
@ -40,7 +41,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
_logger = loggerFactory.CreateLogger<WebSocketsTransport>();
}
public async Task ProcessRequestAsync(HttpContext context)
public async Task ProcessRequestAsync(HttpContext context, CancellationToken token)
{
var feature = context.Features.Get<IHttpWebSocketConnectionFeature>();
if (feature == null || !feature.IsWebSocketRequest)

View File

@ -4,6 +4,7 @@
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.Logging;
using Xunit;
namespace Microsoft.AspNetCore.Sockets.Tests
@ -13,21 +14,23 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public void NewConnectionsHaveConnectionId()
{
var connectionManager = new ConnectionManager();
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.True(state.Active);
Assert.Equal(ConnectionState.ConnectionStatus.Inactive, state.Status);
Assert.Null(state.ApplicationTask);
Assert.Null(state.TransportTask);
Assert.Null(state.Cancellation);
Assert.Null(state.RequestId);
Assert.NotNull(state.Connection.Transport);
}
[Fact]
public void NewConnectionsCanBeRetrieved()
{
var connectionManager = new ConnectionManager();
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
Assert.NotNull(state.Connection);
@ -41,7 +44,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public void AddNewConnection()
{
var connectionManager = new ConnectionManager();
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
var transport = state.Connection.Transport;
@ -59,7 +62,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public void RemoveConnection()
{
var connectionManager = new ConnectionManager();
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
var transport = state.Connection.Transport;
@ -80,7 +83,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task CloseConnectionsEndsAllPendingConnections()
{
var connectionManager = new ConnectionManager();
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
state.ApplicationTask = Task.Run(async () =>
@ -97,5 +100,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
await state.DisposeAsync();
}
private static ConnectionManager CreateConnectionManager()
{
return new ConnectionManager(new Logger<ConnectionManager>(new LoggerFactory()));
}
}
}

View File

@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
@ -21,7 +22,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task NegotiateReservesConnectionIdAndReturnsIt()
{
var manager = new ConnectionManager();
var manager = CreateConnectionManager();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
@ -46,7 +47,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[InlineData("/ws")]
public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvided(string path)
{
var manager = new ConnectionManager();
var manager = CreateConnectionManager();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
using (var strm = new MemoryStream())
@ -77,7 +78,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[InlineData("/poll")]
public async Task EndpointsThatRequireConnectionId400WhenNoConnectionIdProvided(string path)
{
var manager = new ConnectionManager();
var manager = CreateConnectionManager();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
using (var strm = new MemoryStream())
{
@ -95,13 +96,171 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal("Connection ID required", Encoding.UTF8.GetString(strm.ToArray()));
}
}
[Fact]
public async Task CompletedEndPointEndsConnection()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/sse", state);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("", context);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
ConnectionState removed;
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
Assert.False(exists);
}
[Fact]
public async Task CompletedEndPointEndsLongPollingConnection()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/poll", state);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("", context);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
ConnectionState removed;
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
Assert.False(exists);
}
[Fact]
public async Task RequestToActiveConnectionId409ForStreamingTransports()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context1 = MakeRequest<TestEndPoint>("/sse", state);
var context2 = MakeRequest<TestEndPoint>("/sse", state);
var request1 = dispatcher.ExecuteAsync<TestEndPoint>("", context1);
await dispatcher.ExecuteAsync<TestEndPoint>("", context2);
Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode);
manager.CloseConnections();
await request1;
}
[Fact]
public async Task RequestToActiveConnectionIdKillsPreviousConnectionLongPolling()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context1 = MakeRequest<TestEndPoint>("/poll", state);
var context2 = MakeRequest<TestEndPoint>("/poll", state);
var request1 = dispatcher.ExecuteAsync<TestEndPoint>("", context1);
var request2 = dispatcher.ExecuteAsync<TestEndPoint>("", context2);
await request1;
Assert.Equal(StatusCodes.Status204NoContent, context1.Response.StatusCode);
Assert.Equal(ConnectionState.ConnectionStatus.Active, state.Status);
Assert.False(request2.IsCompleted);
manager.CloseConnections();
await request2;
}
[Theory]
[InlineData("/sse")]
[InlineData("/poll")]
public async Task RequestToDisposedConnectionIdReturns404(string path)
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
state.Status = ConnectionState.ConnectionStatus.Disposed;
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<TestEndPoint>(path, state);
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
}
[Fact]
public async Task ConnectionStateSetToInactiveAfterPoll()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<TestEndPoint>("/poll", state);
var task = dispatcher.ExecuteAsync<TestEndPoint>("", context);
var buffer = ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve();
// Write to the transport so the poll yields
await state.Connection.Transport.Output.WriteAsync(new Message(buffer, Format.Text, endOfMessage: true));
await task;
Assert.Equal(ConnectionState.ConnectionStatus.Inactive, state.Status);
Assert.Null(state.RequestId);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
}
private static DefaultHttpContext MakeRequest<TEndPoint>(string path, ConnectionState state) where TEndPoint : EndPoint
{
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = path;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
return context;
}
private static ConnectionManager CreateConnectionManager()
{
return new ConnectionManager(new Logger<ConnectionManager>(new LoggerFactory()));
}
}
public class ImmediatelyCompleteEndPoint : EndPoint
{
public override Task OnConnectedAsync(Connection connection)
{
return Task.CompletedTask;
}
}
public class TestEndPoint : EndPoint
{
public override Task OnConnectedAsync(Connection connection)
public override async Task OnConnectedAsync(Connection connection)
{
throw new NotImplementedException();
while (await connection.Transport.Input.WaitToReadAsync())
{
}
}
}
}

View File

@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.True(channel.Out.TryComplete());
await poll.ProcessRequestAsync(context);
await poll.ProcessRequestAsync(context, context.RequestAborted);
Assert.Equal(204, context.Response.StatusCode);
}
@ -48,7 +48,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.True(channel.Out.TryComplete());
await poll.ProcessRequestAsync(context);
await poll.ProcessRequestAsync(context, context.RequestAborted);
Assert.Equal(200, context.Response.StatusCode);
Assert.Equal("Hello World", Encoding.UTF8.GetString(ms.ToArray()));

View File

@ -24,7 +24,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.True(channel.Out.TryComplete());
await sse.ProcessRequestAsync(context);
await sse.ProcessRequestAsync(context, context.RequestAborted);
Assert.Equal("text/event-stream", context.Response.ContentType);
Assert.Equal("no-cache", context.Response.Headers["Cache-Control"]);
@ -46,7 +46,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.True(channel.Out.TryComplete());
await sse.ProcessRequestAsync(context);
await sse.ProcessRequestAsync(context, context.RequestAborted);
var expected = "data: Hello World\n\n";
Assert.Equal(expected, Encoding.UTF8.GetString(ms.ToArray()));