Support for binary protocols on the server side

This commit is contained in:
Pawel Kadluczka 2017-07-13 17:30:24 -07:00
parent 8fc2cd98b6
commit ae815475b8
10 changed files with 94 additions and 47 deletions

View File

@ -5,7 +5,6 @@ using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Features;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
namespace Microsoft.AspNetCore.SignalR

View File

@ -8,5 +8,7 @@ namespace Microsoft.AspNetCore.Sockets.Features
public interface IConnectionTransportFeature
{
Channel<byte[]> Transport { get; set; }
TransferMode TransportCapabilities { get; set; }
}
}

View File

@ -1,10 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Security.Claims;
using System.Text;
namespace Microsoft.AspNetCore.Sockets.Features
{

View File

@ -89,6 +89,9 @@ namespace Microsoft.AspNetCore.Sockets
_logger.EstablishedConnection(connection.ConnectionId, context.TraceIdentifier);
// ServerSentEvents is a text protocol only
connection.TransportCapabilities = TransferMode.Text;
// We only need to provide the Input channel since writing to the application is handled through /send.
var sse = new ServerSentEventsTransport(connection.Application.In, connection.ConnectionId, _loggerFactory);
@ -112,7 +115,7 @@ namespace Microsoft.AspNetCore.Sockets
_logger.EstablishedConnection(connection.ConnectionId, context.TraceIdentifier);
var ws = new WebSocketsTransport(options.WebSockets, connection.Application, connection.ConnectionId, _loggerFactory);
var ws = new WebSocketsTransport(options.WebSockets, connection.Application, connection, _loggerFactory);
await DoPersistentConnection(socketDelegate, ws, context, connection);
}
@ -330,7 +333,7 @@ namespace Microsoft.AspNetCore.Sockets
// Establish the connection
var connection = _manager.CreateConnection();
// Set the Connection ID on the logging scope so that logs from now on will have the
// Connection ID metadata set.
logScope.ConnectionId = connection.ConnectionId;
@ -433,6 +436,9 @@ namespace Microsoft.AspNetCore.Sockets
connection.User = context.User;
connection.SetHttpContext(context);
// this is the default setting which should be overwritten by transports that have different capabilities (e.g. SSE)
connection.TransportCapabilities = TransferMode.Binary | TransferMode.Text;
// Set the Connection ID on the logging scope so that logs from now on will have the
// Connection ID metadata set.
logScope.ConnectionId = connection.ConnectionId;

View File

