From 8fc2cd98b613b5562be4d9b4045b9bb1b2d5683d Mon Sep 17 00:00:00 2001 From: Mikael Mengistu Date: Wed, 19 Jul 2017 11:47:47 -0700 Subject: [PATCH] Add timeout to Event Queue drain (#619) --- .../HttpConnection.cs | 3 + .../Internal/TaskExtensions.cs | 29 +++++ .../HttpConnectionTests.cs | 102 ++++++++++++++++++ 3 files changed, 134 insertions(+) create mode 100644 src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/TaskExtensions.cs diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index 5ddf582865..30f9041a83 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -11,6 +11,7 @@ using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Sockets.Features; using Microsoft.AspNetCore.Sockets.Client.Internal; +using Microsoft.AspNetCore.Sockets.Http.Internal; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -32,6 +33,7 @@ namespace Microsoft.AspNetCore.Sockets.Client private TaskQueue _eventQueue = new TaskQueue(); private readonly ITransportFactory _transportFactory; private string _connectionId; + private readonly TimeSpan _eventQueueDrainTimeout = TimeSpan.FromSeconds(5); private ReadableChannel Input => _transportChannel.In; private WritableChannel Output => _transportChannel.Out; @@ -173,6 +175,7 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger.DrainEvents(_connectionId); await _eventQueue.Drain(); + await Task.WhenAny(_eventQueue.Drain().NoThrow(), Task.Delay(_eventQueueDrainTimeout)); _httpClient.Dispose(); _logger.RaiseClosed(_connectionId); diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/TaskExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/TaskExtensions.cs new file mode 100644 index 0000000000..abd244c6bd --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/TaskExtensions.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Sockets.Http.Internal +{ + public static class TaskExtensions + { + public static async Task NoThrow(this Task task) + { + await new NoThrowAwaiter(task); + } + } + + internal struct NoThrowAwaiter : ICriticalNotifyCompletion + { + private readonly Task _task; + public NoThrowAwaiter(Task task) { _task = task; } + public NoThrowAwaiter GetAwaiter() => this; + public bool IsCompleted => _task.IsCompleted; + // Observe exception + public void GetResult() { _ = _task.Exception; } + public void OnCompleted(Action continuation) => _task.GetAwaiter().OnCompleted(continuation); + public void UnsafeOnCompleted(Action continuation) => OnCompleted(continuation); + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs index 97a8722cda..c71f2794b8 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs @@ -433,6 +433,108 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests await connection.DisposeAsync(); } + [Fact] + public async Task EventQueueTimeout() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); + }); + + var mockTransport = new Mock(); + Channel channel = null; + mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny())) + .Returns, TransferMode, string>((url, c, transferMode, connectionId) => + { + channel = c; + return Task.CompletedTask; + }); + mockTransport.Setup(t => t.StopAsync()) + .Returns(() => + { + channel.Out.TryComplete(); + return Task.CompletedTask; + }); + mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text); + + var blockReceiveCallbackTcs = new TaskCompletionSource(); + var closedTcs = new TaskCompletionSource(); + + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + connection.Received += + async m => + { + await blockReceiveCallbackTcs.Task; + }; + connection.Closed += _ => { + closedTcs.SetResult(null); + return Task.CompletedTask; + }; + + await connection.StartAsync(); + channel.Out.TryWrite(Array.Empty()); + + // Ensure that SignalR isn't blocked by the receive callback + Assert.False(channel.In.TryRead(out var message)); + + await connection.DisposeAsync(); + } + + [Fact] + public async Task EventQueueTimeoutWithException() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); + }); + + var mockTransport = new Mock(); + Channel channel = null; + mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny())) + .Returns, TransferMode, string>((url, c, transferMode, connectionId) => + { + channel = c; + return Task.CompletedTask; + }); + mockTransport.Setup(t => t.StopAsync()) + .Returns(() => + { + channel.Out.TryComplete(); + return Task.CompletedTask; + }); + mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text); + + var callbackInvokedTcs = new TaskCompletionSource(); + var closedTcs = new TaskCompletionSource(); + + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + connection.Received += + m => + { + throw new OperationCanceledException(); + }; + + await connection.StartAsync(); + channel.Out.TryWrite(Array.Empty()); + + // Ensure that SignalR isn't blocked by the receive callback + Assert.False(channel.In.TryRead(out var message)); + + await connection.DisposeAsync(); + } + [Fact] public async Task ClosedEventNotRaisedWhenTheClientIsStoppedButWasNeverStarted() {