Add option to restrict the maximum hub message size (#8135)
- This change moves the limit checking from the transport layer to the protocol parsing layer. One nice side effect is that it gives us better control over error handling.
This commit is contained in:
parent
67d339ee3b
commit
9cb1185a5c
|
|
@ -196,6 +196,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
public bool? EnableDetailedErrors { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
|
||||
public System.TimeSpan? HandshakeTimeout { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
|
||||
public System.TimeSpan? KeepAliveInterval { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
|
||||
public long? MaximumReceiveMessageSize { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
|
||||
public System.Collections.Generic.IList<string> SupportedProtocols { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
|
||||
}
|
||||
public partial class HubOptions<THub> : Microsoft.AspNetCore.SignalR.HubOptions where THub : Microsoft.AspNetCore.SignalR.Hub
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
private readonly IUserIdProvider _userIdProvider;
|
||||
private readonly HubDispatcher<THub> _dispatcher;
|
||||
private readonly bool _enableDetailedErrors;
|
||||
private readonly long? _maximumMessageSize;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="HubConnectionHandler{THub}"/> class.
|
||||
|
|
@ -61,6 +62,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
_dispatcher = dispatcher;
|
||||
|
||||
_enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _globalHubOptions.EnableDetailedErrors ?? false;
|
||||
_maximumMessageSize = _hubOptions.MaximumReceiveMessageSize ?? _globalHubOptions.MaximumReceiveMessageSize;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
|
|
@ -69,7 +71,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
// We check to see if HubOptions<THub> are set because those take precedence over global hub options.
|
||||
// Then set the keepAlive and handshakeTimeout values to the defaults in HubOptionsSetup incase they were explicitly set to null.
|
||||
var keepAlive = _hubOptions.KeepAliveInterval ?? _globalHubOptions.KeepAliveInterval ?? HubOptionsSetup.DefaultKeepAliveInterval;
|
||||
var clientTimeout = _hubOptions.ClientTimeoutInterval ?? _globalHubOptions.ClientTimeoutInterval ?? HubOptionsSetup.DefaultClientTimeoutInterval;
|
||||
var clientTimeout = _hubOptions.ClientTimeoutInterval ?? _globalHubOptions.ClientTimeoutInterval ?? HubOptionsSetup.DefaultClientTimeoutInterval;
|
||||
var handshakeTimeout = _hubOptions.HandshakeTimeout ?? _globalHubOptions.HandshakeTimeout ?? HubOptionsSetup.DefaultHandshakeTimeout;
|
||||
var supportedProtocols = _hubOptions.SupportedProtocols ?? _globalHubOptions.SupportedProtocols;
|
||||
|
||||
|
|
@ -205,9 +207,47 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
{
|
||||
connection.ResetClientTimeout();
|
||||
|
||||
while (protocol.TryParseMessage(ref buffer, binder, out var message))
|
||||
// No message limit, just parse and dispatch
|
||||
if (_maximumMessageSize == null)
|
||||
{
|
||||
await _dispatcher.DispatchMessageAsync(connection, message);
|
||||
while (protocol.TryParseMessage(ref buffer, binder, out var message))
|
||||
{
|
||||
await _dispatcher.DispatchMessageAsync(connection, message);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// We give the parser a sliding window of the default message size
|
||||
var maxMessageSize = _maximumMessageSize.Value;
|
||||
|
||||
while (!buffer.IsEmpty)
|
||||
{
|
||||
var segment = buffer;
|
||||
var overLength = false;
|
||||
|
||||
if (segment.Length > maxMessageSize)
|
||||
{
|
||||
segment = segment.Slice(segment.Start, maxMessageSize);
|
||||
overLength = true;
|
||||
}
|
||||
|
||||
if (protocol.TryParseMessage(ref segment, binder, out var message))
|
||||
{
|
||||
await _dispatcher.DispatchMessageAsync(connection, message);
|
||||
}
|
||||
else if (overLength)
|
||||
{
|
||||
throw new InvalidDataException($"The maximum message size of {maxMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
|
||||
}
|
||||
else
|
||||
{
|
||||
// No need to update the buffer since we didn't parse anything
|
||||
break;
|
||||
}
|
||||
|
||||
// Update the buffer to the remaining segment
|
||||
buffer = buffer.Slice(segment.Start);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -36,6 +36,11 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
/// </summary>
|
||||
public IList<string> SupportedProtocols { get; set; } = null;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the maximum message size of a single incoming hub message. The default is 32KB.
|
||||
/// </summary>
|
||||
public long? MaximumReceiveMessageSize { get; set; } = null;
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets a value indicating whether detailed error messages are sent to the client.
|
||||
/// Detailed error messages include details from exceptions thrown on the server.
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
|
|
@ -16,6 +16,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
|
||||
internal static TimeSpan DefaultClientTimeoutInterval => TimeSpan.FromSeconds(30);
|
||||
|
||||
internal const int DefaultMaximumMessageSize = 32 * 1024 * 1024;
|
||||
|
||||
private readonly List<string> _defaultProtocols = new List<string>();
|
||||
|
||||
public HubOptionsSetup(IEnumerable<IHubProtocol> protocols)
|
||||
|
|
@ -40,6 +42,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
options.HandshakeTimeout = DefaultHandshakeTimeout;
|
||||
}
|
||||
|
||||
if (options.MaximumReceiveMessageSize == null)
|
||||
{
|
||||
options.MaximumReceiveMessageSize = DefaultMaximumMessageSize;
|
||||
}
|
||||
|
||||
if (options.SupportedProtocols == null)
|
||||
{
|
||||
options.SupportedProtocols = new List<string>();
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
return services.BuildServiceProvider();
|
||||
}
|
||||
|
||||
public static Connections.ConnectionHandler GetHubConnectionHandler(Type hubType, Action<ServiceCollection> addServices = null, ILoggerFactory loggerFactory = null)
|
||||
public static Connections.ConnectionHandler GetHubConnectionHandler(Type hubType, ILoggerFactory loggerFactory = null, Action<ServiceCollection> addServices = null)
|
||||
{
|
||||
var serviceProvider = CreateServiceProvider(addServices, loggerFactory);
|
||||
return (Connections.ConnectionHandler)serviceProvider.GetService(GetConnectionHandlerType(hubType));
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ using System;
|
|||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Security.Claims;
|
||||
using System.Text;
|
||||
|
|
@ -452,6 +453,191 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task HubMessageOverTheMaxMessageSizeThrows()
|
||||
{
|
||||
var payload = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"hello\"]}\u001e");
|
||||
var maximumMessageSize = payload.Length - 10;
|
||||
InvalidDataException exception = null;
|
||||
|
||||
bool ExpectedErrors(WriteContext writeContext)
|
||||
{
|
||||
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide))
|
||||
{
|
||||
exception = ide;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
using (StartVerifiableLog(ExpectedErrors))
|
||||
{
|
||||
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
|
||||
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||
|
||||
await client.Connection.Application.Output.WriteAsync(payload);
|
||||
|
||||
client.Dispose();
|
||||
|
||||
await connectionHandlerTask.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
Assert.NotNull(exception);
|
||||
Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ChunkedHubMessageOverTheMaxMessageSizeThrows()
|
||||
{
|
||||
var payload = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"hello\"]}\u001e");
|
||||
var maximumMessageSize = payload.Length - 10;
|
||||
InvalidDataException exception = null;
|
||||
|
||||
bool ExpectedErrors(WriteContext writeContext)
|
||||
{
|
||||
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide))
|
||||
{
|
||||
exception = ide;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
using (StartVerifiableLog(ExpectedErrors))
|
||||
{
|
||||
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
|
||||
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||
|
||||
await client.Connection.Application.Output.WriteAsync(payload.AsMemory(0, payload.Length / 2));
|
||||
await client.Connection.Application.Output.WriteAsync(payload.AsMemory(payload.Length / 2));
|
||||
|
||||
client.Dispose();
|
||||
|
||||
await connectionHandlerTask.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
Assert.NotNull(exception);
|
||||
Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ManyHubMessagesOneOverTheMaxMessageSizeThrows()
|
||||
{
|
||||
var payload1 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"one\"]}\u001e");
|
||||
var payload2 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"2\", \"target\": \"Echo\", \"arguments\":[\"two\"]}\u001e");
|
||||
var payload3 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"3\", \"target\": \"Echo\", \"arguments\":[\"three\"]}\u001e");
|
||||
|
||||
// Between the first and the second payload so we'll end up slicing with some remaining in the slice for
|
||||
// the next message
|
||||
var maximumMessageSize = payload1.Length + 1;
|
||||
InvalidDataException exception = null;
|
||||
|
||||
bool ExpectedErrors(WriteContext writeContext)
|
||||
{
|
||||
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide))
|
||||
{
|
||||
exception = ide;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
using (StartVerifiableLog(ExpectedErrors))
|
||||
{
|
||||
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
|
||||
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||
|
||||
client.Connection.Application.Output.Write(payload1);
|
||||
client.Connection.Application.Output.Write(payload2);
|
||||
client.Connection.Application.Output.Write(payload3);
|
||||
await client.Connection.Application.Output.FlushAsync();
|
||||
|
||||
// 2 invocations should be processed
|
||||
var completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
|
||||
Assert.NotNull(completionMessage);
|
||||
Assert.Equal("1", completionMessage.InvocationId);
|
||||
Assert.Equal("one", completionMessage.Result);
|
||||
|
||||
completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
|
||||
Assert.NotNull(completionMessage);
|
||||
Assert.Equal("2", completionMessage.InvocationId);
|
||||
Assert.Equal("two", completionMessage.Result);
|
||||
|
||||
// We never receive the 3rd message since it was over the maximum message size
|
||||
CloseMessage closeMessage = await client.ReadAsync().OrTimeout() as CloseMessage;
|
||||
Assert.NotNull(closeMessage);
|
||||
|
||||
client.Dispose();
|
||||
|
||||
await connectionHandlerTask.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
Assert.NotNull(exception);
|
||||
Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ManyHubMessagesUnderTheMessageSizeButConfiguredWithMax()
|
||||
{
|
||||
var payload1 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"one\"]}\u001e");
|
||||
var payload2 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"2\", \"target\": \"Echo\", \"arguments\":[\"two\"]}\u001e");
|
||||
var payload3 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"3\", \"target\": \"Echo\", \"arguments\":[\"three\"]}\u001e");
|
||||
|
||||
// Bigger than all 3 messages
|
||||
var maximumMessageSize = payload3.Length + 10;
|
||||
|
||||
using (StartVerifiableLog())
|
||||
{
|
||||
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
|
||||
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||
|
||||
client.Connection.Application.Output.Write(payload1);
|
||||
client.Connection.Application.Output.Write(payload2);
|
||||
client.Connection.Application.Output.Write(payload3);
|
||||
await client.Connection.Application.Output.FlushAsync();
|
||||
|
||||
// 2 invocations should be processed
|
||||
var completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
|
||||
Assert.NotNull(completionMessage);
|
||||
Assert.Equal("1", completionMessage.InvocationId);
|
||||
Assert.Equal("one", completionMessage.Result);
|
||||
|
||||
completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
|
||||
Assert.NotNull(completionMessage);
|
||||
Assert.Equal("2", completionMessage.InvocationId);
|
||||
Assert.Equal("two", completionMessage.Result);
|
||||
|
||||
completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
|
||||
Assert.NotNull(completionMessage);
|
||||
Assert.Equal("3", completionMessage.InvocationId);
|
||||
Assert.Equal("three", completionMessage.Result);
|
||||
|
||||
client.Dispose();
|
||||
|
||||
await connectionHandlerTask.OrTimeout();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task HandshakeFailureFromIncompatibleProtocolVersionSendsResponseWithError()
|
||||
{
|
||||
|
|
@ -2789,7 +2975,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
foreach (string id in ids)
|
||||
{
|
||||
await client.BeginUploadStreamAsync("invocation_"+id, nameof(MethodHub.StreamingConcat), new[] { id }, Array.Empty<object>());
|
||||
await client.BeginUploadStreamAsync("invocation_" + id, nameof(MethodHub.StreamingConcat), new[] { id }, Array.Empty<object>());
|
||||
}
|
||||
|
||||
var words = new[] { "zygapophyses", "qwerty", "abcd" };
|
||||
|
|
@ -2868,7 +3054,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
[Fact]
|
||||
public async Task ServerReportsProtocolMinorVersion()
|
||||
{
|
||||
|
|
@ -2881,7 +3067,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
testProtocol.Setup(m => m.TransferFormat).Returns(TransferFormat.Binary);
|
||||
|
||||
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT),
|
||||
(services) => services.AddSingleton<IHubProtocol>(testProtocol.Object), LoggerFactory);
|
||||
LoggerFactory, (services) => services.AddSingleton<IHubProtocol>(testProtocol.Object));
|
||||
|
||||
using (var client = new TestClient(protocol: testProtocol.Object))
|
||||
{
|
||||
|
|
|
|||
Loading…
Reference in New Issue