Always close websocket on server (#567)

* refactor server websockets transport
This commit is contained in:
BrennanConroy 2017-06-20 16:00:32 -07:00 committed by GitHub
parent d169b96d2d
commit 0dd29b3050
5 changed files with 200 additions and 144 deletions

View File

@ -131,11 +131,12 @@ namespace Microsoft.AspNetCore.Sockets.Client
}
catch (OperationCanceledException)
{
_logger.LogDebug("Receive loop canceled");
}
finally
{
_transportCts.Cancel();
_logger.LogInformation("Receive loop stopped");
_transportCts.Cancel();
}
}
@ -176,11 +177,12 @@ namespace Microsoft.AspNetCore.Sockets.Client
}
catch (OperationCanceledException)
{
_logger.LogDebug("Send loop canceled");
}
finally
{
_transportCts.Cancel();
_logger.LogInformation("Send loop stopped");
_transportCts.Cancel();
}
}
@ -227,7 +229,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
if (_webSocket.State != WebSocketState.Closed)
{
_logger.LogInformation("Closing webSocket");
await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, CancellationToken.None);
await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None);
}
}
catch (Exception ex)

View File

@ -66,69 +66,44 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
receiving,
sending);
// What happened?
var failed = trigger.IsCanceled || trigger.IsFaulted;
var task = Task.CompletedTask;
if (trigger == receiving)
{
if (receiving.IsCanceled || receiving.IsFaulted)
{
// The receiver faulted or cancelled. This means the socket is probably broken. Abort the socket and propagate the exception
receiving.GetAwaiter().GetResult();
// Should never get here because GetResult above will throw
Debug.Fail("GetResult didn't throw?");
return;
}
// Shutting down because we received a close frame from the client.
// Complete the input writer so that the application knows there won't be any more input.
_logger.ClientClosed(_connectionId, receiving.Result.CloseStatus, receiving.Result.CloseStatusDescription);
_application.Output.TryComplete();
// Wait for the application to finish sending.
task = sending;
_logger.WaitingForSend(_connectionId);
await sending;
// Send the server's close frame
await socket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
}
else
{
var failed = sending.IsFaulted || _application.Input.Completion.IsFaulted;
// The application finished sending. Close our end of the connection
if (failed)
{
_logger.FailedSending(_connectionId);
}
else
{
_logger.FinishedSending(_connectionId);
}
await socket.CloseOutputAsync(!failed ? WebSocketCloseStatus.NormalClosure : WebSocketCloseStatus.InternalServerError, "", CancellationToken.None);
// Now trigger the exception from the application, if there was one.
sending.GetAwaiter().GetResult();
task = receiving;
_logger.WaitingForClose(_connectionId);
// Wait for the client to close or wait for the close timeout
var resultTask = await Task.WhenAny(receiving, Task.Delay(_options.CloseTimeout));
// We timed out waiting for the transport to close so abort the connection so we don't attempt to write anything else
if (resultTask != receiving)
{
_logger.CloseTimedOut(_connectionId);
socket.Abort();
}
// We're done writing
_application.Output.TryComplete();
}
// We're done writing
_application.Output.TryComplete();
await socket.CloseOutputAsync(failed ? WebSocketCloseStatus.InternalServerError : WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
var resultTask = await Task.WhenAny(task, Task.Delay(_options.CloseTimeout));
if (resultTask != task)
{
_logger.CloseTimedOut(_connectionId);
socket.Abort();
}
else
{
// Observe any exceptions from second completed task
task.GetAwaiter().GetResult();
}
// Observe any exceptions from original completed task
trigger.GetAwaiter().GetResult();
}
private async Task<WebSocketReceiveResult> StartReceiving(WebSocket socket)
{
// REVIEW: This code was copied from the client, it's highly unoptimized at the moment (especially
// REVIEW: This code was copied from the client, it's highly unoptimized at the moment (especially
// for server logic)
var incomingMessage = new List<ArraySegment<byte>>();
while (true)

View File

@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Sockets
var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport);
var connection = new DefaultConnectionContext(id, applicationSide, transportSide);
_connections.TryAdd(id, connection);
return connection;
}

