From 9b1cbedffcc48154c7c52655efe67927ae933e32 Mon Sep 17 00:00:00 2001 From: Eric Date: Thu, 13 Jul 2017 17:37:20 +0200 Subject: [PATCH] WebSocketClient.ConnectAsync throws when the provided CancellationToken is cancelled. --- samples/SampleStartups/FakeServer.cs | 2 +- .../WebSocketClient.cs | 21 ++++++-- .../TestClientTests.cs | 50 ++++++++++++++++++- 3 files changed, 65 insertions(+), 8 deletions(-) diff --git a/samples/SampleStartups/FakeServer.cs b/samples/SampleStartups/FakeServer.cs index a43790c352..ebdfdb383e 100644 --- a/samples/SampleStartups/FakeServer.cs +++ b/samples/SampleStartups/FakeServer.cs @@ -22,7 +22,7 @@ namespace SampleStartups } } - public static class FakeServerWebHostBuliderExtensions + public static class FakeServerWebHostBuilderExtensions { public static IWebHostBuilder UseFakeServer(this IWebHostBuilder builder) { diff --git a/src/Microsoft.AspNetCore.TestHost/WebSocketClient.cs b/src/Microsoft.AspNetCore.TestHost/WebSocketClient.cs index 206c5a6980..3d267f3a1b 100644 --- a/src/Microsoft.AspNetCore.TestHost/WebSocketClient.cs +++ b/src/Microsoft.AspNetCore.TestHost/WebSocketClient.cs @@ -87,6 +87,7 @@ namespace Microsoft.AspNetCore.TestHost { private readonly IHttpApplication _application; private TaskCompletionSource _clientWebSocketTcs; + private CancellationTokenRegistration _cancellationTokenRegistration; private WebSocket _serverWebSocket; public Context Context { get; private set; } @@ -95,6 +96,8 @@ namespace Microsoft.AspNetCore.TestHost public RequestState(Uri uri, PathString pathBase, CancellationToken cancellationToken, IHttpApplication application) { _clientWebSocketTcs = new TaskCompletionSource(); + _cancellationTokenRegistration = cancellationToken.Register( + () => _clientWebSocketTcs.TrySetCanceled(cancellationToken)); _application = application; // HttpContext @@ -181,12 +184,20 @@ namespace Microsoft.AspNetCore.TestHost Task IHttpWebSocketFeature.AcceptAsync(WebSocketAcceptContext context) { - Context.HttpContext.Response.StatusCode = 101; // Switching Protocols - var websockets = TestWebSocket.CreatePair(context.SubProtocol); - _clientWebSocketTcs.SetResult(websockets.Item1); - _serverWebSocket = websockets.Item2; - return Task.FromResult(_serverWebSocket); + if (_clientWebSocketTcs.TrySetResult(websockets.Item1)) + { + Context.HttpContext.Response.StatusCode = StatusCodes.Status101SwitchingProtocols; + _serverWebSocket = websockets.Item2; + return Task.FromResult(_serverWebSocket); + } + else + { + Context.HttpContext.Response.StatusCode = StatusCodes.Status500InternalServerError; + websockets.Item1.Dispose(); + websockets.Item2.Dispose(); + return _clientWebSocketTcs.Task; // Canceled or Faulted - no result + } } } } diff --git a/test/Microsoft.AspNetCore.TestHost.Tests/TestClientTests.cs b/test/Microsoft.AspNetCore.TestHost.Tests/TestClientTests.cs index 283d60ccd9..748e79c841 100644 --- a/test/Microsoft.AspNetCore.TestHost.Tests/TestClientTests.cs +++ b/test/Microsoft.AspNetCore.TestHost.Tests/TestClientTests.cs @@ -11,7 +11,6 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Hosting.Internal; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.DependencyInjection; @@ -173,7 +172,7 @@ namespace Microsoft.AspNetCore.TestHost public async Task WebSocketWorks() { // Arrange - // This logger will attempt to access information from HttpRequest once the HttpContext is createds + // This logger will attempt to access information from HttpRequest once the HttpContext is created var logger = new VerifierLogger(); RequestDelegate appDelegate = async ctx => { @@ -239,6 +238,53 @@ namespace Microsoft.AspNetCore.TestHost clientSocket.Dispose(); } + [ConditionalFact] + public async Task WebSocketAcceptThrowsWhenCancelled() + { + // Arrange + // This logger will attempt to access information from HttpRequest once the HttpContext is created + var logger = new VerifierLogger(); + RequestDelegate appDelegate = async ctx => + { + if (ctx.WebSockets.IsWebSocketRequest) + { + var websocket = await ctx.WebSockets.AcceptWebSocketAsync(); + var receiveArray = new byte[1024]; + while (true) + { + var receiveResult = await websocket.ReceiveAsync(new System.ArraySegment(receiveArray), CancellationToken.None); + if (receiveResult.MessageType == WebSocketMessageType.Close) + { + await websocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Normal Closure", CancellationToken.None); + break; + } + else + { + var sendBuffer = new System.ArraySegment(receiveArray, 0, receiveResult.Count); + await websocket.SendAsync(sendBuffer, receiveResult.MessageType, receiveResult.EndOfMessage, CancellationToken.None); + } + } + } + }; + var builder = new WebHostBuilder() + .ConfigureServices(services => + { + services.AddSingleton>(logger); + }) + .Configure(app => + { + app.Run(appDelegate); + }); + var server = new TestServer(builder); + + // Act + var client = server.CreateWebSocketClient(); + var tokenSource = new CancellationTokenSource(); + tokenSource.Cancel(); + + // Assert + await Assert.ThrowsAnyAsync(async () => await client.ConnectAsync(new System.Uri("http://localhost"), tokenSource.Token)); + } private class VerifierLogger : ILogger {