From 0dd29b305002604db19b4a5607f2e622b82d06b3 Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Tue, 20 Jun 2017 16:00:32 -0700 Subject: [PATCH] Always close websocket on server (#567) * refactor server websockets transport --- .../WebSocketsTransport.cs | 8 +- .../Transports/WebSocketsTransport.cs | 77 +++---- .../ConnectionManager.cs | 2 +- .../WebSocketsTransportTests.cs | 69 ++++--- .../WebSocketsTests.cs | 188 ++++++++++++------ 5 files changed, 200 insertions(+), 144 deletions(-) diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs index 64f9ae376d..8b7030213a 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs @@ -131,11 +131,12 @@ namespace Microsoft.AspNetCore.Sockets.Client } catch (OperationCanceledException) { + _logger.LogDebug("Receive loop canceled"); } finally { - _transportCts.Cancel(); _logger.LogInformation("Receive loop stopped"); + _transportCts.Cancel(); } } @@ -176,11 +177,12 @@ namespace Microsoft.AspNetCore.Sockets.Client } catch (OperationCanceledException) { + _logger.LogDebug("Send loop canceled"); } finally { - _transportCts.Cancel(); _logger.LogInformation("Send loop stopped"); + _transportCts.Cancel(); } } @@ -227,7 +229,7 @@ namespace Microsoft.AspNetCore.Sockets.Client if (_webSocket.State != WebSocketState.Closed) { _logger.LogInformation("Closing webSocket"); - await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, CancellationToken.None); + await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None); } } catch (Exception ex) diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs index 45d8ac372e..9f78d2209c 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs @@ -66,69 +66,44 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports receiving, sending); - // What happened? + var failed = trigger.IsCanceled || trigger.IsFaulted; + var task = Task.CompletedTask; if (trigger == receiving) { - if (receiving.IsCanceled || receiving.IsFaulted) - { - // The receiver faulted or cancelled. This means the socket is probably broken. Abort the socket and propagate the exception - receiving.GetAwaiter().GetResult(); - - // Should never get here because GetResult above will throw - Debug.Fail("GetResult didn't throw?"); - return; - } - - // Shutting down because we received a close frame from the client. - // Complete the input writer so that the application knows there won't be any more input. - _logger.ClientClosed(_connectionId, receiving.Result.CloseStatus, receiving.Result.CloseStatusDescription); - _application.Output.TryComplete(); - - // Wait for the application to finish sending. + task = sending; _logger.WaitingForSend(_connectionId); - await sending; - - // Send the server's close frame - await socket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); } else { - var failed = sending.IsFaulted || _application.Input.Completion.IsFaulted; - - // The application finished sending. Close our end of the connection - if (failed) - { - _logger.FailedSending(_connectionId); - } - else - { - _logger.FinishedSending(_connectionId); - } - await socket.CloseOutputAsync(!failed ? WebSocketCloseStatus.NormalClosure : WebSocketCloseStatus.InternalServerError, "", CancellationToken.None); - - // Now trigger the exception from the application, if there was one. - sending.GetAwaiter().GetResult(); - + task = receiving; _logger.WaitingForClose(_connectionId); - - // Wait for the client to close or wait for the close timeout - var resultTask = await Task.WhenAny(receiving, Task.Delay(_options.CloseTimeout)); - - // We timed out waiting for the transport to close so abort the connection so we don't attempt to write anything else - if (resultTask != receiving) - { - _logger.CloseTimedOut(_connectionId); - socket.Abort(); - } - - // We're done writing - _application.Output.TryComplete(); } + + // We're done writing + _application.Output.TryComplete(); + + await socket.CloseOutputAsync(failed ? WebSocketCloseStatus.InternalServerError : WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + + var resultTask = await Task.WhenAny(task, Task.Delay(_options.CloseTimeout)); + + if (resultTask != task) + { + _logger.CloseTimedOut(_connectionId); + socket.Abort(); + } + else + { + // Observe any exceptions from second completed task + task.GetAwaiter().GetResult(); + } + + // Observe any exceptions from original completed task + trigger.GetAwaiter().GetResult(); } private async Task StartReceiving(WebSocket socket) { - // REVIEW: This code was copied from the client, it's highly unoptimized at the moment (especially + // REVIEW: This code was copied from the client, it's highly unoptimized at the moment (especially // for server logic) var incomingMessage = new List>(); while (true) diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index f9fde07f58..3dcac4856b 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Sockets var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); var connection = new DefaultConnectionContext(id, applicationSide, transportSide); - + _connections.TryAdd(id, connection); return connection; } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs index 5947c2533b..44ace932c9 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs @@ -8,16 +8,18 @@ using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Testing.xunit; +using Microsoft.Extensions.Logging.Testing; using Xunit; +using Xunit.Abstractions; namespace Microsoft.AspNetCore.SignalR.Tests { [Collection(EndToEndTestsCollection.Name)] - public class WebSocketsTransportTests + public class WebSocketsTransportTests : LoggedTest { private readonly ServerFixture _serverFixture; - public WebSocketsTransportTests(ServerFixture serverFixture) + public WebSocketsTransportTests(ServerFixture serverFixture, ITestOutputHelper output) : base(output) { if (serverFixture == null) { @@ -27,53 +29,62 @@ namespace Microsoft.AspNetCore.SignalR.Tests _serverFixture = serverFixture; } - [ConditionalFact(Skip = "WebsocketClient.CloseAsync never returns - investigating")] + [ConditionalFact] [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] public async Task WebSocketsTransportStopsSendAndReceiveLoopsWhenTransportIsStopped() { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + using (StartLog(out var loggerFactory)) + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - var webSocketsTransport = new WebSocketsTransport(); - await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection); - await webSocketsTransport.StopAsync(); - await webSocketsTransport.Running.OrTimeout(); + var webSocketsTransport = new WebSocketsTransport(loggerFactory); + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection).OrTimeout(); + await webSocketsTransport.StopAsync().OrTimeout(); + await webSocketsTransport.Running.OrTimeout(); + } } [ConditionalFact] [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] public async Task WebSocketsTransportStopsWhenConnectionChannelClosed() { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + using (StartLog(out var loggerFactory)) + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - var webSocketsTransport = new WebSocketsTransport(); - await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection); - connectionToTransport.Out.TryComplete(); - await webSocketsTransport.Running.OrTimeout(); + var webSocketsTransport = new WebSocketsTransport(loggerFactory); + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection); + connectionToTransport.Out.TryComplete(); + await webSocketsTransport.Running.OrTimeout(); + } } [ConditionalFact] [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] public async Task WebSocketsTransportStopsWhenConnectionClosedByTheServer() { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + using (StartLog(out var loggerFactory)) + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - var webSocketsTransport = new WebSocketsTransport(); - await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection); + var webSocketsTransport = new WebSocketsTransport(loggerFactory); + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection); - var sendTcs = new TaskCompletionSource(); - connectionToTransport.Out.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs)); - await sendTcs.Task; - // The echo endpoint close the connection immediately after sending response which should stop the transport - await webSocketsTransport.Running.OrTimeout(); + var sendTcs = new TaskCompletionSource(); + connectionToTransport.Out.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs)); + await sendTcs.Task; + // The echo endpoint close the connection immediately after sending response which should stop the transport + await webSocketsTransport.Running.OrTimeout(); - Assert.True(transportToConnection.In.TryRead(out var buffer)); - Assert.Equal(new byte[] { 0x42 }, buffer); + Assert.True(transportToConnection.In.TryRead(out var buffer)); + Assert.Equal(new byte[] { 0x42 }, buffer); + } } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs index af5c7fdc02..6d76940540 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs @@ -25,9 +25,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests var transportToApplication = Channel.CreateUnbounded(); var applicationToTransport = Channel.CreateUnbounded(); - var transportSide = new ChannelConnection(applicationToTransport, transportToApplication); - var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); - + using (var transportSide = new ChannelConnection(applicationToTransport, transportToApplication)) + using (var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); @@ -69,9 +68,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests var transportToApplication = Channel.CreateUnbounded(); var applicationToTransport = Channel.CreateUnbounded(); - var transportSide = new ChannelConnection(applicationToTransport, transportToApplication); - var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); - + using (var transportSide = new ChannelConnection(applicationToTransport, transportToApplication)) + using (var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { var ws = new WebSocketsTransport(new WebSocketOptions() { WebSocketMessageType = webSocketMessageType }, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); @@ -98,60 +96,22 @@ namespace Microsoft.AspNetCore.Sockets.Tests } } - [Theory] - [InlineData(WebSocketMessageType.Text)] - [InlineData(WebSocketMessageType.Binary)] - public async Task FrameReceivedAfterServerCloseSent(WebSocketMessageType webSocketMessageType) - { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); - - var transportSide = new ChannelConnection(applicationToTransport, transportToApplication); - var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); - - using (var feature = new TestWebSocketConnectionFeature()) - { - var ws = new WebSocketsTransport(new WebSocketOptions() { WebSocketMessageType = webSocketMessageType }, transportSide, - connectionId: string.Empty, loggerFactory: new LoggerFactory()); - - // Give the server socket to the transport and run it - var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); - - // Run the client socket - var client = feature.Client.ExecuteAndCaptureFramesAsync(); - - // Close the output and wait for the close frame - Assert.True(applicationSide.Output.Out.TryComplete()); - await client; - - // Send another frame. Then close - await feature.Client.SendAsync( - buffer: new ArraySegment(Encoding.UTF8.GetBytes("Hello")), - endOfMessage: true, - messageType: webSocketMessageType, - cancellationToken: CancellationToken.None); - await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); - - // Read that frame from the input - var buffer = await applicationSide.Input.In.ReadAsync(); - Assert.Equal("Hello", Encoding.UTF8.GetString(buffer)); - - await transport; - } - } - [Fact] public async Task TransportFailsWhenClientDisconnectsAbnormally() { var transportToApplication = Channel.CreateUnbounded(); var applicationToTransport = Channel.CreateUnbounded(); - var transportSide = new ChannelConnection(applicationToTransport, transportToApplication); - var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); - + using (var transportSide = new ChannelConnection(applicationToTransport, transportToApplication)) + using (var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { - var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var options = new WebSocketOptions() + { + CloseTimeout = TimeSpan.FromSeconds(1) + }; + + var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); @@ -163,7 +123,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests feature.Client.Abort(); // Wait for the transport - await Assert.ThrowsAsync(() => transport); + await Assert.ThrowsAsync(() => transport).OrTimeout(); + + var summary = await client.OrTimeout(); + Assert.Equal(WebSocketCloseStatus.InternalServerError, summary.CloseResult.CloseStatus); } } @@ -173,9 +136,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests var transportToApplication = Channel.CreateUnbounded(); var applicationToTransport = Channel.CreateUnbounded(); - var transportSide = new ChannelConnection(applicationToTransport, transportToApplication); - var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); - + using (var transportSide = new ChannelConnection(applicationToTransport, transportToApplication)) + using (var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); @@ -188,7 +150,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests // Fail in the app Assert.True(applicationSide.Output.Out.TryComplete(new InvalidOperationException("Catastrophic failure."))); - var clientSummary = await client; + var clientSummary = await client.OrTimeout(); Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.CloseStatus); // Close from the client @@ -205,9 +167,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests var transportToApplication = Channel.CreateUnbounded(); var applicationToTransport = Channel.CreateUnbounded(); - var transportSide = new ChannelConnection(applicationToTransport, transportToApplication); - var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); - + using (var transportSide = new ChannelConnection(applicationToTransport, transportToApplication)) + using (var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { var options = new WebSocketOptions() @@ -232,5 +193,112 @@ namespace Microsoft.AspNetCore.Sockets.Tests serverSocket.Dispose(); } } + + [Fact] + public async Task TransportFailsOnTimeoutWithErrorWhenApplicationFailsAndClientDoesNotSendCloseFrame() + { + var transportToApplication = Channel.CreateUnbounded(); + var applicationToTransport = Channel.CreateUnbounded(); + + using (var transportSide = new ChannelConnection(applicationToTransport, transportToApplication)) + using (var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport)) + using (var feature = new TestWebSocketConnectionFeature()) + { + var options = new WebSocketOptions() + { + CloseTimeout = TimeSpan.FromSeconds(1) + }; + + var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + + var serverSocket = await feature.AcceptAsync(); + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(serverSocket); + + // Run the client socket + var client = feature.Client.ExecuteAndCaptureFramesAsync(); + + // fail the client to server channel + applicationToTransport.Out.TryComplete(new Exception()); + + await Assert.ThrowsAsync(() => transport).OrTimeout(); + + Assert.Equal(WebSocketState.Aborted, serverSocket.State); + } + } + + [Fact] + public async Task ServerGracefullyClosesWhenApplicationEndsThenClientSendsCloseFrame() + { + var transportToApplication = Channel.CreateUnbounded(); + var applicationToTransport = Channel.CreateUnbounded(); + + using (var transportSide = new ChannelConnection(applicationToTransport, transportToApplication)) + using (var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport)) + using (var feature = new TestWebSocketConnectionFeature()) + { + var options = new WebSocketOptions() + { + // We want to verify behavior without timeout affecting it + CloseTimeout = TimeSpan.FromSeconds(20) + }; + var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + + var serverSocket = await feature.AcceptAsync(); + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(serverSocket); + + // Run the client socket + var client = feature.Client.ExecuteAndCaptureFramesAsync(); + + // close the client to server channel + applicationToTransport.Out.TryComplete(); + + _ = await client.OrTimeout(); + + await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).OrTimeout(); + + await transport.OrTimeout(); + + Assert.Equal(WebSocketCloseStatus.NormalClosure, serverSocket.CloseStatus); + } + } + + [Fact] + public async Task ServerGracefullyClosesWhenClientSendsCloseFrameThenApplicationEnds() + { + var transportToApplication = Channel.CreateUnbounded(); + var applicationToTransport = Channel.CreateUnbounded(); + + using (var transportSide = new ChannelConnection(applicationToTransport, transportToApplication)) + using (var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport)) + using (var feature = new TestWebSocketConnectionFeature()) + { + var options = new WebSocketOptions() + { + // We want to verify behavior without timeout affecting it + CloseTimeout = TimeSpan.FromSeconds(20) + }; + var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + + var serverSocket = await feature.AcceptAsync(); + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(serverSocket); + + // Run the client socket + var client = feature.Client.ExecuteAndCaptureFramesAsync(); + + await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).OrTimeout(); + + // close the client to server channel + applicationToTransport.Out.TryComplete(); + + _ = await client.OrTimeout(); + + await transport.OrTimeout(); + + Assert.Equal(WebSocketCloseStatus.NormalClosure, serverSocket.CloseStatus); + } + } } }