diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index eaa81799c5..c1aeb69c78 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -27,6 +27,7 @@ namespace Microsoft.AspNetCore.SignalR { private static readonly Base64Encoder Base64Encoder = new Base64Encoder(); private static readonly PassThroughEncoder PassThroughEncoder = new PassThroughEncoder(); + private static readonly TimeSpan NegotiateTimeout = TimeSpan.FromSeconds(5); private readonly Dictionary _methods = new Dictionary(StringComparer.OrdinalIgnoreCase); @@ -108,35 +109,46 @@ namespace Microsoft.AspNetCore.SignalR private async Task ProcessNegotiate(HubConnectionContext connection) { - while (await connection.Input.WaitToReadAsync()) + try { - while (connection.Input.TryRead(out var buffer)) + using (var cts = new CancellationTokenSource()) { - if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage)) + cts.CancelAfter(NegotiateTimeout); + while (await connection.Input.WaitToReadAsync(cts.Token)) { - var protocol = _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection); + while (connection.Input.TryRead(out var buffer)) + { + if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage)) + { + var protocol = _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection); - var transportCapabilities = connection.Features.Get()?.TransportCapabilities - ?? throw new InvalidOperationException("Unable to read transport capabilities."); + var transportCapabilities = connection.Features.Get()?.TransportCapabilities + ?? throw new InvalidOperationException("Unable to read transport capabilities."); - var dataEncoder = (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) == 0) - ? (IDataEncoder)Base64Encoder - : PassThroughEncoder; + var dataEncoder = (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) == 0) + ? (IDataEncoder)Base64Encoder + : PassThroughEncoder; - var transferModeFeature = connection.Features.Get() ?? - throw new InvalidOperationException("Unable to read transfer mode."); + var transferModeFeature = connection.Features.Get() ?? + throw new InvalidOperationException("Unable to read transfer mode."); - transferModeFeature.TransferMode = - (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) != 0) - ? TransferMode.Binary - : TransferMode.Text; + transferModeFeature.TransferMode = + (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) != 0) + ? TransferMode.Binary + : TransferMode.Text; - connection.ProtocolReaderWriter = new HubProtocolReaderWriter(protocol, dataEncoder); + connection.ProtocolReaderWriter = new HubProtocolReaderWriter(protocol, dataEncoder); - return true; + return true; + } + } } } } + catch (OperationCanceledException) + { + _logger.LogDebug("Negotiate was canceled."); + } return false; } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 3c0bde7433..8dee3a565a 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -59,6 +59,21 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task NegotiateTimesOut() + { + var serviceProvider = CreateServiceProvider(); + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + // TestClient automatically writes negotiate, for this test we want to assume negotiate never gets sent + client.Connection.Transport.In.TryRead(out var item); + + await endPoint.OnConnectedAsync(client.Connection).OrTimeout(TimeSpan.FromSeconds(10)); + } + } + [Fact] public async Task LifetimeManagerOnDisconnectedAsyncCalledIfLifetimeManagerOnConnectedAsyncThrows() {