diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 64ff041d5c..eac4268e4c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,4 @@ Contributing ====== -Information on contributing to this repo is in the [Contributing Guide](https://github.com/aspnet/Home/blob/dev/CONTRIBUTING.md) in the Home repo. +Information on contributing to this repo is in the [Contributing Guide](https://github.com/aspnet/Home/blob/master/CONTRIBUTING.md) in the Home repo. diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 50b73dd824..5aaa164cea 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -36,6 +36,7 @@ namespace Microsoft.AspNetCore.SignalR private bool _receivedMessageThisInterval = false; private ReadOnlyMemory _cachedPingMessage; private bool _clientTimeoutActive; + private bool _connectedAborted; /// /// Initializes a new instance of the class. @@ -105,6 +106,11 @@ namespace Microsoft.AspNetCore.SignalR public virtual ValueTask WriteAsync(HubMessage message, CancellationToken cancellationToken = default) { + if (_connectedAborted) + { + return default; + } + // Try to grab the lock synchronously, if we fail, go to the slower path if (!_writeLock.Wait(0)) { @@ -135,6 +141,11 @@ namespace Microsoft.AspNetCore.SignalR /// public virtual ValueTask WriteAsync(SerializedHubMessage message, CancellationToken cancellationToken = default) { + if (_connectedAborted) + { + return default; + } + // Try to grab the lock synchronously, if we fail, go to the slower path if (!_writeLock.Wait(0)) { @@ -170,6 +181,8 @@ namespace Microsoft.AspNetCore.SignalR { Log.FailedWritingMessage(_logger, ex); + Abort(); + return new ValueTask(new FlushResult(isCanceled: false, isCompleted: true)); } } @@ -187,6 +200,8 @@ namespace Microsoft.AspNetCore.SignalR { Log.FailedWritingMessage(_logger, ex); + Abort(); + return new ValueTask(new FlushResult(isCanceled: false, isCompleted: true)); } } @@ -200,6 +215,8 @@ namespace Microsoft.AspNetCore.SignalR catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); + + Abort(); } finally { @@ -220,6 +237,7 @@ namespace Microsoft.AspNetCore.SignalR catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); + Abort(); } finally { @@ -239,6 +257,7 @@ namespace Microsoft.AspNetCore.SignalR catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); + Abort(); } finally { @@ -311,6 +330,8 @@ namespace Microsoft.AspNetCore.SignalR return; } + _connectedAborted = true; + Input.CancelPendingRead(); // We fire and forget since this can trigger user code to run @@ -531,7 +552,7 @@ namespace Microsoft.AspNetCore.SignalR LoggerMessage.Define(LogLevel.Error, new EventId(5, "HandshakeFailed"), "Failed connection handshake."); private static readonly Action _failedWritingMessage = - LoggerMessage.Define(LogLevel.Debug, new EventId(6, "FailedWritingMessage"), "Failed writing message."); + LoggerMessage.Define(LogLevel.Error, new EventId(6, "FailedWritingMessage"), "Failed writing message. Aborting connection."); private static readonly Action _protocolVersionFailed = LoggerMessage.Define(LogLevel.Warning, new EventId(7, "ProtocolVersionFailed"), "Server does not support version {Version} of the {Protocol} protocol."); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs index f2dad3db45..7eb0aec498 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs @@ -159,6 +159,21 @@ namespace Microsoft.AspNetCore.SignalR.Tests { return Clients.Caller.SendAsync("Send", message); } + + public Task ProtocolError() + { + return Clients.Caller.SendAsync("Send", new string('x', 3000), new SelfRef()); + } + + private class SelfRef + { + public SelfRef() + { + Self = this; + } + + public SelfRef Self; + } } public abstract class TestHub : Hub diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs index 26cd439555..04156ad03b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs @@ -16,6 +16,7 @@ using Microsoft.AspNetCore.Http.Connections.Features; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Options; using Moq; using Newtonsoft.Json; @@ -102,9 +103,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var connectionHandlerTask = await client.ConnectAsync(connectionHandler); - await client.InvokeAsync(nameof(AbortHub.Kill)); + await client.SendInvocationAsync(nameof(AbortHub.Kill)).OrTimeout(); await connectionHandlerTask.OrTimeout(); + + Assert.Null(client.TryRead()); } } @@ -374,6 +377,32 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task ConnectionClosesOnServerWithPartialHandshakeMessageAndCompletedPipe() + { + var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT)); + + using (var client = new TestClient()) + { + // partial handshake + var payload = Encoding.UTF8.GetBytes("{\"protocol\": \"json\",\"ver"); + await client.Connection.Application.Output.WriteAsync(payload).OrTimeout(); + + var connectionHandlerTask = await client.ConnectAsync(connectionHandler, sendHandshakeRequestMessage: false, expectedHandshakeResponseMessage: false); + // Complete the pipe to 'close' the connection + client.Connection.Application.Output.Complete(); + + // This will never complete as the pipe was completed and nothing can be written to it + var handshakeReadTask = client.ReadAsync(true); + + // Check that the connection was closed on the server + await connectionHandlerTask.OrTimeout(); + Assert.False(handshakeReadTask.IsCompleted); + + client.Dispose(); + } + } + [Fact] public async Task LifetimeManagerOnDisconnectedAsyncCalledIfLifetimeManagerOnConnectedAsyncThrows() { @@ -2332,6 +2361,26 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task ConnectionAbortedIfSendFailsWithProtocolError() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => options.EnableDetailedErrors = true); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + + await client.SendInvocationAsync(nameof(MethodHub.ProtocolError)).OrTimeout(); + + await client.Connected.OrTimeout(); + await connectionHandlerTask.OrTimeout(); + } + } + [Fact] public async Task ServerReportsProtocolMinorVersion() {