diff --git a/samples/SocketsSample/Startup.cs b/samples/SocketsSample/Startup.cs index 35a84ccae5..e93d81f04c 100644 --- a/samples/SocketsSample/Startup.cs +++ b/samples/SocketsSample/Startup.cs @@ -16,18 +16,17 @@ namespace SocketsSample // For more information on how to configure your application, visit http://go.microsoft.com/fwlink/?LinkID=398940 public void ConfigureServices(IServiceCollection services) { - services.AddRouting(); - services.AddSingleton(); services.AddSingleton(); - services.AddSignalR() - .AddSignalROptions(options => + services.AddSockets(); + + services.AddSignalR(options => { options.RegisterInvocationAdapter("protobuf"); options.RegisterInvocationAdapter("line"); }); - // .AddRedis(); + // .AddRedis(); services.AddSingleton(); services.AddSingleton(); diff --git a/src/Microsoft.AspNetCore.SignalR/DependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs similarity index 72% rename from src/Microsoft.AspNetCore.SignalR/DependencyInjectionExtensions.cs rename to src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs index 73cdef1eae..6991f13ddd 100644 --- a/src/Microsoft.AspNetCore.SignalR/DependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs @@ -8,10 +8,11 @@ using Microsoft.Extensions.Options; namespace Microsoft.Extensions.DependencyInjection { - public static class DependencyInjectionExtensions + public static class SignalRDependencyInjectionExtensions { public static ISignalRBuilder AddSignalR(this IServiceCollection services) { + services.AddSockets(); services.AddSingleton(typeof(HubLifetimeManager<>), typeof(DefaultHubLifetimeManager<>)); services.AddSingleton(typeof(IHubContext<>), typeof(HubContext<>)); services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>)); @@ -22,9 +23,14 @@ namespace Microsoft.Extensions.DependencyInjection return new SignalRBuilder(services); } - public static ISignalRBuilder AddSignalROptions(this ISignalRBuilder builder, Action configure) + public static ISignalRBuilder AddSignalR(this IServiceCollection services, Action setupAction) { - builder.Services.Configure(configure); + return services.AddSignalR().AddSignalROptions(setupAction); + } + + public static ISignalRBuilder AddSignalROptions(this ISignalRBuilder builder, Action setupAction) + { + builder.Services.Configure(setupAction); return builder; } } diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index 5052116bc9..46ae37698e 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -5,17 +5,22 @@ using System; using System.Collections.Concurrent; using System.IO.Pipelines; using System.Threading; +using Microsoft.AspNetCore.Hosting; namespace Microsoft.AspNetCore.Sockets { - public class ConnectionManager : IDisposable + public class ConnectionManager { private ConcurrentDictionary _connections = new ConcurrentDictionary(); private Timer _timer; - public ConnectionManager() + public ConnectionManager(IApplicationLifetime lifetime) { _timer = new Timer(Scan, this, 0, 1000); + + // We hook stopping because we need the requests to end, Dispose doesn't work since + // that happens after requests are drained + lifetime.ApplicationStopping.Register(CloseConnections); } public bool TryGetConnection(string id, out ConnectionState state) @@ -96,7 +101,7 @@ namespace Microsoft.AspNetCore.Sockets } } - public void Dispose() + private void CloseConnections() { // Stop firing the timer _timer.Dispose(); diff --git a/src/Microsoft.AspNetCore.Sockets/HttpDispatcherAppBuilderExtensions.cs b/src/Microsoft.AspNetCore.Sockets/HttpDispatcherAppBuilderExtensions.cs index 99936c4080..f3c98f944a 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpDispatcherAppBuilderExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpDispatcherAppBuilderExtensions.cs @@ -2,13 +2,10 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO.Pipelines; -using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Routing; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Builder { @@ -16,15 +13,7 @@ namespace Microsoft.AspNetCore.Builder { public static IApplicationBuilder UseSockets(this IApplicationBuilder app, Action callback) { - var manager = new ConnectionManager(); - var factory = new PipelineFactory(); - - var loggerFactory = app.ApplicationServices.GetService(); - var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory); - - // Dispose the connection manager when application shutdown is triggered - var lifetime = app.ApplicationServices.GetRequiredService(); - lifetime.ApplicationStopping.Register(state => ((IDisposable)state).Dispose(), manager); + var dispatcher = app.ApplicationServices.GetRequiredService(); var routes = new RouteBuilder(app); diff --git a/src/Microsoft.AspNetCore.Sockets/SocketsDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.Sockets/SocketsDependencyInjectionExtensions.cs new file mode 100644 index 0000000000..8b756520d4 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/SocketsDependencyInjectionExtensions.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Sockets; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Microsoft.Extensions.DependencyInjection +{ + public static class SocketsDependencyInjectionExtensions + { + public static IServiceCollection AddSockets(this IServiceCollection services) + { + services.AddRouting(); + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(); + return services; + } + } +} diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs index 4cc9f9e86c..88056fc93a 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs @@ -1,11 +1,9 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Collections.Generic; using System.IO.Pipelines; -using System.Linq; using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting.Internal; using Xunit; namespace Microsoft.AspNetCore.Sockets.Tests @@ -15,7 +13,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public void ReservedConnectionsHaveConnectionId() { - var connectionManager = new ConnectionManager(); + var lifetime = new ApplicationLifetime(); + var connectionManager = new ConnectionManager(lifetime); var state = connectionManager.ReserveConnection(); Assert.NotNull(state.Connection); @@ -28,7 +27,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public void ReservedConnectionsCanBeRetrieved() { - var connectionManager = new ConnectionManager(); + var lifetime = new ApplicationLifetime(); + var connectionManager = new ConnectionManager(lifetime); var state = connectionManager.ReserveConnection(); Assert.NotNull(state.Connection); @@ -43,10 +43,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests public void AddNewConnection() { using (var factory = new PipelineFactory()) - using (var channel = new HttpConnection(factory)) + using (var connection = new HttpConnection(factory)) { - var connectionManager = new ConnectionManager(); - var state = connectionManager.AddNewConnection(channel); + var lifetime = new ApplicationLifetime(); + var connectionManager = new ConnectionManager(lifetime); + var state = connectionManager.AddNewConnection(connection); Assert.NotNull(state.Connection); Assert.NotNull(state.Connection.ConnectionId); @@ -55,7 +56,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests ConnectionState newState; Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); Assert.Same(newState, state); - Assert.Same(channel, newState.Connection.Channel); + Assert.Same(connection, newState.Connection.Channel); } } @@ -63,10 +64,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests public void RemoveConnection() { using (var factory = new PipelineFactory()) - using (var channel = new HttpConnection(factory)) + using (var connection = new HttpConnection(factory)) { - var connectionManager = new ConnectionManager(); - var state = connectionManager.AddNewConnection(channel); + var lifetime = new ApplicationLifetime(); + var connectionManager = new ConnectionManager(lifetime); + var state = connectionManager.AddNewConnection(connection); Assert.NotNull(state.Connection); Assert.NotNull(state.Connection.ConnectionId); @@ -75,11 +77,34 @@ namespace Microsoft.AspNetCore.Sockets.Tests ConnectionState newState; Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); Assert.Same(newState, state); - Assert.Same(channel, newState.Connection.Channel); + Assert.Same(connection, newState.Connection.Channel); connectionManager.RemoveConnection(state.Connection.ConnectionId); Assert.False(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); } } + + [Fact] + public async Task ApplicationStoppingClosesConnections() + { + using (var factory = new PipelineFactory()) + using (var connection = new HttpConnection(factory)) + { + var lifetime = new ApplicationLifetime(); + var connectionManager = new ConnectionManager(lifetime); + var state = connectionManager.AddNewConnection(connection); + + var task = Task.Run(async () => + { + var result = await connection.Input.ReadAsync(); + + Assert.True(result.IsCompleted); + }); + + lifetime.StopApplication(); + + await task; + } + } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 6547ec7a67..aa7a731fb3 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -8,6 +8,7 @@ using System.IO.Pipelines; using System.Linq; using System.Text; using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting.Internal; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Internal; using Microsoft.Extensions.Primitives; @@ -20,7 +21,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task GetIdReservesConnectionIdAndReturnsIt() { - var manager = new ConnectionManager(); + var lifetime = new ApplicationLifetime(); + var manager = new ConnectionManager(lifetime); using (var factory = new PipelineFactory()) { var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory: null); @@ -41,7 +43,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task SendingToReservedConnectionsThatHaveNotConnectedThrows() { - var manager = new ConnectionManager(); + var lifetime = new ApplicationLifetime(); + var manager = new ConnectionManager(lifetime); var state = manager.ReserveConnection(); using (var factory = new PipelineFactory()) @@ -63,7 +66,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task SendingToUnknownConnectionIdThrows() { - var manager = new ConnectionManager(); + var lifetime = new ApplicationLifetime(); + var manager = new ConnectionManager(lifetime); using (var factory = new PipelineFactory()) { var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory: null); @@ -83,7 +87,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task SendingWithoutConnectionIdThrows() { - var manager = new ConnectionManager(); + var lifetime = new ApplicationLifetime(); + var manager = new ConnectionManager(lifetime); using (var factory = new PipelineFactory()) { var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory: null);