@ -18,9 +18,9 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
private readonly WebSocketOptions _options;
private readonly ILogger _logger;
private readonly Channel<byte[]> _application;
private readonly string _connectionId;
private readonly DefaultConnectionContext _connection;
public WebSocketsTransport(WebSocketOptions options, Channel<byte[]> application, string connectionId, ILoggerFactory loggerFactory)
public WebSocketsTransport(WebSocketOptions options, Channel<byte[]> application, DefaultConnectionContext connection, ILoggerFactory loggerFactory)
{
if (options == null)
{
@ -39,7 +39,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
_options = options;
_application = application;
_connectionId = connectionId;
_connection = connection;
_logger = loggerFactory.CreateLogger<WebSocketsTransport>();
}
@ -49,11 +49,11 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
using (var ws = await context.WebSockets.AcceptWebSocketAsync())
{
_logger.SocketOpened(_connectionId);
_logger.SocketOpened(_connection.ConnectionId);
await ProcessSocketAsync(ws);
}
_logger.SocketClosed(_connectionId);
_logger.SocketClosed(_connection.ConnectionId);
}
public async Task ProcessSocketAsync(WebSocket socket)
@ -72,12 +72,12 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
if (trigger == receiving)
{
task = sending;
_logger.WaitingForSend(_connectionId);
_logger.WaitingForSend(_connection.ConnectionId);
}
else
{
task = receiving;
_logger.WaitingForClose(_connectionId);
_logger.WaitingForClose(_connection.ConnectionId);
}
// We're done writing
@ -89,7 +89,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
if (resultTask != task)
{
_logger.CloseTimedOut(_connectionId);
_logger.CloseTimedOut(_connection.ConnectionId);
socket.Abort();
}
else
@ -123,7 +123,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
return receiveResult;
}
_logger.MessageReceived(_connectionId, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage);
_logger.MessageReceived(_connection.ConnectionId, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage);
var truncBuffer = new ArraySegment<byte>(buffer.Array, 0, receiveResult.Count);
incomingMessage.Add(truncBuffer);
@ -153,7 +153,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
Buffer.BlockCopy(incomingMessage[0].Array, incomingMessage[0].Offset, messageBuffer, 0, incomingMessage[0].Count);
}
_logger.MessageToApplication(_connectionId, messageBuffer.Length);
_logger.MessageToApplication(_connection.ConnectionId, messageBuffer.Length);
while (await _application.Out.WaitToWriteAsync())
{
if (_application.Out.TryWrite(messageBuffer))
@ -176,22 +176,26 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
{
try
{
_logger.SendPayload(_connectionId, buffer.Length);
_logger.SendPayload(_connection.ConnectionId, buffer.Length);
var webSocketMessageType = (_connection.TransferMode == TransferMode.Binary
? WebSocketMessageType.Binary
: WebSocketMessageType.Text);
if (WebSocketCanSend(ws))
{
await ws.SendAsync(new ArraySegment<byte>(buffer), _options.WebSocketMessageType, endOfMessage: true, cancellationToken: CancellationToken.None);
await ws.SendAsync(new ArraySegment<byte>(buffer), webSocketMessageType, endOfMessage: true, cancellationToken: CancellationToken.None);
}
}
catch (WebSocketException socketException) when (!WebSocketCanSend(ws))
{
// this can happen when we send the CloseFrame to the client and try to write afterwards
_logger.SendFailed(_connectionId, socketException);
_logger.SendFailed(_connection.ConnectionId, socketException);
break;
}
catch (Exception ex)
{
_logger.ErrorWritingFrame(_connectionId, ex);
_logger.ErrorWritingFrame(_connection.ConnectionId, ex);
break;
}
}

View File

@ -2,14 +2,11 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Net.WebSockets;
namespace Microsoft.AspNetCore.Sockets
{
public class WebSocketOptions
{
public TimeSpan CloseTimeout { get; set; } = TimeSpan.FromSeconds(5);
public WebSocketMessageType WebSocketMessageType { get; set; } = WebSocketMessageType.Text;
}
}

View File

@ -16,7 +16,8 @@ namespace Microsoft.AspNetCore.Sockets
IConnectionIdFeature,
IConnectionMetadataFeature,
IConnectionTransportFeature,
IConnectionUserFeature
IConnectionUserFeature,
ITransferModeFeature
{
// This tcs exists so that multiple calls to DisposeAsync all wait asynchronously
// on the same task
@ -35,6 +36,7 @@ namespace Microsoft.AspNetCore.Sockets
Features.Set<IConnectionMetadataFeature>(this);
Features.Set<IConnectionIdFeature>(this);
Features.Set<IConnectionTransportFeature>(this);
Features.Set<ITransferModeFeature>(this);
}
public CancellationTokenSource Cancellation { get; set; }
@ -61,6 +63,10 @@ namespace Microsoft.AspNetCore.Sockets
public override Channel<byte[]> Transport { get; set; }
public TransferMode TransportCapabilities { get; set; }
public TransferMode TransferMode { get; set; }
public async Task DisposeAsync()
{
Task disposeTask = Task.CompletedTask;

View File

@ -14,7 +14,7 @@
<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Client\Microsoft.AspNetCore.SignalR.Client.csproj" />
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.Sockets.Client.Http\Microsoft.AspNetCore.Sockets.Client.Http.csproj" />
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.Sockets.Client.Http\Microsoft.AspNetCore.Sockets.Client.Http.csproj" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Testing" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(TestSdkVersion)" />

View File

@ -118,7 +118,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
}
[Fact]
public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvidedForPost()
{
@ -584,6 +583,33 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal("Hello, World", GetContentAsString(context2.Response.Body));
}
[Theory]
[InlineData(TransportType.LongPolling, TransferMode.Binary | TransferMode.Text)]
[InlineData(TransportType.ServerSentEvents, TransferMode.Text)]
[InlineData(TransportType.WebSockets, TransferMode.Binary | TransferMode.Text)]
public async Task TransportCapabilitiesSet(TransportType transportType, TransferMode expectedTransportCapabilities)
{
var manager = CreateConnectionManager();
var connection = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest("/foo", connection);
SetTransport(context, transportType);
var services = new ServiceCollection();
services.AddEndPoint<ImmediatelyCompleteEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<ImmediatelyCompleteEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(0);
await dispatcher.ExecuteAsync(context, options, app);
Assert.Equal(expectedTransportCapabilities, connection.TransportCapabilities);
}
[Fact]
public async Task UnauthorizedConnectionFailsToStartEndPoint()
{
@ -599,7 +625,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
});
services.AddAuthenticationCore(o =>
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
@ -641,7 +667,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
});
services.AddAuthenticationCore(o =>
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
@ -690,7 +716,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
});
});
services.AddLogging();
services.AddAuthenticationCore(o =>
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
@ -747,7 +773,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
});
});
services.AddLogging();
services.AddAuthenticationCore(o =>
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
@ -822,7 +848,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
});
services.AddAuthorizationPolicyEvaluator();
services.AddLogging();
services.AddAuthenticationCore(o =>
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
@ -875,7 +901,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
});
services.AddAuthorizationPolicyEvaluator();
services.AddLogging();
services.AddAuthenticationCore(o =>
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(RejectHandler));

