Merge DefaultConnectionContext with ConnectionState (#514)

* Merge DefaultConnectionContext with ConnectionState
- Removed ConnectionState as a result
This commit is contained in:
David Fowler 2017-06-03 12:14:01 -10:00 committed by GitHub
parent cad9f2f671
commit b77f50a99a
9 changed files with 320 additions and 359 deletions

View File

@ -20,7 +20,7 @@ namespace ChatSample.Hubs
{
if (!Context.User.Identity.IsAuthenticated)
{
Context.Connection.Dispose();
Context.Connection.Transport.Dispose();
return;
}

View File

@ -10,7 +10,7 @@ using Microsoft.AspNetCore.Http.Features;
namespace Microsoft.AspNetCore.Sockets
{
public abstract class ConnectionContext : IDisposable
public abstract class ConnectionContext
{
public abstract string ConnectionId { get; }
@ -23,11 +23,5 @@ namespace Microsoft.AspNetCore.Sockets
// TEMPORARY
public abstract IChannelConnection<Message> Transport { get; set; }
// TEMPORARY
public void Dispose()
{
Transport?.Dispose();
}
}
}

View File

@ -3,18 +3,40 @@
using System;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http.Features;
namespace Microsoft.AspNetCore.Sockets
{
public class DefaultConnectionContext : ConnectionContext
{
public DefaultConnectionContext(string id, IChannelConnection<Message> transport)
// This tcs exists so that multiple calls to DisposeAsync all wait asynchronously
// on the same task
private TaskCompletionSource<object> _disposeTcs = new TaskCompletionSource<object>();
public DefaultConnectionContext(string id, IChannelConnection<Message> transport, IChannelConnection<Message> application)
{
Transport = transport;
Application = application;
ConnectionId = id;
}
public CancellationTokenSource Cancellation { get; set; }
public SemaphoreSlim Lock { get; } = new SemaphoreSlim(1, 1);
// REVIEW: This should only be on the Http implementation
public string RequestId { get; set; }
public Task TransportTask { get; set; }
public Task ApplicationTask { get; set; }
public DateTime LastSeenUtc { get; set; }
public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive;
public override string ConnectionId { get; }
public override IFeatureCollection Features { get; } = new FeatureCollection();
@ -23,6 +45,86 @@ namespace Microsoft.AspNetCore.Sockets
public override ConnectionMetadata Metadata { get; } = new ConnectionMetadata();
public IChannelConnection<Message> Application { get; }
public override IChannelConnection<Message> Transport { get; set; }
public async Task DisposeAsync()
{
Task disposeTask = Task.CompletedTask;
try
{
await Lock.WaitAsync();
if (Status == ConnectionStatus.Disposed)
{
disposeTask = _disposeTcs.Task;
}
else
{
Status = ConnectionStatus.Disposed;
RequestId = null;
// If the application task is faulted, propagate the error to the transport
if (ApplicationTask?.IsFaulted == true)
{
Transport.Output.TryComplete(ApplicationTask.Exception.InnerException);
}
// If the transport task is faulted, propagate the error to the application
if (TransportTask?.IsFaulted == true)
{
Application.Output.TryComplete(TransportTask.Exception.InnerException);
}
Transport.Dispose();
Application.Dispose();
var applicationTask = ApplicationTask ?? Task.CompletedTask;
var transportTask = TransportTask ?? Task.CompletedTask;
disposeTask = WaitOnTasks(applicationTask, transportTask);
}
}
finally
{
Lock.Release();
}
await disposeTask;
}
private async Task WaitOnTasks(Task applicationTask, Task transportTask)
{
try
{
await Task.WhenAll(applicationTask, transportTask);
// Notify all waiters that we're done disposing
_disposeTcs.TrySetResult(null);
}
catch (OperationCanceledException)
{
_disposeTcs.TrySetCanceled();
throw;
}
catch (Exception ex)
{
_disposeTcs.TrySetException(ex);
throw;
}
}
public enum ConnectionStatus
{
Inactive,
Active,
Disposed
}
}
}

View File

@ -70,57 +70,57 @@ namespace Microsoft.AspNetCore.Sockets
if (headers.Accept?.Contains(new Net.Http.Headers.MediaTypeHeaderValue("text/event-stream")) == true)
{
// Connection must already exist
var state = await GetConnectionAsync(context);
if (state == null)
var connection = await GetConnectionAsync(context);
if (connection == null)
{
// No such connection, GetConnection already set the response status code
return;
}
if (!await EnsureConnectionStateAsync(state, context, TransportType.ServerSentEvents, supportedTransports))
if (!await EnsureConnectionStateAsync(connection, context, TransportType.ServerSentEvents, supportedTransports))
{
// Bad connection state. It's already set the response status code.
return;
}
// We only need to provide the Input channel since writing to the application is handled through /send.
var sse = new ServerSentEventsTransport(state.Application.Input, _loggerFactory);
var sse = new ServerSentEventsTransport(connection.Application.Input, _loggerFactory);
await DoPersistentConnection(socketDelegate, sse, context, state);
await DoPersistentConnection(socketDelegate, sse, context, connection);
}
else if (context.WebSockets.IsWebSocketRequest)
{
// Connection can be established lazily
var state = await GetOrCreateConnectionAsync(context);
if (state == null)
var connection = await GetOrCreateConnectionAsync(context);
if (connection == null)
{
// No such connection, GetOrCreateConnection already set the response status code
return;
}
if (!await EnsureConnectionStateAsync(state, context, TransportType.WebSockets, supportedTransports))
if (!await EnsureConnectionStateAsync(connection, context, TransportType.WebSockets, supportedTransports))
{
// Bad connection state. It's already set the response status code.
return;
}
var ws = new WebSocketsTransport(options.WebSockets, state.Application, _loggerFactory);
var ws = new WebSocketsTransport(options.WebSockets, connection.Application, _loggerFactory);
await DoPersistentConnection(socketDelegate, ws, context, state);
await DoPersistentConnection(socketDelegate, ws, context, connection);
}
else
{
// GET /{path} maps to long polling
// Connection must already exist
var state = await GetConnectionAsync(context);
if (state == null)
var connection = await GetConnectionAsync(context);
if (connection == null)
{
// No such connection, GetConnection already set the response status code
return;
}
if (!await EnsureConnectionStateAsync(state, context, TransportType.LongPolling, supportedTransports))
if (!await EnsureConnectionStateAsync(connection, context, TransportType.LongPolling, supportedTransports))
{
// Bad connection state. It's already set the response status code.
return;
@ -128,94 +128,94 @@ namespace Microsoft.AspNetCore.Sockets
try
{
await state.Lock.WaitAsync();
await connection.Lock.WaitAsync();
if (state.Status == ConnectionState.ConnectionStatus.Disposed)
if (connection.Status == DefaultConnectionContext.ConnectionStatus.Disposed)
{
_logger.LogDebug("Connection {connectionId} was disposed,", state.Connection.ConnectionId);
_logger.LogDebug("Connection {connectionId} was disposed,", connection.ConnectionId);
// The connection was disposed
context.Response.StatusCode = StatusCodes.Status404NotFound;
return;
}
if (state.Status == ConnectionState.ConnectionStatus.Active)
if (connection.Status == DefaultConnectionContext.ConnectionStatus.Active)
{
_logger.LogDebug("Connection {connectionId} is already active via {requestId}. Cancelling previous request.", state.Connection.ConnectionId, state.RequestId);
_logger.LogDebug("Connection {connectionId} is already active via {requestId}. Cancelling previous request.", connection.ConnectionId, connection.RequestId);
using (state.Cancellation)
using (connection.Cancellation)
{
// Cancel the previous request
state.Cancellation.Cancel();
connection.Cancellation.Cancel();
try
{
// Wait for the previous request to drain
await state.TransportTask;
await connection.TransportTask;
}
catch (OperationCanceledException)
{
// Should be a cancelled task
}
_logger.LogDebug("Previous poll cancelled for {connectionId} on {requestId}.", state.Connection.ConnectionId, state.RequestId);
_logger.LogDebug("Previous poll cancelled for {connectionId} on {requestId}.", connection.ConnectionId, connection.RequestId);
}
}
// Mark the request identifier
state.RequestId = context.TraceIdentifier;
connection.RequestId = context.TraceIdentifier;
// Mark the connection as active
state.Status = ConnectionState.ConnectionStatus.Active;
connection.Status = DefaultConnectionContext.ConnectionStatus.Active;
// Raise OnConnected for new connections only since polls happen all the time
if (state.ApplicationTask == null)
if (connection.ApplicationTask == null)
{
_logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId);
_logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", connection.ConnectionId, connection.RequestId);
state.Connection.Metadata[ConnectionMetadataNames.Transport] = TransportType.LongPolling;
connection.Metadata[ConnectionMetadataNames.Transport] = TransportType.LongPolling;
state.ApplicationTask = ExecuteApplication(socketDelegate, state.Connection);
connection.ApplicationTask = ExecuteApplication(socketDelegate, connection);
}
else
{
_logger.LogDebug("Resuming existing connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId);
_logger.LogDebug("Resuming existing connection: {connectionId} on {requestId}", connection.ConnectionId, connection.RequestId);
}
var longPolling = new LongPollingTransport(state.Application.Input, _loggerFactory);
var longPolling = new LongPollingTransport(connection.Application.Input, _loggerFactory);
state.Cancellation = new CancellationTokenSource();
connection.Cancellation = new CancellationTokenSource();
// REVIEW: Performance of this isn't great as this does a bunch of per request allocations
var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(state.Cancellation.Token, context.RequestAborted);
var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(connection.Cancellation.Token, context.RequestAborted);
// Start the transport
state.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token);
connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token);
}
finally
{
state.Lock.Release();
connection.Lock.Release();
}
var resultTask = await Task.WhenAny(state.ApplicationTask, state.TransportTask);
var resultTask = await Task.WhenAny(connection.ApplicationTask, connection.TransportTask);
var pollAgain = true;
// If the application ended before the transport task then we need to potentially need to end the
// connection
if (resultTask == state.ApplicationTask)
if (resultTask == connection.ApplicationTask)
{
// Complete the transport (notifying it of the application error if there is one)
state.Connection.Transport.Output.TryComplete(state.ApplicationTask.Exception);
connection.Transport.Output.TryComplete(connection.ApplicationTask.Exception);
// Wait for the transport to run
await state.TransportTask;
await connection.TransportTask;
// If the status code is a 204 it means we didn't write anything
if (context.Response.StatusCode == StatusCodes.Status204NoContent)
{
// We should be able to safely dispose because there's no more data being written
await _manager.DisposeAndRemoveAsync(state);
await _manager.DisposeAndRemoveAsync(connection);
// Don't poll again if we've removed the connection completely
pollAgain = false;
@ -232,53 +232,53 @@ namespace Microsoft.AspNetCore.Sockets
// Otherwise, we update the state to inactive again and wait for the next poll
try
{
await state.Lock.WaitAsync();
await connection.Lock.WaitAsync();
if (state.Status == ConnectionState.ConnectionStatus.Active)
if (connection.Status == DefaultConnectionContext.ConnectionStatus.Active)
{
// Mark the connection as inactive
state.LastSeenUtc = DateTime.UtcNow;
connection.LastSeenUtc = DateTime.UtcNow;
state.Status = ConnectionState.ConnectionStatus.Inactive;
connection.Status = DefaultConnectionContext.ConnectionStatus.Inactive;
state.RequestId = null;
connection.RequestId = null;
// Dispose the cancellation token
state.Cancellation.Dispose();
connection.Cancellation.Dispose();
state.Cancellation = null;
connection.Cancellation = null;
}
}
finally
{
state.Lock.Release();
connection.Lock.Release();
}
}
}
}
private ConnectionState CreateConnection(HttpContext context)
private DefaultConnectionContext CreateConnection(HttpContext context)
{
var state = _manager.CreateConnection();
var connection = _manager.CreateConnection();
var format = (string)context.Request.Query[ConnectionMetadataNames.Format];
state.Connection.User = context.User;
state.Connection.Metadata[ConnectionMetadataNames.HttpContext] = context;
state.Connection.Metadata[ConnectionMetadataNames.Format] = string.IsNullOrEmpty(format) ? "json" : format;
return state;
connection.User = context.User;
connection.Metadata[ConnectionMetadataNames.HttpContext] = context;
connection.Metadata[ConnectionMetadataNames.Format] = string.IsNullOrEmpty(format) ? "json" : format;
return connection;
}
private async Task DoPersistentConnection(SocketDelegate socketDelegate,
IHttpTransport transport,
HttpContext context,
ConnectionState state)
DefaultConnectionContext connection)
{
try
{
await state.Lock.WaitAsync();
await connection.Lock.WaitAsync();
if (state.Status == ConnectionState.ConnectionStatus.Disposed)
if (connection.Status == DefaultConnectionContext.ConnectionStatus.Disposed)
{
_logger.LogDebug("Connection {connectionId} was disposed,", state.Connection.ConnectionId);
_logger.LogDebug("Connection {connectionId} was disposed,", connection.ConnectionId);
// Connection was disposed
context.Response.StatusCode = StatusCodes.Status404NotFound;
@ -286,9 +286,9 @@ namespace Microsoft.AspNetCore.Sockets
}
// There's already an active request
if (state.Status == ConnectionState.ConnectionStatus.Active)
if (connection.Status == DefaultConnectionContext.ConnectionStatus.Active)
{
_logger.LogDebug("Connection {connectionId} is already active via {requestId}.", state.Connection.ConnectionId, state.RequestId);
_logger.LogDebug("Connection {connectionId} is already active via {requestId}.", connection.ConnectionId, connection.RequestId);
// Reject the request with a 409 conflict
context.Response.StatusCode = StatusCodes.Status409Conflict;
@ -296,26 +296,26 @@ namespace Microsoft.AspNetCore.Sockets
}
// Mark the connection as active
state.Status = ConnectionState.ConnectionStatus.Active;
connection.Status = DefaultConnectionContext.ConnectionStatus.Active;
// Store the request identifier
state.RequestId = context.TraceIdentifier;
connection.RequestId = context.TraceIdentifier;
// Call into the end point passing the connection
state.ApplicationTask = ExecuteApplication(socketDelegate, state.Connection);
connection.ApplicationTask = ExecuteApplication(socketDelegate, connection);
// Start the transport
state.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted);
connection.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted);
}
finally
{
state.Lock.Release();
connection.Lock.Release();
}
// Wait for any of them to end
await Task.WhenAny(state.ApplicationTask, state.TransportTask);
await Task.WhenAny(connection.ApplicationTask, connection.TransportTask);
await _manager.DisposeAndRemoveAsync(state);
await _manager.DisposeAndRemoveAsync(connection);
}
private async Task ExecuteApplication(SocketDelegate socketDelegate, ConnectionContext connection)
@ -336,10 +336,10 @@ namespace Microsoft.AspNetCore.Sockets
context.Response.ContentType = "text/plain";
// Establish the connection
var state = CreateConnection(context);
var connection = CreateConnection(context);
// Get the bytes for the connection id
var connectionIdBuffer = Encoding.UTF8.GetBytes(state.Connection.ConnectionId);
var connectionIdBuffer = Encoding.UTF8.GetBytes(connection.ConnectionId);
// Write it out to the response with the right content length
context.Response.ContentLength = connectionIdBuffer.Length;
@ -348,8 +348,8 @@ namespace Microsoft.AspNetCore.Sockets
private async Task ProcessSend(HttpContext context)
{
var state = await GetConnectionAsync(context);
if (state == null)
var connection = await GetConnectionAsync(context);
if (connection == null)
{
// No such connection, GetConnection already set the response status code
return;
@ -388,9 +388,9 @@ namespace Microsoft.AspNetCore.Sockets
_logger.LogDebug("Received batch of {count} message(s)", messages.Count);
foreach (var message in messages)
{
while (!state.Application.Output.TryWrite(message))
while (!connection.Application.Output.TryWrite(message))
{
if (!await state.Application.Output.WaitToWriteAsync())
if (!await connection.Application.Output.WaitToWriteAsync())
{
return;
}
@ -398,7 +398,7 @@ namespace Microsoft.AspNetCore.Sockets
}
}
private async Task<bool> EnsureConnectionStateAsync(ConnectionState connectionState, HttpContext context, TransportType transportType, TransportType supportedTransports)
private async Task<bool> EnsureConnectionStateAsync(DefaultConnectionContext connection, HttpContext context, TransportType transportType, TransportType supportedTransports)
{
if ((supportedTransports & transportType) == 0)
{
@ -407,13 +407,13 @@ namespace Microsoft.AspNetCore.Sockets
return false;
}
connectionState.Connection.User = context.User;
connection.User = context.User;
var transport = connectionState.Connection.Metadata.Get<TransportType?>(ConnectionMetadataNames.Transport);
var transport = connection.Metadata.Get<TransportType?>(ConnectionMetadataNames.Transport);
if (transport == null)
{
connectionState.Connection.Metadata[ConnectionMetadataNames.Transport] = transportType;
connection.Metadata[ConnectionMetadataNames.Transport] = transportType;
}
else if (transport != transportType)
{
@ -424,10 +424,9 @@ namespace Microsoft.AspNetCore.Sockets
return true;
}
private async Task<ConnectionState> GetConnectionAsync(HttpContext context)
private async Task<DefaultConnectionContext> GetConnectionAsync(HttpContext context)
{
var connectionId = context.Request.Query["id"];
ConnectionState connectionState;
if (StringValues.IsNullOrEmpty(connectionId))
{
@ -437,7 +436,7 @@ namespace Microsoft.AspNetCore.Sockets
return null;
}
if (!_manager.TryGetConnection(connectionId, out connectionState))
if (!_manager.TryGetConnection(connectionId, out var connection))
{
// No connection with that ID: Not Found
context.Response.StatusCode = StatusCodes.Status404NotFound;
@ -445,20 +444,20 @@ namespace Microsoft.AspNetCore.Sockets
return null;
}
return connectionState;
return connection;
}
private async Task<ConnectionState> GetOrCreateConnectionAsync(HttpContext context)
private async Task<DefaultConnectionContext> GetOrCreateConnectionAsync(HttpContext context)
{
var connectionId = context.Request.Query["id"];
ConnectionState connectionState;
DefaultConnectionContext connection;
// There's no connection id so this is a brand new connection
if (StringValues.IsNullOrEmpty(connectionId))
{
connectionState = CreateConnection(context);
connection = CreateConnection(context);
}
else if (!_manager.TryGetConnection(connectionId, out connectionState))
else if (!_manager.TryGetConnection(connectionId, out connection))
{
// No connection with that ID: Not Found
context.Response.StatusCode = StatusCodes.Status404NotFound;
@ -466,7 +465,7 @@ namespace Microsoft.AspNetCore.Sockets
return null;
}
return connectionState;
return connection;
}
private List<Message> ParseSendBatch(ref BytesReader payload, MessageFormat messageFormat)

View File

@ -14,7 +14,7 @@ namespace Microsoft.AspNetCore.Sockets
{
public class ConnectionManager
{
private readonly ConcurrentDictionary<string, ConnectionState> _connections = new ConcurrentDictionary<string, ConnectionState>();
private readonly ConcurrentDictionary<string, DefaultConnectionContext> _connections = new ConcurrentDictionary<string, DefaultConnectionContext>();
private Timer _timer;
private readonly ILogger<ConnectionManager> _logger;
private object _executionLock = new object();
@ -41,12 +41,12 @@ namespace Microsoft.AspNetCore.Sockets
}
}
public bool TryGetConnection(string id, out ConnectionState state)
public bool TryGetConnection(string id, out DefaultConnectionContext connection)
{
return _connections.TryGetValue(id, out state);
return _connections.TryGetValue(id, out connection);
}
public ConnectionState CreateConnection()
public DefaultConnectionContext CreateConnection()
{
var id = MakeNewConnectionId();
@ -56,12 +56,10 @@ namespace Microsoft.AspNetCore.Sockets
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
var state = new ConnectionState(
new DefaultConnectionContext(id, applicationSide),
transportSide);
_connections.TryAdd(id, state);
return state;
var connection = new DefaultConnectionContext(id, applicationSide, transportSide);
_connections.TryAdd(id, connection);
return connection;
}
public void RemoveConnection(string id)
@ -108,7 +106,7 @@ namespace Microsoft.AspNetCore.Sockets
// Scan the registered connections looking for ones that have timed out
foreach (var c in _connections)
{
var status = ConnectionState.ConnectionStatus.Inactive;
var status = DefaultConnectionContext.ConnectionStatus.Inactive;
var lastSeenUtc = DateTimeOffset.UtcNow;
try
@ -126,7 +124,7 @@ namespace Microsoft.AspNetCore.Sockets
}
// Once the decision has been made to to dispose we don't check the status again
if (status == ConnectionState.ConnectionStatus.Inactive && (DateTimeOffset.UtcNow - lastSeenUtc).TotalSeconds > 5)
if (status == DefaultConnectionContext.ConnectionStatus.Inactive && (DateTimeOffset.UtcNow - lastSeenUtc).TotalSeconds > 5)
{
var ignore = DisposeAndRemoveAsync(c.Value);
}
@ -167,21 +165,21 @@ namespace Microsoft.AspNetCore.Sockets
}
}
public async Task DisposeAndRemoveAsync(ConnectionState state)
public async Task DisposeAndRemoveAsync(DefaultConnectionContext connection)
{
try
{
await state.DisposeAsync();
await connection.DisposeAsync();
}
catch (Exception ex)
{
_logger.LogError(0, ex, "Failed disposing connection {connectionId}", state.Connection.ConnectionId);
_logger.LogError(0, ex, "Failed disposing connection {connectionId}", connection.ConnectionId);
}
finally
{
// Remove it from the list after disposal so that's it's easy to see
// connections that might be in a hung state via the connections list
RemoveConnection(state.Connection.ConnectionId);
RemoveConnection(connection.ConnectionId);
}
}
}

