diff --git a/src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs b/src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs index d00b14a1a4..8185d616b5 100644 --- a/src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs +++ b/src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using Microsoft.AspNetCore.Authorization; namespace Microsoft.AspNetCore.Sockets @@ -10,5 +11,7 @@ namespace Microsoft.AspNetCore.Sockets public AuthorizationPolicy Policy { get; set; } public TransportType Transports { get; set; } = TransportType.All; + + public WebSocketOptions WebSockets { get; } = new WebSocketOptions(); } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index 0a58e615ff..a21280e839 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -100,7 +100,7 @@ namespace Microsoft.AspNetCore.Sockets return; } - var ws = new WebSocketsTransport(state.Application, _loggerFactory); + var ws = new WebSocketsTransport(options.WebSockets, state.Application, _loggerFactory); await DoPersistentConnection(endpoint, ws, context, state); } diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs index 7714d14cff..6e3bd605bb 100644 --- a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs +++ b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs @@ -80,7 +80,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal Lock.Release(); } - // REVIEW: Add a timeout so we don't wait forever await disposeTask; } diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs index ad409fada1..1a3fe2519b 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs @@ -16,7 +16,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports { public class WebSocketsTransport : IHttpTransport { - private static readonly TimeSpan _closeTimeout = TimeSpan.FromSeconds(5); + private readonly WebSocketOptions _options; private static readonly WebSocketAcceptContext _emptyContext = new WebSocketAcceptContext(); private WebSocketOpcode _lastOpcode = WebSocketOpcode.Continuation; @@ -25,17 +25,24 @@ namespace Microsoft.AspNetCore.Sockets.Transports private readonly ILogger _logger; private readonly IChannelConnection _application; - public WebSocketsTransport(IChannelConnection application, ILoggerFactory loggerFactory) + public WebSocketsTransport(WebSocketOptions options, IChannelConnection application, ILoggerFactory loggerFactory) { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + if (application == null) { throw new ArgumentNullException(nameof(application)); } + if (loggerFactory == null) { throw new ArgumentNullException(nameof(loggerFactory)); } + _options = options; _application = application; _logger = loggerFactory.CreateLogger(); } @@ -107,9 +114,17 @@ namespace Microsoft.AspNetCore.Sockets.Transports _logger.LogDebug("Waiting for the client to close the socket"); - // Wait for the client to close. - // TODO: Consider timing out here and cancelling the receive loop. - await receiving; + // Wait for the client to close or wait for the close timeout + var resultTask = await Task.WhenAny(receiving, Task.Delay(_options.CloseTimeout)); + + // We timed out waiting for the transport to close so abort the connection so we don't attempt to write anything else + if (resultTask != receiving) + { + _logger.LogDebug("Timed out waiting for client to send the close frame, aborting the connection."); + socket.Abort(); + } + + // We're done writing _application.Output.TryComplete(); } } diff --git a/src/Microsoft.AspNetCore.Sockets/WebSocketOptions.cs b/src/Microsoft.AspNetCore.Sockets/WebSocketOptions.cs new file mode 100644 index 0000000000..0157392996 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/WebSocketOptions.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Sockets +{ + public class WebSocketOptions + { + public TimeSpan CloseTimeout { get; set; } = TimeSpan.FromSeconds(5); + } +} diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 7b80480a84..f55d813f49 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -217,6 +217,58 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.False(exists); } + [Theory(Skip = "Timeouts have not been implemented as yet")] + [InlineData("/ws", true)] + [InlineData("/sse", false)] + [InlineData("/poll", false)] + public async Task NeverEndingEndPointCompletesWithTimeoutWhenTransportCloses(string path, bool isWebSocketRequest) + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context = MakeRequest(path, state, isWebSocketRequest: isWebSocketRequest); + + var task = dispatcher.ExecuteAsync("", context); + var webSocketTask = Task.CompletedTask; + + Assert.False(task.IsCompleted); + + if (isWebSocketRequest) + { + var ws = (TestWebSocketConnectionFeature)context.Features.Get(); + webSocketTask = ws.Client.ExecuteAsync(frame => Task.CompletedTask); + await ws.Client.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure), CancellationToken.None); + } + + // Shut the application down so the transport begins to unwind + state.Application.Dispose(); + + // Make sure the transport unwinds + await state.TransportTask.OrTimeout(); + + await webSocketTask.OrTimeout(); + + // The task should be cancelled because of the timeout + await Assert.ThrowsAsync(async () => await task.OrTimeout()); + } + + [Fact] + public async Task WebSocketTransportTimesOutWhenCloseFrameNotReceived() + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context = MakeRequest("/ws", state, isWebSocketRequest: true); + + var task = dispatcher.ExecuteAsync("", context); + + await task.OrTimeout(); + } + [Theory] [InlineData("/ws", true)] [InlineData("/sse", false)] @@ -519,7 +571,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests { var context = new DefaultHttpContext(); var services = new ServiceCollection(); - services.AddEndPoint(); + services.AddEndPoint(o => + { + // Make the close timeout less than the default for OrTimeout() test helper + o.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1); + }); + services.AddOptions(); context.RequestServices = services.BuildServiceProvider(); context.Request.Path = path; @@ -558,6 +615,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests } } + public class NerverEndingEndPoint : EndPoint + { + public override Task OnConnectedAsync(Connection connection) + { + var tcs = new TaskCompletionSource(); + return tcs.Task; + } + } + public class BlockingEndPoint : EndPoint { public override Task OnConnectedAsync(Connection connection) diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs index 75b92e3514..47a89eda83 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs @@ -32,7 +32,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var factory = new PipeFactory()) using (var pair = WebSocketPair.Create(factory)) { - var ws = new WebSocketsTransport(transportSide, new LoggerFactory()); + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory()); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(pair.ServerSocket); @@ -78,7 +78,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var factory = new PipeFactory()) using (var pair = WebSocketPair.Create(factory)) { - var ws = new WebSocketsTransport(transportSide, new LoggerFactory()); + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory()); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(pair.ServerSocket); @@ -133,7 +133,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var factory = new PipeFactory()) using (var pair = WebSocketPair.Create(factory)) { - var ws = new WebSocketsTransport(transportSide, new LoggerFactory()); + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory()); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(pair.ServerSocket); @@ -181,7 +181,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var factory = new PipeFactory()) using (var pair = WebSocketPair.Create(factory)) { - var ws = new WebSocketsTransport(transportSide, new LoggerFactory()); + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory()); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(pair.ServerSocket); @@ -222,7 +222,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var factory = new PipeFactory()) using (var pair = WebSocketPair.Create(factory)) { - var ws = new WebSocketsTransport(transportSide, new LoggerFactory()); + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory()); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(pair.ServerSocket); @@ -263,7 +263,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var factory = new PipeFactory()) using (var pair = WebSocketPair.Create(factory)) { - var ws = new WebSocketsTransport(transportSide, new LoggerFactory()); + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory()); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(pair.ServerSocket); @@ -291,7 +291,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var factory = new PipeFactory()) using (var pair = WebSocketPair.Create(factory)) { - var ws = new WebSocketsTransport(transportSide, new LoggerFactory()); + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory()); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(pair.ServerSocket); @@ -310,5 +310,42 @@ namespace Microsoft.AspNetCore.Sockets.Tests await transport.OrTimeout(); } } + + [Fact] + public async Task TransportClosesOnCloseTimeoutIfClientDoesNotSendCloseFrame() + { + var transportToApplication = Channel.CreateUnbounded(); + var applicationToTransport = Channel.CreateUnbounded(); + + var transportSide = new ChannelConnection(applicationToTransport, transportToApplication); + var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); + + using (var factory = new PipeFactory()) + using (var pair = WebSocketPair.Create(factory)) + { + var options = new WebSocketOptions() + { + CloseTimeout = TimeSpan.FromSeconds(1) + }; + + var ws = new WebSocketsTransport(options, transportSide, new LoggerFactory()); + + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(pair.ServerSocket); + + // End the app + applicationSide.Dispose(); + + await transport.OrTimeout(); + + // We're still in the closed sent state since the client never sent the close frame + Assert.Equal(WebSocketConnectionState.CloseSent, pair.ServerSocket.State); + + pair.ServerSocket.Dispose(); + + // Now we're closed + Assert.Equal(WebSocketConnectionState.Closed, pair.ServerSocket.State); + } + } } }