Restrict HubProtocol on Server + HubOptions<THub> (#1492)

This commit is contained in:
Mikael Mengistu 2018-03-07 01:31:56 +00:00 committed by GitHub
parent 0eb2b96c45
commit 1b9313287b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 200 additions and 48 deletions

View File

@ -35,7 +35,6 @@ namespace Microsoft.AspNetCore.SignalR
private readonly CancellationTokenSource _connectionAbortedTokenSource = new CancellationTokenSource();
private readonly TaskCompletionSource<object> _abortCompletedTcs = new TaskCompletionSource<object>();
private readonly long _keepAliveDuration;
private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1);
private long _lastSendTimestamp = Stopwatch.GetTimestamp();
@ -136,7 +135,7 @@ namespace Microsoft.AspNetCore.SignalR
Task.Factory.StartNew(_abortedCallback, this);
}
internal async Task<bool> NegotiateAsync(TimeSpan timeout, IHubProtocolResolver protocolResolver, IUserIdProvider userIdProvider)
internal async Task<bool> NegotiateAsync(TimeSpan timeout, IList<string> supportedProtocols, IHubProtocolResolver protocolResolver, IUserIdProvider userIdProvider)
{
try
{
@ -157,7 +156,7 @@ namespace Microsoft.AspNetCore.SignalR
{
if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage, out consumed, out examined))
{
var protocol = protocolResolver.GetProtocol(negotiationMessage.Protocol, this);
var protocol = protocolResolver.GetProtocol(negotiationMessage.Protocol, supportedProtocols, this);
var transportCapabilities = Features.Get<IConnectionTransportFeature>()?.TransportCapabilities
?? throw new InvalidOperationException("Unable to read transport capabilities.");

View File

@ -17,13 +17,15 @@ namespace Microsoft.AspNetCore.SignalR
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger<HubEndPoint<THub>> _logger;
private readonly IHubProtocolResolver _protocolResolver;
private readonly HubOptions _hubOptions;
private readonly HubOptions<THub> _hubOptions;
private readonly HubOptions _globalHubOptions;
private readonly IUserIdProvider _userIdProvider;
private readonly HubDispatcher<THub> _dispatcher;
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
IHubProtocolResolver protocolResolver,
IOptions<HubOptions> hubOptions,
IOptions<HubOptions> globalHubOptions,
IOptions<HubOptions<THub>> hubOptions,
ILoggerFactory loggerFactory,
IUserIdProvider userIdProvider,
HubDispatcher<THub> dispatcher)
@ -32,6 +34,7 @@ namespace Microsoft.AspNetCore.SignalR
_lifetimeManager = lifetimeManager;
_loggerFactory = loggerFactory;
_hubOptions = hubOptions.Value;
_globalHubOptions = globalHubOptions.Value;
_logger = loggerFactory.CreateLogger<HubEndPoint<THub>>();
_userIdProvider = userIdProvider;
_dispatcher = dispatcher;
@ -39,9 +42,20 @@ namespace Microsoft.AspNetCore.SignalR
public async Task OnConnectedAsync(ConnectionContext connection)
{
var connectionContext = new HubConnectionContext(connection, _hubOptions.KeepAliveInterval, _loggerFactory);
// We check to see if HubOptions<THub> are set because those take precedence over global hub options.
// Then set the keepAlive and negotiateTimeout values to the defaults in HubOptionsSetup incase they were explicitly set to null.
var keepAlive = _hubOptions.KeepAliveInterval ?? _globalHubOptions.KeepAliveInterval ?? HubOptionsSetup.DefaultKeepAliveInterval;
var negotiateTimeout = _hubOptions.NegotiateTimeout ?? _globalHubOptions.NegotiateTimeout ?? HubOptionsSetup.DefaultNegotiateTimeout;
var supportedProtocols = _hubOptions.SupportedProtocols ?? _globalHubOptions.SupportedProtocols;
if (!await connectionContext.NegotiateAsync(_hubOptions.NegotiateTimeout, _protocolResolver, _userIdProvider))
if (supportedProtocols != null && supportedProtocols.Count == 0)
{
throw new InvalidOperationException("There are no supported protocols");
}
var connectionContext = new HubConnectionContext(connection, keepAlive, _loggerFactory);
if (!await connectionContext.NegotiateAsync(negotiateTimeout, supportedProtocols, _protocolResolver, _userIdProvider))
{
return;
}
@ -57,7 +71,6 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task RunHubAsync(HubConnectionContext connection)
{
try

View File

@ -2,27 +2,20 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
namespace Microsoft.AspNetCore.SignalR
{
public class HubOptions
{
/// <summary>
/// The default keep-alive interval. This is set to exactly half of the default client timeout window,
/// to ensure a ping can arrive in time to satisfy the client timeout.
/// </summary>
public static readonly TimeSpan DefaultKeepAliveInterval = TimeSpan.FromSeconds(15);
// NegotiateTimeout and KeepAliveInterval are set to null here to help identify when
// local hub options have been set. Global default values are set in HubOptionsSetup.
// SupportedProtocols being null is the true default value, and it represents support
// for all available protocols.
public TimeSpan? NegotiateTimeout { get; set; } = null;
public TimeSpan NegotiateTimeout { get; set; } = TimeSpan.FromSeconds(5);
public TimeSpan? KeepAliveInterval { get; set; } = null;
/// <summary>
/// The interval at which keep-alive messages should be sent. The default interval
/// is 15 seconds.
/// </summary>
/// <remarks>
/// This interval is not used by the Long Polling transport as it has inherent keep-alive
/// functionality because of the polling mechanism.
/// </remarks>
public TimeSpan KeepAliveInterval { get; set; } = DefaultKeepAliveInterval;
public IList<string> SupportedProtocols { get; set; } = null;
}
}

View File

@ -0,0 +1,54 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.SignalR
{
public class HubOptionsSetup : IConfigureOptions<HubOptions>
{
internal static TimeSpan DefaultNegotiateTimeout => TimeSpan.FromSeconds(5);
internal static TimeSpan DefaultKeepAliveInterval => TimeSpan.FromSeconds(15);
private readonly List<string> _protocols = new List<string>();
public HubOptionsSetup(IEnumerable<IHubProtocol> protocols)
{
foreach (var hubProtocol in protocols)
{
_protocols.Add(hubProtocol.Name);
}
}
public void Configure(HubOptions options)
{
if (options.SupportedProtocols == null)
{
options.SupportedProtocols = new List<string>();
}
if (options.KeepAliveInterval == null)
{
// The default keep - alive interval.This is set to exactly half of the default client timeout window,
// to ensure a ping can arrive in time to satisfy the client timeout.
options.KeepAliveInterval = DefaultKeepAliveInterval;
}
if (options.NegotiateTimeout == null)
{
options.NegotiateTimeout = DefaultNegotiateTimeout;
}
foreach (var protocol in _protocols)
{
options.SupportedProtocols.Add(protocol);
}
}
}
}

View File

@ -0,0 +1,24 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.Extensions.Options;
using System.Collections.Generic;
namespace Microsoft.AspNetCore.SignalR
{
public class HubOptionsSetup<THub> : IConfigureOptions<HubOptions<THub>> where THub : Hub
{
private readonly HubOptions _hubOptions;
public HubOptionsSetup(IOptions<HubOptions> options)
{
_hubOptions = options.Value;
}
public void Configure(HubOptions<THub> options)
{
options.SupportedProtocols = _hubOptions.SupportedProtocols;
options.KeepAliveInterval = _hubOptions.KeepAliveInterval;
options.NegotiateTimeout = _hubOptions.NegotiateTimeout;
}
}
}

View File

@ -0,0 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.SignalR
{
public class HubOptions<THub> : HubOptions where THub : Hub { }
}

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
@ -12,13 +13,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
public class DefaultHubProtocolResolver : IHubProtocolResolver
{
private readonly IOptions<HubOptions> _options;
private readonly ILogger<DefaultHubProtocolResolver> _logger;
private readonly Dictionary<string, IHubProtocol> _availableProtocols;
public DefaultHubProtocolResolver(IOptions<HubOptions> options, IEnumerable<IHubProtocol> availableProtocols, ILogger<DefaultHubProtocolResolver> logger)
public DefaultHubProtocolResolver(IEnumerable<IHubProtocol> availableProtocols, ILogger<DefaultHubProtocolResolver> logger)
{
_options = options ?? throw new ArgumentNullException(nameof(options));
_logger = logger ?? NullLogger<DefaultHubProtocolResolver>.Instance;
_availableProtocols = new Dictionary<string, IHubProtocol>(StringComparer.OrdinalIgnoreCase);
@ -33,11 +32,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal
}
}
public IHubProtocol GetProtocol(string protocolName, HubConnectionContext connection)
public IHubProtocol GetProtocol(string protocolName, IList<string> supportedProtocols, HubConnectionContext connection)
{
protocolName = protocolName ?? throw new ArgumentNullException(nameof(protocolName));
if (_availableProtocols.TryGetValue(protocolName, out var protocol))
if (_availableProtocols.TryGetValue(protocolName, out var protocol) && (supportedProtocols == null || supportedProtocols.Contains(protocolName, StringComparer.OrdinalIgnoreCase)))
{
Log.FoundImplementationForProtocol(_logger, protocolName);
return protocol;

View File

@ -1,12 +1,13 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
namespace Microsoft.AspNetCore.SignalR.Internal
{
public interface IHubProtocolResolver
{
IHubProtocol GetProtocol(string protocolName, HubConnectionContext connection);
IHubProtocol GetProtocol(string protocolName, IList<string> supportedProtocols, HubConnectionContext connection);
}
}

View File

@ -0,0 +1,19 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Options;
namespace Microsoft.Extensions.DependencyInjection
{
public static class HubOptionsDependencyInjectionExtensions
{
public static ISignalRBuilder AddHubOptions<THub>(this ISignalRBuilder signalrBuilder, Action<HubOptions<THub>> options) where THub : Hub
{
signalrBuilder.Services.AddSingleton<IConfigureOptions<HubOptions<THub>>, HubOptionsSetup<THub>>();
signalrBuilder.Services.Configure(options);
return signalrBuilder;
}
}
}

View File

@ -3,6 +3,7 @@
using System;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Options;
namespace Microsoft.Extensions.DependencyInjection
{
@ -10,15 +11,16 @@ namespace Microsoft.Extensions.DependencyInjection
{
public static ISignalRBuilder AddSignalR(this IServiceCollection services)
{
return AddSignalR(services, _ => { });
}
public static ISignalRBuilder AddSignalR(this IServiceCollection services, Action<HubOptions> configure)
{
services.Configure(configure);
services.AddSockets();
services.AddSingleton<IConfigureOptions<HubOptions>, HubOptionsSetup>();
return services.AddSignalRCore()
.AddJsonProtocol();
}
public static ISignalRBuilder AddSignalR(this IServiceCollection services, Action<HubOptions> options)
{
return services.Configure(options)
.AddSignalR();
}
}
}

View File

@ -275,9 +275,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var serviceProvider = HubEndPointTestUtils.CreateServiceProvider(services =>
{
services.Configure<HubOptions>(hubOptions =>
services.Configure<HubOptions>(options =>
{
hubOptions.NegotiateTimeout = TimeSpan.FromMilliseconds(5);
options.NegotiateTimeout = TimeSpan.FromMilliseconds(5);
});
});
var endPoint = serviceProvider.GetService<HubEndPoint<SimpleHub>>();
@ -1327,9 +1327,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public async Task HubsCanStreamResponses(string method, IHubProtocol protocol)
{
var serviceProvider = HubEndPointTestUtils.CreateServiceProvider();
var endPoint = serviceProvider.GetService<HubEndPoint<StreamingHub>>();
var invocationBinder = new Mock<IInvocationBinder>();
invocationBinder.Setup(b => b.GetReturnType(It.IsAny<string>())).Returns(typeof(string));

View File

@ -8,7 +8,6 @@ using Microsoft.AspNetCore.Protocols;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Moq;
using Xunit;
@ -16,24 +15,54 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests
{
public class DefaultHubProtocolResolverTests
{
private static readonly List<string> AllProtocolNames = new List<string> { "json", "messagepack" };
private static readonly IList<IHubProtocol> AllProtocols = new List<IHubProtocol>()
{
new JsonHubProtocol(),
new MessagePackHubProtocol()
};
[Theory]
[MemberData(nameof(HubProtocols))]
public void DefaultHubProtocolResolverTestsCanCreateAllProtocols(IHubProtocol protocol)
{
var connection = new Mock<ConnectionContext>();
connection.Setup(m => m.Features).Returns(new FeatureCollection());
var mockConnection = new Mock<HubConnectionContext>(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true };
var resolver = new DefaultHubProtocolResolver(AllProtocols, NullLogger<DefaultHubProtocolResolver>.Instance);
Assert.IsType(
protocol.GetType(),
resolver.GetProtocol(protocol.Name, AllProtocolNames, mockConnection.Object));
}
[Theory]
[MemberData(nameof(HubProtocols))]
public void DefaultHubProtocolResolverCreatesProtocolswhenSupoortedProtocolsIsNull(IHubProtocol protocol)
{
var connection = new Mock<ConnectionContext>();
connection.Setup(m => m.Features).Returns(new FeatureCollection());
var mockConnection = new Mock<HubConnectionContext>(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true };
List<string> supportedProtocols = null;
var resolver = new DefaultHubProtocolResolver(AllProtocols, NullLogger<DefaultHubProtocolResolver>.Instance);
Assert.IsType(
protocol.GetType(),
resolver.GetProtocol(protocol.Name, supportedProtocols, mockConnection.Object));
}
[Theory]
[MemberData(nameof(HubProtocols))]
public void DefaultHubProtocolResolverTestsCanCreateSupportedProtocols(IHubProtocol protocol)
{
var connection = new Mock<ConnectionContext>();
connection.Setup(m => m.Features).Returns(new FeatureCollection());
var mockConnection = new Mock<HubConnectionContext>(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true };
var resolver = new DefaultHubProtocolResolver(Options.Create(new HubOptions()), AllProtocols, NullLogger<DefaultHubProtocolResolver>.Instance);
var supportedProtocols = new List<string> { protocol.Name };
var resolver = new DefaultHubProtocolResolver(AllProtocols, NullLogger<DefaultHubProtocolResolver>.Instance);
Assert.IsType(
protocol.GetType(),
resolver.GetProtocol(protocol.Name, mockConnection.Object));
resolver.GetProtocol(protocol.Name, supportedProtocols, mockConnection.Object));
}
[Fact]
@ -42,9 +71,9 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests
var connection = new Mock<ConnectionContext>();
connection.Setup(m => m.Features).Returns(new FeatureCollection());
var mockConnection = new Mock<HubConnectionContext>(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true };
var resolver = new DefaultHubProtocolResolver(Options.Create(new HubOptions()), AllProtocols, NullLogger<DefaultHubProtocolResolver>.Instance);
var resolver = new DefaultHubProtocolResolver(AllProtocols, NullLogger<DefaultHubProtocolResolver>.Instance);
var exception = Assert.Throws<ArgumentNullException>(
() => resolver.GetProtocol(null, mockConnection.Object));
() => resolver.GetProtocol(null, AllProtocolNames, mockConnection.Object));
Assert.Equal("protocolName", exception.ParamName);
}
@ -55,20 +84,34 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests
var connection = new Mock<ConnectionContext>();
connection.Setup(m => m.Features).Returns(new FeatureCollection());
var mockConnection = new Mock<HubConnectionContext>(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true };
var resolver = new DefaultHubProtocolResolver(Options.Create(new HubOptions()), AllProtocols, NullLogger<DefaultHubProtocolResolver>.Instance);
var resolver = new DefaultHubProtocolResolver(AllProtocols, NullLogger<DefaultHubProtocolResolver>.Instance);
var exception = Assert.Throws<NotSupportedException>(
() => resolver.GetProtocol("notARealProtocol", mockConnection.Object));
() => resolver.GetProtocol("notARealProtocol", AllProtocolNames, mockConnection.Object));
Assert.Equal("The protocol 'notARealProtocol' is not supported.", exception.Message);
}
[Theory]
[MemberData(nameof(HubProtocols))]
public void DefaultHubProtocolResolverThrowsWhenNoProtocolsAreSupported(IHubProtocol protocol)
{
var connection = new Mock<ConnectionContext>();
connection.Setup(m => m.Features).Returns(new FeatureCollection());
var mockConnection = new Mock<HubConnectionContext>(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true };
var supportedProtocols= new List<string>();
var resolver = new DefaultHubProtocolResolver(AllProtocols, NullLogger<DefaultHubProtocolResolver>.Instance);
var exception = Assert.Throws<NotSupportedException>(
() => resolver.GetProtocol(protocol.Name, supportedProtocols, mockConnection.Object));
Assert.Equal($"The protocol '{protocol.Name}' is not supported.", exception.Message);
}
[Fact]
public void RegisteringMultipleHubProtocolsFails()
{
var connection = new Mock<ConnectionContext>();
connection.Setup(m => m.Features).Returns(new FeatureCollection());
var mockConnection = new Mock<HubConnectionContext>(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true };
var exception = Assert.Throws<InvalidOperationException>(() => new DefaultHubProtocolResolver(Options.Create(new HubOptions()), new[] {
var exception = Assert.Throws<InvalidOperationException>(() => new DefaultHubProtocolResolver(new[] {
new JsonHubProtocol(),
new JsonHubProtocol()
}, NullLogger<DefaultHubProtocolResolver>.Instance));