View File

@ -1,116 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Internal;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public class ConnectionState
{
// This tcs exists so that multiple calls to DisposeAsync all wait asynchronously
// on the same task
private TaskCompletionSource<object> _disposeTcs = new TaskCompletionSource<object>();
public ConnectionContext Connection { get; set; }
public IChannelConnection<Message> Application { get; }
public CancellationTokenSource Cancellation { get; set; }
public SemaphoreSlim Lock { get; } = new SemaphoreSlim(1, 1);
public string RequestId { get; set; }
public Task TransportTask { get; set; }
public Task ApplicationTask { get; set; }
public DateTime LastSeenUtc { get; set; }
public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive;
public ConnectionState(ConnectionContext connection, IChannelConnection<Message> application)
{
Connection = connection;
Application = application;
LastSeenUtc = DateTime.UtcNow;
}
public async Task DisposeAsync()
{
Task disposeTask = Task.CompletedTask;
try
{
await Lock.WaitAsync();
if (Status == ConnectionStatus.Disposed)
{
disposeTask = _disposeTcs.Task;
}
else
{
Status = ConnectionStatus.Disposed;
RequestId = null;
// If the application task is faulted, propagate the error to the transport
if (ApplicationTask?.IsFaulted == true)
{
Connection.Transport.Output.TryComplete(ApplicationTask.Exception.InnerException);
}
// If the transport task is faulted, propagate the error to the application
if (TransportTask?.IsFaulted == true)
{
Application.Output.TryComplete(TransportTask.Exception.InnerException);
}
Connection.Dispose();
Application.Dispose();
var applicationTask = ApplicationTask ?? Task.CompletedTask;
var transportTask = TransportTask ?? Task.CompletedTask;
disposeTask = WaitOnTasks(applicationTask, transportTask);
}
}
finally
{
Lock.Release();
}
await disposeTask;
}
private async Task WaitOnTasks(Task applicationTask, Task transportTask)
{
try
{
await Task.WhenAll(applicationTask, transportTask);
// Notify all waiters that we're done disposing
_disposeTcs.TrySetResult(null);
}
catch (OperationCanceledException)
{
_disposeTcs.TrySetCanceled();
throw;
}
catch (Exception ex)
{
_disposeTcs.TrySetException(ex);
throw;
}
}
public enum ConnectionStatus
{
Inactive,
Active,
Disposed
}
}
}

