From b77f50a99a2b6e3a62455866e759afba122ca619 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Sat, 3 Jun 2017 12:14:01 -1000 Subject: [PATCH] Merge DefaultConnectionContext with ConnectionState (#514) * Merge DefaultConnectionContext with ConnectionState - Removed ConnectionState as a result --- samples/ChatSample/Hubs/Chat.cs | 2 +- .../ConnectionContext.cs | 8 +- .../DefaultConnectionContext.cs | 104 ++++++++++- .../HttpConnectionDispatcher.cs | 169 +++++++++--------- .../ConnectionManager.cs | 30 ++-- .../ConnectionState.cs | 116 ------------ .../TestClient.cs | 6 +- .../ConnectionManagerTests.cs | 115 ++++++------ .../HttpConnectionDispatcherTests.cs | 129 +++++++------ 9 files changed, 320 insertions(+), 359 deletions(-) delete mode 100644 src/Microsoft.AspNetCore.Sockets/ConnectionState.cs diff --git a/samples/ChatSample/Hubs/Chat.cs b/samples/ChatSample/Hubs/Chat.cs index 2b9919ac70..4a52886f66 100644 --- a/samples/ChatSample/Hubs/Chat.cs +++ b/samples/ChatSample/Hubs/Chat.cs @@ -20,7 +20,7 @@ namespace ChatSample.Hubs { if (!Context.User.Identity.IsAuthenticated) { - Context.Connection.Dispose(); + Context.Connection.Transport.Dispose(); return; } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs index 7300980331..2ec3553fef 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs @@ -10,7 +10,7 @@ using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Sockets { - public abstract class ConnectionContext : IDisposable + public abstract class ConnectionContext { public abstract string ConnectionId { get; } @@ -23,11 +23,5 @@ namespace Microsoft.AspNetCore.Sockets // TEMPORARY public abstract IChannelConnection Transport { get; set; } - - // TEMPORARY - public void Dispose() - { - Transport?.Dispose(); - } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs index b353216d37..6617daa92f 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs @@ -3,18 +3,40 @@ using System; using System.Security.Claims; +using System.Threading; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Sockets { public class DefaultConnectionContext : ConnectionContext { - public DefaultConnectionContext(string id, IChannelConnection transport) + // This tcs exists so that multiple calls to DisposeAsync all wait asynchronously + // on the same task + private TaskCompletionSource _disposeTcs = new TaskCompletionSource(); + + public DefaultConnectionContext(string id, IChannelConnection transport, IChannelConnection application) { Transport = transport; + Application = application; ConnectionId = id; } + public CancellationTokenSource Cancellation { get; set; } + + public SemaphoreSlim Lock { get; } = new SemaphoreSlim(1, 1); + + // REVIEW: This should only be on the Http implementation + public string RequestId { get; set; } + + public Task TransportTask { get; set; } + + public Task ApplicationTask { get; set; } + + public DateTime LastSeenUtc { get; set; } + + public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive; + public override string ConnectionId { get; } public override IFeatureCollection Features { get; } = new FeatureCollection(); @@ -23,6 +45,86 @@ namespace Microsoft.AspNetCore.Sockets public override ConnectionMetadata Metadata { get; } = new ConnectionMetadata(); + public IChannelConnection Application { get; } + public override IChannelConnection Transport { get; set; } + + public async Task DisposeAsync() + { + Task disposeTask = Task.CompletedTask; + + try + { + await Lock.WaitAsync(); + + if (Status == ConnectionStatus.Disposed) + { + disposeTask = _disposeTcs.Task; + } + else + { + Status = ConnectionStatus.Disposed; + + RequestId = null; + + // If the application task is faulted, propagate the error to the transport + if (ApplicationTask?.IsFaulted == true) + { + Transport.Output.TryComplete(ApplicationTask.Exception.InnerException); + } + + // If the transport task is faulted, propagate the error to the application + if (TransportTask?.IsFaulted == true) + { + Application.Output.TryComplete(TransportTask.Exception.InnerException); + } + + Transport.Dispose(); + Application.Dispose(); + + var applicationTask = ApplicationTask ?? Task.CompletedTask; + var transportTask = TransportTask ?? Task.CompletedTask; + + disposeTask = WaitOnTasks(applicationTask, transportTask); + } + } + finally + { + Lock.Release(); + } + + await disposeTask; + } + + private async Task WaitOnTasks(Task applicationTask, Task transportTask) + { + try + { + await Task.WhenAll(applicationTask, transportTask); + + // Notify all waiters that we're done disposing + _disposeTcs.TrySetResult(null); + } + catch (OperationCanceledException) + { + _disposeTcs.TrySetCanceled(); + + throw; + } + catch (Exception ex) + { + _disposeTcs.TrySetException(ex); + + throw; + } + } + + + public enum ConnectionStatus + { + Inactive, + Active, + Disposed + } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs index 09e8e2cd89..c09fb78656 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs @@ -70,57 +70,57 @@ namespace Microsoft.AspNetCore.Sockets if (headers.Accept?.Contains(new Net.Http.Headers.MediaTypeHeaderValue("text/event-stream")) == true) { // Connection must already exist - var state = await GetConnectionAsync(context); - if (state == null) + var connection = await GetConnectionAsync(context); + if (connection == null) { // No such connection, GetConnection already set the response status code return; } - if (!await EnsureConnectionStateAsync(state, context, TransportType.ServerSentEvents, supportedTransports)) + if (!await EnsureConnectionStateAsync(connection, context, TransportType.ServerSentEvents, supportedTransports)) { // Bad connection state. It's already set the response status code. return; } // We only need to provide the Input channel since writing to the application is handled through /send. - var sse = new ServerSentEventsTransport(state.Application.Input, _loggerFactory); + var sse = new ServerSentEventsTransport(connection.Application.Input, _loggerFactory); - await DoPersistentConnection(socketDelegate, sse, context, state); + await DoPersistentConnection(socketDelegate, sse, context, connection); } else if (context.WebSockets.IsWebSocketRequest) { // Connection can be established lazily - var state = await GetOrCreateConnectionAsync(context); - if (state == null) + var connection = await GetOrCreateConnectionAsync(context); + if (connection == null) { // No such connection, GetOrCreateConnection already set the response status code return; } - if (!await EnsureConnectionStateAsync(state, context, TransportType.WebSockets, supportedTransports)) + if (!await EnsureConnectionStateAsync(connection, context, TransportType.WebSockets, supportedTransports)) { // Bad connection state. It's already set the response status code. return; } - var ws = new WebSocketsTransport(options.WebSockets, state.Application, _loggerFactory); + var ws = new WebSocketsTransport(options.WebSockets, connection.Application, _loggerFactory); - await DoPersistentConnection(socketDelegate, ws, context, state); + await DoPersistentConnection(socketDelegate, ws, context, connection); } else { // GET /{path} maps to long polling // Connection must already exist - var state = await GetConnectionAsync(context); - if (state == null) + var connection = await GetConnectionAsync(context); + if (connection == null) { // No such connection, GetConnection already set the response status code return; } - if (!await EnsureConnectionStateAsync(state, context, TransportType.LongPolling, supportedTransports)) + if (!await EnsureConnectionStateAsync(connection, context, TransportType.LongPolling, supportedTransports)) { // Bad connection state. It's already set the response status code. return; @@ -128,94 +128,94 @@ namespace Microsoft.AspNetCore.Sockets try { - await state.Lock.WaitAsync(); + await connection.Lock.WaitAsync(); - if (state.Status == ConnectionState.ConnectionStatus.Disposed) + if (connection.Status == DefaultConnectionContext.ConnectionStatus.Disposed) { - _logger.LogDebug("Connection {connectionId} was disposed,", state.Connection.ConnectionId); + _logger.LogDebug("Connection {connectionId} was disposed,", connection.ConnectionId); // The connection was disposed context.Response.StatusCode = StatusCodes.Status404NotFound; return; } - if (state.Status == ConnectionState.ConnectionStatus.Active) + if (connection.Status == DefaultConnectionContext.ConnectionStatus.Active) { - _logger.LogDebug("Connection {connectionId} is already active via {requestId}. Cancelling previous request.", state.Connection.ConnectionId, state.RequestId); + _logger.LogDebug("Connection {connectionId} is already active via {requestId}. Cancelling previous request.", connection.ConnectionId, connection.RequestId); - using (state.Cancellation) + using (connection.Cancellation) { // Cancel the previous request - state.Cancellation.Cancel(); + connection.Cancellation.Cancel(); try { // Wait for the previous request to drain - await state.TransportTask; + await connection.TransportTask; } catch (OperationCanceledException) { // Should be a cancelled task } - _logger.LogDebug("Previous poll cancelled for {connectionId} on {requestId}.", state.Connection.ConnectionId, state.RequestId); + _logger.LogDebug("Previous poll cancelled for {connectionId} on {requestId}.", connection.ConnectionId, connection.RequestId); } } // Mark the request identifier - state.RequestId = context.TraceIdentifier; + connection.RequestId = context.TraceIdentifier; // Mark the connection as active - state.Status = ConnectionState.ConnectionStatus.Active; + connection.Status = DefaultConnectionContext.ConnectionStatus.Active; // Raise OnConnected for new connections only since polls happen all the time - if (state.ApplicationTask == null) + if (connection.ApplicationTask == null) { - _logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId); + _logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", connection.ConnectionId, connection.RequestId); - state.Connection.Metadata[ConnectionMetadataNames.Transport] = TransportType.LongPolling; + connection.Metadata[ConnectionMetadataNames.Transport] = TransportType.LongPolling; - state.ApplicationTask = ExecuteApplication(socketDelegate, state.Connection); + connection.ApplicationTask = ExecuteApplication(socketDelegate, connection); } else { - _logger.LogDebug("Resuming existing connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId); + _logger.LogDebug("Resuming existing connection: {connectionId} on {requestId}", connection.ConnectionId, connection.RequestId); } - var longPolling = new LongPollingTransport(state.Application.Input, _loggerFactory); + var longPolling = new LongPollingTransport(connection.Application.Input, _loggerFactory); - state.Cancellation = new CancellationTokenSource(); + connection.Cancellation = new CancellationTokenSource(); // REVIEW: Performance of this isn't great as this does a bunch of per request allocations - var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(state.Cancellation.Token, context.RequestAborted); + var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(connection.Cancellation.Token, context.RequestAborted); // Start the transport - state.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token); + connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token); } finally { - state.Lock.Release(); + connection.Lock.Release(); } - var resultTask = await Task.WhenAny(state.ApplicationTask, state.TransportTask); + var resultTask = await Task.WhenAny(connection.ApplicationTask, connection.TransportTask); var pollAgain = true; // If the application ended before the transport task then we need to potentially need to end the // connection - if (resultTask == state.ApplicationTask) + if (resultTask == connection.ApplicationTask) { // Complete the transport (notifying it of the application error if there is one) - state.Connection.Transport.Output.TryComplete(state.ApplicationTask.Exception); + connection.Transport.Output.TryComplete(connection.ApplicationTask.Exception); // Wait for the transport to run - await state.TransportTask; + await connection.TransportTask; // If the status code is a 204 it means we didn't write anything if (context.Response.StatusCode == StatusCodes.Status204NoContent) { // We should be able to safely dispose because there's no more data being written - await _manager.DisposeAndRemoveAsync(state); + await _manager.DisposeAndRemoveAsync(connection); // Don't poll again if we've removed the connection completely pollAgain = false; @@ -232,53 +232,53 @@ namespace Microsoft.AspNetCore.Sockets // Otherwise, we update the state to inactive again and wait for the next poll try { - await state.Lock.WaitAsync(); + await connection.Lock.WaitAsync(); - if (state.Status == ConnectionState.ConnectionStatus.Active) + if (connection.Status == DefaultConnectionContext.ConnectionStatus.Active) { // Mark the connection as inactive - state.LastSeenUtc = DateTime.UtcNow; + connection.LastSeenUtc = DateTime.UtcNow; - state.Status = ConnectionState.ConnectionStatus.Inactive; + connection.Status = DefaultConnectionContext.ConnectionStatus.Inactive; - state.RequestId = null; + connection.RequestId = null; // Dispose the cancellation token - state.Cancellation.Dispose(); + connection.Cancellation.Dispose(); - state.Cancellation = null; + connection.Cancellation = null; } } finally { - state.Lock.Release(); + connection.Lock.Release(); } } } } - private ConnectionState CreateConnection(HttpContext context) + private DefaultConnectionContext CreateConnection(HttpContext context) { - var state = _manager.CreateConnection(); + var connection = _manager.CreateConnection(); var format = (string)context.Request.Query[ConnectionMetadataNames.Format]; - state.Connection.User = context.User; - state.Connection.Metadata[ConnectionMetadataNames.HttpContext] = context; - state.Connection.Metadata[ConnectionMetadataNames.Format] = string.IsNullOrEmpty(format) ? "json" : format; - return state; + connection.User = context.User; + connection.Metadata[ConnectionMetadataNames.HttpContext] = context; + connection.Metadata[ConnectionMetadataNames.Format] = string.IsNullOrEmpty(format) ? "json" : format; + return connection; } private async Task DoPersistentConnection(SocketDelegate socketDelegate, IHttpTransport transport, HttpContext context, - ConnectionState state) + DefaultConnectionContext connection) { try { - await state.Lock.WaitAsync(); + await connection.Lock.WaitAsync(); - if (state.Status == ConnectionState.ConnectionStatus.Disposed) + if (connection.Status == DefaultConnectionContext.ConnectionStatus.Disposed) { - _logger.LogDebug("Connection {connectionId} was disposed,", state.Connection.ConnectionId); + _logger.LogDebug("Connection {connectionId} was disposed,", connection.ConnectionId); // Connection was disposed context.Response.StatusCode = StatusCodes.Status404NotFound; @@ -286,9 +286,9 @@ namespace Microsoft.AspNetCore.Sockets } // There's already an active request - if (state.Status == ConnectionState.ConnectionStatus.Active) + if (connection.Status == DefaultConnectionContext.ConnectionStatus.Active) { - _logger.LogDebug("Connection {connectionId} is already active via {requestId}.", state.Connection.ConnectionId, state.RequestId); + _logger.LogDebug("Connection {connectionId} is already active via {requestId}.", connection.ConnectionId, connection.RequestId); // Reject the request with a 409 conflict context.Response.StatusCode = StatusCodes.Status409Conflict; @@ -296,26 +296,26 @@ namespace Microsoft.AspNetCore.Sockets } // Mark the connection as active - state.Status = ConnectionState.ConnectionStatus.Active; + connection.Status = DefaultConnectionContext.ConnectionStatus.Active; // Store the request identifier - state.RequestId = context.TraceIdentifier; + connection.RequestId = context.TraceIdentifier; // Call into the end point passing the connection - state.ApplicationTask = ExecuteApplication(socketDelegate, state.Connection); + connection.ApplicationTask = ExecuteApplication(socketDelegate, connection); // Start the transport - state.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted); + connection.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted); } finally { - state.Lock.Release(); + connection.Lock.Release(); } // Wait for any of them to end - await Task.WhenAny(state.ApplicationTask, state.TransportTask); + await Task.WhenAny(connection.ApplicationTask, connection.TransportTask); - await _manager.DisposeAndRemoveAsync(state); + await _manager.DisposeAndRemoveAsync(connection); } private async Task ExecuteApplication(SocketDelegate socketDelegate, ConnectionContext connection) @@ -336,10 +336,10 @@ namespace Microsoft.AspNetCore.Sockets context.Response.ContentType = "text/plain"; // Establish the connection - var state = CreateConnection(context); + var connection = CreateConnection(context); // Get the bytes for the connection id - var connectionIdBuffer = Encoding.UTF8.GetBytes(state.Connection.ConnectionId); + var connectionIdBuffer = Encoding.UTF8.GetBytes(connection.ConnectionId); // Write it out to the response with the right content length context.Response.ContentLength = connectionIdBuffer.Length; @@ -348,8 +348,8 @@ namespace Microsoft.AspNetCore.Sockets private async Task ProcessSend(HttpContext context) { - var state = await GetConnectionAsync(context); - if (state == null) + var connection = await GetConnectionAsync(context); + if (connection == null) { // No such connection, GetConnection already set the response status code return; @@ -388,9 +388,9 @@ namespace Microsoft.AspNetCore.Sockets _logger.LogDebug("Received batch of {count} message(s)", messages.Count); foreach (var message in messages) { - while (!state.Application.Output.TryWrite(message)) + while (!connection.Application.Output.TryWrite(message)) { - if (!await state.Application.Output.WaitToWriteAsync()) + if (!await connection.Application.Output.WaitToWriteAsync()) { return; } @@ -398,7 +398,7 @@ namespace Microsoft.AspNetCore.Sockets } } - private async Task EnsureConnectionStateAsync(ConnectionState connectionState, HttpContext context, TransportType transportType, TransportType supportedTransports) + private async Task EnsureConnectionStateAsync(DefaultConnectionContext connection, HttpContext context, TransportType transportType, TransportType supportedTransports) { if ((supportedTransports & transportType) == 0) { @@ -407,13 +407,13 @@ namespace Microsoft.AspNetCore.Sockets return false; } - connectionState.Connection.User = context.User; + connection.User = context.User; - var transport = connectionState.Connection.Metadata.Get(ConnectionMetadataNames.Transport); + var transport = connection.Metadata.Get(ConnectionMetadataNames.Transport); if (transport == null) { - connectionState.Connection.Metadata[ConnectionMetadataNames.Transport] = transportType; + connection.Metadata[ConnectionMetadataNames.Transport] = transportType; } else if (transport != transportType) { @@ -424,10 +424,9 @@ namespace Microsoft.AspNetCore.Sockets return true; } - private async Task GetConnectionAsync(HttpContext context) + private async Task GetConnectionAsync(HttpContext context) { var connectionId = context.Request.Query["id"]; - ConnectionState connectionState; if (StringValues.IsNullOrEmpty(connectionId)) { @@ -437,7 +436,7 @@ namespace Microsoft.AspNetCore.Sockets return null; } - if (!_manager.TryGetConnection(connectionId, out connectionState)) + if (!_manager.TryGetConnection(connectionId, out var connection)) { // No connection with that ID: Not Found context.Response.StatusCode = StatusCodes.Status404NotFound; @@ -445,20 +444,20 @@ namespace Microsoft.AspNetCore.Sockets return null; } - return connectionState; + return connection; } - private async Task GetOrCreateConnectionAsync(HttpContext context) + private async Task GetOrCreateConnectionAsync(HttpContext context) { var connectionId = context.Request.Query["id"]; - ConnectionState connectionState; + DefaultConnectionContext connection; // There's no connection id so this is a brand new connection if (StringValues.IsNullOrEmpty(connectionId)) { - connectionState = CreateConnection(context); + connection = CreateConnection(context); } - else if (!_manager.TryGetConnection(connectionId, out connectionState)) + else if (!_manager.TryGetConnection(connectionId, out connection)) { // No connection with that ID: Not Found context.Response.StatusCode = StatusCodes.Status404NotFound; @@ -466,7 +465,7 @@ namespace Microsoft.AspNetCore.Sockets return null; } - return connectionState; + return connection; } private List ParseSendBatch(ref BytesReader payload, MessageFormat messageFormat) diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index 201fb02c8c..90735e2ac9 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -14,7 +14,7 @@ namespace Microsoft.AspNetCore.Sockets { public class ConnectionManager { - private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); private Timer _timer; private readonly ILogger _logger; private object _executionLock = new object(); @@ -41,12 +41,12 @@ namespace Microsoft.AspNetCore.Sockets } } - public bool TryGetConnection(string id, out ConnectionState state) + public bool TryGetConnection(string id, out DefaultConnectionContext connection) { - return _connections.TryGetValue(id, out state); + return _connections.TryGetValue(id, out connection); } - public ConnectionState CreateConnection() + public DefaultConnectionContext CreateConnection() { var id = MakeNewConnectionId(); @@ -56,12 +56,10 @@ namespace Microsoft.AspNetCore.Sockets var transportSide = new ChannelConnection(applicationToTransport, transportToApplication); var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); - var state = new ConnectionState( - new DefaultConnectionContext(id, applicationSide), - transportSide); - - _connections.TryAdd(id, state); - return state; + var connection = new DefaultConnectionContext(id, applicationSide, transportSide); + + _connections.TryAdd(id, connection); + return connection; } public void RemoveConnection(string id) @@ -108,7 +106,7 @@ namespace Microsoft.AspNetCore.Sockets // Scan the registered connections looking for ones that have timed out foreach (var c in _connections) { - var status = ConnectionState.ConnectionStatus.Inactive; + var status = DefaultConnectionContext.ConnectionStatus.Inactive; var lastSeenUtc = DateTimeOffset.UtcNow; try @@ -126,7 +124,7 @@ namespace Microsoft.AspNetCore.Sockets } // Once the decision has been made to to dispose we don't check the status again - if (status == ConnectionState.ConnectionStatus.Inactive && (DateTimeOffset.UtcNow - lastSeenUtc).TotalSeconds > 5) + if (status == DefaultConnectionContext.ConnectionStatus.Inactive && (DateTimeOffset.UtcNow - lastSeenUtc).TotalSeconds > 5) { var ignore = DisposeAndRemoveAsync(c.Value); } @@ -167,21 +165,21 @@ namespace Microsoft.AspNetCore.Sockets } } - public async Task DisposeAndRemoveAsync(ConnectionState state) + public async Task DisposeAndRemoveAsync(DefaultConnectionContext connection) { try { - await state.DisposeAsync(); + await connection.DisposeAsync(); } catch (Exception ex) { - _logger.LogError(0, ex, "Failed disposing connection {connectionId}", state.Connection.ConnectionId); + _logger.LogError(0, ex, "Failed disposing connection {connectionId}", connection.ConnectionId); } finally { // Remove it from the list after disposal so that's it's easy to see // connections that might be in a hung state via the connections list - RemoveConnection(state.Connection.ConnectionId); + RemoveConnection(connection.ConnectionId); } } } diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionState.cs deleted file mode 100644 index 7c4db1a773..0000000000 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionState.cs +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.Internal; - -namespace Microsoft.AspNetCore.Sockets.Internal -{ - public class ConnectionState - { - // This tcs exists so that multiple calls to DisposeAsync all wait asynchronously - // on the same task - private TaskCompletionSource _disposeTcs = new TaskCompletionSource(); - - public ConnectionContext Connection { get; set; } - public IChannelConnection Application { get; } - - public CancellationTokenSource Cancellation { get; set; } - - public SemaphoreSlim Lock { get; } = new SemaphoreSlim(1, 1); - - public string RequestId { get; set; } - - public Task TransportTask { get; set; } - public Task ApplicationTask { get; set; } - - public DateTime LastSeenUtc { get; set; } - public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive; - - public ConnectionState(ConnectionContext connection, IChannelConnection application) - { - Connection = connection; - Application = application; - LastSeenUtc = DateTime.UtcNow; - } - - public async Task DisposeAsync() - { - Task disposeTask = Task.CompletedTask; - - try - { - await Lock.WaitAsync(); - - if (Status == ConnectionStatus.Disposed) - { - disposeTask = _disposeTcs.Task; - } - else - { - Status = ConnectionStatus.Disposed; - - RequestId = null; - - // If the application task is faulted, propagate the error to the transport - if (ApplicationTask?.IsFaulted == true) - { - Connection.Transport.Output.TryComplete(ApplicationTask.Exception.InnerException); - } - - // If the transport task is faulted, propagate the error to the application - if (TransportTask?.IsFaulted == true) - { - Application.Output.TryComplete(TransportTask.Exception.InnerException); - } - - Connection.Dispose(); - Application.Dispose(); - - var applicationTask = ApplicationTask ?? Task.CompletedTask; - var transportTask = TransportTask ?? Task.CompletedTask; - - disposeTask = WaitOnTasks(applicationTask, transportTask); - } - } - finally - { - Lock.Release(); - } - - await disposeTask; - } - - private async Task WaitOnTasks(Task applicationTask, Task transportTask) - { - try - { - await Task.WhenAll(applicationTask, transportTask); - - // Notify all waiters that we're done disposing - _disposeTcs.TrySetResult(null); - } - catch (OperationCanceledException) - { - _disposeTcs.TrySetCanceled(); - - throw; - } - catch (Exception ex) - { - _disposeTcs.TrySetException(ex); - - throw; - } - } - - public enum ConnectionStatus - { - Inactive, - Active, - Disposed - } - } -} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs index 1cd51f0180..475cd492ac 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs @@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests private IHubProtocol _protocol; private CancellationTokenSource _cts; - public ConnectionContext Connection { get; } + public DefaultConnectionContext Connection { get; } public IChannelConnection Application { get; } public Task Connected => Connection.Metadata.Get>("ConnectedTask").Task; @@ -33,7 +33,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests Application = ChannelConnection.Create(input: applicationToTransport, output: transportToApplication); var transport = ChannelConnection.Create(input: transportToApplication, output: applicationToTransport); - Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), transport); + Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), transport, Application); Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) })); Connection.Metadata["ConnectedTask"] = new TaskCompletionSource(); @@ -147,7 +147,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public void Dispose() { _cts.Cancel(); - Connection.Dispose(); + Connection.Transport.Dispose(); } private static string GetInvocationId() diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs index 4b8f3d32d1..da376c8247 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs @@ -4,7 +4,6 @@ using System; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Tests.Common; -using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.Logging; using Xunit; @@ -16,104 +15,97 @@ namespace Microsoft.AspNetCore.Sockets.Tests public void NewConnectionsHaveConnectionId() { var connectionManager = CreateConnectionManager(); - var state = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(); - Assert.NotNull(state.Connection); - Assert.NotNull(state.Connection.ConnectionId); - Assert.Equal(ConnectionState.ConnectionStatus.Inactive, state.Status); - Assert.Null(state.ApplicationTask); - Assert.Null(state.TransportTask); - Assert.Null(state.Cancellation); - Assert.Null(state.RequestId); - Assert.NotNull(state.Connection.Transport); + Assert.NotNull(connection.ConnectionId); + Assert.Equal(DefaultConnectionContext.ConnectionStatus.Inactive, connection.Status); + Assert.Null(connection.ApplicationTask); + Assert.Null(connection.TransportTask); + Assert.Null(connection.Cancellation); + Assert.Null(connection.RequestId); + Assert.NotNull(connection.Transport); } [Fact] public void NewConnectionsCanBeRetrieved() { var connectionManager = CreateConnectionManager(); - var state = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(); - Assert.NotNull(state.Connection); - Assert.NotNull(state.Connection.ConnectionId); + Assert.NotNull(connection.ConnectionId); - ConnectionState newState; - Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); - Assert.Same(newState, state); + Assert.True(connectionManager.TryGetConnection(connection.ConnectionId, out var newConnection)); + Assert.Same(newConnection, connection); } [Fact] public void AddNewConnection() { var connectionManager = CreateConnectionManager(); - var state = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(); - var transport = state.Connection.Transport; + var transport = connection.Transport; - Assert.NotNull(state.Connection); - Assert.NotNull(state.Connection.ConnectionId); + Assert.NotNull(connection.ConnectionId); Assert.NotNull(transport); - ConnectionState newState; - Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); - Assert.Same(newState, state); - Assert.Same(transport, newState.Connection.Transport); + Assert.True(connectionManager.TryGetConnection(connection.ConnectionId, out var newConnection)); + Assert.Same(newConnection, connection); + Assert.Same(transport, newConnection.Transport); } [Fact] public void RemoveConnection() { var connectionManager = CreateConnectionManager(); - var state = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(); - var transport = state.Connection.Transport; + var transport = connection.Transport; - Assert.NotNull(state.Connection); - Assert.NotNull(state.Connection.ConnectionId); + Assert.NotNull(connection.ConnectionId); Assert.NotNull(transport); - ConnectionState newState; - Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); - Assert.Same(newState, state); - Assert.Same(transport, newState.Connection.Transport); + Assert.True(connectionManager.TryGetConnection(connection.ConnectionId, out var newConnection)); + Assert.Same(newConnection, connection); + Assert.Same(transport, newConnection.Transport); - connectionManager.RemoveConnection(state.Connection.ConnectionId); - Assert.False(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); + connectionManager.RemoveConnection(connection.ConnectionId); + Assert.False(connectionManager.TryGetConnection(connection.ConnectionId, out newConnection)); } [Fact] public async Task CloseConnectionsEndsAllPendingConnections() { var connectionManager = CreateConnectionManager(); - var state = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(); - state.ApplicationTask = Task.Run(async () => + connection.ApplicationTask = Task.Run(async () => { - Assert.False(await state.Connection.Transport.Input.WaitToReadAsync()); + Assert.False(await connection.Transport.Input.WaitToReadAsync()); }); - state.TransportTask = Task.Run(async () => + connection.TransportTask = Task.Run(async () => { - Assert.False(await state.Application.Input.WaitToReadAsync()); + Assert.False(await connection.Application.Input.WaitToReadAsync()); }); connectionManager.CloseConnections(); - await state.DisposeAsync(); + await connection.DisposeAsync(); } [Fact] public async Task DisposingConnectionMultipleTimesWaitsOnConnectionClose() { var connectionManager = CreateConnectionManager(); - var state = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - state.ApplicationTask = tcs.Task; - state.TransportTask = tcs.Task; + connection.ApplicationTask = tcs.Task; + connection.TransportTask = tcs.Task; - var firstTask = state.DisposeAsync(); - var secondTask = state.DisposeAsync(); + var firstTask = connection.DisposeAsync(); + var secondTask = connection.DisposeAsync(); Assert.False(firstTask.IsCompleted); Assert.False(secondTask.IsCompleted); @@ -126,14 +118,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task DisposingConnectionMultipleGetsExceptionFromTransportOrApp() { var connectionManager = CreateConnectionManager(); - var state = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - state.ApplicationTask = tcs.Task; - state.TransportTask = tcs.Task; + connection.ApplicationTask = tcs.Task; + connection.TransportTask = tcs.Task; - var firstTask = state.DisposeAsync(); - var secondTask = state.DisposeAsync(); + var firstTask = connection.DisposeAsync(); + var secondTask = connection.DisposeAsync(); Assert.False(firstTask.IsCompleted); Assert.False(secondTask.IsCompleted); @@ -150,14 +142,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task DisposingConnectionMultipleGetsCancellation() { var connectionManager = CreateConnectionManager(); - var state = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(); var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - state.ApplicationTask = tcs.Task; - state.TransportTask = tcs.Task; + connection.ApplicationTask = tcs.Task; + connection.TransportTask = tcs.Task; - var firstTask = state.DisposeAsync(); - var secondTask = state.DisposeAsync(); + var firstTask = connection.DisposeAsync(); + var secondTask = connection.DisposeAsync(); Assert.False(firstTask.IsCompleted); Assert.False(secondTask.IsCompleted); @@ -171,21 +163,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task DisposeInactiveConnection() { var connectionManager = CreateConnectionManager(); - var state = connectionManager.CreateConnection();; + var connection = connectionManager.CreateConnection();; - Assert.NotNull(state.Connection); - Assert.NotNull(state.Connection.ConnectionId); - Assert.NotNull(state.Connection.Transport); + Assert.NotNull(connection.ConnectionId); + Assert.NotNull(connection.Transport); - await state.DisposeAsync(); - Assert.Equal(ConnectionState.ConnectionStatus.Disposed, state.Status); + await connection.DisposeAsync(); + Assert.Equal(DefaultConnectionContext.ConnectionStatus.Disposed, connection.Status); } [Fact] public void ScanAfterDisposeNoops() { var connectionManager = CreateConnectionManager(); - var state = connectionManager.CreateConnection(); + var connection = connectionManager.CreateConnection(); connectionManager.CloseConnections(); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 42d98a88ea..9f1a1279ee 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -14,7 +14,6 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Internal; using Microsoft.AspNetCore.SignalR.Tests.Common; -using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; @@ -48,9 +47,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests var id = Encoding.UTF8.GetString(ms.ToArray()); - ConnectionState state; - Assert.True(manager.TryGetConnection(id, out state)); - Assert.Equal(id, state.Connection.ConnectionId); + Assert.True(manager.TryGetConnection(id, out var connection)); + Assert.Equal(id, connection.ConnectionId); } [Theory] @@ -182,7 +180,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task SendRequestsWithInvalidContentTypeAreRejected() { var manager = CreateConnectionManager(); - var connectionState = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); using (var strm = new MemoryStream()) { @@ -192,7 +190,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests services.AddEndPoint(); context.Request.Path = "/foo"; context.Request.Method = "POST"; - context.Request.QueryString = new QueryString($"?id={connectionState.Connection.ConnectionId}"); + context.Request.QueryString = new QueryString($"?id={connection.ConnectionId}"); context.Request.ContentType = "text/plain"; context.Response.Body = strm; @@ -245,11 +243,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task CompletedEndPointEndsConnection() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", connection); SetTransport(context, TransportType.ServerSentEvents); var services = new ServiceCollection(); @@ -261,8 +259,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - ConnectionState removed; - bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed); + bool exists = manager.TryGetConnection(connection.ConnectionId, out _); Assert.False(exists); } @@ -270,10 +267,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task SynchronusExceptionEndsConnection() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", connection); SetTransport(context, TransportType.ServerSentEvents); var services = new ServiceCollection(); @@ -285,8 +282,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - ConnectionState removed; - bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed); + bool exists = manager.TryGetConnection(connection.ConnectionId, out _); Assert.False(exists); } @@ -294,11 +290,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task CompletedEndPointEndsLongPollingConnection() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", connection); var services = new ServiceCollection(); services.AddEndPoint(); @@ -309,8 +305,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); - ConnectionState removed; - bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed); + bool exists = manager.TryGetConnection(connection.ConnectionId, out _); Assert.False(exists); } @@ -318,11 +313,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task WebSocketTransportTimesOutWhenCloseFrameNotReceived() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", connection); SetTransport(context, TransportType.WebSockets); var services = new ServiceCollection(); @@ -344,12 +339,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task RequestToActiveConnectionId409ForStreamingTransports(TransportType transportType) { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context1 = MakeRequest("/foo", state); - var context2 = MakeRequest("/foo", state); + var context1 = MakeRequest("/foo", connection); + var context2 = MakeRequest("/foo", connection); SetTransport(context1, transportType); SetTransport(context2, transportType); @@ -383,12 +378,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task RequestToActiveConnectionIdKillsPreviousConnectionLongPolling() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context1 = MakeRequest("/foo", state); - var context2 = MakeRequest("/foo", state); + var context1 = MakeRequest("/foo", connection); + var context2 = MakeRequest("/foo", connection); var services = new ServiceCollection(); services.AddEndPoint(); @@ -402,7 +397,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests await request1; Assert.Equal(StatusCodes.Status204NoContent, context1.Response.StatusCode); - Assert.Equal(ConnectionState.ConnectionStatus.Active, state.Status); + Assert.Equal(DefaultConnectionContext.ConnectionStatus.Active, connection.Status); Assert.False(request2.IsCompleted); @@ -417,12 +412,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task RequestToDisposedConnectionIdReturns404(TransportType transportType) { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); - state.Status = ConnectionState.ConnectionStatus.Disposed; + var connection = manager.CreateConnection(); + connection.Status = DefaultConnectionContext.ConnectionStatus.Disposed; var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", connection); SetTransport(context, transportType); var services = new ServiceCollection(); @@ -441,11 +436,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task ConnectionStateSetToInactiveAfterPoll() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", connection); var services = new ServiceCollection(); services.AddEndPoint(); @@ -458,12 +453,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests var buffer = Encoding.UTF8.GetBytes("Hello World"); // Write to the transport so the poll yields - await state.Connection.Transport.Output.WriteAsync(new Message(buffer, MessageType.Text)); + await connection.Transport.Output.WriteAsync(new Message(buffer, MessageType.Text)); await task; - Assert.Equal(ConnectionState.ConnectionStatus.Inactive, state.Status); - Assert.Null(state.RequestId); + Assert.Equal(DefaultConnectionContext.ConnectionStatus.Inactive, connection.Status); + Assert.Null(connection.RequestId); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); } @@ -472,11 +467,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task BlockingConnectionWorksWithStreamingConnections() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", connection); SetTransport(context, TransportType.ServerSentEvents); var services = new ServiceCollection(); @@ -490,13 +485,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests var buffer = Encoding.UTF8.GetBytes("Hello World"); // Write to the application - await state.Application.Output.WriteAsync(new Message(buffer, MessageType.Text)); + await connection.Application.Output.WriteAsync(new Message(buffer, MessageType.Text)); await task; Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); - ConnectionState removed; - bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed); + bool exists = manager.TryGetConnection(connection.ConnectionId, out _); Assert.False(exists); } @@ -504,11 +498,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task BlockingConnectionWorksWithLongPollingConnection() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", connection); var services = new ServiceCollection(); services.AddEndPoint(); @@ -521,13 +515,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests var buffer = Encoding.UTF8.GetBytes("Hello World"); // Write to the application - await state.Application.Output.WriteAsync(new Message(buffer, MessageType.Text)); + await connection.Application.Output.WriteAsync(new Message(buffer, MessageType.Text)); await task; Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); - ConnectionState removed; - bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed); + bool exists = manager.TryGetConnection(connection.ConnectionId, out _); Assert.False(exists); } @@ -535,7 +528,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task AttemptingToPollWhileAlreadyPollingReplacesTheCurrentPoll() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); @@ -546,16 +539,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests var app = builder.Build(); var options = new HttpSocketOptions(); - var context1 = MakeRequest("/foo", state); + var context1 = MakeRequest("/foo", connection); var task1 = dispatcher.ExecuteAsync(context1, options, app); - var context2 = MakeRequest("/foo", state); + var context2 = MakeRequest("/foo", connection); var task2 = dispatcher.ExecuteAsync(context2, options, app); // Task 1 should finish when request 2 arrives await task1.OrTimeout(); // Send a message from the app to complete Task 2 - await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)); + await connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)); await task2.OrTimeout(); @@ -606,7 +599,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task UnauthorizedConnectionFailsToStartEndPoint() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); var context = new DefaultHttpContext(); var services = new ServiceCollection(); @@ -624,7 +617,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.Request.Method = "GET"; context.RequestServices = sp; var values = new Dictionary(); - values["id"] = state.Connection.ConnectionId; + values["id"] = connection.ConnectionId; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -644,7 +637,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task AuthenticatedUserWithoutPermissionCausesForbidden() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); var context = new DefaultHttpContext(); var services = new ServiceCollection(); @@ -662,7 +655,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.Request.Method = "GET"; context.RequestServices = sp; var values = new Dictionary(); - values["id"] = state.Connection.ConnectionId; + values["id"] = connection.ConnectionId; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -684,7 +677,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task AuthorizedConnectionCanConnectToEndPoint() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); var context = new DefaultHttpContext(); var services = new ServiceCollection(); @@ -705,7 +698,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.Request.Method = "GET"; context.RequestServices = sp; var values = new Dictionary(); - values["id"] = state.Connection.ConnectionId; + values["id"] = connection.ConnectionId; var qs = new QueryCollection(values); context.Request.Query = qs; context.Response.Body = new MemoryStream(); @@ -720,7 +713,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); var endPointTask = dispatcher.ExecuteAsync(context, options, app); - await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout(); + await connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout(); await endPointTask.OrTimeout(); @@ -733,7 +726,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task AuthorizedConnectionWithAcceptedSchemesCanConnectToEndPoint() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); var context = new DefaultHttpContext(); var services = new ServiceCollection(); @@ -755,7 +748,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.Request.Method = "GET"; context.RequestServices = sp; var values = new Dictionary(); - values["id"] = state.Connection.ConnectionId; + values["id"] = connection.ConnectionId; var qs = new QueryCollection(values); context.Request.Query = qs; context.Response.Body = new MemoryStream(); @@ -770,7 +763,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); var endPointTask = dispatcher.ExecuteAsync(context, options, app); - await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout(); + await connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout(); await endPointTask.OrTimeout(); @@ -782,7 +775,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public async Task AuthorizedConnectionWithRejectedSchemesFailsToConnectToEndPoint() { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); var context = new DefaultHttpContext(); var services = new ServiceCollection(); @@ -804,7 +797,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.Request.Method = "GET"; context.RequestServices = sp; var values = new Dictionary(); - values["id"] = state.Connection.ConnectionId; + values["id"] = connection.ConnectionId; var qs = new QueryCollection(values); context.Request.Query = qs; context.Response.Body = new MemoryStream(); @@ -881,7 +874,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests private static async Task CheckTransportSupported(TransportType supportedTransports, TransportType transportType, int status) { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); using (var strm = new MemoryStream()) { @@ -894,7 +887,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.Request.Path = "/foo"; context.Request.Method = "GET"; var values = new Dictionary(); - values["id"] = state.Connection.ConnectionId; + values["id"] = connection.ConnectionId; var qs = new QueryCollection(values); context.Request.Query = qs; @@ -919,11 +912,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests private static async Task> RunSendTest(string contentType, string encoded, string format) { var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); + var connection = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state, format); + var context = MakeRequest("/foo", connection, format); context.Request.Method = "POST"; context.Request.ContentType = contentType; @@ -942,7 +935,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app).OrTimeout(); } - while (state.Connection.Transport.Input.TryRead(out var message)) + while (connection.Transport.Input.TryRead(out var message)) { messages.Add(message); } @@ -950,13 +943,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests return messages; } - private static DefaultHttpContext MakeRequest(string path, ConnectionState state, string format = null) + private static DefaultHttpContext MakeRequest(string path, DefaultConnectionContext connection, string format = null) { var context = new DefaultHttpContext(); context.Request.Path = path; context.Request.Method = "GET"; var values = new Dictionary(); - values["id"] = state.Connection.ConnectionId; + values["id"] = connection.ConnectionId; if (format != null) { values["format"] = format;