View File

@ -8,16 +8,18 @@ using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.AspNetCore.Testing.xunit;
using Microsoft.Extensions.Logging.Testing;
using Xunit;
using Xunit.Abstractions;
namespace Microsoft.AspNetCore.SignalR.Tests
{
[Collection(EndToEndTestsCollection.Name)]
public class WebSocketsTransportTests
public class WebSocketsTransportTests : LoggedTest
{
private readonly ServerFixture _serverFixture;
public WebSocketsTransportTests(ServerFixture serverFixture)
public WebSocketsTransportTests(ServerFixture serverFixture, ITestOutputHelper output) : base(output)
{
if (serverFixture == null)
{
@ -27,53 +29,62 @@ namespace Microsoft.AspNetCore.SignalR.Tests
_serverFixture = serverFixture;
}
[ConditionalFact(Skip = "WebsocketClient.CloseAsync never returns - investigating")]
[ConditionalFact]
[OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")]
public async Task WebSocketsTransportStopsSendAndReceiveLoopsWhenTransportIsStopped()
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
using (StartLog(out var loggerFactory))
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var webSocketsTransport = new WebSocketsTransport();
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection);
await webSocketsTransport.StopAsync();
await webSocketsTransport.Running.OrTimeout();
var webSocketsTransport = new WebSocketsTransport(loggerFactory);
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection).OrTimeout();
await webSocketsTransport.StopAsync().OrTimeout();
await webSocketsTransport.Running.OrTimeout();
}
}
[ConditionalFact]
[OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")]
public async Task WebSocketsTransportStopsWhenConnectionChannelClosed()
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
using (StartLog(out var loggerFactory))
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var webSocketsTransport = new WebSocketsTransport();
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection);
connectionToTransport.Out.TryComplete();
await webSocketsTransport.Running.OrTimeout();
var webSocketsTransport = new WebSocketsTransport(loggerFactory);
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection);
connectionToTransport.Out.TryComplete();
await webSocketsTransport.Running.OrTimeout();
}
}
[ConditionalFact]
[OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")]
public async Task WebSocketsTransportStopsWhenConnectionClosedByTheServer()
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
using (StartLog(out var loggerFactory))
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<byte[]>();
var channelConnection = new ChannelConnection<SendMessage, byte[]>(connectionToTransport, transportToConnection);
var webSocketsTransport = new WebSocketsTransport();
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection);
var webSocketsTransport = new WebSocketsTransport(loggerFactory);
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection);
var sendTcs = new TaskCompletionSource<object>();
connectionToTransport.Out.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs));
await sendTcs.Task;
// The echo endpoint close the connection immediately after sending response which should stop the transport
await webSocketsTransport.Running.OrTimeout();
var sendTcs = new TaskCompletionSource<object>();
connectionToTransport.Out.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs));
await sendTcs.Task;
// The echo endpoint close the connection immediately after sending response which should stop the transport
await webSocketsTransport.Running.OrTimeout();
Assert.True(transportToConnection.In.TryRead(out var buffer));
Assert.Equal(new byte[] { 0x42 }, buffer);
Assert.True(transportToConnection.In.TryRead(out var buffer));
Assert.Equal(new byte[] { 0x42 }, buffer);
}
}
}
}

View File