View File

@ -29,7 +29,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory: new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
@ -61,9 +62,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
[Theory]
[InlineData(WebSocketMessageType.Text)]
[InlineData(WebSocketMessageType.Binary)]
public async Task DataWrittenToOutputPipelineAreSentAsFrames(WebSocketMessageType webSocketMessageType)
[InlineData(TransferMode.Text, WebSocketMessageType.Text)]
[InlineData(TransferMode.Binary, WebSocketMessageType.Binary)]
public async Task WebSocketTransportSetsMessageTypeBasedOnTransferModeFeature(TransferMode transferMode, WebSocketMessageType expectedMessageType)
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
@ -72,7 +73,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions() { WebSocketMessageType = webSocketMessageType }, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var connectionContext = new DefaultConnectionContext(string.Empty, null, null) { TransferMode = transferMode };
var ws = new WebSocketsTransport(new WebSocketOptions(),
transportSide, connectionContext, loggerFactory: new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
@ -91,7 +94,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal(1, clientSummary.Received.Count);
Assert.True(clientSummary.Received[0].EndOfMessage);
Assert.Equal(webSocketMessageType, clientSummary.Received[0].MessageType);
Assert.Equal(expectedMessageType, clientSummary.Received[0].MessageType);
Assert.Equal("Hello", Encoding.UTF8.GetString(clientSummary.Received[0].Buffer));
}
}
@ -115,7 +118,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
applicationSide.Out.TryComplete();
}
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory: new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
@ -148,7 +152,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory: new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
@ -184,7 +189,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
CloseTimeout = TimeSpan.FromSeconds(1)
};
var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory: new LoggerFactory());
var serverSocket = await feature.AcceptAsync();
// Give the server socket to the transport and run it
@ -212,12 +218,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var options = new WebSocketOptions()
var options = new WebSocketOptions
{
CloseTimeout = TimeSpan.FromSeconds(1)
};
var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory: new LoggerFactory());
var serverSocket = await feature.AcceptAsync();
// Give the server socket to the transport and run it
@ -245,12 +252,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var options = new WebSocketOptions()
var options = new WebSocketOptions
{
// We want to verify behavior without timeout affecting it
CloseTimeout = TimeSpan.FromSeconds(20)
};
var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory: new LoggerFactory());
var serverSocket = await feature.AcceptAsync();
// Give the server socket to the transport and run it
@ -282,12 +291,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var options = new WebSocketOptions()
var options = new WebSocketOptions
{
// We want to verify behavior without timeout affecting it
CloseTimeout = TimeSpan.FromSeconds(20)
};
var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory: new LoggerFactory());
var serverSocket = await feature.AcceptAsync();
// Give the server socket to the transport and run it