diff --git a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs index eca117dc30..3b7b33645c 100644 --- a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.SignalR.Features; using Microsoft.AspNetCore.SignalR.Internal.Protocol; namespace Microsoft.AspNetCore.SignalR diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs index 3bdce8877e..d29718fa31 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs @@ -8,5 +8,7 @@ namespace Microsoft.AspNetCore.Sockets.Features public interface IConnectionTransportFeature { Channel Transport { get; set; } + + TransferMode TransportCapabilities { get; set; } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionUserFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionUserFeature.cs index 7637468399..78fd8bd382 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionUserFeature.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionUserFeature.cs @@ -1,10 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Collections.Generic; using System.Security.Claims; -using System.Text; namespace Microsoft.AspNetCore.Sockets.Features { diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs index 3930884104..fbd6cf468a 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs @@ -89,6 +89,9 @@ namespace Microsoft.AspNetCore.Sockets _logger.EstablishedConnection(connection.ConnectionId, context.TraceIdentifier); + // ServerSentEvents is a text protocol only + connection.TransportCapabilities = TransferMode.Text; + // We only need to provide the Input channel since writing to the application is handled through /send. var sse = new ServerSentEventsTransport(connection.Application.In, connection.ConnectionId, _loggerFactory); @@ -112,7 +115,7 @@ namespace Microsoft.AspNetCore.Sockets _logger.EstablishedConnection(connection.ConnectionId, context.TraceIdentifier); - var ws = new WebSocketsTransport(options.WebSockets, connection.Application, connection.ConnectionId, _loggerFactory); + var ws = new WebSocketsTransport(options.WebSockets, connection.Application, connection, _loggerFactory); await DoPersistentConnection(socketDelegate, ws, context, connection); } @@ -330,7 +333,7 @@ namespace Microsoft.AspNetCore.Sockets // Establish the connection var connection = _manager.CreateConnection(); - + // Set the Connection ID on the logging scope so that logs from now on will have the // Connection ID metadata set. logScope.ConnectionId = connection.ConnectionId; @@ -433,6 +436,9 @@ namespace Microsoft.AspNetCore.Sockets connection.User = context.User; connection.SetHttpContext(context); + // this is the default setting which should be overwritten by transports that have different capabilities (e.g. SSE) + connection.TransportCapabilities = TransferMode.Binary | TransferMode.Text; + // Set the Connection ID on the logging scope so that logs from now on will have the // Connection ID metadata set. logScope.ConnectionId = connection.ConnectionId; diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs index 5811dc6fd4..4d93111c3e 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs @@ -18,9 +18,9 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports private readonly WebSocketOptions _options; private readonly ILogger _logger; private readonly Channel _application; - private readonly string _connectionId; + private readonly DefaultConnectionContext _connection; - public WebSocketsTransport(WebSocketOptions options, Channel application, string connectionId, ILoggerFactory loggerFactory) + public WebSocketsTransport(WebSocketOptions options, Channel application, DefaultConnectionContext connection, ILoggerFactory loggerFactory) { if (options == null) { @@ -39,7 +39,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports _options = options; _application = application; - _connectionId = connectionId; + _connection = connection; _logger = loggerFactory.CreateLogger(); } @@ -49,11 +49,11 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports using (var ws = await context.WebSockets.AcceptWebSocketAsync()) { - _logger.SocketOpened(_connectionId); + _logger.SocketOpened(_connection.ConnectionId); await ProcessSocketAsync(ws); } - _logger.SocketClosed(_connectionId); + _logger.SocketClosed(_connection.ConnectionId); } public async Task ProcessSocketAsync(WebSocket socket) @@ -72,12 +72,12 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports if (trigger == receiving) { task = sending; - _logger.WaitingForSend(_connectionId); + _logger.WaitingForSend(_connection.ConnectionId); } else { task = receiving; - _logger.WaitingForClose(_connectionId); + _logger.WaitingForClose(_connection.ConnectionId); } // We're done writing @@ -89,7 +89,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports if (resultTask != task) { - _logger.CloseTimedOut(_connectionId); + _logger.CloseTimedOut(_connection.ConnectionId); socket.Abort(); } else @@ -123,7 +123,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports return receiveResult; } - _logger.MessageReceived(_connectionId, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); + _logger.MessageReceived(_connection.ConnectionId, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); var truncBuffer = new ArraySegment(buffer.Array, 0, receiveResult.Count); incomingMessage.Add(truncBuffer); @@ -153,7 +153,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports Buffer.BlockCopy(incomingMessage[0].Array, incomingMessage[0].Offset, messageBuffer, 0, incomingMessage[0].Count); } - _logger.MessageToApplication(_connectionId, messageBuffer.Length); + _logger.MessageToApplication(_connection.ConnectionId, messageBuffer.Length); while (await _application.Out.WaitToWriteAsync()) { if (_application.Out.TryWrite(messageBuffer)) @@ -176,22 +176,26 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports { try { - _logger.SendPayload(_connectionId, buffer.Length); + _logger.SendPayload(_connection.ConnectionId, buffer.Length); + + var webSocketMessageType = (_connection.TransferMode == TransferMode.Binary + ? WebSocketMessageType.Binary + : WebSocketMessageType.Text); if (WebSocketCanSend(ws)) { - await ws.SendAsync(new ArraySegment(buffer), _options.WebSocketMessageType, endOfMessage: true, cancellationToken: CancellationToken.None); + await ws.SendAsync(new ArraySegment(buffer), webSocketMessageType, endOfMessage: true, cancellationToken: CancellationToken.None); } } catch (WebSocketException socketException) when (!WebSocketCanSend(ws)) { // this can happen when we send the CloseFrame to the client and try to write afterwards - _logger.SendFailed(_connectionId, socketException); + _logger.SendFailed(_connection.ConnectionId, socketException); break; } catch (Exception ex) { - _logger.ErrorWritingFrame(_connectionId, ex); + _logger.ErrorWritingFrame(_connection.ConnectionId, ex); break; } } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/WebSocketOptions.cs b/src/Microsoft.AspNetCore.Sockets.Http/WebSocketOptions.cs index 7211167cbb..0157392996 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/WebSocketOptions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/WebSocketOptions.cs @@ -2,14 +2,11 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Net.WebSockets; namespace Microsoft.AspNetCore.Sockets { public class WebSocketOptions { public TimeSpan CloseTimeout { get; set; } = TimeSpan.FromSeconds(5); - - public WebSocketMessageType WebSocketMessageType { get; set; } = WebSocketMessageType.Text; } } diff --git a/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs index 4072a9e9ec..3acb6bb0c1 100644 --- a/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs @@ -16,7 +16,8 @@ namespace Microsoft.AspNetCore.Sockets IConnectionIdFeature, IConnectionMetadataFeature, IConnectionTransportFeature, - IConnectionUserFeature + IConnectionUserFeature, + ITransferModeFeature { // This tcs exists so that multiple calls to DisposeAsync all wait asynchronously // on the same task @@ -35,6 +36,7 @@ namespace Microsoft.AspNetCore.Sockets Features.Set(this); Features.Set(this); Features.Set(this); + Features.Set(this); } public CancellationTokenSource Cancellation { get; set; } @@ -61,6 +63,10 @@ namespace Microsoft.AspNetCore.Sockets public override Channel Transport { get; set; } + public TransferMode TransportCapabilities { get; set; } + + public TransferMode TransferMode { get; set; } + public async Task DisposeAsync() { Task disposeTask = Task.CompletedTask; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj index 71be053e5d..b8570b06bc 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/Microsoft.AspNetCore.SignalR.Client.Tests.csproj @@ -14,7 +14,7 @@ - + diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 0828eee09f..5b7c44bfea 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -118,7 +118,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests } } - [Fact] public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvidedForPost() { @@ -584,6 +583,33 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.Equal("Hello, World", GetContentAsString(context2.Response.Body)); } + [Theory] + [InlineData(TransportType.LongPolling, TransferMode.Binary | TransferMode.Text)] + [InlineData(TransportType.ServerSentEvents, TransferMode.Text)] + [InlineData(TransportType.WebSockets, TransferMode.Binary | TransferMode.Text)] + public async Task TransportCapabilitiesSet(TransportType transportType, TransferMode expectedTransportCapabilities) + { + var manager = CreateConnectionManager(); + var connection = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context = MakeRequest("/foo", connection); + SetTransport(context, transportType); + + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + + var options = new HttpSocketOptions(); + options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(0); + await dispatcher.ExecuteAsync(context, options, app); + + Assert.Equal(expectedTransportCapabilities, connection.TransportCapabilities); + } + [Fact] public async Task UnauthorizedConnectionFailsToStartEndPoint() { @@ -599,7 +625,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests { o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); }); - services.AddAuthenticationCore(o => + services.AddAuthenticationCore(o => { o.DefaultScheme = "Default"; o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)); @@ -641,7 +667,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests { o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); }); - services.AddAuthenticationCore(o => + services.AddAuthenticationCore(o => { o.DefaultScheme = "Default"; o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)); @@ -690,7 +716,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests }); }); services.AddLogging(); - services.AddAuthenticationCore(o => + services.AddAuthenticationCore(o => { o.DefaultScheme = "Default"; o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)); @@ -747,7 +773,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests }); }); services.AddLogging(); - services.AddAuthenticationCore(o => + services.AddAuthenticationCore(o => { o.DefaultScheme = "Default"; o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)); @@ -822,7 +848,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests }); services.AddAuthorizationPolicyEvaluator(); services.AddLogging(); - services.AddAuthenticationCore(o => + services.AddAuthenticationCore(o => { o.DefaultScheme = "Default"; o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)); @@ -875,7 +901,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests }); services.AddAuthorizationPolicyEvaluator(); services.AddLogging(); - services.AddAuthenticationCore(o => + services.AddAuthenticationCore(o => { o.DefaultScheme = "Default"; o.AddScheme("Default", a => a.HandlerType = typeof(RejectHandler)); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs index 26809f5665..ba696ff46e 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs @@ -29,7 +29,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { - var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory: new LoggerFactory()); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); @@ -61,9 +62,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests } [Theory] - [InlineData(WebSocketMessageType.Text)] - [InlineData(WebSocketMessageType.Binary)] - public async Task DataWrittenToOutputPipelineAreSentAsFrames(WebSocketMessageType webSocketMessageType) + [InlineData(TransferMode.Text, WebSocketMessageType.Text)] + [InlineData(TransferMode.Binary, WebSocketMessageType.Binary)] + public async Task WebSocketTransportSetsMessageTypeBasedOnTransferModeFeature(TransferMode transferMode, WebSocketMessageType expectedMessageType) { var transportToApplication = Channel.CreateUnbounded(); var applicationToTransport = Channel.CreateUnbounded(); @@ -72,7 +73,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { - var ws = new WebSocketsTransport(new WebSocketOptions() { WebSocketMessageType = webSocketMessageType }, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var connectionContext = new DefaultConnectionContext(string.Empty, null, null) { TransferMode = transferMode }; + var ws = new WebSocketsTransport(new WebSocketOptions(), + transportSide, connectionContext, loggerFactory: new LoggerFactory()); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); @@ -91,7 +94,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.Equal(1, clientSummary.Received.Count); Assert.True(clientSummary.Received[0].EndOfMessage); - Assert.Equal(webSocketMessageType, clientSummary.Received[0].MessageType); + Assert.Equal(expectedMessageType, clientSummary.Received[0].MessageType); Assert.Equal("Hello", Encoding.UTF8.GetString(clientSummary.Received[0].Buffer)); } } @@ -115,7 +118,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests applicationSide.Out.TryComplete(); } - var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory: new LoggerFactory()); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); @@ -148,7 +152,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { - var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory: new LoggerFactory()); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); @@ -184,7 +189,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests CloseTimeout = TimeSpan.FromSeconds(1) }; - var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory: new LoggerFactory()); var serverSocket = await feature.AcceptAsync(); // Give the server socket to the transport and run it @@ -212,12 +218,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { - var options = new WebSocketOptions() + var options = new WebSocketOptions { CloseTimeout = TimeSpan.FromSeconds(1) }; - var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory: new LoggerFactory()); var serverSocket = await feature.AcceptAsync(); // Give the server socket to the transport and run it @@ -245,12 +252,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { - var options = new WebSocketOptions() + var options = new WebSocketOptions { // We want to verify behavior without timeout affecting it CloseTimeout = TimeSpan.FromSeconds(20) }; - var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory: new LoggerFactory()); var serverSocket = await feature.AcceptAsync(); // Give the server socket to the transport and run it @@ -282,12 +291,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { - var options = new WebSocketOptions() + var options = new WebSocketOptions { // We want to verify behavior without timeout affecting it CloseTimeout = TimeSpan.FromSeconds(20) }; - var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var connectionContext = new DefaultConnectionContext(string.Empty, null, null); + var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory: new LoggerFactory()); var serverSocket = await feature.AcceptAsync(); // Give the server socket to the transport and run it