@ -25,9 +25,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport);
using (var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
@ -69,9 +68,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport);
using (var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions() { WebSocketMessageType = webSocketMessageType }, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
@ -98,60 +96,22 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
}
[Theory]
[InlineData(WebSocketMessageType.Text)]
[InlineData(WebSocketMessageType.Binary)]
public async Task FrameReceivedAfterServerCloseSent(WebSocketMessageType webSocketMessageType)
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport);
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions() { WebSocketMessageType = webSocketMessageType }, transportSide,
connectionId: string.Empty, loggerFactory: new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
// Run the client socket
var client = feature.Client.ExecuteAndCaptureFramesAsync();
// Close the output and wait for the close frame
Assert.True(applicationSide.Output.Out.TryComplete());
await client;
// Send another frame. Then close
await feature.Client.SendAsync(
buffer: new ArraySegment<byte>(Encoding.UTF8.GetBytes("Hello")),
endOfMessage: true,
messageType: webSocketMessageType,
cancellationToken: CancellationToken.None);
await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
// Read that frame from the input
var buffer = await applicationSide.Input.In.ReadAsync();
Assert.Equal("Hello", Encoding.UTF8.GetString(buffer));
await transport;
}
}
[Fact]
public async Task TransportFailsWhenClientDisconnectsAbnormally()
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport);
using (var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var options = new WebSocketOptions()
{
CloseTimeout = TimeSpan.FromSeconds(1)
};
var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
@ -163,7 +123,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
feature.Client.Abort();
// Wait for the transport
await Assert.ThrowsAsync<OperationCanceledException>(() => transport);
await Assert.ThrowsAsync<OperationCanceledException>(() => transport).OrTimeout();
var summary = await client.OrTimeout();
Assert.Equal(WebSocketCloseStatus.InternalServerError, summary.CloseResult.CloseStatus);
}
}
@ -173,9 +136,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport);
using (var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
@ -188,7 +150,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
// Fail in the app
Assert.True(applicationSide.Output.Out.TryComplete(new InvalidOperationException("Catastrophic failure.")));
var clientSummary = await client;
var clientSummary = await client.OrTimeout();
Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.CloseStatus);
// Close from the client
@ -205,9 +167,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport);
using (var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var options = new WebSocketOptions()
@ -232,5 +193,112 @@ namespace Microsoft.AspNetCore.Sockets.Tests
serverSocket.Dispose();
}
}
[Fact]
public async Task TransportFailsOnTimeoutWithErrorWhenApplicationFailsAndClientDoesNotSendCloseFrame()
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
using (var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var options = new WebSocketOptions()
{
CloseTimeout = TimeSpan.FromSeconds(1)
};
var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var serverSocket = await feature.AcceptAsync();
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(serverSocket);
// Run the client socket
var client = feature.Client.ExecuteAndCaptureFramesAsync();
// fail the client to server channel
applicationToTransport.Out.TryComplete(new Exception());
await Assert.ThrowsAsync<Exception>(() => transport).OrTimeout();
Assert.Equal(WebSocketState.Aborted, serverSocket.State);
}
}
[Fact]
public async Task ServerGracefullyClosesWhenApplicationEndsThenClientSendsCloseFrame()
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
using (var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var options = new WebSocketOptions()
{
// We want to verify behavior without timeout affecting it
CloseTimeout = TimeSpan.FromSeconds(20)
};
var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var serverSocket = await feature.AcceptAsync();
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(serverSocket);
// Run the client socket
var client = feature.Client.ExecuteAndCaptureFramesAsync();
// close the client to server channel
applicationToTransport.Out.TryComplete();
_ = await client.OrTimeout();
await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).OrTimeout();
await transport.OrTimeout();
Assert.Equal(WebSocketCloseStatus.NormalClosure, serverSocket.CloseStatus);
}
}
[Fact]
public async Task ServerGracefullyClosesWhenClientSendsCloseFrameThenApplicationEnds()
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
using (var transportSide = new ChannelConnection<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = new ChannelConnection<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var options = new WebSocketOptions()
{
// We want to verify behavior without timeout affecting it
CloseTimeout = TimeSpan.FromSeconds(20)
};
var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var serverSocket = await feature.AcceptAsync();
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(serverSocket);
// Run the client socket
var client = feature.Client.ExecuteAndCaptureFramesAsync();
await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).OrTimeout();
// close the client to server channel
applicationToTransport.Out.TryComplete();
_ = await client.OrTimeout();
await transport.OrTimeout();
Assert.Equal(WebSocketCloseStatus.NormalClosure, serverSocket.CloseStatus);
}
}
}
}