diff --git a/src/Microsoft.AspNetCore.Http.Connections/ConnectionsAppBuilderExtensions.cs b/src/Microsoft.AspNetCore.Http.Connections/ConnectionsAppBuilderExtensions.cs index f18713b37e..a2219f0649 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/ConnectionsAppBuilderExtensions.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/ConnectionsAppBuilderExtensions.cs @@ -3,6 +3,7 @@ using System; using Microsoft.AspNetCore.Http.Connections; +using Microsoft.AspNetCore.Http.Connections.Internal; using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.DependencyInjection; diff --git a/src/Microsoft.AspNetCore.Http.Connections/ConnectionsDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.Http.Connections/ConnectionsDependencyInjectionExtensions.cs index 6c57058cf1..3bf514ed99 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/ConnectionsDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/ConnectionsDependencyInjectionExtensions.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Microsoft.AspNetCore.Http.Connections; +using Microsoft.AspNetCore.Http.Connections.Internal; using Microsoft.Extensions.DependencyInjection.Extensions; namespace Microsoft.Extensions.DependencyInjection diff --git a/src/Microsoft.AspNetCore.Http.Connections/ConnectionsRouteBuilder.cs b/src/Microsoft.AspNetCore.Http.Connections/ConnectionsRouteBuilder.cs index f2f15e4c09..2836835f3c 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/ConnectionsRouteBuilder.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/ConnectionsRouteBuilder.cs @@ -6,6 +6,7 @@ using System.Reflection; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Connections.Internal; using Microsoft.AspNetCore.Routing; namespace Microsoft.AspNetCore.Http.Connections @@ -21,13 +22,10 @@ namespace Microsoft.AspNetCore.Http.Connections _dispatcher = dispatcher; } - public void MapConnections(string path, Action configure) => - MapConnections(new PathString(path), new HttpConnectionOptions(), configure); - public void MapConnections(PathString path, Action configure) => - MapConnections(path, new HttpConnectionOptions(), configure); + MapConnections(path, new HttpConnectionDispatcherOptions(), configure); - public void MapConnections(PathString path, HttpConnectionOptions options, Action configure) + public void MapConnections(PathString path, HttpConnectionDispatcherOptions options, Action configure) { var connectionBuilder = new ConnectionBuilder(_routes.ServiceProvider); configure(connectionBuilder); @@ -36,20 +34,15 @@ namespace Microsoft.AspNetCore.Http.Connections _routes.MapRoute(path + "/negotiate", c => _dispatcher.ExecuteNegotiateAsync(c, options)); } - public void MapConnectionHandler(string path) where TConnectionHandler : ConnectionHandler - { - MapConnectionHandler(new PathString(path), configureOptions: null); - } - public void MapConnectionHandler(PathString path) where TConnectionHandler : ConnectionHandler { MapConnectionHandler(path, configureOptions: null); } - public void MapConnectionHandler(PathString path, Action configureOptions) where TConnectionHandler : ConnectionHandler + public void MapConnectionHandler(PathString path, Action configureOptions) where TConnectionHandler : ConnectionHandler { var authorizeAttributes = typeof(TConnectionHandler).GetCustomAttributes(inherit: true); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); foreach (var attribute in authorizeAttributes) { options.AuthorizationData.Add(attribute); diff --git a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionOptions.cs b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcherOptions.cs similarity index 92% rename from src/Microsoft.AspNetCore.Http.Connections/HttpConnectionOptions.cs rename to src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcherOptions.cs index 62a88cd6ca..e73cff273f 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionOptions.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcherOptions.cs @@ -7,13 +7,13 @@ using Microsoft.AspNetCore.Http.Connections.Internal; namespace Microsoft.AspNetCore.Http.Connections { - public class HttpConnectionOptions + public class HttpConnectionDispatcherOptions { // Selected because this is the default value of PipeWriter.PauseWriterThreshold. // There maybe the opportunity for performance gains by tuning this default. private const int DefaultPipeBufferSize = 32768; - public HttpConnectionOptions() + public HttpConnectionDispatcherOptions() { AuthorizationData = new List(); Transports = HttpTransports.All; diff --git a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionContext.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionContext.cs similarity index 97% rename from src/Microsoft.AspNetCore.Http.Connections/HttpConnectionContext.cs rename to src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionContext.cs index c9ea9cc850..86d5292e0d 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionContext.cs @@ -8,14 +8,13 @@ using System.IO.Pipelines; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Connections.Features; -using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; using Microsoft.Extensions.Logging; -namespace Microsoft.AspNetCore.Http.Connections +namespace Microsoft.AspNetCore.Http.Connections.Internal { public class HttpConnectionContext : ConnectionContext, IConnectionIdFeature, @@ -88,7 +87,7 @@ namespace Microsoft.AspNetCore.Http.Connections public DateTime LastSeenUtc { get; set; } - public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive; + public HttpConnectionStatus Status { get; set; } = HttpConnectionStatus.Inactive; public override string ConnectionId { get; set; } @@ -161,13 +160,13 @@ namespace Microsoft.AspNetCore.Http.Connections { await Lock.WaitAsync(); - if (Status == ConnectionStatus.Disposed) + if (Status == HttpConnectionStatus.Disposed) { disposeTask = _disposeTcs.Task; } else { - Status = ConnectionStatus.Disposed; + Status = HttpConnectionStatus.Disposed; Log.DisposingConnection(_logger, ConnectionId); @@ -282,13 +281,6 @@ namespace Microsoft.AspNetCore.Http.Connections } } - public enum ConnectionStatus - { - Inactive, - Active, - Disposed - } - private static class Log { private static readonly Action _disposingConnection = diff --git a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.Log.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.Log.cs similarity index 99% rename from src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.Log.cs rename to src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.Log.cs index 00d78bee4e..6b16d8533a 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.Log.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.Log.cs @@ -4,7 +4,7 @@ using System; using Microsoft.Extensions.Logging; -namespace Microsoft.AspNetCore.Http.Connections +namespace Microsoft.AspNetCore.Http.Connections.Internal { public partial class HttpConnectionDispatcher { diff --git a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs similarity index 95% rename from src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs rename to src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs index 3327fd925b..0d3717ee41 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs @@ -12,14 +12,13 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Connections.Features; -using Microsoft.AspNetCore.Http.Connections.Internal; using Microsoft.AspNetCore.Http.Connections.Internal.Transports; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; -namespace Microsoft.AspNetCore.Http.Connections +namespace Microsoft.AspNetCore.Http.Connections.Internal { public partial class HttpConnectionDispatcher { @@ -55,7 +54,7 @@ namespace Microsoft.AspNetCore.Http.Connections _logger = _loggerFactory.CreateLogger(); } - public async Task ExecuteAsync(HttpContext context, HttpConnectionOptions options, ConnectionDelegate connectionDelegate) + public async Task ExecuteAsync(HttpContext context, HttpConnectionDispatcherOptions options, ConnectionDelegate connectionDelegate) { // Create the log scope and attempt to pass the Connection ID to it so as many logs as possible contain // the Connection ID metadata. If this is the negotiate request then the Connection ID for the scope will @@ -91,7 +90,7 @@ namespace Microsoft.AspNetCore.Http.Connections } } - public async Task ExecuteNegotiateAsync(HttpContext context, HttpConnectionOptions options) + public async Task ExecuteNegotiateAsync(HttpContext context, HttpConnectionDispatcherOptions options) { // Create the log scope and the scope connectionId param will be set when the connection is created. var logScope = new ConnectionLogScope(connectionId: string.Empty); @@ -115,7 +114,7 @@ namespace Microsoft.AspNetCore.Http.Connections } } - private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connectionDelegate, HttpConnectionOptions options, ConnectionLogScope logScope) + private async Task ExecuteAsync(HttpContext context, ConnectionDelegate connectionDelegate, HttpConnectionDispatcherOptions options, ConnectionLogScope logScope) { var supportedTransports = options.Transports; @@ -193,7 +192,7 @@ namespace Microsoft.AspNetCore.Http.Connections { await connection.Lock.WaitAsync(); - if (connection.Status == HttpConnectionContext.ConnectionStatus.Disposed) + if (connection.Status == HttpConnectionStatus.Disposed) { Log.ConnectionDisposed(_logger, connection.ConnectionId); @@ -203,7 +202,7 @@ namespace Microsoft.AspNetCore.Http.Connections return; } - if (connection.Status == HttpConnectionContext.ConnectionStatus.Active) + if (connection.Status == HttpConnectionStatus.Active) { var existing = connection.GetHttpContext(); Log.ConnectionAlreadyActive(_logger, connection.ConnectionId, existing.TraceIdentifier); @@ -221,7 +220,7 @@ namespace Microsoft.AspNetCore.Http.Connections } // Mark the connection as active - connection.Status = HttpConnectionContext.ConnectionStatus.Active; + connection.Status = HttpConnectionStatus.Active; // Raise OnConnected for new connections only since polls happen all the time if (connection.ApplicationTask == null) @@ -295,12 +294,12 @@ namespace Microsoft.AspNetCore.Http.Connections { await connection.Lock.WaitAsync(); - if (connection.Status == HttpConnectionContext.ConnectionStatus.Active) + if (connection.Status == HttpConnectionStatus.Active) { // Mark the connection as inactive connection.LastSeenUtc = DateTime.UtcNow; - connection.Status = HttpConnectionContext.ConnectionStatus.Inactive; + connection.Status = HttpConnectionStatus.Inactive; // Dispose the cancellation token connection.Cancellation.Dispose(); @@ -325,7 +324,7 @@ namespace Microsoft.AspNetCore.Http.Connections { await connection.Lock.WaitAsync(); - if (connection.Status == HttpConnectionContext.ConnectionStatus.Disposed) + if (connection.Status == HttpConnectionStatus.Disposed) { Log.ConnectionDisposed(_logger, connection.ConnectionId); @@ -335,7 +334,7 @@ namespace Microsoft.AspNetCore.Http.Connections } // There's already an active request - if (connection.Status == HttpConnectionContext.ConnectionStatus.Active) + if (connection.Status == HttpConnectionStatus.Active) { Log.ConnectionAlreadyActive(_logger, connection.ConnectionId, connection.GetHttpContext().TraceIdentifier); @@ -345,7 +344,7 @@ namespace Microsoft.AspNetCore.Http.Connections } // Mark the connection as active - connection.Status = HttpConnectionContext.ConnectionStatus.Active; + connection.Status = HttpConnectionStatus.Active; // Call into the end point passing the connection connection.ApplicationTask = ExecuteApplication(connectionDelegate, connection); @@ -380,7 +379,7 @@ namespace Microsoft.AspNetCore.Http.Connections await connectionDelegate(connection); } - private async Task ProcessNegotiate(HttpContext context, HttpConnectionOptions options, ConnectionLogScope logScope) + private async Task ProcessNegotiate(HttpContext context, HttpConnectionDispatcherOptions options, ConnectionLogScope logScope) { context.Response.ContentType = "application/json"; @@ -411,7 +410,7 @@ namespace Microsoft.AspNetCore.Http.Connections } } - private static void WriteNegotiatePayload(IBufferWriter writer, string connectionId, HttpContext context, HttpConnectionOptions options) + private static void WriteNegotiatePayload(IBufferWriter writer, string connectionId, HttpContext context, HttpConnectionDispatcherOptions options) { var response = new NegotiationResponse(); response.ConnectionId = connectionId; @@ -442,7 +441,7 @@ namespace Microsoft.AspNetCore.Http.Connections private static string GetConnectionId(HttpContext context) => context.Request.Query["id"]; - private async Task ProcessSend(HttpContext context, HttpConnectionOptions options) + private async Task ProcessSend(HttpContext context, HttpConnectionDispatcherOptions options) { var connection = await GetConnectionAsync(context); if (connection == null) @@ -469,7 +468,7 @@ namespace Microsoft.AspNetCore.Http.Connections try { - if (connection.Status == HttpConnectionContext.ConnectionStatus.Disposed) + if (connection.Status == HttpConnectionStatus.Disposed) { Log.ConnectionDisposed(_logger, connection.ConnectionId); @@ -523,7 +522,7 @@ namespace Microsoft.AspNetCore.Http.Connections context.Response.ContentType = "text/plain"; } - private async Task EnsureConnectionStateAsync(HttpConnectionContext connection, HttpContext context, HttpTransportType transportType, HttpTransportType supportedTransports, ConnectionLogScope logScope, HttpConnectionOptions options) + private async Task EnsureConnectionStateAsync(HttpConnectionContext connection, HttpContext context, HttpTransportType transportType, HttpTransportType supportedTransports, ConnectionLogScope logScope, HttpConnectionDispatcherOptions options) { if ((supportedTransports & transportType) == 0) { @@ -672,7 +671,7 @@ namespace Microsoft.AspNetCore.Http.Connections } // This is only used for WebSockets connections, which can connect directly without negotiating - private async Task GetOrCreateConnectionAsync(HttpContext context, HttpConnectionOptions options) + private async Task GetOrCreateConnectionAsync(HttpContext context, HttpConnectionDispatcherOptions options) { var connectionId = GetConnectionId(context); HttpConnectionContext connection; @@ -693,7 +692,7 @@ namespace Microsoft.AspNetCore.Http.Connections return connection; } - private HttpConnectionContext CreateConnection(HttpConnectionOptions options) + private HttpConnectionContext CreateConnection(HttpConnectionDispatcherOptions options) { var transportPipeOptions = new PipeOptions(pauseWriterThreshold: options.TransportMaxBufferSize, resumeWriterThreshold: options.TransportMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); var appPipeOptions = new PipeOptions(pauseWriterThreshold: options.ApplicationMaxBufferSize, resumeWriterThreshold: options.ApplicationMaxBufferSize / 2, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); diff --git a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.Log.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionManager.Log.cs similarity index 98% rename from src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.Log.cs rename to src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionManager.Log.cs index 0e65679e19..7f542b920e 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.Log.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionManager.Log.cs @@ -4,7 +4,7 @@ using System; using Microsoft.Extensions.Logging; -namespace Microsoft.AspNetCore.Http.Connections +namespace Microsoft.AspNetCore.Http.Connections.Internal { public partial class HttpConnectionManager { diff --git a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionManager.cs similarity index 97% rename from src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.cs rename to src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionManager.cs index 2f228cda50..79b849120d 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionManager.cs @@ -13,11 +13,10 @@ using System.Security.Cryptography; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Http.Connections.Internal; using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; -namespace Microsoft.AspNetCore.Http.Connections +namespace Microsoft.AspNetCore.Http.Connections.Internal { public partial class HttpConnectionManager { @@ -139,7 +138,7 @@ namespace Microsoft.AspNetCore.Http.Connections // Scan the registered connections looking for ones that have timed out foreach (var c in _connections) { - HttpConnectionContext.ConnectionStatus status; + HttpConnectionStatus status; DateTimeOffset lastSeenUtc; var connection = c.Value.Connection; @@ -159,7 +158,7 @@ namespace Microsoft.AspNetCore.Http.Connections // Once the decision has been made to dispose we don't check the status again // But don't clean up connections while the debugger is attached. - if (!Debugger.IsAttached && status == HttpConnectionContext.ConnectionStatus.Inactive && (DateTimeOffset.UtcNow - lastSeenUtc).TotalSeconds > 5) + if (!Debugger.IsAttached && status == HttpConnectionStatus.Inactive && (DateTimeOffset.UtcNow - lastSeenUtc).TotalSeconds > 5) { Log.ConnectionTimedOut(_logger, connection.ConnectionId); HttpConnectionsEventSource.Log.ConnectionTimedOut(connection.ConnectionId); diff --git a/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionStatus.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionStatus.cs new file mode 100644 index 0000000000..a167bba04b --- /dev/null +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionStatus.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Http.Connections.Internal +{ + public enum HttpConnectionStatus + { + Inactive, + Active, + Disposed + } +} diff --git a/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/WebSocketsTransport.Log.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/WebSocketsTransport.Log.cs index b7cb21ab19..13df2672c6 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/WebSocketsTransport.Log.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/WebSocketsTransport.Log.cs @@ -11,8 +11,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports { private static class Log { - private static readonly Action _socketOpened = - LoggerMessage.Define(LogLevel.Debug, new EventId(1, "SocketOpened"), "Socket opened."); + private static readonly Action _socketOpened = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, "SocketOpened"), "Socket opened using Sub-Protocol: '{SubProtocol}'."); private static readonly Action _socketClosed = LoggerMessage.Define(LogLevel.Debug, new EventId(2, "SocketClosed"), "Socket closed."); @@ -50,9 +50,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports private static readonly Action _sendFailed = LoggerMessage.Define(LogLevel.Error, new EventId(13, "SendFailed"), "Socket failed to send."); - public static void SocketOpened(ILogger logger) + public static void SocketOpened(ILogger logger, string subProtocol) { - _socketOpened(logger, null); + _socketOpened(logger, subProtocol, null); } public static void SocketClosed(ILogger logger) diff --git a/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/WebSocketsTransport.cs index e8f51781af..82e091b5b5 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/WebSocketsTransport.cs @@ -49,9 +49,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports { Debug.Assert(context.WebSockets.IsWebSocketRequest, "Not a websocket request"); - using (var ws = await context.WebSockets.AcceptWebSocketAsync(_options.SubProtocol)) + var subProtocol = _options.SubProtocolSelector?.Invoke(context.WebSockets.WebSocketRequestedProtocols); + + using (var ws = await context.WebSockets.AcceptWebSocketAsync(subProtocol)) { - Log.SocketOpened(_logger); + Log.SocketOpened(_logger, subProtocol); try { diff --git a/src/Microsoft.AspNetCore.Http.Connections/WebSocketOptions.cs b/src/Microsoft.AspNetCore.Http.Connections/WebSocketOptions.cs index 82f94f2213..3fe27b35c4 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/WebSocketOptions.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/WebSocketOptions.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; namespace Microsoft.AspNetCore.Http.Connections { @@ -9,6 +10,17 @@ namespace Microsoft.AspNetCore.Http.Connections { public TimeSpan CloseTimeout { get; set; } = TimeSpan.FromSeconds(5); - public string SubProtocol { get; set; } + /// + /// Gets or sets a delegate that will be called when a new WebSocket is established to select the value + /// for the 'Sec-WebSocket-Protocol' response header. The delegate will be called with a list of the protocols provided + /// by the client in the 'Sec-WebSocket-Protocol' request header. + /// + /// + /// See RFC 6455 section 1.3 for more details on the WebSocket handshake: https://tools.ietf.org/html/rfc6455#section-1.3 + /// + // WebSocketManager's list of sub protocols is an IList: + // https://github.com/aspnet/HttpAbstractions/blob/a6bdb9b1ec6ed99978a508e71a7f131be7e4d9fb/src/Microsoft.AspNetCore.Http.Abstractions/WebSocketManager.cs#L23 + // Unfortunately, IList does not implement IReadOnlyList :( + public Func, string> SubProtocolSelector { get; set; } } } diff --git a/src/Microsoft.AspNetCore.SignalR/HubRouteBuilder.cs b/src/Microsoft.AspNetCore.SignalR/HubRouteBuilder.cs index 8e5d641f06..4d9a798659 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubRouteBuilder.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubRouteBuilder.cs @@ -28,11 +28,11 @@ namespace Microsoft.AspNetCore.SignalR MapHub(path, configureOptions: null); } - public void MapHub(PathString path, Action configureOptions) where THub : Hub + public void MapHub(PathString path, Action configureOptions) where THub : Hub { // find auth attributes var authorizeAttributes = typeof(THub).GetCustomAttributes(inherit: true); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); foreach (var attribute in authorizeAttributes) { options.AuthorizationData.Add(attribute); diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs index b9d743df97..4ecfb1d433 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs @@ -17,6 +17,7 @@ using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Connections.Internal; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Internal; using Microsoft.Extensions.DependencyInjection; @@ -51,7 +52,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; context.Response.Body = ms; - await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionOptions()); + await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions()); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); var connectionId = negotiateResponse.Value("connectionId"); Assert.True(manager.TryGetConnection(connectionId, out var connectionContext)); @@ -74,7 +75,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; context.Response.Body = ms; - var options = new HttpConnectionOptions { TransportMaxBufferSize = 4, ApplicationMaxBufferSize = 4 }; + var options = new HttpConnectionDispatcherOptions { TransportMaxBufferSize = 4, ApplicationMaxBufferSize = 4 }; await dispatcher.ExecuteNegotiateAsync(context, options); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); var connectionId = negotiateResponse.Value("connectionId"); @@ -134,7 +135,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var app = builder.Build(); // This task should complete immediately but it exceeds the writer threshold - var executeTask = dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + var executeTask = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); Assert.False(executeTask.IsCompleted); await connection.Transport.Input.ConsumeAsync(10); await executeTask.OrTimeout(); @@ -166,7 +167,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; context.Response.Body = ms; - await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionOptions { Transports = transports }); + await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions { Transports = transports }); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); var availableTransports = HttpTransportType.None; @@ -211,7 +212,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode); await strm.FlushAsync(); @@ -246,7 +247,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode); await strm.FlushAsync(); @@ -283,7 +284,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); Assert.Equal(StatusCodes.Status405MethodNotAllowed, context.Response.StatusCode); await strm.FlushAsync(); @@ -321,7 +322,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode); } @@ -371,7 +372,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests }); var app = builder.Build(); - var task = dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + var task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); // Pretend the transport closed because the client disconnected if (context.WebSockets.IsWebSocketRequest) @@ -432,7 +433,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests }); var app = builder.Build(); - var task = dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + var task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); // Pretend the transport closed because the client disconnected cts.Cancel(); @@ -491,7 +492,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.Equal(0, connection.ApplicationStream.Length); - await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); Assert.True(connection.Transport.Input.TryRead(out var result)); Assert.Equal("Hello World", Encoding.UTF8.GetString(result.Buffer.ToArray())); @@ -551,7 +552,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests builder.UseConnectionHandler(); var app = builder.Build(); - await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); } } } @@ -633,7 +634,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var app = builder.Build(); // Start a poll - var task = dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + var task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); // Send to the application var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -704,7 +705,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); await strm.FlushAsync(); @@ -733,7 +734,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); await strm.FlushAsync(); @@ -807,7 +808,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); @@ -834,7 +835,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); @@ -861,7 +862,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); @@ -888,7 +889,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); options.LongPolling.PollTimeout = TimeSpan.FromSeconds(2); await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); @@ -915,7 +916,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1); var task = dispatcher.ExecuteAsync(context, options, app); @@ -948,7 +949,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); var request1 = dispatcher.ExecuteAsync(context1, options, app); await dispatcher.ExecuteAsync(context2, options, app); @@ -988,14 +989,14 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); var request1 = dispatcher.ExecuteAsync(context1, options, app); var request2 = dispatcher.ExecuteAsync(context2, options, app); await request1; Assert.Equal(StatusCodes.Status204NoContent, context1.Response.StatusCode); - Assert.Equal(HttpConnectionContext.ConnectionStatus.Active, connection.Status); + Assert.Equal(HttpConnectionStatus.Active, connection.Status); Assert.False(request2.IsCompleted); @@ -1015,7 +1016,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); connection.TransportType = transportType; - connection.Status = HttpConnectionContext.ConnectionStatus.Disposed; + connection.Status = HttpConnectionStatus.Disposed; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -1027,7 +1028,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); await dispatcher.ExecuteAsync(context, options, app); @@ -1053,7 +1054,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); var task = dispatcher.ExecuteAsync(context, options, app); var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -1063,7 +1064,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests await task; - Assert.Equal(HttpConnectionContext.ConnectionStatus.Inactive, connection.Status); + Assert.Equal(HttpConnectionStatus.Inactive, connection.Status); Assert.NotNull(connection.GetHttpContext()); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); @@ -1089,7 +1090,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); var task = dispatcher.ExecuteAsync(context, options, app); var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -1123,7 +1124,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); var task = dispatcher.ExecuteAsync(context, options, app); var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -1155,7 +1156,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); var context1 = MakeRequest("/foo", connection); var task1 = dispatcher.ExecuteAsync(context1, options, app); @@ -1201,7 +1202,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(0); await dispatcher.ExecuteAsync(context, options, app); @@ -1249,7 +1250,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(sp); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); options.AuthorizationData.Add(new AuthorizeAttribute("test")); // would get stuck if EndPoint was running @@ -1295,7 +1296,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(sp); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); options.AuthorizationData.Add(new AuthorizeAttribute("test")); context.User = new ClaimsPrincipal(new ClaimsIdentity("authenticated")); @@ -1348,7 +1349,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(sp); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); options.AuthorizationData.Add(new AuthorizeAttribute("test")); // "authorize" user @@ -1409,7 +1410,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(sp); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); options.AuthorizationData.Add(new AuthorizeAttribute("test")); options.AuthorizationData.Add(new AuthorizeAttribute("secondPolicy")); @@ -1488,7 +1489,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(sp); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); options.AuthorizationData.Add(new AuthorizeAttribute("test")); // "authorize" user @@ -1545,7 +1546,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(sp); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); options.AuthorizationData.Add(new AuthorizeAttribute("test")); // "authorize" user @@ -1576,7 +1577,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); options.LongPolling.PollTimeout = TimeSpan.FromMilliseconds(1); // We don't care about the poll itself Assert.Null(connection.Features.Get()); @@ -1610,7 +1611,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); _ = dispatcher.ExecuteAsync(context, options, app).OrTimeout(); @@ -1649,7 +1650,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); var pollTask = dispatcher.ExecuteAsync(context, options, app); @@ -1696,7 +1697,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); options.LongPolling.PollTimeout = TimeSpan.FromMilliseconds(1); await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); @@ -1742,7 +1743,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Request.Path = "/foo"; context.Request.Method = "POST"; context.Response.Body = ms; - await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionOptions { Transports = HttpTransportType.WebSockets }); + await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions { Transports = HttpTransportType.WebSockets }); var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); var availableTransports = (JArray)negotiateResponse["availableTransports"]; @@ -1821,7 +1822,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); options.Transports = supportedTransports; await dispatcher.ExecuteAsync(context, options, app); diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionManagerTests.cs index daa3a6cf78..820d28749d 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionManagerTests.cs @@ -5,6 +5,7 @@ using System; using System.IO.Pipelines; using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http.Connections.Internal; using Microsoft.Extensions.Logging; using Xunit; @@ -19,7 +20,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var connection = connectionManager.CreateConnection(); Assert.NotNull(connection.ConnectionId); - Assert.Equal(HttpConnectionContext.ConnectionStatus.Inactive, connection.Status); + Assert.Equal(HttpConnectionStatus.Inactive, connection.Status); Assert.Null(connection.ApplicationTask); Assert.Null(connection.TransportTask); Assert.Null(connection.Cancellation); @@ -265,7 +266,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.NotNull(connection.Transport); await connection.DisposeAsync(); - Assert.Equal(HttpConnectionContext.ConnectionStatus.Disposed, connection.Status); + Assert.Equal(HttpConnectionStatus.Disposed, connection.Status); } [Fact] @@ -279,7 +280,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.NotNull(connection.Application); await connection.DisposeAsync(); - Assert.Equal(HttpConnectionContext.ConnectionStatus.Disposed, connection.Status); + Assert.Equal(HttpConnectionStatus.Disposed, connection.Status); } [Fact] diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/LongPollingTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/LongPollingTests.cs index 86c74a4752..ed6cdc4e60 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/LongPollingTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/LongPollingTests.cs @@ -100,7 +100,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests [Fact] public void CheckLongPollingTimeoutValue() { - var options = new HttpConnectionOptions(); + var options = new HttpConnectionDispatcherOptions(); Assert.Equal(options.LongPolling.PollTimeout, TimeSpan.FromSeconds(90)); } } diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/MapConnectionHandlerTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/MapConnectionHandlerTests.cs index 0d24cf0e61..d62c4f953d 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/MapConnectionHandlerTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/MapConnectionHandlerTests.cs @@ -73,7 +73,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests public async Task MapConnectionHandlerWithWebSocketSubProtocolSetsProtocol() { var host = BuildWebHost("/socket", - options => options.WebSockets.SubProtocol = "protocol1"); + options => options.WebSockets.SubProtocolSelector = subprotocols => + { + Assert.Equal(new [] { "protocol1", "protocol2" }, subprotocols.ToArray()); + return "protocol1"; + }); await host.StartAsync(); @@ -131,7 +135,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } - private IWebHost BuildWebHost(string path, Action configureOptions) where TConnectionHandler : ConnectionHandler + private IWebHost BuildWebHost(string path, Action configureOptions) where TConnectionHandler : ConnectionHandler { return new WebHostBuilder() .UseUrls("http://127.0.0.1:0") diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/TestWebSocketConnectionFeature.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/TestWebSocketConnectionFeature.cs index 9f6e1a5092..c0711f4ace 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/TestWebSocketConnectionFeature.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/TestWebSocketConnectionFeature.cs @@ -10,10 +10,16 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { internal class TestWebSocketConnectionFeature : IHttpWebSocketFeature, IDisposable { + private readonly TaskCompletionSource _accepted = new TaskCompletionSource(); + public bool IsWebSocketRequest => true; public WebSocketChannel Client { get; private set; } + public string SubProtocol { get; private set; } + + public Task Accepted => _accepted.Task; + public Task AcceptAsync() => AcceptAsync(new WebSocketAcceptContext()); public Task AcceptAsync(WebSocketAcceptContext context) @@ -25,6 +31,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var serverSocket = new WebSocketChannel(clientToServer.Reader, serverToClient.Writer); Client = clientSocket; + SubProtocol = context.SubProtocol; + + _accepted.TrySetResult(null); return Task.FromResult(serverSocket); } diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs index c5baa83e33..4bc577be48 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs @@ -4,14 +4,18 @@ using System; using System.Buffers; using System.IO.Pipelines; +using System.Linq; using System.Net.WebSockets; using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Connections.Internal; using Microsoft.AspNetCore.Http.Connections.Internal.Transports; +using Microsoft.AspNetCore.Http.Features; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Testing; +using Microsoft.Net.Http.Headers; using Xunit; using Xunit.Abstractions; @@ -346,5 +350,57 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } } + + [Fact] + public async Task SubProtocolSelectorIsUsedToSelectSubProtocol() + { + const string ExpectedSubProtocol = "expected"; + var providedSubProtocols = new[] {"provided1", "provided2"}; + + using (StartLog(out var loggerFactory, LogLevel.Debug)) + { + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var connection = new HttpConnectionContext("foo", pair.Transport, pair.Application); + + using (var feature = new TestWebSocketConnectionFeature()) + { + var options = new WebSocketOptions + { + // We want to verify behavior without timeout affecting it + CloseTimeout = TimeSpan.FromSeconds(20), + SubProtocolSelector = protocols => { + Assert.Equal(providedSubProtocols, protocols.ToArray()); + return ExpectedSubProtocol; + }, + }; + + var connectionContext = new HttpConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(options, connection.Application, connectionContext, loggerFactory); + + // Create an HttpContext + var context = new DefaultHttpContext(); + context.Request.Headers.Add(HeaderNames.WebSocketSubProtocols, providedSubProtocols.ToArray()); + context.Features.Set(feature); + var transport = ws.ProcessRequestAsync(context, CancellationToken.None); + + await feature.Accepted.OrThrowIfOtherFails(transport); + + // Assert the feature got the right subprotocol + Assert.Equal(ExpectedSubProtocol, feature.SubProtocol); + + // Run the client socket + var client = feature.Client.ExecuteAndCaptureFramesAsync(); + + await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).OrTimeout(); + + // close the client to server channel + connection.Transport.Output.Complete(); + + _ = await client.OrTimeout(); + + await transport.OrTimeout(); + } + } + } } }