diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index 0f6f64a179..8afb75b247 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -152,16 +152,25 @@ namespace Microsoft.AspNetCore.Sockets.Client if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connected, ConnectionState.Connecting) == ConnectionState.Connecting) { - var ignore = _eventQueue.Enqueue(() => + _ = _eventQueue.Enqueue(async () => { _logger.RaiseConnected(_connectionId); - Connected?.Invoke(); - - return Task.CompletedTask; + var connectedEventHandler = Connected; + if (connectedEventHandler != null) + { + try + { + await connectedEventHandler.Invoke(); + } + catch (Exception ex) + { + _logger.ExceptionThrownFromEventHandler(_connectionId, nameof(Connected), ex); + } + } }); - ignore = Input.Completion.ContinueWith(async t => + _ = Input.Completion.ContinueWith(async t => { Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected); @@ -183,9 +192,18 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger.RaiseClosed(_connectionId); - Closed?.Invoke(t.IsFaulted ? t.Exception.InnerException : null); - - return Task.CompletedTask; + var closedEventHandler = Closed; + if (closedEventHandler != null) + { + try + { + await closedEventHandler.Invoke(t.IsFaulted ? t.Exception.InnerException : null); + } + catch (Exception ex) + { + _logger.ExceptionThrownFromEventHandler(_connectionId, nameof(Closed), ex); + } + } }); // start receive loop only after the Connected event was raised to @@ -331,19 +349,22 @@ namespace Microsoft.AspNetCore.Sockets.Client if (Input.TryRead(out var buffer)) { _logger.ScheduleReceiveEvent(_connectionId); - _ = _eventQueue.Enqueue(() => + _ = _eventQueue.Enqueue(async () => { _logger.RaiseReceiveEvent(_connectionId); - // Making a copy of the Received handler to ensure that its not null - // Can't use the ? operator because we specifically want to check if the handler is null var receivedHandler = Received; if (receivedHandler != null) { - return receivedHandler(buffer); + try + { + await receivedHandler(buffer); + } + catch (Exception ex) + { + _logger.ExceptionThrownFromEventHandler(_connectionId, nameof(Received), ex); + } } - - return Task.CompletedTask; }); } else diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs index 32d8bd52ed..9ee629e20e 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs @@ -150,6 +150,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal private static readonly Action _stoppingClient = LoggerMessage.Define(LogLevel.Information, 18, "{time}: Connection Id {connectionId}: Stopping client."); + private static readonly Action _exceptionThrownFromHandler = + LoggerMessage.Define(LogLevel.Error, 19, "{time}: Connection Id {connectionId}: An exception was thrown from the '{eventHandlerName}' event handler."); + + public static void StartTransport(this ILogger logger, string connectionId, TransferMode transferMode) { if (logger.IsEnabled(LogLevel.Information)) @@ -509,5 +513,13 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal _stoppingClient(logger, DateTime.Now, connectionId, null); } } + + public static void ExceptionThrownFromEventHandler(this ILogger logger, string connectionId, string eventHandlerName, Exception exception) + { + if (logger.IsEnabled(LogLevel.Error)) + { + _exceptionThrownFromHandler(logger, DateTime.Now, connectionId, eventHandlerName, exception); + } + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs index c71f2794b8..68673ae54e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs @@ -476,7 +476,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests closedTcs.SetResult(null); return Task.CompletedTask; }; - + await connection.StartAsync(); channel.Out.TryWrite(Array.Empty()); @@ -746,6 +746,249 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests } } + [Fact] + public async Task CanReceiveDataEvenIfUserThrowsInConnectedEvent() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + + var content = string.Empty; + + if (request.Method == HttpMethod.Get) + { + content = "42"; + } + + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); + }); + + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + try + { + connection.Connected += () => Task.FromException(new InvalidOperationException()); + + var receiveTcs = new TaskCompletionSource(); + connection.Received += data => + { + receiveTcs.TrySetResult(Encoding.UTF8.GetString(data)); + return Task.CompletedTask; + }; + + connection.Closed += e => + { + if (e != null) + { + receiveTcs.TrySetException(e); + } + else + { + receiveTcs.TrySetCanceled(); + } + return Task.CompletedTask; + }; + + await connection.StartAsync(); + + Assert.Equal("42", await receiveTcs.Task.OrTimeout()); + } + finally + { + await connection.DisposeAsync(); + } + } + + [Fact] + public async Task CanReceiveDataEvenIfUserThrowsSynchronouslyInConnectedEvent() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + + var content = string.Empty; + + if (request.Method == HttpMethod.Get) + { + content = "42"; + } + + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); + }); + + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + try + { + connection.Connected += () => + { + throw new InvalidOperationException(); + }; + + var receiveTcs = new TaskCompletionSource(); + connection.Received += data => + { + receiveTcs.TrySetResult(Encoding.UTF8.GetString(data)); + return Task.CompletedTask; + }; + + connection.Closed += e => + { + if (e != null) + { + receiveTcs.TrySetException(e); + } + else + { + receiveTcs.TrySetCanceled(); + } + return Task.CompletedTask; + }; + + await connection.StartAsync(); + + Assert.Equal("42", await receiveTcs.Task.OrTimeout()); + } + finally + { + await connection.DisposeAsync(); + } + } + + [Fact] + public async Task CanReceiveDataEvenIfExceptionThrownFromPreviousReceivedEvent() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + + var content = string.Empty; + + if (request.Method == HttpMethod.Get) + { + content = "42"; + } + + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); + }); + + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + try + { + var receiveTcs = new TaskCompletionSource(); + + var receivedRaised = false; + connection.Received += data => + { + if (!receivedRaised) + { + receivedRaised = true; + return Task.FromException(new InvalidOperationException()); + } + + receiveTcs.TrySetResult(Encoding.UTF8.GetString(data)); + return Task.CompletedTask; + }; + + connection.Closed += e => + { + if (e != null) + { + receiveTcs.TrySetException(e); + } + else + { + receiveTcs.TrySetCanceled(); + } + return Task.CompletedTask; + }; + + await connection.StartAsync(); + + Assert.Equal("42", await receiveTcs.Task.OrTimeout()); + } + finally + { + await connection.DisposeAsync(); + } + } + + [Fact] + public async Task CanReceiveDataEvenIfExceptionThrownSynchronouslyFromPreviousReceivedEvent() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + + var content = string.Empty; + + if (request.Method == HttpMethod.Get) + { + content = "42"; + } + + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); + }); + + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + try + { + var receiveTcs = new TaskCompletionSource(); + + var receivedRaised = false; + connection.Received += data => + { + if (!receivedRaised) + { + receivedRaised = true; + throw new InvalidOperationException(); + } + + receiveTcs.TrySetResult(Encoding.UTF8.GetString(data)); + return Task.CompletedTask; + }; + + connection.Closed += e => + { + if (e != null) + { + receiveTcs.TrySetException(e); + } + else + { + receiveTcs.TrySetCanceled(); + } + return Task.CompletedTask; + }; + + await connection.StartAsync(); + + Assert.Equal("42", await receiveTcs.Task.OrTimeout()); + } + finally + { + await connection.DisposeAsync(); + } + } + [Fact] public async Task CannotSendAfterReceiveThrewException() {