Abort connection on protocol error (#2654)
This commit is contained in:
parent
056da5114a
commit
433eeb6943
|
|
@ -36,6 +36,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
private bool _receivedMessageThisInterval = false;
|
||||
private ReadOnlyMemory<byte> _cachedPingMessage;
|
||||
private bool _clientTimeoutActive;
|
||||
private bool _connectedAborted;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="HubConnectionContext"/> class.
|
||||
|
|
@ -105,6 +106,11 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
public virtual ValueTask WriteAsync(HubMessage message, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (_connectedAborted)
|
||||
{
|
||||
return default;
|
||||
}
|
||||
|
||||
// Try to grab the lock synchronously, if we fail, go to the slower path
|
||||
if (!_writeLock.Wait(0))
|
||||
{
|
||||
|
|
@ -135,6 +141,11 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
/// <returns></returns>
|
||||
public virtual ValueTask WriteAsync(SerializedHubMessage message, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (_connectedAborted)
|
||||
{
|
||||
return default;
|
||||
}
|
||||
|
||||
// Try to grab the lock synchronously, if we fail, go to the slower path
|
||||
if (!_writeLock.Wait(0))
|
||||
{
|
||||
|
|
@ -170,6 +181,8 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
{
|
||||
Log.FailedWritingMessage(_logger, ex);
|
||||
|
||||
Abort();
|
||||
|
||||
return new ValueTask<FlushResult>(new FlushResult(isCanceled: false, isCompleted: true));
|
||||
}
|
||||
}
|
||||
|
|
@ -187,6 +200,8 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
{
|
||||
Log.FailedWritingMessage(_logger, ex);
|
||||
|
||||
Abort();
|
||||
|
||||
return new ValueTask<FlushResult>(new FlushResult(isCanceled: false, isCompleted: true));
|
||||
}
|
||||
}
|
||||
|
|
@ -200,6 +215,8 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
catch (Exception ex)
|
||||
{
|
||||
Log.FailedWritingMessage(_logger, ex);
|
||||
|
||||
Abort();
|
||||
}
|
||||
finally
|
||||
{
|
||||
|
|
@ -220,6 +237,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
catch (Exception ex)
|
||||
{
|
||||
Log.FailedWritingMessage(_logger, ex);
|
||||
Abort();
|
||||
}
|
||||
finally
|
||||
{
|
||||
|
|
@ -239,6 +257,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
catch (Exception ex)
|
||||
{
|
||||
Log.FailedWritingMessage(_logger, ex);
|
||||
Abort();
|
||||
}
|
||||
finally
|
||||
{
|
||||
|
|
@ -312,6 +331,8 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
return;
|
||||
}
|
||||
|
||||
_connectedAborted = true;
|
||||
|
||||
Input.CancelPendingRead();
|
||||
|
||||
// We fire and forget since this can trigger user code to run
|
||||
|
|
@ -531,7 +552,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
LoggerMessage.Define(LogLevel.Error, new EventId(5, "HandshakeFailed"), "Failed connection handshake.");
|
||||
|
||||
private static readonly Action<ILogger, Exception> _failedWritingMessage =
|
||||
LoggerMessage.Define(LogLevel.Debug, new EventId(6, "FailedWritingMessage"), "Failed writing message.");
|
||||
LoggerMessage.Define(LogLevel.Error, new EventId(6, "FailedWritingMessage"), "Failed writing message. Aborting connection.");
|
||||
|
||||
private static readonly Action<ILogger, string, int, Exception> _protocolVersionFailed =
|
||||
LoggerMessage.Define<string, int>(LogLevel.Warning, new EventId(7, "ProtocolVersionFailed"), "Server does not support version {Version} of the {Protocol} protocol.");
|
||||
|
|
|
|||
|
|
@ -159,6 +159,21 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
return Clients.Caller.SendAsync("Send", message);
|
||||
}
|
||||
|
||||
public Task ProtocolError()
|
||||
{
|
||||
return Clients.Caller.SendAsync("Send", new string('x', 3000), new SelfRef());
|
||||
}
|
||||
|
||||
private class SelfRef
|
||||
{
|
||||
public SelfRef()
|
||||
{
|
||||
Self = this;
|
||||
}
|
||||
|
||||
public SelfRef Self;
|
||||
}
|
||||
}
|
||||
|
||||
public abstract class TestHub : Hub
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ using Microsoft.AspNetCore.Http.Connections.Features;
|
|||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Protocol;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.DependencyInjection.Extensions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Moq;
|
||||
using Newtonsoft.Json;
|
||||
|
|
@ -102,9 +103,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||
|
||||
await client.InvokeAsync(nameof(AbortHub.Kill));
|
||||
await client.SendInvocationAsync(nameof(AbortHub.Kill)).OrTimeout();
|
||||
|
||||
await connectionHandlerTask.OrTimeout();
|
||||
|
||||
Assert.Null(client.TryRead());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2358,6 +2361,26 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ConnectionAbortedIfSendFailsWithProtocolError()
|
||||
{
|
||||
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||
{
|
||||
services.AddSignalR(options => options.EnableDetailedErrors = true);
|
||||
});
|
||||
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
|
||||
|
||||
await client.SendInvocationAsync(nameof(MethodHub.ProtocolError)).OrTimeout();
|
||||
|
||||
await client.Connected.OrTimeout();
|
||||
await connectionHandlerTask.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
private class CustomHubActivator<THub> : IHubActivator<THub> where THub : Hub
|
||||
{
|
||||
public int ReleaseCount;
|
||||
|
|
|
|||
Loading…
Reference in New Issue