Add option to restrict the maximum hub message size (#8135)

- This change moves the limit checking from the transport layer to the protocol parsing layer. One nice side effect is that it gives us better control over error handling.
This commit is contained in:
David Fowler 2019-03-03 15:18:32 -08:00 committed by GitHub
parent 67d339ee3b
commit 9cb1185a5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 247 additions and 8 deletions

View File

@ -196,6 +196,7 @@ namespace Microsoft.AspNetCore.SignalR
public bool? EnableDetailedErrors { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public System.TimeSpan? HandshakeTimeout { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public System.TimeSpan? KeepAliveInterval { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public long? MaximumReceiveMessageSize { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
public System.Collections.Generic.IList<string> SupportedProtocols { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
}
public partial class HubOptions<THub> : Microsoft.AspNetCore.SignalR.HubOptions where THub : Microsoft.AspNetCore.SignalR.Hub

View File

@ -28,6 +28,7 @@ namespace Microsoft.AspNetCore.SignalR
private readonly IUserIdProvider _userIdProvider;
private readonly HubDispatcher<THub> _dispatcher;
private readonly bool _enableDetailedErrors;
private readonly long? _maximumMessageSize;
/// <summary>
/// Initializes a new instance of the <see cref="HubConnectionHandler{THub}"/> class.
@ -61,6 +62,7 @@ namespace Microsoft.AspNetCore.SignalR
_dispatcher = dispatcher;
_enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _globalHubOptions.EnableDetailedErrors ?? false;
_maximumMessageSize = _hubOptions.MaximumReceiveMessageSize ?? _globalHubOptions.MaximumReceiveMessageSize;
}
/// <inheritdoc />
@ -69,7 +71,7 @@ namespace Microsoft.AspNetCore.SignalR
// We check to see if HubOptions<THub> are set because those take precedence over global hub options.
// Then set the keepAlive and handshakeTimeout values to the defaults in HubOptionsSetup incase they were explicitly set to null.
var keepAlive = _hubOptions.KeepAliveInterval ?? _globalHubOptions.KeepAliveInterval ?? HubOptionsSetup.DefaultKeepAliveInterval;
var clientTimeout = _hubOptions.ClientTimeoutInterval ?? _globalHubOptions.ClientTimeoutInterval ?? HubOptionsSetup.DefaultClientTimeoutInterval;
var clientTimeout = _hubOptions.ClientTimeoutInterval ?? _globalHubOptions.ClientTimeoutInterval ?? HubOptionsSetup.DefaultClientTimeoutInterval;
var handshakeTimeout = _hubOptions.HandshakeTimeout ?? _globalHubOptions.HandshakeTimeout ?? HubOptionsSetup.DefaultHandshakeTimeout;
var supportedProtocols = _hubOptions.SupportedProtocols ?? _globalHubOptions.SupportedProtocols;
@ -205,9 +207,47 @@ namespace Microsoft.AspNetCore.SignalR
{
connection.ResetClientTimeout();
while (protocol.TryParseMessage(ref buffer, binder, out var message))
// No message limit, just parse and dispatch
if (_maximumMessageSize == null)
{
await _dispatcher.DispatchMessageAsync(connection, message);
while (protocol.TryParseMessage(ref buffer, binder, out var message))
{
await _dispatcher.DispatchMessageAsync(connection, message);
}
}
else
{
// We give the parser a sliding window of the default message size
var maxMessageSize = _maximumMessageSize.Value;
while (!buffer.IsEmpty)
{
var segment = buffer;
var overLength = false;
if (segment.Length > maxMessageSize)
{
segment = segment.Slice(segment.Start, maxMessageSize);
overLength = true;
}
if (protocol.TryParseMessage(ref segment, binder, out var message))
{
await _dispatcher.DispatchMessageAsync(connection, message);
}
else if (overLength)
{
throw new InvalidDataException($"The maximum message size of {maxMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
}
else
{
// No need to update the buffer since we didn't parse anything
break;
}
// Update the buffer to the remaining segment
buffer = buffer.Slice(segment.Start);
}
}
}

View File

@ -36,6 +36,11 @@ namespace Microsoft.AspNetCore.SignalR
/// </summary>
public IList<string> SupportedProtocols { get; set; } = null;
/// <summary>
/// Gets or sets the maximum message size of a single incoming hub message. The default is 32KB.
/// </summary>
public long? MaximumReceiveMessageSize { get; set; } = null;
/// <summary>
/// Gets or sets a value indicating whether detailed error messages are sent to the client.
/// Detailed error messages include details from exceptions thrown on the server.

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
@ -16,6 +16,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
internal static TimeSpan DefaultClientTimeoutInterval => TimeSpan.FromSeconds(30);
internal const int DefaultMaximumMessageSize = 32 * 1024 * 1024;
private readonly List<string> _defaultProtocols = new List<string>();
public HubOptionsSetup(IEnumerable<IHubProtocol> protocols)
@ -40,6 +42,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal
options.HandshakeTimeout = DefaultHandshakeTimeout;
}
if (options.MaximumReceiveMessageSize == null)
{
options.MaximumReceiveMessageSize = DefaultMaximumMessageSize;
}
if (options.SupportedProtocols == null)
{
options.SupportedProtocols = new List<string>();

View File

@ -72,7 +72,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return services.BuildServiceProvider();
}
public static Connections.ConnectionHandler GetHubConnectionHandler(Type hubType, Action<ServiceCollection> addServices = null, ILoggerFactory loggerFactory = null)
public static Connections.ConnectionHandler GetHubConnectionHandler(Type hubType, ILoggerFactory loggerFactory = null, Action<ServiceCollection> addServices = null)
{
var serviceProvider = CreateServiceProvider(addServices, loggerFactory);
return (Connections.ConnectionHandler)serviceProvider.GetService(GetConnectionHandlerType(hubType));

View File

@ -5,6 +5,7 @@ using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Security.Claims;
using System.Text;
@ -452,6 +453,191 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task HubMessageOverTheMaxMessageSizeThrows()
{
var payload = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"hello\"]}\u001e");
var maximumMessageSize = payload.Length - 10;
InvalidDataException exception = null;
bool ExpectedErrors(WriteContext writeContext)
{
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide))
{
exception = ide;
return true;
}
return false;
}
using (StartVerifiableLog(ExpectedErrors))
{
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await client.Connection.Application.Output.WriteAsync(payload);
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
Assert.NotNull(exception);
Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
}
[Fact]
public async Task ChunkedHubMessageOverTheMaxMessageSizeThrows()
{
var payload = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"hello\"]}\u001e");
var maximumMessageSize = payload.Length - 10;
InvalidDataException exception = null;
bool ExpectedErrors(WriteContext writeContext)
{
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide))
{
exception = ide;
return true;
}
return false;
}
using (StartVerifiableLog(ExpectedErrors))
{
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await client.Connection.Application.Output.WriteAsync(payload.AsMemory(0, payload.Length / 2));
await client.Connection.Application.Output.WriteAsync(payload.AsMemory(payload.Length / 2));
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
Assert.NotNull(exception);
Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
}
[Fact]
public async Task ManyHubMessagesOneOverTheMaxMessageSizeThrows()
{
var payload1 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"one\"]}\u001e");
var payload2 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"2\", \"target\": \"Echo\", \"arguments\":[\"two\"]}\u001e");
var payload3 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"3\", \"target\": \"Echo\", \"arguments\":[\"three\"]}\u001e");
// Between the first and the second payload so we'll end up slicing with some remaining in the slice for
// the next message
var maximumMessageSize = payload1.Length + 1;
InvalidDataException exception = null;
bool ExpectedErrors(WriteContext writeContext)
{
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && (writeContext.Exception is InvalidDataException ide))
{
exception = ide;
return true;
}
return false;
}
using (StartVerifiableLog(ExpectedErrors))
{
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
client.Connection.Application.Output.Write(payload1);
client.Connection.Application.Output.Write(payload2);
client.Connection.Application.Output.Write(payload3);
await client.Connection.Application.Output.FlushAsync();
// 2 invocations should be processed
var completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
Assert.NotNull(completionMessage);
Assert.Equal("1", completionMessage.InvocationId);
Assert.Equal("one", completionMessage.Result);
completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
Assert.NotNull(completionMessage);
Assert.Equal("2", completionMessage.InvocationId);
Assert.Equal("two", completionMessage.Result);
// We never receive the 3rd message since it was over the maximum message size
CloseMessage closeMessage = await client.ReadAsync().OrTimeout() as CloseMessage;
Assert.NotNull(closeMessage);
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
Assert.NotNull(exception);
Assert.Equal(exception.Message, $"The maximum message size of {maximumMessageSize}B was exceeded. The message size can be configured in AddHubOptions.");
}
[Fact]
public async Task ManyHubMessagesUnderTheMessageSizeButConfiguredWithMax()
{
var payload1 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"one\"]}\u001e");
var payload2 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"2\", \"target\": \"Echo\", \"arguments\":[\"two\"]}\u001e");
var payload3 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"3\", \"target\": \"Echo\", \"arguments\":[\"three\"]}\u001e");
// Bigger than all 3 messages
var maximumMessageSize = payload3.Length + 10;
using (StartVerifiableLog())
{
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), LoggerFactory,
services => services.AddSignalR().AddHubOptions<HubT>(o => o.MaximumReceiveMessageSize = maximumMessageSize));
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
client.Connection.Application.Output.Write(payload1);
client.Connection.Application.Output.Write(payload2);
client.Connection.Application.Output.Write(payload3);
await client.Connection.Application.Output.FlushAsync();
// 2 invocations should be processed
var completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
Assert.NotNull(completionMessage);
Assert.Equal("1", completionMessage.InvocationId);
Assert.Equal("one", completionMessage.Result);
completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
Assert.NotNull(completionMessage);
Assert.Equal("2", completionMessage.InvocationId);
Assert.Equal("two", completionMessage.Result);
completionMessage = await client.ReadAsync().OrTimeout() as CompletionMessage;
Assert.NotNull(completionMessage);
Assert.Equal("3", completionMessage.InvocationId);
Assert.Equal("three", completionMessage.Result);
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task HandshakeFailureFromIncompatibleProtocolVersionSendsResponseWithError()
{
@ -2789,7 +2975,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
foreach (string id in ids)
{
await client.BeginUploadStreamAsync("invocation_"+id, nameof(MethodHub.StreamingConcat), new[] { id }, Array.Empty<object>());
await client.BeginUploadStreamAsync("invocation_" + id, nameof(MethodHub.StreamingConcat), new[] { id }, Array.Empty<object>());
}
var words = new[] { "zygapophyses", "qwerty", "abcd" };
@ -2868,7 +3054,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
}
[Fact]
public async Task ServerReportsProtocolMinorVersion()
{
@ -2881,7 +3067,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
testProtocol.Setup(m => m.TransferFormat).Returns(TransferFormat.Binary);
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT),
(services) => services.AddSingleton<IHubProtocol>(testProtocol.Object), LoggerFactory);
LoggerFactory, (services) => services.AddSingleton<IHubProtocol>(testProtocol.Object));
using (var client = new TestClient(protocol: testProtocol.Object))
{