View File

@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
private IHubProtocol _protocol;
private CancellationTokenSource _cts;
public ConnectionContext Connection { get; }
public DefaultConnectionContext Connection { get; }
public IChannelConnection<Message> Application { get; }
public Task Connected => Connection.Metadata.Get<TaskCompletionSource<bool>>("ConnectedTask").Task;
@ -33,7 +33,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Application = ChannelConnection.Create<Message>(input: applicationToTransport, output: transportToApplication);
var transport = ChannelConnection.Create<Message>(input: transportToApplication, output: applicationToTransport);
Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), transport);
Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), transport, Application);
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) }));
Connection.Metadata["ConnectedTask"] = new TaskCompletionSource<bool>();
@ -147,7 +147,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public void Dispose()
{
_cts.Cancel();
Connection.Dispose();
Connection.Transport.Dispose();
}
private static string GetInvocationId()

View File

@ -4,7 +4,6 @@
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.Logging;
using Xunit;
@ -16,104 +15,97 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public void NewConnectionsHaveConnectionId()
{
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
var connection = connectionManager.CreateConnection();
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.Equal(ConnectionState.ConnectionStatus.Inactive, state.Status);
Assert.Null(state.ApplicationTask);
Assert.Null(state.TransportTask);
Assert.Null(state.Cancellation);
Assert.Null(state.RequestId);
Assert.NotNull(state.Connection.Transport);
Assert.NotNull(connection.ConnectionId);
Assert.Equal(DefaultConnectionContext.ConnectionStatus.Inactive, connection.Status);
Assert.Null(connection.ApplicationTask);
Assert.Null(connection.TransportTask);
Assert.Null(connection.Cancellation);
Assert.Null(connection.RequestId);
Assert.NotNull(connection.Transport);
}
[Fact]
public void NewConnectionsCanBeRetrieved()
{
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
var connection = connectionManager.CreateConnection();
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.NotNull(connection.ConnectionId);
ConnectionState newState;
Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
Assert.Same(newState, state);
Assert.True(connectionManager.TryGetConnection(connection.ConnectionId, out var newConnection));
Assert.Same(newConnection, connection);
}
[Fact]
public void AddNewConnection()
{
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
var connection = connectionManager.CreateConnection();
var transport = state.Connection.Transport;
var transport = connection.Transport;
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.NotNull(connection.ConnectionId);
Assert.NotNull(transport);
ConnectionState newState;
Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
Assert.Same(newState, state);
Assert.Same(transport, newState.Connection.Transport);
Assert.True(connectionManager.TryGetConnection(connection.ConnectionId, out var newConnection));
Assert.Same(newConnection, connection);
Assert.Same(transport, newConnection.Transport);
}
[Fact]
public void RemoveConnection()
{
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
var connection = connectionManager.CreateConnection();
var transport = state.Connection.Transport;
var transport = connection.Transport;
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.NotNull(connection.ConnectionId);
Assert.NotNull(transport);
ConnectionState newState;
Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
Assert.Same(newState, state);
Assert.Same(transport, newState.Connection.Transport);
Assert.True(connectionManager.TryGetConnection(connection.ConnectionId, out var newConnection));
Assert.Same(newConnection, connection);
Assert.Same(transport, newConnection.Transport);
connectionManager.RemoveConnection(state.Connection.ConnectionId);
Assert.False(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
connectionManager.RemoveConnection(connection.ConnectionId);
Assert.False(connectionManager.TryGetConnection(connection.ConnectionId, out newConnection));
}
[Fact]
public async Task CloseConnectionsEndsAllPendingConnections()
{
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
var connection = connectionManager.CreateConnection();
state.ApplicationTask = Task.Run(async () =>
connection.ApplicationTask = Task.Run(async () =>
{
Assert.False(await state.Connection.Transport.Input.WaitToReadAsync());
Assert.False(await connection.Transport.Input.WaitToReadAsync());
});
state.TransportTask = Task.Run(async () =>
connection.TransportTask = Task.Run(async () =>
{
Assert.False(await state.Application.Input.WaitToReadAsync());
Assert.False(await connection.Application.Input.WaitToReadAsync());
});
connectionManager.CloseConnections();
await state.DisposeAsync();
await connection.DisposeAsync();
}
[Fact]
public async Task DisposingConnectionMultipleTimesWaitsOnConnectionClose()
{
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
var connection = connectionManager.CreateConnection();
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
state.ApplicationTask = tcs.Task;
state.TransportTask = tcs.Task;
connection.ApplicationTask = tcs.Task;
connection.TransportTask = tcs.Task;
var firstTask = state.DisposeAsync();
var secondTask = state.DisposeAsync();
var firstTask = connection.DisposeAsync();
var secondTask = connection.DisposeAsync();
Assert.False(firstTask.IsCompleted);
Assert.False(secondTask.IsCompleted);
@ -126,14 +118,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task DisposingConnectionMultipleGetsExceptionFromTransportOrApp()
{
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
var connection = connectionManager.CreateConnection();
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
state.ApplicationTask = tcs.Task;
state.TransportTask = tcs.Task;
connection.ApplicationTask = tcs.Task;
connection.TransportTask = tcs.Task;
var firstTask = state.DisposeAsync();
var secondTask = state.DisposeAsync();
var firstTask = connection.DisposeAsync();
var secondTask = connection.DisposeAsync();
Assert.False(firstTask.IsCompleted);
Assert.False(secondTask.IsCompleted);
@ -150,14 +142,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task DisposingConnectionMultipleGetsCancellation()
{
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
var connection = connectionManager.CreateConnection();
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
state.ApplicationTask = tcs.Task;
state.TransportTask = tcs.Task;
connection.ApplicationTask = tcs.Task;
connection.TransportTask = tcs.Task;
var firstTask = state.DisposeAsync();
var secondTask = state.DisposeAsync();
var firstTask = connection.DisposeAsync();
var secondTask = connection.DisposeAsync();
Assert.False(firstTask.IsCompleted);
Assert.False(secondTask.IsCompleted);
@ -171,21 +163,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task DisposeInactiveConnection()
{
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();;
var connection = connectionManager.CreateConnection();;
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.NotNull(state.Connection.Transport);
Assert.NotNull(connection.ConnectionId);
Assert.NotNull(connection.Transport);
await state.DisposeAsync();
Assert.Equal(ConnectionState.ConnectionStatus.Disposed, state.Status);
await connection.DisposeAsync();
Assert.Equal(DefaultConnectionContext.ConnectionStatus.Disposed, connection.Status);
}
[Fact]
public void ScanAfterDisposeNoops()
{
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
var connection = connectionManager.CreateConnection();
connectionManager.CloseConnections();

View File

@ -14,7 +14,6 @@ using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Http.Internal;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
@ -48,9 +47,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var id = Encoding.UTF8.GetString(ms.ToArray());
ConnectionState state;
Assert.True(manager.TryGetConnection(id, out state));
Assert.Equal(id, state.Connection.ConnectionId);
Assert.True(manager.TryGetConnection(id, out var connection));
Assert.Equal(id, connection.ConnectionId);
}
[Theory]
@ -182,7 +180,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task SendRequestsWithInvalidContentTypeAreRejected()
{
var manager = CreateConnectionManager();
var connectionState = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
using (var strm = new MemoryStream())
{
@ -192,7 +190,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
services.AddEndPoint<TestEndPoint>();
context.Request.Path = "/foo";
context.Request.Method = "POST";
context.Request.QueryString = new QueryString($"?id={connectionState.Connection.ConnectionId}");
context.Request.QueryString = new QueryString($"?id={connection.ConnectionId}");
context.Request.ContentType = "text/plain";
context.Response.Body = strm;
@ -245,11 +243,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task CompletedEndPointEndsConnection()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest("/foo", state);
var context = MakeRequest("/foo", connection);
SetTransport(context, TransportType.ServerSentEvents);
var services = new ServiceCollection();
@ -261,8 +259,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
ConnectionState removed;
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
bool exists = manager.TryGetConnection(connection.ConnectionId, out _);
Assert.False(exists);
}
@ -270,10 +267,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task SynchronusExceptionEndsConnection()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest("/foo", state);
var context = MakeRequest("/foo", connection);
SetTransport(context, TransportType.ServerSentEvents);
var services = new ServiceCollection();
@ -285,8 +282,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
ConnectionState removed;
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
bool exists = manager.TryGetConnection(connection.ConnectionId, out _);
Assert.False(exists);
}
@ -294,11 +290,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task CompletedEndPointEndsLongPollingConnection()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest("/foo", state);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddEndPoint<ImmediatelyCompleteEndPoint>();
@ -309,8 +305,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
ConnectionState removed;
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
bool exists = manager.TryGetConnection(connection.ConnectionId, out _);
Assert.False(exists);
}
@ -318,11 +313,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task WebSocketTransportTimesOutWhenCloseFrameNotReceived()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest("/foo", state);
var context = MakeRequest("/foo", connection);
SetTransport(context, TransportType.WebSockets);
var services = new ServiceCollection();
@ -344,12 +339,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task RequestToActiveConnectionId409ForStreamingTransports(TransportType transportType)
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context1 = MakeRequest("/foo", state);
var context2 = MakeRequest("/foo", state);
var context1 = MakeRequest("/foo", connection);
var context2 = MakeRequest("/foo", connection);
SetTransport(context1, transportType);
SetTransport(context2, transportType);
@ -383,12 +378,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task RequestToActiveConnectionIdKillsPreviousConnectionLongPolling()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context1 = MakeRequest("/foo", state);
var context2 = MakeRequest("/foo", state);
var context1 = MakeRequest("/foo", connection);
var context2 = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
@ -402,7 +397,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
await request1;
Assert.Equal(StatusCodes.Status204NoContent, context1.Response.StatusCode);
Assert.Equal(ConnectionState.ConnectionStatus.Active, state.Status);
Assert.Equal(DefaultConnectionContext.ConnectionStatus.Active, connection.Status);
Assert.False(request2.IsCompleted);
@ -417,12 +412,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task RequestToDisposedConnectionIdReturns404(TransportType transportType)
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
state.Status = ConnectionState.ConnectionStatus.Disposed;
var connection = manager.CreateConnection();
connection.Status = DefaultConnectionContext.ConnectionStatus.Disposed;
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest("/foo", state);
var context = MakeRequest("/foo", connection);
SetTransport(context, transportType);
var services = new ServiceCollection();
@ -441,11 +436,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task ConnectionStateSetToInactiveAfterPoll()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest("/foo", state);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
@ -458,12 +453,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var buffer = Encoding.UTF8.GetBytes("Hello World");
// Write to the transport so the poll yields
await state.Connection.Transport.Output.WriteAsync(new Message(buffer, MessageType.Text));
await connection.Transport.Output.WriteAsync(new Message(buffer, MessageType.Text));
await task;
Assert.Equal(ConnectionState.ConnectionStatus.Inactive, state.Status);
Assert.Null(state.RequestId);
Assert.Equal(DefaultConnectionContext.ConnectionStatus.Inactive, connection.Status);
Assert.Null(connection.RequestId);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
}
@ -472,11 +467,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task BlockingConnectionWorksWithStreamingConnections()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest("/foo", state);
var context = MakeRequest("/foo", connection);
SetTransport(context, TransportType.ServerSentEvents);
var services = new ServiceCollection();
@ -490,13 +485,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var buffer = Encoding.UTF8.GetBytes("Hello World");
// Write to the application
await state.Application.Output.WriteAsync(new Message(buffer, MessageType.Text));
await connection.Application.Output.WriteAsync(new Message(buffer, MessageType.Text));
await task;
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
ConnectionState removed;
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
bool exists = manager.TryGetConnection(connection.ConnectionId, out _);
Assert.False(exists);
}
@ -504,11 +498,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task BlockingConnectionWorksWithLongPollingConnection()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest("/foo", state);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddEndPoint<BlockingEndPoint>();
@ -521,13 +515,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var buffer = Encoding.UTF8.GetBytes("Hello World");
// Write to the application
await state.Application.Output.WriteAsync(new Message(buffer, MessageType.Text));
await connection.Application.Output.WriteAsync(new Message(buffer, MessageType.Text));
await task;
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
ConnectionState removed;
bool exists = manager.TryGetConnection(state.Connection.ConnectionId, out removed);
bool exists = manager.TryGetConnection(connection.ConnectionId, out _);
Assert.False(exists);
}
@ -535,7 +528,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task AttemptingToPollWhileAlreadyPollingReplacesTheCurrentPoll()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
@ -546,16 +539,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var app = builder.Build();
var options = new HttpSocketOptions();
var context1 = MakeRequest("/foo", state);
var context1 = MakeRequest("/foo", connection);
var task1 = dispatcher.ExecuteAsync(context1, options, app);
var context2 = MakeRequest("/foo", state);
var context2 = MakeRequest("/foo", connection);
var task2 = dispatcher.ExecuteAsync(context2, options, app);
// Task 1 should finish when request 2 arrives
await task1.OrTimeout();
// Send a message from the app to complete Task 2
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text));
await connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text));
await task2.OrTimeout();
@ -606,7 +599,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task UnauthorizedConnectionFailsToStartEndPoint()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
@ -624,7 +617,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
@ -644,7 +637,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task AuthenticatedUserWithoutPermissionCausesForbidden()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
@ -662,7 +655,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
@ -684,7 +677,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task AuthorizedConnectionCanConnectToEndPoint()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
@ -705,7 +698,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
@ -720,7 +713,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var endPointTask = dispatcher.ExecuteAsync(context, options, app);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
@ -733,7 +726,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task AuthorizedConnectionWithAcceptedSchemesCanConnectToEndPoint()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
@ -755,7 +748,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
@ -770,7 +763,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var endPointTask = dispatcher.ExecuteAsync(context, options, app);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
@ -782,7 +775,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task AuthorizedConnectionWithRejectedSchemesFailsToConnectToEndPoint()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
@ -804,7 +797,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
@ -881,7 +874,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
private static async Task CheckTransportSupported(TransportType supportedTransports, TransportType transportType, int status)
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
using (var strm = new MemoryStream())
{
@ -894,7 +887,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
@ -919,11 +912,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests
private static async Task<List<Message>> RunSendTest(string contentType, string encoded, string format)
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest("/foo", state, format);
var context = MakeRequest("/foo", connection, format);
context.Request.Method = "POST";
context.Request.ContentType = contentType;
@ -942,7 +935,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app).OrTimeout();
}
while (state.Connection.Transport.Input.TryRead(out var message))
while (connection.Transport.Input.TryRead(out var message))
{
messages.Add(message);
}
@ -950,13 +943,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests
return messages;
}
private static DefaultHttpContext MakeRequest(string path, ConnectionState state, string format = null)
private static DefaultHttpContext MakeRequest(string path, DefaultConnectionContext connection, string format = null)
{
var context = new DefaultHttpContext();
context.Request.Path = path;
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
values["id"] = connection.ConnectionId;
if (format != null)
{
values["format"] = format;