diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 8702373375..583fcba410 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -35,7 +35,6 @@ namespace Microsoft.AspNetCore.SignalR private readonly CancellationTokenSource _connectionAbortedTokenSource = new CancellationTokenSource(); private readonly TaskCompletionSource _abortCompletedTcs = new TaskCompletionSource(); private readonly long _keepAliveDuration; - private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1); private long _lastSendTimestamp = Stopwatch.GetTimestamp(); @@ -136,7 +135,7 @@ namespace Microsoft.AspNetCore.SignalR Task.Factory.StartNew(_abortedCallback, this); } - internal async Task NegotiateAsync(TimeSpan timeout, IHubProtocolResolver protocolResolver, IUserIdProvider userIdProvider) + internal async Task NegotiateAsync(TimeSpan timeout, IList supportedProtocols, IHubProtocolResolver protocolResolver, IUserIdProvider userIdProvider) { try { @@ -157,7 +156,7 @@ namespace Microsoft.AspNetCore.SignalR { if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage, out consumed, out examined)) { - var protocol = protocolResolver.GetProtocol(negotiationMessage.Protocol, this); + var protocol = protocolResolver.GetProtocol(negotiationMessage.Protocol, supportedProtocols, this); var transportCapabilities = Features.Get()?.TransportCapabilities ?? throw new InvalidOperationException("Unable to read transport capabilities."); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs index 253b1ea9ed..6a544fb81f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs @@ -17,13 +17,15 @@ namespace Microsoft.AspNetCore.SignalR private readonly ILoggerFactory _loggerFactory; private readonly ILogger> _logger; private readonly IHubProtocolResolver _protocolResolver; - private readonly HubOptions _hubOptions; + private readonly HubOptions _hubOptions; + private readonly HubOptions _globalHubOptions; private readonly IUserIdProvider _userIdProvider; private readonly HubDispatcher _dispatcher; public HubEndPoint(HubLifetimeManager lifetimeManager, IHubProtocolResolver protocolResolver, - IOptions hubOptions, + IOptions globalHubOptions, + IOptions> hubOptions, ILoggerFactory loggerFactory, IUserIdProvider userIdProvider, HubDispatcher dispatcher) @@ -32,6 +34,7 @@ namespace Microsoft.AspNetCore.SignalR _lifetimeManager = lifetimeManager; _loggerFactory = loggerFactory; _hubOptions = hubOptions.Value; + _globalHubOptions = globalHubOptions.Value; _logger = loggerFactory.CreateLogger>(); _userIdProvider = userIdProvider; _dispatcher = dispatcher; @@ -39,9 +42,20 @@ namespace Microsoft.AspNetCore.SignalR public async Task OnConnectedAsync(ConnectionContext connection) { - var connectionContext = new HubConnectionContext(connection, _hubOptions.KeepAliveInterval, _loggerFactory); + // We check to see if HubOptions are set because those take precedence over global hub options. + // Then set the keepAlive and negotiateTimeout values to the defaults in HubOptionsSetup incase they were explicitly set to null. + var keepAlive = _hubOptions.KeepAliveInterval ?? _globalHubOptions.KeepAliveInterval ?? HubOptionsSetup.DefaultKeepAliveInterval; + var negotiateTimeout = _hubOptions.NegotiateTimeout ?? _globalHubOptions.NegotiateTimeout ?? HubOptionsSetup.DefaultNegotiateTimeout; + var supportedProtocols = _hubOptions.SupportedProtocols ?? _globalHubOptions.SupportedProtocols; - if (!await connectionContext.NegotiateAsync(_hubOptions.NegotiateTimeout, _protocolResolver, _userIdProvider)) + if (supportedProtocols != null && supportedProtocols.Count == 0) + { + throw new InvalidOperationException("There are no supported protocols"); + } + + var connectionContext = new HubConnectionContext(connection, keepAlive, _loggerFactory); + + if (!await connectionContext.NegotiateAsync(negotiateTimeout, supportedProtocols, _protocolResolver, _userIdProvider)) { return; } @@ -57,7 +71,6 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task RunHubAsync(HubConnectionContext connection) { try diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubOptions.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubOptions.cs index 320d0097b8..2a15354f20 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubOptions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubOptions.cs @@ -2,27 +2,20 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; namespace Microsoft.AspNetCore.SignalR { public class HubOptions { - /// - /// The default keep-alive interval. This is set to exactly half of the default client timeout window, - /// to ensure a ping can arrive in time to satisfy the client timeout. - /// - public static readonly TimeSpan DefaultKeepAliveInterval = TimeSpan.FromSeconds(15); + // NegotiateTimeout and KeepAliveInterval are set to null here to help identify when + // local hub options have been set. Global default values are set in HubOptionsSetup. + // SupportedProtocols being null is the true default value, and it represents support + // for all available protocols. + public TimeSpan? NegotiateTimeout { get; set; } = null; - public TimeSpan NegotiateTimeout { get; set; } = TimeSpan.FromSeconds(5); + public TimeSpan? KeepAliveInterval { get; set; } = null; - /// - /// The interval at which keep-alive messages should be sent. The default interval - /// is 15 seconds. - /// - /// - /// This interval is not used by the Long Polling transport as it has inherent keep-alive - /// functionality because of the polling mechanism. - /// - public TimeSpan KeepAliveInterval { get; set; } = DefaultKeepAliveInterval; + public IList SupportedProtocols { get; set; } = null; } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubOptionsSetup.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubOptionsSetup.cs new file mode 100644 index 0000000000..c2b3aee5e4 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubOptionsSetup.cs @@ -0,0 +1,54 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.SignalR +{ + public class HubOptionsSetup : IConfigureOptions + { + internal static TimeSpan DefaultNegotiateTimeout => TimeSpan.FromSeconds(5); + + internal static TimeSpan DefaultKeepAliveInterval => TimeSpan.FromSeconds(15); + + private readonly List _protocols = new List(); + + public HubOptionsSetup(IEnumerable protocols) + { + foreach (var hubProtocol in protocols) + { + _protocols.Add(hubProtocol.Name); + } + } + + public void Configure(HubOptions options) + { + if (options.SupportedProtocols == null) + { + options.SupportedProtocols = new List(); + } + + if (options.KeepAliveInterval == null) + { + // The default keep - alive interval.This is set to exactly half of the default client timeout window, + // to ensure a ping can arrive in time to satisfy the client timeout. + options.KeepAliveInterval = DefaultKeepAliveInterval; + } + + if (options.NegotiateTimeout == null) + { + options.NegotiateTimeout = DefaultNegotiateTimeout; + } + + foreach (var protocol in _protocols) + { + options.SupportedProtocols.Add(protocol); + } + } + } +} + diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubOptionsSetup`T.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubOptionsSetup`T.cs new file mode 100644 index 0000000000..2dba501a4d --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubOptionsSetup`T.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.Extensions.Options; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.SignalR +{ + public class HubOptionsSetup : IConfigureOptions> where THub : Hub + { + private readonly HubOptions _hubOptions; + public HubOptionsSetup(IOptions options) + { + _hubOptions = options.Value; + } + + public void Configure(HubOptions options) + { + options.SupportedProtocols = _hubOptions.SupportedProtocols; + options.KeepAliveInterval = _hubOptions.KeepAliveInterval; + options.NegotiateTimeout = _hubOptions.NegotiateTimeout; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubOptions`T.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubOptions`T.cs new file mode 100644 index 0000000000..95ddcf0790 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubOptions`T.cs @@ -0,0 +1,7 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.SignalR +{ + public class HubOptions : HubOptions where THub : Hub { } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubProtocolResolver.cs index bd021f1524..49a56d4c3d 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubProtocolResolver.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubProtocolResolver.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -12,13 +13,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal { public class DefaultHubProtocolResolver : IHubProtocolResolver { - private readonly IOptions _options; private readonly ILogger _logger; private readonly Dictionary _availableProtocols; - public DefaultHubProtocolResolver(IOptions options, IEnumerable availableProtocols, ILogger logger) + public DefaultHubProtocolResolver(IEnumerable availableProtocols, ILogger logger) { - _options = options ?? throw new ArgumentNullException(nameof(options)); _logger = logger ?? NullLogger.Instance; _availableProtocols = new Dictionary(StringComparer.OrdinalIgnoreCase); @@ -33,11 +32,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal } } - public IHubProtocol GetProtocol(string protocolName, HubConnectionContext connection) + public IHubProtocol GetProtocol(string protocolName, IList supportedProtocols, HubConnectionContext connection) { protocolName = protocolName ?? throw new ArgumentNullException(nameof(protocolName)); - if (_availableProtocols.TryGetValue(protocolName, out var protocol)) + if (_availableProtocols.TryGetValue(protocolName, out var protocol) && (supportedProtocols == null || supportedProtocols.Contains(protocolName, StringComparer.OrdinalIgnoreCase))) { Log.FoundImplementationForProtocol(_logger, protocolName); return protocol; diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/IHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/IHubProtocolResolver.cs index 29d1c392d8..8d45e69fc6 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/IHubProtocolResolver.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/IHubProtocolResolver.cs @@ -1,12 +1,13 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections.Generic; using Microsoft.AspNetCore.SignalR.Internal.Protocol; namespace Microsoft.AspNetCore.SignalR.Internal { public interface IHubProtocolResolver { - IHubProtocol GetProtocol(string protocolName, HubConnectionContext connection); + IHubProtocol GetProtocol(string protocolName, IList supportedProtocols, HubConnectionContext connection); } } diff --git a/src/Microsoft.AspNetCore.SignalR/HubOptionsDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR/HubOptionsDependencyInjectionExtensions.cs new file mode 100644 index 0000000000..d8186fa566 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR/HubOptionsDependencyInjectionExtensions.cs @@ -0,0 +1,19 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Extensions.Options; + +namespace Microsoft.Extensions.DependencyInjection +{ + public static class HubOptionsDependencyInjectionExtensions + { + public static ISignalRBuilder AddHubOptions(this ISignalRBuilder signalrBuilder, Action> options) where THub : Hub + { + signalrBuilder.Services.AddSingleton>, HubOptionsSetup>(); + signalrBuilder.Services.Configure(options); + return signalrBuilder; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs index 169ae0cc3d..ad547f20fa 100644 --- a/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs @@ -3,6 +3,7 @@ using System; using Microsoft.AspNetCore.SignalR; +using Microsoft.Extensions.Options; namespace Microsoft.Extensions.DependencyInjection { @@ -10,15 +11,16 @@ namespace Microsoft.Extensions.DependencyInjection { public static ISignalRBuilder AddSignalR(this IServiceCollection services) { - return AddSignalR(services, _ => { }); - } - - public static ISignalRBuilder AddSignalR(this IServiceCollection services, Action configure) - { - services.Configure(configure); services.AddSockets(); + services.AddSingleton, HubOptionsSetup>(); return services.AddSignalRCore() .AddJsonProtocol(); } + + public static ISignalRBuilder AddSignalR(this IServiceCollection services, Action options) + { + return services.Configure(options) + .AddSignalR(); + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index e3192dd582..493670b0a9 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -275,9 +275,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var serviceProvider = HubEndPointTestUtils.CreateServiceProvider(services => { - services.Configure(hubOptions => + services.Configure(options => { - hubOptions.NegotiateTimeout = TimeSpan.FromMilliseconds(5); + options.NegotiateTimeout = TimeSpan.FromMilliseconds(5); }); }); var endPoint = serviceProvider.GetService>(); @@ -1327,9 +1327,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public async Task HubsCanStreamResponses(string method, IHubProtocol protocol) { var serviceProvider = HubEndPointTestUtils.CreateServiceProvider(); - var endPoint = serviceProvider.GetService>(); - var invocationBinder = new Mock(); invocationBinder.Setup(b => b.GetReturnType(It.IsAny())).Returns(typeof(string)); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs index ed825d32fa..0bda0ad878 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/Internal/DefaultHubProtocolResolverTests.cs @@ -8,7 +8,6 @@ using Microsoft.AspNetCore.Protocols; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.Extensions.Logging.Abstractions; -using Microsoft.Extensions.Options; using Moq; using Xunit; @@ -16,24 +15,54 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests { public class DefaultHubProtocolResolverTests { + private static readonly List AllProtocolNames = new List { "json", "messagepack" }; + private static readonly IList AllProtocols = new List() { new JsonHubProtocol(), new MessagePackHubProtocol() }; + + [Theory] + [MemberData(nameof(HubProtocols))] + public void DefaultHubProtocolResolverTestsCanCreateAllProtocols(IHubProtocol protocol) + { + var connection = new Mock(); + connection.Setup(m => m.Features).Returns(new FeatureCollection()); + var mockConnection = new Mock(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true }; + var resolver = new DefaultHubProtocolResolver(AllProtocols, NullLogger.Instance); + Assert.IsType( + protocol.GetType(), + resolver.GetProtocol(protocol.Name, AllProtocolNames, mockConnection.Object)); + } + + [Theory] + [MemberData(nameof(HubProtocols))] + public void DefaultHubProtocolResolverCreatesProtocolswhenSupoortedProtocolsIsNull(IHubProtocol protocol) + { + var connection = new Mock(); + connection.Setup(m => m.Features).Returns(new FeatureCollection()); + var mockConnection = new Mock(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true }; + List supportedProtocols = null; + var resolver = new DefaultHubProtocolResolver(AllProtocols, NullLogger.Instance); + Assert.IsType( + protocol.GetType(), + resolver.GetProtocol(protocol.Name, supportedProtocols, mockConnection.Object)); + } + [Theory] [MemberData(nameof(HubProtocols))] public void DefaultHubProtocolResolverTestsCanCreateSupportedProtocols(IHubProtocol protocol) { - var connection = new Mock(); connection.Setup(m => m.Features).Returns(new FeatureCollection()); var mockConnection = new Mock(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true }; - var resolver = new DefaultHubProtocolResolver(Options.Create(new HubOptions()), AllProtocols, NullLogger.Instance); + var supportedProtocols = new List { protocol.Name }; + var resolver = new DefaultHubProtocolResolver(AllProtocols, NullLogger.Instance); Assert.IsType( protocol.GetType(), - resolver.GetProtocol(protocol.Name, mockConnection.Object)); + resolver.GetProtocol(protocol.Name, supportedProtocols, mockConnection.Object)); } [Fact] @@ -42,9 +71,9 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests var connection = new Mock(); connection.Setup(m => m.Features).Returns(new FeatureCollection()); var mockConnection = new Mock(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true }; - var resolver = new DefaultHubProtocolResolver(Options.Create(new HubOptions()), AllProtocols, NullLogger.Instance); + var resolver = new DefaultHubProtocolResolver(AllProtocols, NullLogger.Instance); var exception = Assert.Throws( - () => resolver.GetProtocol(null, mockConnection.Object)); + () => resolver.GetProtocol(null, AllProtocolNames, mockConnection.Object)); Assert.Equal("protocolName", exception.ParamName); } @@ -55,20 +84,34 @@ namespace Microsoft.AspNetCore.SignalR.Common.Protocol.Tests var connection = new Mock(); connection.Setup(m => m.Features).Returns(new FeatureCollection()); var mockConnection = new Mock(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true }; - var resolver = new DefaultHubProtocolResolver(Options.Create(new HubOptions()), AllProtocols, NullLogger.Instance); + var resolver = new DefaultHubProtocolResolver(AllProtocols, NullLogger.Instance); var exception = Assert.Throws( - () => resolver.GetProtocol("notARealProtocol", mockConnection.Object)); + () => resolver.GetProtocol("notARealProtocol", AllProtocolNames, mockConnection.Object)); Assert.Equal("The protocol 'notARealProtocol' is not supported.", exception.Message); } + [Theory] + [MemberData(nameof(HubProtocols))] + public void DefaultHubProtocolResolverThrowsWhenNoProtocolsAreSupported(IHubProtocol protocol) + { + var connection = new Mock(); + connection.Setup(m => m.Features).Returns(new FeatureCollection()); + var mockConnection = new Mock(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true }; + var supportedProtocols= new List(); + var resolver = new DefaultHubProtocolResolver(AllProtocols, NullLogger.Instance); + var exception = Assert.Throws( + () => resolver.GetProtocol(protocol.Name, supportedProtocols, mockConnection.Object)); + Assert.Equal($"The protocol '{protocol.Name}' is not supported.", exception.Message); + } + [Fact] public void RegisteringMultipleHubProtocolsFails() { var connection = new Mock(); connection.Setup(m => m.Features).Returns(new FeatureCollection()); var mockConnection = new Mock(connection.Object, TimeSpan.FromSeconds(30), NullLoggerFactory.Instance) { CallBase = true }; - var exception = Assert.Throws(() => new DefaultHubProtocolResolver(Options.Create(new HubOptions()), new[] { + var exception = Assert.Throws(() => new DefaultHubProtocolResolver(new[] { new JsonHubProtocol(), new JsonHubProtocol() }, NullLogger.Instance));