diff --git a/samples/SocketsSample/Startup.cs b/samples/SocketsSample/Startup.cs index 8413ab5d12..91065bd568 100644 --- a/samples/SocketsSample/Startup.cs +++ b/samples/SocketsSample/Startup.cs @@ -3,6 +3,7 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using SocketsSample.EndPoints; @@ -30,6 +31,7 @@ namespace SocketsSample // .AddRedis(); services.AddEndPoint(); + services.AddSingleton(); } diff --git a/src/Microsoft.AspNetCore.Sockets.Common/TransportType.cs b/src/Microsoft.AspNetCore.Sockets.Common/TransportType.cs new file mode 100644 index 0000000000..7b95904de0 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common/TransportType.cs @@ -0,0 +1,16 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Sockets +{ + [Flags] + public enum TransportType + { + WebSockets = 1, + ServerSentEvents = 2, + LongPolling = 4, + All = WebSockets | ServerSentEvents | LongPolling + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs b/src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs index 9183f744cb..6b1a4091a1 100644 --- a/src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs +++ b/src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs @@ -8,5 +8,7 @@ namespace Microsoft.AspNetCore.Sockets public class EndPointOptions where TEndPoint : EndPoint { public AuthorizationPolicy Policy { get; set; } + + public TransportType Transports { get; set; } = TransportType.All; } } diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index 8e4806f1a0..1fa6bf0dbf 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -15,6 +15,7 @@ using Microsoft.AspNetCore.Sockets.Internal.Formatters; using Microsoft.AspNetCore.Sockets.Transports; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Sockets @@ -36,10 +37,11 @@ namespace Microsoft.AspNetCore.Sockets { // Get the end point mapped to this http connection var endpoint = (EndPoint)context.RequestServices.GetRequiredService(); + var options = context.RequestServices.GetRequiredService>>().Value; if (context.Request.Path.StartsWithSegments(path + "/negotiate")) { - await ProcessNegotiate(context); + await ProcessNegotiate(context, options); } else if (context.Request.Path.StartsWithSegments(path + "/send")) { @@ -47,12 +49,14 @@ namespace Microsoft.AspNetCore.Sockets } else { - await ExecuteEndpointAsync(path, context, endpoint); + await ExecuteEndpointAsync(path, context, endpoint, options); } } - private async Task ExecuteEndpointAsync(string path, HttpContext context, EndPoint endpoint) + private async Task ExecuteEndpointAsync(string path, HttpContext context, EndPoint endpoint, EndPointOptions options) where TEndPoint : EndPoint { + var supportedTransports = options.Transports; + // Server sent events transport if (context.Request.Path.StartsWithSegments(path + "/sse")) { @@ -64,7 +68,7 @@ namespace Microsoft.AspNetCore.Sockets return; } - if (!await EnsureConnectionStateAsync(state, context, ServerSentEventsTransport.Name)) + if (!await EnsureConnectionStateAsync(state, context, TransportType.ServerSentEvents, supportedTransports)) { // Bad connection state. It's already set the response status code. return; @@ -85,7 +89,7 @@ namespace Microsoft.AspNetCore.Sockets return; } - if (!await EnsureConnectionStateAsync(state, context, WebSocketsTransport.Name)) + if (!await EnsureConnectionStateAsync(state, context, TransportType.WebSockets, supportedTransports)) { // Bad connection state. It's already set the response status code. return; @@ -105,7 +109,7 @@ namespace Microsoft.AspNetCore.Sockets return; } - if (!await EnsureConnectionStateAsync(state, context, LongPollingTransport.Name)) + if (!await EnsureConnectionStateAsync(state, context, TransportType.LongPolling, supportedTransports)) { // Bad connection state. It's already set the response status code. return; @@ -158,7 +162,7 @@ namespace Microsoft.AspNetCore.Sockets { _logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId); - state.Connection.Metadata["transport"] = LongPollingTransport.Name; + state.Connection.Metadata["transport"] = TransportType.LongPolling; state.ApplicationTask = ExecuteApplication(endpoint, state.Connection); } @@ -297,7 +301,7 @@ namespace Microsoft.AspNetCore.Sockets await endpoint.OnConnectedAsync(connection); } - private Task ProcessNegotiate(HttpContext context) + private Task ProcessNegotiate(HttpContext context, EndPointOptions options) where TEndPoint : EndPoint { // Establish the connection var state = CreateConnection(context); @@ -369,16 +373,24 @@ namespace Microsoft.AspNetCore.Sockets } } - private async Task EnsureConnectionStateAsync(ConnectionState connectionState, HttpContext context, string transportName) + private async Task EnsureConnectionStateAsync(ConnectionState connectionState, HttpContext context, TransportType transportType, TransportType supportedTransports) { + if ((supportedTransports & transportType) == 0) + { + context.Response.StatusCode = StatusCodes.Status404NotFound; + await context.Response.WriteAsync($"{transportType} transport not supported by this end point type"); + return false; + } + connectionState.Connection.User = context.User; - var transport = connectionState.Connection.Metadata.Get("transport"); - if (string.IsNullOrEmpty(transport)) + var transport = connectionState.Connection.Metadata.Get("transport"); + + if (transport == null) { - connectionState.Connection.Metadata["transport"] = transportName; + connectionState.Connection.Metadata["transport"] = transportType; } - else if (!string.Equals(transport, transportName, StringComparison.Ordinal)) + else if (transport != transportType) { context.Response.StatusCode = StatusCodes.Status400BadRequest; await context.Response.WriteAsync("Cannot change transports mid-connection"); diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs index be13f16415..c3390319ae 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs @@ -17,7 +17,6 @@ namespace Microsoft.AspNetCore.Sockets.Transports { public class LongPollingTransport : IHttpTransport { - public static readonly string Name = "longPolling"; private readonly ReadableChannel _application; private readonly ILogger _logger; diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs index 8c7abff349..2b4e6cbf32 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs @@ -16,7 +16,6 @@ namespace Microsoft.AspNetCore.Sockets.Transports { public class ServerSentEventsTransport : IHttpTransport { - public static readonly string Name = "serverSentEvents"; private readonly ReadableChannel _application; private readonly ILogger _logger; diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs index 4ec6667d99..ad409fada1 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs @@ -16,8 +16,6 @@ namespace Microsoft.AspNetCore.Sockets.Transports { public class WebSocketsTransport : IHttpTransport { - public static readonly string Name = "webSockets"; - private static readonly TimeSpan _closeTimeout = TimeSpan.FromSeconds(5); private static readonly WebSocketAcceptContext _emptyContext = new WebSocketAcceptContext(); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 7d69685d4c..7bb7d4ebbe 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -32,6 +32,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var services = new ServiceCollection(); services.AddEndPoint(); + services.AddOptions(); context.RequestServices = services.BuildServiceProvider(); var ms = new MemoryStream(); context.Request.Path = "/negotiate"; @@ -62,6 +63,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddEndPoint(); + services.AddOptions(); context.RequestServices = services.BuildServiceProvider(); context.Request.Path = path; var values = new Dictionary(); @@ -90,6 +92,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); context.Response.Body = strm; var services = new ServiceCollection(); + services.AddOptions(); services.AddEndPoint(); context.RequestServices = services.BuildServiceProvider(); context.Request.Path = path; @@ -102,6 +105,40 @@ namespace Microsoft.AspNetCore.Sockets.Tests } } + [Theory] + [InlineData(TransportType.LongPolling, 204)] + [InlineData(TransportType.WebSockets, 404)] + [InlineData(TransportType.ServerSentEvents, 404)] + public async Task EndPointThatOnlySupportsLongPollingRejectsOtherTransports(TransportType transportType, int status) + { + await CheckTransportSupported(TransportType.LongPolling, transportType, status); + } + + [Theory] + [InlineData(TransportType.ServerSentEvents, 200)] + [InlineData(TransportType.WebSockets, 404)] + [InlineData(TransportType.LongPolling, 404)] + public async Task EndPointThatOnlySupportsSSERejectsOtherTransports(TransportType transportType, int status) + { + await CheckTransportSupported(TransportType.ServerSentEvents, transportType, status); + } + + [Theory] + [InlineData(TransportType.WebSockets, 200)] + [InlineData(TransportType.ServerSentEvents, 404)] + [InlineData(TransportType.LongPolling, 404)] + public async Task EndPointThatOnlySupportsWebSockesRejectsOtherTransports(TransportType transportType, int status) + { + await CheckTransportSupported(TransportType.WebSockets, transportType, status); + } + + [Theory] + [InlineData(TransportType.LongPolling, 404)] + public async Task EndPointThatOnlySupportsWebSocketsAndSSERejectsLongPolling(TransportType transportType, int status) + { + await CheckTransportSupported(TransportType.WebSockets | TransportType.ServerSentEvents, transportType, status); + } + [Fact] public async Task CompletedEndPointEndsConnection() { @@ -387,6 +424,56 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.Equal(MessageType.Close, messages[3].Type); } + private static async Task CheckTransportSupported(TransportType supportedTransports, TransportType transportType, int status) + { + var path = ""; + switch (transportType) + { + case TransportType.WebSockets: + path = "/ws"; + break; + case TransportType.ServerSentEvents: + path = "/sse"; + break; + case TransportType.LongPolling: + path = "/poll"; + break; + default: + break; + } + + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + using (var strm = new MemoryStream()) + { + var context = new DefaultHttpContext(); + context.Response.Body = strm; + var services = new ServiceCollection(); + services.AddOptions(); + services.AddEndPoint(options => + { + options.Transports = supportedTransports; + }); + + context.RequestServices = services.BuildServiceProvider(); + context.Request.Path = path; + var values = new Dictionary(); + values["id"] = state.Connection.ConnectionId; + var qs = new QueryCollection(values); + context.Request.Query = qs; + await dispatcher.ExecuteAsync("", context); + Assert.Equal(status, context.Response.StatusCode); + await strm.FlushAsync(); + + // Check the message for 404 + if (status == 404) + { + Assert.Equal($"{transportType} transport not supported by this end point type", Encoding.UTF8.GetString(strm.ToArray())); + } + } + } + private static async Task> RunSendTest(string contentType, string encoded, string format) { var manager = CreateConnectionManager(); @@ -420,6 +507,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var services = new ServiceCollection(); services.AddEndPoint(); + services.AddOptions(); context.RequestServices = services.BuildServiceProvider(); context.Request.Path = path; var values = new Dictionary();