diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index 46ae37698e..3eceb0c6d3 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Concurrent; using System.IO.Pipelines; using System.Threading; -using Microsoft.AspNetCore.Hosting; namespace Microsoft.AspNetCore.Sockets { @@ -14,13 +13,9 @@ namespace Microsoft.AspNetCore.Sockets private ConcurrentDictionary _connections = new ConcurrentDictionary(); private Timer _timer; - public ConnectionManager(IApplicationLifetime lifetime) + public ConnectionManager() { _timer = new Timer(Scan, this, 0, 1000); - - // We hook stopping because we need the requests to end, Dispose doesn't work since - // that happens after requests are drained - lifetime.ApplicationStopping.Register(CloseConnections); } public bool TryGetConnection(string id, out ConnectionState state) @@ -101,7 +96,7 @@ namespace Microsoft.AspNetCore.Sockets } } - private void CloseConnections() + public void CloseConnections() { // Stop firing the timer _timer.Dispose(); diff --git a/src/Microsoft.AspNetCore.Sockets/SocketsApplicationLifetimeEvents.cs b/src/Microsoft.AspNetCore.Sockets/SocketsApplicationLifetimeEvents.cs new file mode 100644 index 0000000000..f06f09ff47 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/SocketsApplicationLifetimeEvents.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting; + +namespace Microsoft.AspNetCore.Sockets +{ + public class SocketsApplicationLifetimeEvents : IApplicationLifetimeEvents + { + private readonly ConnectionManager _connectionManager; + + public SocketsApplicationLifetimeEvents(ConnectionManager connectionManager) + { + _connectionManager = connectionManager; + } + + public void OnApplicationStarted() + { + + } + + public void OnApplicationStopped() + { + + } + + public void OnApplicationStopping() + { + _connectionManager.CloseConnections(); + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/SocketsDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.Sockets/SocketsDependencyInjectionExtensions.cs index 8b756520d4..fd20047a9a 100644 --- a/src/Microsoft.AspNetCore.Sockets/SocketsDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets/SocketsDependencyInjectionExtensions.cs @@ -2,7 +2,7 @@ using System.Collections.Generic; using System.IO.Pipelines; using System.Linq; -using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection.Extensions; @@ -14,6 +14,7 @@ namespace Microsoft.Extensions.DependencyInjection { services.AddRouting(); services.TryAddSingleton(); + services.TryAddEnumerable(ServiceDescriptor.Singleton()); services.TryAddSingleton(); services.TryAddSingleton(); return services; diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs index 23b59d26b6..5c4ba864a9 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs @@ -2,11 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.IO.Pipelines; -using System.Linq; using System.Threading.Tasks; -using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Hosting.Internal; -using Microsoft.Extensions.Logging; using Xunit; namespace Microsoft.AspNetCore.Sockets.Tests @@ -16,8 +12,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public void ReservedConnectionsHaveConnectionId() { - var lifetime = new ApplicationLifetime(new Logger(new LoggerFactory()), Enumerable.Empty()); - var connectionManager = new ConnectionManager(lifetime); + var connectionManager = new ConnectionManager(); var state = connectionManager.ReserveConnection(); Assert.NotNull(state.Connection); @@ -30,8 +25,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public void ReservedConnectionsCanBeRetrieved() { - var lifetime = new ApplicationLifetime(new Logger(new LoggerFactory()), Enumerable.Empty()); - var connectionManager = new ConnectionManager(lifetime); + var connectionManager = new ConnectionManager(); var state = connectionManager.ReserveConnection(); Assert.NotNull(state.Connection); @@ -48,8 +42,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var factory = new PipelineFactory()) using (var connection = new HttpConnection(factory)) { - var lifetime = new ApplicationLifetime(new Logger(new LoggerFactory()), Enumerable.Empty()); - var connectionManager = new ConnectionManager(lifetime); + var connectionManager = new ConnectionManager(); var state = connectionManager.AddNewConnection(connection); Assert.NotNull(state.Connection); @@ -69,8 +62,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var factory = new PipelineFactory()) using (var connection = new HttpConnection(factory)) { - var lifetime = new ApplicationLifetime(new Logger(new LoggerFactory()), Enumerable.Empty()); - var connectionManager = new ConnectionManager(lifetime); + var connectionManager = new ConnectionManager(); var state = connectionManager.AddNewConnection(connection); Assert.NotNull(state.Connection); @@ -88,13 +80,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests } [Fact] - public async Task ApplicationStoppingClosesConnections() + public async Task CloseConnectionsEndsAllPendingConnections() { using (var factory = new PipelineFactory()) using (var connection = new HttpConnection(factory)) { - var lifetime = new ApplicationLifetime(new Logger(new LoggerFactory()), Enumerable.Empty()); - var connectionManager = new ConnectionManager(lifetime); + var connectionManager = new ConnectionManager(); var state = connectionManager.AddNewConnection(connection); var task = Task.Run(async () => @@ -104,7 +95,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.True(result.IsCompleted); }); - lifetime.StopApplication(); + connectionManager.CloseConnections(); await task; } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index ad7b1adc1f..7aacce726e 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -23,8 +23,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task GetIdReservesConnectionIdAndReturnsIt() { - var lifetime = new ApplicationLifetime(new Logger(new LoggerFactory()), Enumerable.Empty()); - var manager = new ConnectionManager(lifetime); + var manager = new ConnectionManager(); using (var factory = new PipelineFactory()) { var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory: null); @@ -45,8 +44,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task SendingToReservedConnectionsThatHaveNotConnectedThrows() { - var lifetime = new ApplicationLifetime(new Logger(new LoggerFactory()), Enumerable.Empty()); - var manager = new ConnectionManager(lifetime); + var manager = new ConnectionManager(); var state = manager.ReserveConnection(); using (var factory = new PipelineFactory()) @@ -68,8 +66,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task SendingToUnknownConnectionIdThrows() { - var lifetime = new ApplicationLifetime(new Logger(new LoggerFactory()), Enumerable.Empty()); - var manager = new ConnectionManager(lifetime); + var manager = new ConnectionManager(); using (var factory = new PipelineFactory()) { var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory: null); @@ -89,8 +86,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task SendingWithoutConnectionIdThrows() { - var lifetime = new ApplicationLifetime(new Logger(new LoggerFactory()), Enumerable.Empty()); - var manager = new ConnectionManager(lifetime); + var manager = new ConnectionManager(); using (var factory = new PipelineFactory()) { var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory: null);