Supported transports (#294)
Supported transport spike - Allow turning transports on or off per end point with a flags enum - Added `TransportType` to Sockets.Common - Added tests
This commit is contained in:
parent
9659c73e05
commit
63ce7f6160
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AspNetCore.Hosting;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using SocketsSample.EndPoints;
|
||||
|
|
@ -30,6 +31,7 @@ namespace SocketsSample
|
|||
// .AddRedis();
|
||||
|
||||
services.AddEndPoint<MessagesEndPoint>();
|
||||
|
||||
services.AddSingleton<ProtobufSerializer>();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
[Flags]
|
||||
public enum TransportType
|
||||
{
|
||||
WebSockets = 1,
|
||||
ServerSentEvents = 2,
|
||||
LongPolling = 4,
|
||||
All = WebSockets | ServerSentEvents | LongPolling
|
||||
}
|
||||
}
|
||||
|
|
@ -8,5 +8,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
public class EndPointOptions<TEndPoint> where TEndPoint : EndPoint
|
||||
{
|
||||
public AuthorizationPolicy Policy { get; set; }
|
||||
|
||||
public TransportType Transports { get; set; } = TransportType.All;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ using Microsoft.AspNetCore.Sockets.Internal.Formatters;
|
|||
using Microsoft.AspNetCore.Sockets.Transports;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Microsoft.Extensions.Primitives;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
|
|
@ -36,10 +37,11 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
{
|
||||
// Get the end point mapped to this http connection
|
||||
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
|
||||
var options = context.RequestServices.GetRequiredService<IOptions<EndPointOptions<TEndPoint>>>().Value;
|
||||
|
||||
if (context.Request.Path.StartsWithSegments(path + "/negotiate"))
|
||||
{
|
||||
await ProcessNegotiate(context);
|
||||
await ProcessNegotiate(context, options);
|
||||
}
|
||||
else if (context.Request.Path.StartsWithSegments(path + "/send"))
|
||||
{
|
||||
|
|
@ -47,12 +49,14 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
}
|
||||
else
|
||||
{
|
||||
await ExecuteEndpointAsync(path, context, endpoint);
|
||||
await ExecuteEndpointAsync(path, context, endpoint, options);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task ExecuteEndpointAsync(string path, HttpContext context, EndPoint endpoint)
|
||||
private async Task ExecuteEndpointAsync<TEndPoint>(string path, HttpContext context, EndPoint endpoint, EndPointOptions<TEndPoint> options) where TEndPoint : EndPoint
|
||||
{
|
||||
var supportedTransports = options.Transports;
|
||||
|
||||
// Server sent events transport
|
||||
if (context.Request.Path.StartsWithSegments(path + "/sse"))
|
||||
{
|
||||
|
|
@ -64,7 +68,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
return;
|
||||
}
|
||||
|
||||
if (!await EnsureConnectionStateAsync(state, context, ServerSentEventsTransport.Name))
|
||||
if (!await EnsureConnectionStateAsync(state, context, TransportType.ServerSentEvents, supportedTransports))
|
||||
{
|
||||
// Bad connection state. It's already set the response status code.
|
||||
return;
|
||||
|
|
@ -85,7 +89,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
return;
|
||||
}
|
||||
|
||||
if (!await EnsureConnectionStateAsync(state, context, WebSocketsTransport.Name))
|
||||
if (!await EnsureConnectionStateAsync(state, context, TransportType.WebSockets, supportedTransports))
|
||||
{
|
||||
// Bad connection state. It's already set the response status code.
|
||||
return;
|
||||
|
|
@ -105,7 +109,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
return;
|
||||
}
|
||||
|
||||
if (!await EnsureConnectionStateAsync(state, context, LongPollingTransport.Name))
|
||||
if (!await EnsureConnectionStateAsync(state, context, TransportType.LongPolling, supportedTransports))
|
||||
{
|
||||
// Bad connection state. It's already set the response status code.
|
||||
return;
|
||||
|
|
@ -158,7 +162,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
{
|
||||
_logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId);
|
||||
|
||||
state.Connection.Metadata["transport"] = LongPollingTransport.Name;
|
||||
state.Connection.Metadata["transport"] = TransportType.LongPolling;
|
||||
|
||||
state.ApplicationTask = ExecuteApplication(endpoint, state.Connection);
|
||||
}
|
||||
|
|
@ -297,7 +301,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
await endpoint.OnConnectedAsync(connection);
|
||||
}
|
||||
|
||||
private Task ProcessNegotiate(HttpContext context)
|
||||
private Task ProcessNegotiate<TEndPoint>(HttpContext context, EndPointOptions<TEndPoint> options) where TEndPoint : EndPoint
|
||||
{
|
||||
// Establish the connection
|
||||
var state = CreateConnection(context);
|
||||
|
|
@ -369,16 +373,24 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
}
|
||||
}
|
||||
|
||||
private async Task<bool> EnsureConnectionStateAsync(ConnectionState connectionState, HttpContext context, string transportName)
|
||||
private async Task<bool> EnsureConnectionStateAsync(ConnectionState connectionState, HttpContext context, TransportType transportType, TransportType supportedTransports)
|
||||
{
|
||||
if ((supportedTransports & transportType) == 0)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status404NotFound;
|
||||
await context.Response.WriteAsync($"{transportType} transport not supported by this end point type");
|
||||
return false;
|
||||
}
|
||||
|
||||
connectionState.Connection.User = context.User;
|
||||
|
||||
var transport = connectionState.Connection.Metadata.Get<string>("transport");
|
||||
if (string.IsNullOrEmpty(transport))
|
||||
var transport = connectionState.Connection.Metadata.Get<TransportType?>("transport");
|
||||
|
||||
if (transport == null)
|
||||
{
|
||||
connectionState.Connection.Metadata["transport"] = transportName;
|
||||
connectionState.Connection.Metadata["transport"] = transportType;
|
||||
}
|
||||
else if (!string.Equals(transport, transportName, StringComparison.Ordinal))
|
||||
else if (transport != transportType)
|
||||
{
|
||||
context.Response.StatusCode = StatusCodes.Status400BadRequest;
|
||||
await context.Response.WriteAsync("Cannot change transports mid-connection");
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ namespace Microsoft.AspNetCore.Sockets.Transports
|
|||
{
|
||||
public class LongPollingTransport : IHttpTransport
|
||||
{
|
||||
public static readonly string Name = "longPolling";
|
||||
private readonly ReadableChannel<Message> _application;
|
||||
private readonly ILogger _logger;
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ namespace Microsoft.AspNetCore.Sockets.Transports
|
|||
{
|
||||
public class ServerSentEventsTransport : IHttpTransport
|
||||
{
|
||||
public static readonly string Name = "serverSentEvents";
|
||||
private readonly ReadableChannel<Message> _application;
|
||||
private readonly ILogger _logger;
|
||||
|
||||
|
|
|
|||
|
|
@ -16,8 +16,6 @@ namespace Microsoft.AspNetCore.Sockets.Transports
|
|||
{
|
||||
public class WebSocketsTransport : IHttpTransport
|
||||
{
|
||||
public static readonly string Name = "webSockets";
|
||||
|
||||
private static readonly TimeSpan _closeTimeout = TimeSpan.FromSeconds(5);
|
||||
private static readonly WebSocketAcceptContext _emptyContext = new WebSocketAcceptContext();
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddOptions();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
var ms = new MemoryStream();
|
||||
context.Request.Path = "/negotiate";
|
||||
|
|
@ -62,6 +63,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddOptions();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = path;
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
|
|
@ -90,6 +92,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var context = new DefaultHttpContext();
|
||||
context.Response.Body = strm;
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = path;
|
||||
|
|
@ -102,6 +105,40 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(TransportType.LongPolling, 204)]
|
||||
[InlineData(TransportType.WebSockets, 404)]
|
||||
[InlineData(TransportType.ServerSentEvents, 404)]
|
||||
public async Task EndPointThatOnlySupportsLongPollingRejectsOtherTransports(TransportType transportType, int status)
|
||||
{
|
||||
await CheckTransportSupported(TransportType.LongPolling, transportType, status);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(TransportType.ServerSentEvents, 200)]
|
||||
[InlineData(TransportType.WebSockets, 404)]
|
||||
[InlineData(TransportType.LongPolling, 404)]
|
||||
public async Task EndPointThatOnlySupportsSSERejectsOtherTransports(TransportType transportType, int status)
|
||||
{
|
||||
await CheckTransportSupported(TransportType.ServerSentEvents, transportType, status);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(TransportType.WebSockets, 200)]
|
||||
[InlineData(TransportType.ServerSentEvents, 404)]
|
||||
[InlineData(TransportType.LongPolling, 404)]
|
||||
public async Task EndPointThatOnlySupportsWebSockesRejectsOtherTransports(TransportType transportType, int status)
|
||||
{
|
||||
await CheckTransportSupported(TransportType.WebSockets, transportType, status);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(TransportType.LongPolling, 404)]
|
||||
public async Task EndPointThatOnlySupportsWebSocketsAndSSERejectsLongPolling(TransportType transportType, int status)
|
||||
{
|
||||
await CheckTransportSupported(TransportType.WebSockets | TransportType.ServerSentEvents, transportType, status);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompletedEndPointEndsConnection()
|
||||
{
|
||||
|
|
@ -387,6 +424,56 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
Assert.Equal(MessageType.Close, messages[3].Type);
|
||||
}
|
||||
|
||||
private static async Task CheckTransportSupported(TransportType supportedTransports, TransportType transportType, int status)
|
||||
{
|
||||
var path = "";
|
||||
switch (transportType)
|
||||
{
|
||||
case TransportType.WebSockets:
|
||||
path = "/ws";
|
||||
break;
|
||||
case TransportType.ServerSentEvents:
|
||||
path = "/sse";
|
||||
break;
|
||||
case TransportType.LongPolling:
|
||||
path = "/poll";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
using (var strm = new MemoryStream())
|
||||
{
|
||||
var context = new DefaultHttpContext();
|
||||
context.Response.Body = strm;
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<ImmediatelyCompleteEndPoint>(options =>
|
||||
{
|
||||
options.Transports = supportedTransports;
|
||||
});
|
||||
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = path;
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("", context);
|
||||
Assert.Equal(status, context.Response.StatusCode);
|
||||
await strm.FlushAsync();
|
||||
|
||||
// Check the message for 404
|
||||
if (status == 404)
|
||||
{
|
||||
Assert.Equal($"{transportType} transport not supported by this end point type", Encoding.UTF8.GetString(strm.ToArray()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task<List<Message>> RunSendTest(string contentType, string encoded, string format)
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
|
|
@ -420,6 +507,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<TEndPoint>();
|
||||
services.AddOptions();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = path;
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
|
|
|
|||
Loading…
Reference in New Issue