Supported transports (#294)

Supported transport spike
- Allow turning transports on or off per end point with a flags enum
- Added `TransportType` to Sockets.Common
- Added tests
This commit is contained in:
David Fowler 2017-03-20 12:23:00 -07:00 committed by GitHub
parent 9659c73e05
commit 63ce7f6160
8 changed files with 133 additions and 17 deletions

View File

@ -3,6 +3,7 @@
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using SocketsSample.EndPoints;
@ -30,6 +31,7 @@ namespace SocketsSample
// .AddRedis();
services.AddEndPoint<MessagesEndPoint>();
services.AddSingleton<ProtobufSerializer>();
}

View File

@ -0,0 +1,16 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.Sockets
{
[Flags]
public enum TransportType
{
WebSockets = 1,
ServerSentEvents = 2,
LongPolling = 4,
All = WebSockets | ServerSentEvents | LongPolling
}
}

View File

@ -8,5 +8,7 @@ namespace Microsoft.AspNetCore.Sockets
public class EndPointOptions<TEndPoint> where TEndPoint : EndPoint
{
public AuthorizationPolicy Policy { get; set; }
public TransportType Transports { get; set; } = TransportType.All;
}
}

View File

@ -15,6 +15,7 @@ using Microsoft.AspNetCore.Sockets.Internal.Formatters;
using Microsoft.AspNetCore.Sockets.Transports;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
namespace Microsoft.AspNetCore.Sockets
@ -36,10 +37,11 @@ namespace Microsoft.AspNetCore.Sockets
{
// Get the end point mapped to this http connection
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
var options = context.RequestServices.GetRequiredService<IOptions<EndPointOptions<TEndPoint>>>().Value;
if (context.Request.Path.StartsWithSegments(path + "/negotiate"))
{
await ProcessNegotiate(context);
await ProcessNegotiate(context, options);
}
else if (context.Request.Path.StartsWithSegments(path + "/send"))
{
@ -47,12 +49,14 @@ namespace Microsoft.AspNetCore.Sockets
}
else
{
await ExecuteEndpointAsync(path, context, endpoint);
await ExecuteEndpointAsync(path, context, endpoint, options);
}
}
private async Task ExecuteEndpointAsync(string path, HttpContext context, EndPoint endpoint)
private async Task ExecuteEndpointAsync<TEndPoint>(string path, HttpContext context, EndPoint endpoint, EndPointOptions<TEndPoint> options) where TEndPoint : EndPoint
{
var supportedTransports = options.Transports;
// Server sent events transport
if (context.Request.Path.StartsWithSegments(path + "/sse"))
{
@ -64,7 +68,7 @@ namespace Microsoft.AspNetCore.Sockets
return;
}
if (!await EnsureConnectionStateAsync(state, context, ServerSentEventsTransport.Name))
if (!await EnsureConnectionStateAsync(state, context, TransportType.ServerSentEvents, supportedTransports))
{
// Bad connection state. It's already set the response status code.
return;
@ -85,7 +89,7 @@ namespace Microsoft.AspNetCore.Sockets
return;
}
if (!await EnsureConnectionStateAsync(state, context, WebSocketsTransport.Name))
if (!await EnsureConnectionStateAsync(state, context, TransportType.WebSockets, supportedTransports))
{
// Bad connection state. It's already set the response status code.
return;
@ -105,7 +109,7 @@ namespace Microsoft.AspNetCore.Sockets
return;
}
if (!await EnsureConnectionStateAsync(state, context, LongPollingTransport.Name))
if (!await EnsureConnectionStateAsync(state, context, TransportType.LongPolling, supportedTransports))
{
// Bad connection state. It's already set the response status code.
return;
@ -158,7 +162,7 @@ namespace Microsoft.AspNetCore.Sockets
{
_logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", state.Connection.ConnectionId, state.RequestId);
state.Connection.Metadata["transport"] = LongPollingTransport.Name;
state.Connection.Metadata["transport"] = TransportType.LongPolling;
state.ApplicationTask = ExecuteApplication(endpoint, state.Connection);
}
@ -297,7 +301,7 @@ namespace Microsoft.AspNetCore.Sockets
await endpoint.OnConnectedAsync(connection);
}
private Task ProcessNegotiate(HttpContext context)
private Task ProcessNegotiate<TEndPoint>(HttpContext context, EndPointOptions<TEndPoint> options) where TEndPoint : EndPoint
{
// Establish the connection
var state = CreateConnection(context);
@ -369,16 +373,24 @@ namespace Microsoft.AspNetCore.Sockets
}
}
private async Task<bool> EnsureConnectionStateAsync(ConnectionState connectionState, HttpContext context, string transportName)
private async Task<bool> EnsureConnectionStateAsync(ConnectionState connectionState, HttpContext context, TransportType transportType, TransportType supportedTransports)
{
if ((supportedTransports & transportType) == 0)
{
context.Response.StatusCode = StatusCodes.Status404NotFound;
await context.Response.WriteAsync($"{transportType} transport not supported by this end point type");
return false;
}
connectionState.Connection.User = context.User;
var transport = connectionState.Connection.Metadata.Get<string>("transport");
if (string.IsNullOrEmpty(transport))
var transport = connectionState.Connection.Metadata.Get<TransportType?>("transport");
if (transport == null)
{
connectionState.Connection.Metadata["transport"] = transportName;
connectionState.Connection.Metadata["transport"] = transportType;
}
else if (!string.Equals(transport, transportName, StringComparison.Ordinal))
else if (transport != transportType)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
await context.Response.WriteAsync("Cannot change transports mid-connection");

View File

@ -17,7 +17,6 @@ namespace Microsoft.AspNetCore.Sockets.Transports
{
public class LongPollingTransport : IHttpTransport
{
public static readonly string Name = "longPolling";
private readonly ReadableChannel<Message> _application;
private readonly ILogger _logger;

View File

@ -16,7 +16,6 @@ namespace Microsoft.AspNetCore.Sockets.Transports
{
public class ServerSentEventsTransport : IHttpTransport
{
public static readonly string Name = "serverSentEvents";
private readonly ReadableChannel<Message> _application;
private readonly ILogger _logger;

View File

@ -16,8 +16,6 @@ namespace Microsoft.AspNetCore.Sockets.Transports
{
public class WebSocketsTransport : IHttpTransport
{
public static readonly string Name = "webSockets";
private static readonly TimeSpan _closeTimeout = TimeSpan.FromSeconds(5);
private static readonly WebSocketAcceptContext _emptyContext = new WebSocketAcceptContext();

View File

@ -32,6 +32,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
services.AddOptions();
context.RequestServices = services.BuildServiceProvider();
var ms = new MemoryStream();
context.Request.Path = "/negotiate";
@ -62,6 +63,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
services.AddOptions();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = path;
var values = new Dictionary<string, StringValues>();
@ -90,6 +92,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
context.Response.Body = strm;
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = path;
@ -102,6 +105,40 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
}
[Theory]
[InlineData(TransportType.LongPolling, 204)]
[InlineData(TransportType.WebSockets, 404)]
[InlineData(TransportType.ServerSentEvents, 404)]
public async Task EndPointThatOnlySupportsLongPollingRejectsOtherTransports(TransportType transportType, int status)
{
await CheckTransportSupported(TransportType.LongPolling, transportType, status);
}
[Theory]
[InlineData(TransportType.ServerSentEvents, 200)]
[InlineData(TransportType.WebSockets, 404)]
[InlineData(TransportType.LongPolling, 404)]
public async Task EndPointThatOnlySupportsSSERejectsOtherTransports(TransportType transportType, int status)
{
await CheckTransportSupported(TransportType.ServerSentEvents, transportType, status);
}
[Theory]
[InlineData(TransportType.WebSockets, 200)]
[InlineData(TransportType.ServerSentEvents, 404)]
[InlineData(TransportType.LongPolling, 404)]
public async Task EndPointThatOnlySupportsWebSockesRejectsOtherTransports(TransportType transportType, int status)
{
await CheckTransportSupported(TransportType.WebSockets, transportType, status);
}
[Theory]
[InlineData(TransportType.LongPolling, 404)]
public async Task EndPointThatOnlySupportsWebSocketsAndSSERejectsLongPolling(TransportType transportType, int status)
{
await CheckTransportSupported(TransportType.WebSockets | TransportType.ServerSentEvents, transportType, status);
}
[Fact]
public async Task CompletedEndPointEndsConnection()
{
@ -387,6 +424,56 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal(MessageType.Close, messages[3].Type);
}
private static async Task CheckTransportSupported(TransportType supportedTransports, TransportType transportType, int status)
{
var path = "";
switch (transportType)
{
case TransportType.WebSockets:
path = "/ws";
break;
case TransportType.ServerSentEvents:
path = "/sse";
break;
case TransportType.LongPolling:
path = "/poll";
break;
default:
break;
}
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
using (var strm = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Response.Body = strm;
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<ImmediatelyCompleteEndPoint>(options =>
{
options.Transports = supportedTransports;
});
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = path;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("", context);
Assert.Equal(status, context.Response.StatusCode);
await strm.FlushAsync();
// Check the message for 404
if (status == 404)
{
Assert.Equal($"{transportType} transport not supported by this end point type", Encoding.UTF8.GetString(strm.ToArray()));
}
}
}
private static async Task<List<Message>> RunSendTest(string contentType, string encoded, string format)
{
var manager = CreateConnectionManager();
@ -420,6 +507,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddEndPoint<TEndPoint>();
services.AddOptions();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = path;
var values = new Dictionary<string, StringValues>();