From 9d9a52119e4306f2ef88d281aa81a41cdc1f260f Mon Sep 17 00:00:00 2001 From: David Fowler Date: Tue, 23 May 2017 02:43:32 -0700 Subject: [PATCH] Progress towards splitting the layers (#473) * Progress towards splitting the layers - This is based on the work anurse did in anurse/endpoint-middleware-spike to introduce a connection middleware pipeline that mimics much of our http pipeline. The intent is that this layer will be generic enough to build both SignalR and Kestrel on top of but we're not there yet. This change makes incremental progress towards splitting apart sockets and http so that we can add the tcp transport without breaking everything all at once. - Created Microsoft.AspNetCore.Sockets.Abstractions where the primitives for sockets live. That includes, ConnectionContext (formerly Connection), EndPoint, ISocketBuilder, SocketDelegate, etc. - ConnectionContext isn't in it's final form as yet, it still very closely mirrors the original Connection object we had so that tests continue to pass. - The HttpConnectionDispatcher doesn't know about EndPoint anymore, it just cares about invoking the SocketDelegate. - EndPointOptions has been removed as part of this change as it coupled http specific configuration to the end point type. There's a new HttpSocketOptions that needs to be passed into MapSocket calls. - Updated the tests to deal with the API changes. --- SignalR.sln | 9 +- .../EchoEndPoint.cs | 2 +- .../Startup.cs | 2 +- samples/ChatSample/IUserTracker.cs | 4 +- samples/ChatSample/InMemoryUserTracker.cs | 8 +- .../ChatSample/PresenceHubLifetimeManager.cs | 8 +- samples/ChatSample/RedisUserTracker.cs | 8 +- .../PersistentConnectionLifeTimeManager.cs | 8 +- .../SocialWeather/SocialWeatherEndPoint.cs | 4 +- samples/SocialWeather/Startup.cs | 2 +- .../EndPoints/MessagesEndPoint.cs | 2 +- samples/SocketsSample/Startup.cs | 2 +- .../RedisHubLifetimeManager.cs | 10 +- .../DefaultHubLifetimeManager.cs | 12 +- .../HubCallerContext.cs | 4 +- .../HubEndPoint.cs | 24 +- .../HubLifetimeManager.cs | 8 +- .../Internal/DefaultHubProtocolResolver.cs | 2 +- .../Internal/IHubProtocolResolver.cs | 2 +- src/Microsoft.AspNetCore.SignalR/Proxies.cs | 4 +- .../SignalRAppBuilderExtensions.cs | 2 +- .../ConnectionContext.cs | 33 ++ .../ConnectionMetadata.cs | 0 .../DefaultConnectionContext.cs | 28 ++ .../EndPoint.cs | 4 +- .../ISocketBuilder.cs | 18 + ...oft.AspNetCore.Sockets.Abstractions.csproj | 23 ++ .../SocketBuilder.cs | 44 +++ .../SocketBuilderExtensions.cs | 23 ++ .../SocketDelegate.cs | 12 + .../Connection.cs | 29 -- .../ConnectionList.cs | 16 +- .../ConnectionManager.cs | 2 +- .../EndPointDependencyInjectionExtensions.cs | 10 - .../HttpConnectionDispatcher.cs | 29 +- .../HttpDispatcherAppBuilderExtensions.cs | 19 +- .../HttpSocketBuilderExtensions.cs | 23 ++ ...ndPointOptions.cs => HttpSocketOptions.cs} | 2 +- .../Internal/ConnectionState.cs | 4 +- .../Microsoft.AspNetCore.Sockets.csproj | 1 + .../EchoEndPoint.cs | 2 +- .../HubEndpointTests.cs | 14 +- .../ServerFixture.cs | 2 +- .../TestClient.cs | 4 +- .../HttpConnectionDispatcherTests.cs | 329 +++++++++++------- 45 files changed, 535 insertions(+), 263 deletions(-) create mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs rename src/{Microsoft.AspNetCore.Sockets => Microsoft.AspNetCore.Sockets.Abstractions}/ConnectionMetadata.cs (100%) create mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs rename src/{Microsoft.AspNetCore.Sockets => Microsoft.AspNetCore.Sockets.Abstractions}/EndPoint.cs (90%) create mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/ISocketBuilder.cs create mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj create mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilder.cs create mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilderExtensions.cs create mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/SocketDelegate.cs delete mode 100644 src/Microsoft.AspNetCore.Sockets/Connection.cs create mode 100644 src/Microsoft.AspNetCore.Sockets/HttpSocketBuilderExtensions.cs rename src/Microsoft.AspNetCore.Sockets/{EndPointOptions.cs => HttpSocketOptions.cs} (87%) diff --git a/SignalR.sln b/SignalR.sln index bdf0240d45..1c7fe2eabf 100644 --- a/SignalR.sln +++ b/SignalR.sln @@ -1,6 +1,6 @@ Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 15 -VisualStudioVersion = 15.0.26411.1 +VisualStudioVersion = 15.0.26510.0 MinimumVisualStudioVersion = 10.0.40219.1 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}" EndProject @@ -76,6 +76,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.Signal EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.SignalR.Common.Tests", "test\Microsoft.AspNetCore.SignalR.Common.Tests\Microsoft.AspNetCore.SignalR.Common.Tests.csproj", "{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.Sockets.Abstractions", "src\Microsoft.AspNetCore.Sockets.Abstractions\Microsoft.AspNetCore.Sockets.Abstractions.csproj", "{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -186,6 +188,10 @@ Global {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Debug|Any CPU.Build.0 = Debug|Any CPU {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.ActiveCfg = Release|Any CPU {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.Build.0 = Release|Any CPU + {F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -218,5 +224,6 @@ Global {6CEC3DC2-5B01-45A8-8F0D-8531315DA90B} = {6A35B453-52EC-48AF-89CA-D4A69800F131} {96771B3F-4D18-41A7-A75B-FF38E76AAC89} = {6A35B453-52EC-48AF-89CA-D4A69800F131} {75E342F6-5445-4E7E-9143-6D9AE62C2B1E} = {6A35B453-52EC-48AF-89CA-D4A69800F131} + {F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8} = {DA69F624-5398-4884-87E4-B816698CDE65} EndGlobalSection EndGlobal diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs index 7b930e5437..2852c885ae 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs @@ -10,7 +10,7 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server { public class EchoEndPoint : EndPoint { - public async override Task OnConnectedAsync(Connection connection) + public async override Task OnConnectedAsync(ConnectionContext connection) { await connection.Transport.Output.WriteAsync(await connection.Transport.Input.ReadAsync()); } diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs index 494de9a17c..6c30114542 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs @@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server } app.UseFileServer(); - app.UseSockets(options => options.MapEndpoint("echo")); + app.UseSockets(options => options.MapEndPoint("echo")); app.UseSignalR(routes => { routes.MapHub("testhub"); diff --git a/samples/ChatSample/IUserTracker.cs b/samples/ChatSample/IUserTracker.cs index c55e22f8c5..ca46483f72 100644 --- a/samples/ChatSample/IUserTracker.cs +++ b/samples/ChatSample/IUserTracker.cs @@ -11,8 +11,8 @@ namespace ChatSample public interface IUserTracker { Task> UsersOnline(); - Task AddUser(Connection connection, UserDetails userDetails); - Task RemoveUser(Connection connection); + Task AddUser(ConnectionContext connection, UserDetails userDetails); + Task RemoveUser(ConnectionContext connection); event Action UsersJoined; event Action UsersLeft; diff --git a/samples/ChatSample/InMemoryUserTracker.cs b/samples/ChatSample/InMemoryUserTracker.cs index ce7877618c..27ce19d717 100644 --- a/samples/ChatSample/InMemoryUserTracker.cs +++ b/samples/ChatSample/InMemoryUserTracker.cs @@ -9,8 +9,8 @@ namespace ChatSample { public class InMemoryUserTracker : IUserTracker { - private readonly ConcurrentDictionary _usersOnline - = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _usersOnline + = new ConcurrentDictionary(); public event Action UsersJoined; public event Action UsersLeft; @@ -18,7 +18,7 @@ namespace ChatSample public Task> UsersOnline() => Task.FromResult(_usersOnline.Values.AsEnumerable()); - public Task AddUser(Connection connection, UserDetails userDetails) + public Task AddUser(ConnectionContext connection, UserDetails userDetails) { _usersOnline.TryAdd(connection, userDetails); UsersJoined(new[] { userDetails }); @@ -26,7 +26,7 @@ namespace ChatSample return Task.CompletedTask; } - public Task RemoveUser(Connection connection) + public Task RemoveUser(ConnectionContext connection) { if (_usersOnline.TryRemove(connection, out var userDetails)) { diff --git a/samples/ChatSample/PresenceHubLifetimeManager.cs b/samples/ChatSample/PresenceHubLifetimeManager.cs index c918b81e41..fcc168000d 100644 --- a/samples/ChatSample/PresenceHubLifetimeManager.cs +++ b/samples/ChatSample/PresenceHubLifetimeManager.cs @@ -57,14 +57,14 @@ namespace ChatSample _wrappedHubLifetimeManager = serviceProvider.GetRequiredService(); } - public override async Task OnConnectedAsync(Connection connection) + public override async Task OnConnectedAsync(ConnectionContext connection) { await _wrappedHubLifetimeManager.OnConnectedAsync(connection); _connections.Add(connection); await _userTracker.AddUser(connection, new UserDetails(connection.ConnectionId, connection.User.Identity.Name)); } - public override async Task OnDisconnectedAsync(Connection connection) + public override async Task OnDisconnectedAsync(ConnectionContext connection) { await _wrappedHubLifetimeManager.OnDisconnectedAsync(connection); _connections.Remove(connection); @@ -157,12 +157,12 @@ namespace ChatSample return _wrappedHubLifetimeManager.InvokeUserAsync(userId, methodName, args); } - public override Task AddGroupAsync(Connection connection, string groupName) + public override Task AddGroupAsync(ConnectionContext connection, string groupName) { return _wrappedHubLifetimeManager.AddGroupAsync(connection, groupName); } - public override Task RemoveGroupAsync(Connection connection, string groupName) + public override Task RemoveGroupAsync(ConnectionContext connection, string groupName) { return _wrappedHubLifetimeManager.RemoveGroupAsync(connection, groupName); } diff --git a/samples/ChatSample/RedisUserTracker.cs b/samples/ChatSample/RedisUserTracker.cs index 9083cae2fa..06753884ae 100644 --- a/samples/ChatSample/RedisUserTracker.cs +++ b/samples/ChatSample/RedisUserTracker.cs @@ -129,7 +129,7 @@ namespace ChatSample } } - public async Task AddUser(Connection connection, UserDetails userDetails) + public async Task AddUser(ConnectionContext connection, UserDetails userDetails) { var key = GetUserRedisKey(connection); var user = SerializeUser(connection); @@ -156,7 +156,7 @@ namespace ChatSample } } - public async Task RemoveUser(Connection connection) + public async Task RemoveUser(ConnectionContext connection) { await _userSyncSempaphore.WaitAsync(); try @@ -180,7 +180,7 @@ namespace ChatSample } } - private static string GetUserRedisKey(Connection connection) => $"user:{connection.ConnectionId}"; + private static string GetUserRedisKey(ConnectionContext connection) => $"user:{connection.ConnectionId}"; private static void Scan(object state) { @@ -319,7 +319,7 @@ namespace ChatSample } } - private static string SerializeUser(Connection connection) => + private static string SerializeUser(ConnectionContext connection) => $"{{ \"ConnectionID\": \"{connection.ConnectionId}\", \"Name\": \"{connection.User.Identity.Name}\" }}"; private static UserDetails DeserializerUser(string userJson) => diff --git a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs index 899a3fc625..1cbd68c282 100644 --- a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs +++ b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs @@ -20,13 +20,13 @@ namespace SocialWeather _formatterResolver = formatterResolver; } - public void OnConnectedAsync(Connection connection) + public void OnConnectedAsync(ConnectionContext connection) { connection.Metadata[ConnectionMetadataNames.Format] = "json"; _connectionList.Add(connection); } - public void OnDisconnectedAsync(Connection connection) + public void OnDisconnectedAsync(ConnectionContext connection) { _connectionList.Remove(connection); } @@ -64,7 +64,7 @@ namespace SocialWeather throw new NotImplementedException(); } - public void AddGroupAsync(Connection connection, string groupName) + public void AddGroupAsync(ConnectionContext connection, string groupName) { var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet()); lock (groups) @@ -73,7 +73,7 @@ namespace SocialWeather } } - public void RemoveGroupAsync(Connection connection, string groupName) + public void RemoveGroupAsync(ConnectionContext connection, string groupName) { var groups = connection.Metadata.Get>("groups"); if (groups != null) diff --git a/samples/SocialWeather/SocialWeatherEndPoint.cs b/samples/SocialWeather/SocialWeatherEndPoint.cs index a78141fa7e..2e9a0fb8ee 100644 --- a/samples/SocialWeather/SocialWeatherEndPoint.cs +++ b/samples/SocialWeather/SocialWeatherEndPoint.cs @@ -23,14 +23,14 @@ namespace SocialWeather _logger = logger; } - public async override Task OnConnectedAsync(Connection connection) + public async override Task OnConnectedAsync(ConnectionContext connection) { _lifetimeManager.OnConnectedAsync(connection); await ProcessRequests(connection); _lifetimeManager.OnDisconnectedAsync(connection); } - public async Task ProcessRequests(Connection connection) + public async Task ProcessRequests(ConnectionContext connection) { var formatter = _formatterResolver.GetFormatter( connection.Metadata.Get("formatType")); diff --git a/samples/SocialWeather/Startup.cs b/samples/SocialWeather/Startup.cs index a950ef629f..fa2d62bb43 100644 --- a/samples/SocialWeather/Startup.cs +++ b/samples/SocialWeather/Startup.cs @@ -32,7 +32,7 @@ namespace SocialWeather app.UseDeveloperExceptionPage(); } - app.UseSockets(o => { o.MapEndpoint("weather"); }); + app.UseSockets(o => { o.MapEndPoint("weather"); }); app.UseFileServer(); var formatterResolver = app.ApplicationServices.GetRequiredService(); diff --git a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs index 3f68fa8f7d..f6ed46097e 100644 --- a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs +++ b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs @@ -13,7 +13,7 @@ namespace SocketsSample.EndPoints { public ConnectionList Connections { get; } = new ConnectionList(); - public override async Task OnConnectedAsync(Connection connection) + public override async Task OnConnectedAsync(ConnectionContext connection) { Connections.Add(connection); diff --git a/samples/SocketsSample/Startup.cs b/samples/SocketsSample/Startup.cs index 034284f9b1..1acae62ef4 100644 --- a/samples/SocketsSample/Startup.cs +++ b/samples/SocketsSample/Startup.cs @@ -53,7 +53,7 @@ namespace SocketsSample app.UseSockets(routes => { - routes.MapEndpoint("chat"); + routes.MapEndPoint("chat"); }); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index cb847095a6..64e3aa6bbb 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -128,7 +128,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis await _bus.PublishAsync(channel, payload); } - public override Task OnConnectedAsync(Connection connection) + public override Task OnConnectedAsync(ConnectionContext connection) { var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet()); var connectionTask = Task.CompletedTask; @@ -173,7 +173,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis return Task.WhenAll(connectionTask, userTask); } - public override Task OnDisconnectedAsync(Connection connection) + public override Task OnDisconnectedAsync(ConnectionContext connection) { _connections.Remove(connection); @@ -204,7 +204,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis return Task.WhenAll(tasks); } - public override async Task AddGroupAsync(Connection connection, string groupName) + public override async Task AddGroupAsync(ConnectionContext connection, string groupName) { var groupChannel = typeof(THub).FullName + ".group." + groupName; @@ -255,7 +255,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis } } - public override async Task RemoveGroupAsync(Connection connection, string groupName) + public override async Task RemoveGroupAsync(ConnectionContext connection, string groupName) { var groupChannel = typeof(THub).FullName + ".group." + groupName; @@ -297,7 +297,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis _redisServerConnection.Dispose(); } - private async Task WriteAsync(Connection connection, HubMessage hubMessage) + private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage) { var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol); var data = await protocol.WriteToArrayAsync(hubMessage); diff --git a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs index c8b83fe4a2..5105d4ac6c 100644 --- a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs @@ -15,7 +15,7 @@ namespace Microsoft.AspNetCore.SignalR private long _nextInvocationId = 0; private readonly ConnectionList _connections = new ConnectionList(); - public override Task AddGroupAsync(Connection connection, string groupName) + public override Task AddGroupAsync(ConnectionContext connection, string groupName) { var groups = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet()); @@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.SignalR return Task.CompletedTask; } - public override Task RemoveGroupAsync(Connection connection, string groupName) + public override Task RemoveGroupAsync(ConnectionContext connection, string groupName) { var groups = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); @@ -49,7 +49,7 @@ namespace Microsoft.AspNetCore.SignalR return InvokeAllWhere(methodName, args, c => true); } - private Task InvokeAllWhere(string methodName, object[] args, Func include) + private Task InvokeAllWhere(string methodName, object[] args, Func include) { var tasks = new List(_connections.Count); var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); @@ -94,19 +94,19 @@ namespace Microsoft.AspNetCore.SignalR }); } - public override Task OnConnectedAsync(Connection connection) + public override Task OnConnectedAsync(ConnectionContext connection) { _connections.Add(connection); return Task.CompletedTask; } - public override Task OnDisconnectedAsync(Connection connection) + public override Task OnDisconnectedAsync(ConnectionContext connection) { _connections.Remove(connection); return Task.CompletedTask; } - private async Task WriteAsync(Connection connection, HubMessage hubMessage) + private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage) { var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol); var payload = await protocol.WriteToArrayAsync(hubMessage); diff --git a/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs b/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs index 50df41dc6a..f5e7b50cfa 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs @@ -8,12 +8,12 @@ namespace Microsoft.AspNetCore.SignalR { public class HubCallerContext { - public HubCallerContext(Connection connection) + public HubCallerContext(ConnectionContext connection) { Connection = connection; } - public Connection Connection { get; } + public ConnectionContext Connection { get; } public ClaimsPrincipal User => Connection.User; diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index 82af05346f..65c3f14e7f 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -22,10 +22,9 @@ namespace Microsoft.AspNetCore.SignalR public HubEndPoint(HubLifetimeManager lifetimeManager, IHubProtocolResolver protocolResolver, IHubContext hubContext, - IOptions>> endPointOptions, ILogger> logger, IServiceScopeFactory serviceScopeFactory) - : base(lifetimeManager, protocolResolver, hubContext, endPointOptions, logger, serviceScopeFactory) + : base(lifetimeManager, protocolResolver, hubContext, logger, serviceScopeFactory) { } } @@ -43,7 +42,6 @@ namespace Microsoft.AspNetCore.SignalR public HubEndPoint(HubLifetimeManager lifetimeManager, IHubProtocolResolver protocolResolver, IHubContext hubContext, - IOptions>> endPointOptions, ILogger> logger, IServiceScopeFactory serviceScopeFactory) { @@ -56,7 +54,7 @@ namespace Microsoft.AspNetCore.SignalR DiscoverHubMethods(); } - public override async Task OnConnectedAsync(Connection connection) + public override async Task OnConnectedAsync(ConnectionContext connection) { try { @@ -74,7 +72,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task RunHubAsync(Connection connection) + private async Task RunHubAsync(ConnectionContext connection) { await HubOnConnectedAsync(connection); @@ -92,7 +90,7 @@ namespace Microsoft.AspNetCore.SignalR await HubOnDisconnectedAsync(connection, null); } - private async Task HubOnConnectedAsync(Connection connection) + private async Task HubOnConnectedAsync(ConnectionContext connection) { try { @@ -118,7 +116,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task HubOnDisconnectedAsync(Connection connection, Exception exception) + private async Task HubOnDisconnectedAsync(ConnectionContext connection, Exception exception) { try { @@ -144,7 +142,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task DispatchMessagesAsync(Connection connection) + private async Task DispatchMessagesAsync(ConnectionContext connection) { // We use these for error handling. Since we dispatch multiple hub invocations // in parallel, we need a way to communicate failure back to the main processing loop. The @@ -190,7 +188,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task ProcessInvocation(Connection connection, + private async Task ProcessInvocation(ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage, CancellationTokenSource dispatcherCancellation, @@ -212,7 +210,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task Execute(Connection connection, IHubProtocol protocol, InvocationMessage invocationMessage) + private async Task Execute(ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage) { HubMethodDescriptor descriptor; if (!_methods.TryGetValue(invocationMessage.Target, out descriptor)) @@ -228,7 +226,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task SendMessageAsync(Connection connection, IHubProtocol protocol, HubMessage hubMessage) + private async Task SendMessageAsync(ConnectionContext connection, IHubProtocol protocol, HubMessage hubMessage) { var payload = await protocol.WriteToArrayAsync(hubMessage); var message = new Message(payload, protocol.MessageType, endOfMessage: true); @@ -246,7 +244,7 @@ namespace Microsoft.AspNetCore.SignalR throw new OperationCanceledException("Outbound channel was closed while trying to write hub message"); } - private async Task Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationMessage invocationMessage) + private async Task Invoke(HubMethodDescriptor descriptor, ConnectionContext connection, InvocationMessage invocationMessage) { var methodExecutor = descriptor.MethodExecutor; @@ -295,7 +293,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private void InitializeHub(THub hub, Connection connection) + private void InitializeHub(THub hub, ConnectionContext connection) { hub.Clients = _hubContext.Clients; hub.Context = new HubCallerContext(connection); diff --git a/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs index 51a2b4d4f7..2e6f34a01d 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs @@ -8,9 +8,9 @@ namespace Microsoft.AspNetCore.SignalR { public abstract class HubLifetimeManager { - public abstract Task OnConnectedAsync(Connection connection); + public abstract Task OnConnectedAsync(ConnectionContext connection); - public abstract Task OnDisconnectedAsync(Connection connection); + public abstract Task OnDisconnectedAsync(ConnectionContext connection); public abstract Task InvokeAllAsync(string methodName, object[] args); @@ -20,9 +20,9 @@ namespace Microsoft.AspNetCore.SignalR public abstract Task InvokeUserAsync(string userId, string methodName, object[] args); - public abstract Task AddGroupAsync(Connection connection, string groupName); + public abstract Task AddGroupAsync(ConnectionContext connection, string groupName); - public abstract Task RemoveGroupAsync(Connection connection, string groupName); + public abstract Task RemoveGroupAsync(ConnectionContext connection, string groupName); } } diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs index 2d06b8bcee..c4e8957522 100644 --- a/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs +++ b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs @@ -9,7 +9,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal { public class DefaultHubProtocolResolver : IHubProtocolResolver { - public IHubProtocol GetProtocol(Connection connection) + public IHubProtocol GetProtocol(ConnectionContext connection) { // TODO: Allow customization of this serializer! return new JsonHubProtocol(new JsonSerializer()); diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs index c9627d0d59..9d2c1f1bb8 100644 --- a/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs +++ b/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs @@ -8,6 +8,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal { public interface IHubProtocolResolver { - IHubProtocol GetProtocol(Connection connection); + IHubProtocol GetProtocol(ConnectionContext connection); } } diff --git a/src/Microsoft.AspNetCore.SignalR/Proxies.cs b/src/Microsoft.AspNetCore.SignalR/Proxies.cs index dfed4a80ff..f0b8423580 100644 --- a/src/Microsoft.AspNetCore.SignalR/Proxies.cs +++ b/src/Microsoft.AspNetCore.SignalR/Proxies.cs @@ -75,10 +75,10 @@ namespace Microsoft.AspNetCore.SignalR public class GroupManager : IGroupManager { - private readonly Connection _connection; + private readonly ConnectionContext _connection; private readonly HubLifetimeManager _lifetimeManager; - public GroupManager(Connection connection, HubLifetimeManager lifetimeManager) + public GroupManager(ConnectionContext connection, HubLifetimeManager lifetimeManager) { _connection = connection; _lifetimeManager = lifetimeManager; diff --git a/src/Microsoft.AspNetCore.SignalR/SignalRAppBuilderExtensions.cs b/src/Microsoft.AspNetCore.SignalR/SignalRAppBuilderExtensions.cs index 1ded6ec809..00cca9cdbd 100644 --- a/src/Microsoft.AspNetCore.SignalR/SignalRAppBuilderExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR/SignalRAppBuilderExtensions.cs @@ -31,7 +31,7 @@ namespace Microsoft.AspNetCore.Builder public void MapHub(string path) where THub : Hub { - _routes.MapEndpoint>(path); + _routes.MapEndPoint>(path); } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs new file mode 100644 index 0000000000..7300980331 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs @@ -0,0 +1,33 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Security.Claims; +using System.Text; +using System.Threading; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Sockets +{ + public abstract class ConnectionContext : IDisposable + { + public abstract string ConnectionId { get; } + + public abstract IFeatureCollection Features { get; } + + public abstract ClaimsPrincipal User { get; set; } + + // REVIEW: Should this be changed to items + public abstract ConnectionMetadata Metadata { get; } + + // TEMPORARY + public abstract IChannelConnection Transport { get; set; } + + // TEMPORARY + public void Dispose() + { + Transport?.Dispose(); + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionMetadata.cs similarity index 100% rename from src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs rename to src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionMetadata.cs diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs new file mode 100644 index 0000000000..b353216d37 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs @@ -0,0 +1,28 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Security.Claims; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Sockets +{ + public class DefaultConnectionContext : ConnectionContext + { + public DefaultConnectionContext(string id, IChannelConnection transport) + { + Transport = transport; + ConnectionId = id; + } + + public override string ConnectionId { get; } + + public override IFeatureCollection Features { get; } = new FeatureCollection(); + + public override ClaimsPrincipal User { get; set; } + + public override ConnectionMetadata Metadata { get; } = new ConnectionMetadata(); + + public override IChannelConnection Transport { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/EndPoint.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/EndPoint.cs similarity index 90% rename from src/Microsoft.AspNetCore.Sockets/EndPoint.cs rename to src/Microsoft.AspNetCore.Sockets.Abstractions/EndPoint.cs index a975839e34..9108e43d83 100644 --- a/src/Microsoft.AspNetCore.Sockets/EndPoint.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/EndPoint.cs @@ -14,8 +14,8 @@ namespace Microsoft.AspNetCore.Sockets /// /// Called when a new connection is accepted to the endpoint /// - /// The new + /// The new /// A that represents the connection lifetime. When the task completes, the connection is complete. - public abstract Task OnConnectedAsync(Connection connection); + public abstract Task OnConnectedAsync(ConnectionContext connection); } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ISocketBuilder.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ISocketBuilder.cs new file mode 100644 index 0000000000..f7d2f98513 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/ISocketBuilder.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.AspNetCore.Sockets +{ + public interface ISocketBuilder + { + IServiceProvider ApplicationServices { get; } + + ISocketBuilder Use(Func middleware); + + SocketDelegate Build(); + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj b/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj new file mode 100644 index 0000000000..ea17b5e2ea --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj @@ -0,0 +1,23 @@ + + + + + + Components for providing real-time bi-directional communication across the Web. + netcoreapp2.0 + $(NoWarn);CS1591 + true + aspnetcore;signalr + false + + + + + + + + + + + + diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilder.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilder.cs new file mode 100644 index 0000000000..8dfc21b848 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilder.cs @@ -0,0 +1,44 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Sockets +{ + public class SocketBuilder : ISocketBuilder + { + private readonly IList> _components = new List>(); + + public IServiceProvider ApplicationServices { get; } + + public SocketBuilder(IServiceProvider applicationServices) + { + ApplicationServices = applicationServices; + } + + public ISocketBuilder Use(Func middleware) + { + _components.Add(middleware); + return this; + } + + public SocketDelegate Build() + { + SocketDelegate app = features => + { + return Task.CompletedTask; + }; + + foreach (var component in _components.Reverse()) + { + app = component(app); + } + + return app; + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilderExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilderExtensions.cs new file mode 100644 index 0000000000..82635ec4da --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilderExtensions.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Sockets +{ + public static class SocketBuilderExtensions + { + public static ISocketBuilder Use(this ISocketBuilder socketBuilder, Func, Task> middleware) + { + return socketBuilder.Use(next => + { + return context => + { + Func simpleNext = () => next(context); + return middleware(context, simpleNext); + }; + }); + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketDelegate.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketDelegate.cs new file mode 100644 index 0000000000..45af07cd77 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketDelegate.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Sockets +{ + public delegate Task SocketDelegate(ConnectionContext connection); +} diff --git a/src/Microsoft.AspNetCore.Sockets/Connection.cs b/src/Microsoft.AspNetCore.Sockets/Connection.cs deleted file mode 100644 index 1508a95eb4..0000000000 --- a/src/Microsoft.AspNetCore.Sockets/Connection.cs +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Security.Claims; - -namespace Microsoft.AspNetCore.Sockets -{ - public class Connection : IDisposable - { - public string ConnectionId { get; } - - public ClaimsPrincipal User { get; set; } - public ConnectionMetadata Metadata { get; } = new ConnectionMetadata(); - - public IChannelConnection Transport { get; } - - public Connection(string id, IChannelConnection transport) - { - Transport = transport; - ConnectionId = id; - } - - public void Dispose() - { - Transport.Dispose(); - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs index 5f5efaadb8..323c0e442e 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs @@ -8,15 +8,15 @@ using System.Collections.Generic; namespace Microsoft.AspNetCore.Sockets { - public class ConnectionList : IReadOnlyCollection + public class ConnectionList : IReadOnlyCollection { - private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); - public Connection this[string connectionId] + public ConnectionContext this[string connectionId] { get { - Connection connection; + ConnectionContext connection; if (_connections.TryGetValue(connectionId, out connection)) { return connection; @@ -27,18 +27,18 @@ namespace Microsoft.AspNetCore.Sockets public int Count => _connections.Count; - public void Add(Connection connection) + public void Add(ConnectionContext connection) { _connections.TryAdd(connection.ConnectionId, connection); } - public void Remove(Connection connection) + public void Remove(ConnectionContext connection) { - Connection dummy; + ConnectionContext dummy; _connections.TryRemove(connection.ConnectionId, out dummy); } - public IEnumerator GetEnumerator() + public IEnumerator GetEnumerator() { foreach (var item in _connections) { diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index 20b3d1c866..201fb02c8c 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -57,7 +57,7 @@ namespace Microsoft.AspNetCore.Sockets var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); var state = new ConnectionState( - new Connection(id, applicationSide), + new DefaultConnectionContext(id, applicationSide), transportSide); _connections.TryAdd(id, state); diff --git a/src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs index c09c6118fd..dfe59a108c 100644 --- a/src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs @@ -14,15 +14,5 @@ namespace Microsoft.Extensions.DependencyInjection return services; } - - public static IServiceCollection AddEndPoint(this IServiceCollection services, - Action> setupAction) where TEndPoint : EndPoint - { - services.AddEndPoint(); - - services.Configure(setupAction); - - return services; - } } } diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index bd38c6c0ee..55ce652fb7 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -14,9 +14,7 @@ using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Sockets.Internal.Formatters; using Microsoft.AspNetCore.Sockets.Transports; using Microsoft.AspNetCore.WebSockets.Internal; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Sockets @@ -34,10 +32,8 @@ namespace Microsoft.AspNetCore.Sockets _logger = _loggerFactory.CreateLogger(); } - public async Task ExecuteAsync(HttpContext context) where TEndPoint : EndPoint + public async Task ExecuteAsync(HttpContext context, HttpSocketOptions options, SocketDelegate socketDelegate) { - var options = context.RequestServices.GetRequiredService>>().Value; - // TODO: Authorize attribute on EndPoint if (!await AuthorizeHelper.AuthorizeAsync(context, options.AuthorizationPolicyNames)) { return; @@ -56,10 +52,7 @@ namespace Microsoft.AspNetCore.Sockets else if (HttpMethods.IsGet(context.Request.Method)) { // GET /{path} - - // Get the end point mapped to this http connection - var endpoint = (EndPoint)context.RequestServices.GetRequiredService(); - await ExecuteEndpointAsync(context, endpoint, options); + await ExecuteEndpointAsync(context, socketDelegate, options); } else { @@ -67,7 +60,7 @@ namespace Microsoft.AspNetCore.Sockets } } - private async Task ExecuteEndpointAsync(HttpContext context, EndPoint endpoint, EndPointOptions options) where TEndPoint : EndPoint + private async Task ExecuteEndpointAsync(HttpContext context, SocketDelegate socketDelegate, HttpSocketOptions options) { var supportedTransports = options.Transports; @@ -94,7 +87,7 @@ namespace Microsoft.AspNetCore.Sockets // We only need to provide the Input channel since writing to the application is handled through /send. var sse = new ServerSentEventsTransport(state.Application.Input, _loggerFactory); - await DoPersistentConnection(endpoint, sse, context, state); + await DoPersistentConnection(socketDelegate, sse, context, state); } else if (context.Features.Get()?.IsWebSocketRequest == true) { @@ -114,7 +107,7 @@ namespace Microsoft.AspNetCore.Sockets var ws = new WebSocketsTransport(options.WebSockets, state.Application, _loggerFactory); - await DoPersistentConnection(endpoint, ws, context, state); + await DoPersistentConnection(socketDelegate, ws, context, state); } else { @@ -183,7 +176,7 @@ namespace Microsoft.AspNetCore.Sockets state.Connection.Metadata[ConnectionMetadataNames.Transport] = TransportType.LongPolling; - state.ApplicationTask = ExecuteApplication(endpoint, state.Connection); + state.ApplicationTask = ExecuteApplication(socketDelegate, state.Connection); } else { @@ -275,7 +268,7 @@ namespace Microsoft.AspNetCore.Sockets return state; } - private async Task DoPersistentConnection(EndPoint endpoint, + private async Task DoPersistentConnection(SocketDelegate socketDelegate, IHttpTransport transport, HttpContext context, ConnectionState state) @@ -310,7 +303,7 @@ namespace Microsoft.AspNetCore.Sockets state.RequestId = context.TraceIdentifier; // Call into the end point passing the connection - state.ApplicationTask = ExecuteApplication(endpoint, state.Connection); + state.ApplicationTask = ExecuteApplication(socketDelegate, state.Connection); // Start the transport state.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted); @@ -326,17 +319,17 @@ namespace Microsoft.AspNetCore.Sockets await _manager.DisposeAndRemoveAsync(state); } - private async Task ExecuteApplication(EndPoint endpoint, Connection connection) + private async Task ExecuteApplication(SocketDelegate socketDelegate, ConnectionContext connection) { // Jump onto the thread pool thread so blocking user code doesn't block the setup of the // connection and transport await AwaitableThreadPool.Yield(); // Running this in an async method turns sync exceptions into async ones - await endpoint.OnConnectedAsync(connection); + await socketDelegate(connection); } - private Task ProcessNegotiate(HttpContext context, EndPointOptions options) where TEndPoint : EndPoint + private Task ProcessNegotiate(HttpContext context, HttpSocketOptions options) { // Set the allowed headers for this resource context.Response.Headers.AppendCommaSeparatedValues("Allow", "GET", "POST", "OPTIONS"); diff --git a/src/Microsoft.AspNetCore.Sockets/HttpDispatcherAppBuilderExtensions.cs b/src/Microsoft.AspNetCore.Sockets/HttpDispatcherAppBuilderExtensions.cs index a9f1b87888..732b0050af 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpDispatcherAppBuilderExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpDispatcherAppBuilderExtensions.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; @@ -36,9 +35,23 @@ namespace Microsoft.AspNetCore.Builder _dispatcher = dispatcher; } - public void MapEndpoint(string path) where TEndPoint : EndPoint + public void MapSocket(string path, Action socketConfig) => + MapSocket(path, new HttpSocketOptions(), socketConfig); + + public void MapSocket(string path, HttpSocketOptions options, Action socketConfig) { - _routes.MapRoute(path, _dispatcher.ExecuteAsync); + var socketBuilder = new SocketBuilder(_routes.ServiceProvider); + socketConfig(socketBuilder); + var socket = socketBuilder.Build(); + _routes.MapRoute(path, c => _dispatcher.ExecuteAsync(c, options, socket)); + } + + public void MapEndPoint(string path) where TEndPoint : EndPoint + { + MapSocket(path, builder => + { + builder.UseEndPoint(); + }); } } } diff --git a/src/Microsoft.AspNetCore.Sockets/HttpSocketBuilderExtensions.cs b/src/Microsoft.AspNetCore.Sockets/HttpSocketBuilderExtensions.cs new file mode 100644 index 0000000000..9aa946ebe3 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/HttpSocketBuilderExtensions.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Sockets +{ + public static class HttpSocketBuilderExtensions + { + public static ISocketBuilder UseEndPoint(this ISocketBuilder socketBuilder) where TEndPoint : EndPoint + { + // This is a terminal middleware, so there's no need to use the 'next' parameter + return socketBuilder.Use((connection, _) => + { + var endpoint = socketBuilder.ApplicationServices.GetRequiredService(); + return endpoint.OnConnectedAsync(connection); + }); + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs b/src/Microsoft.AspNetCore.Sockets/HttpSocketOptions.cs similarity index 87% rename from src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs rename to src/Microsoft.AspNetCore.Sockets/HttpSocketOptions.cs index ec939ff762..5b676f648c 100644 --- a/src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpSocketOptions.cs @@ -5,7 +5,7 @@ using System.Collections.Generic; namespace Microsoft.AspNetCore.Sockets { - public class EndPointOptions where TEndPoint : EndPoint + public class HttpSocketOptions { public IList AuthorizationPolicyNames { get; } = new List(); diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs index c73423b87f..7c4db1a773 100644 --- a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs +++ b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs @@ -14,7 +14,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal // on the same task private TaskCompletionSource _disposeTcs = new TaskCompletionSource(); - public Connection Connection { get; set; } + public ConnectionContext Connection { get; set; } public IChannelConnection Application { get; } public CancellationTokenSource Cancellation { get; set; } @@ -29,7 +29,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal public DateTime LastSeenUtc { get; set; } public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive; - public ConnectionState(Connection connection, IChannelConnection application) + public ConnectionState(ConnectionContext connection, IChannelConnection application) { Connection = connection; Application = application; diff --git a/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj b/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj index 9f37687e10..4cceabdecd 100644 --- a/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj +++ b/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj @@ -12,6 +12,7 @@ + diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs index 8b733dd154..2575514352 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs @@ -8,7 +8,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { public class EchoEndPoint : EndPoint { - public async override Task OnConnectedAsync(Connection connection) + public async override Task OnConnectedAsync(ConnectionContext connection) { await connection.Transport.Output.WriteAsync(await connection.Transport.Input.ReadAsync()); } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 91b51ba380..bf8536a2e8 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -39,7 +39,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var mockLifetimeManager = new Mock>(); mockLifetimeManager - .Setup(m => m.OnConnectedAsync(It.IsAny())) + .Setup(m => m.OnConnectedAsync(It.IsAny())) .Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed.")); var mockHubActivator = new Mock>(); @@ -60,8 +60,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests client.Dispose(); - mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); - mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); // No hubs should be created since the connection is terminated mockHubActivator.Verify(m => m.Create(), Times.Never); mockHubActivator.Verify(m => m.Release(It.IsAny()), Times.Never); @@ -87,8 +87,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests var exception = await Assert.ThrowsAsync(async () => await endPointTask); Assert.Equal("Hub OnConnected failed.", exception.Message); - mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); - mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); } } @@ -111,8 +111,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests var exception = await Assert.ThrowsAsync(async () => await endPointTask); Assert.Equal("Hub OnDisconnected failed.", exception.Message); - mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); - mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/ServerFixture.cs b/test/Microsoft.AspNetCore.SignalR.Tests/ServerFixture.cs index 60dd69183d..5fa3d2d629 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/ServerFixture.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/ServerFixture.cs @@ -46,7 +46,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public void Configure(IApplicationBuilder app, IHostingEnvironment env) { - app.UseSockets(options => options.MapEndpoint("echo")); + app.UseSockets(options => options.MapEndPoint("echo")); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs index 1139ff5295..e9edcdc268 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs @@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests private IHubProtocol _protocol; private CancellationTokenSource _cts; - public Connection Connection; + public ConnectionContext Connection; public IChannelConnection Application { get; } public Task Connected => Connection.Metadata.Get>("ConnectedTask").Task; @@ -33,7 +33,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests Application = ChannelConnection.Create(input: applicationToTransport, output: transportToApplication); var transport = ChannelConnection.Create(input: transportToApplication, output: applicationToTransport); - Connection = new Connection(Guid.NewGuid().ToString(), transport); + Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), transport); Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) })); Connection.Metadata["ConnectedTask"] = new TaskCompletionSource(); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 4176b8f32a..9e635bf156 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -37,12 +37,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddEndPoint(); services.AddOptions(); - context.RequestServices = services.BuildServiceProvider(); var ms = new MemoryStream(); context.Request.Path = "/foo"; context.Request.Method = "OPTIONS"; context.Response.Body = ms; - await dispatcher.ExecuteAsync(context); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); var id = Encoding.UTF8.GetString(ms.ToArray()); @@ -68,7 +70,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddEndPoint(); services.AddOptions(); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; var values = new Dictionary(); @@ -77,7 +78,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.Request.Query = qs; SetTransport(context, transportType); - await dispatcher.ExecuteAsync(context); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode); await strm.FlushAsync(); @@ -100,7 +104,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddEndPoint(); services.AddOptions(); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); @@ -108,7 +111,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests var qs = new QueryCollection(values); context.Request.Query = qs; - await dispatcher.ExecuteAsync(context); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode); await strm.FlushAsync(); @@ -130,13 +136,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddOptions(); services.AddEndPoint(); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; SetTransport(context, transportType); - await dispatcher.ExecuteAsync(context); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); await strm.FlushAsync(); @@ -156,11 +164,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddOptions(); services.AddEndPoint(); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "POST"; - await dispatcher.ExecuteAsync(context); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); await strm.FlushAsync(); @@ -180,14 +190,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddOptions(); services.AddEndPoint(); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "POST"; context.Request.QueryString = new QueryString($"?id={connectionState.Connection.ConnectionId}"); context.Request.ContentType = "text/plain"; context.Response.Body = strm; - await dispatcher.ExecuteAsync(context); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); await strm.FlushAsync(); @@ -237,10 +249,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); SetTransport(context, TransportType.ServerSentEvents); - await dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); @@ -256,11 +273,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests var state = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); SetTransport(context, TransportType.ServerSentEvents); - await dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); @@ -276,10 +297,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests var state = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + var context = MakeRequest("/foo", state); - var context = MakeRequest("/foo", state); - - await dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); @@ -296,9 +321,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); - await dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); @@ -315,10 +345,18 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); SetTransport(context, TransportType.WebSockets); - var task = dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1); + + var task = dispatcher.ExecuteAsync(context, options, app); await task.OrTimeout(); } @@ -333,15 +371,21 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context1 = MakeRequest("/foo", state); - var context2 = MakeRequest("/foo", state); + var context1 = MakeRequest("/foo", state); + var context2 = MakeRequest("/foo", state); SetTransport(context1, transportType); SetTransport(context2, transportType); - var request1 = dispatcher.ExecuteAsync(context1); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + var request1 = dispatcher.ExecuteAsync(context1, options, app); - await dispatcher.ExecuteAsync(context2); + await dispatcher.ExecuteAsync(context2, options, app); Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode); @@ -369,11 +413,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context1 = MakeRequest("/foo", state); - var context2 = MakeRequest("/foo", state); + var context1 = MakeRequest("/foo", state); + var context2 = MakeRequest("/foo", state); - var request1 = dispatcher.ExecuteAsync(context1); - var request2 = dispatcher.ExecuteAsync(context2); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + var request1 = dispatcher.ExecuteAsync(context1, options, app); + var request2 = dispatcher.ExecuteAsync(context2, options, app); await request1; @@ -398,10 +448,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); SetTransport(context, transportType); - await dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + await dispatcher.ExecuteAsync(context, options, app); + Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode); } @@ -414,9 +471,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); - var task = dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + var task = dispatcher.ExecuteAsync(context, options, app); var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -439,10 +502,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); SetTransport(context, TransportType.ServerSentEvents); - var task = dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + var task = dispatcher.ExecuteAsync(context, options, app); var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -465,9 +534,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); - var task = dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + var task = dispatcher.ExecuteAsync(context, options, app); var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -490,10 +565,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context1 = MakeRequest("/foo", state); - var task1 = dispatcher.ExecuteAsync(context1); - var context2 = MakeRequest("/foo", state); - var task2 = dispatcher.ExecuteAsync(context2); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + + var context1 = MakeRequest("/foo", state); + var task1 = dispatcher.ExecuteAsync(context1, options, app); + var context2 = MakeRequest("/foo", state); + var task2 = dispatcher.ExecuteAsync(context2, options, app); // Task 1 should finish when request 2 arrives await task1.OrTimeout(); @@ -555,19 +637,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var services = new ServiceCollection(); services.AddOptions(); - services.AddEndPoint(options => + services.AddEndPoint(); + services.AddAuthorization(o => { - options.AuthorizationPolicyNames.Add("test"); - }); - services.AddAuthorization(options => - { - options.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); + o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); }); services.AddLogging(); - - context.RequestServices = services.BuildServiceProvider(); + var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; + context.RequestServices = sp; var values = new Dictionary(); values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); @@ -576,8 +655,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests authFeature.Handler = new TestAuthenticationHandler(context); context.Features.Set(authFeature); + var builder = new SocketBuilder(sp); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.AuthorizationPolicyNames.Add("test"); + // would hang if EndPoint was running - await dispatcher.ExecuteAsync(context).OrTimeout(); + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode); } @@ -591,22 +676,19 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var services = new ServiceCollection(); services.AddOptions(); - services.AddEndPoint(options => + services.AddEndPoint(); + services.AddAuthorization(o => { - options.AuthorizationPolicyNames.Add("test"); - }); - services.AddAuthorization(options => - { - options.AddPolicy("test", policy => + o.AddPolicy("test", policy => { policy.RequireClaim(ClaimTypes.NameIdentifier); }); }); services.AddLogging(); - - context.RequestServices = services.BuildServiceProvider(); + var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; + context.RequestServices = sp; var values = new Dictionary(); values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); @@ -616,10 +698,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests authFeature.Handler = new TestAuthenticationHandler(context); context.Features.Set(authFeature); + var builder = new SocketBuilder(sp); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.AuthorizationPolicyNames.Add("test"); + // "authorize" user context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); - var endPointTask = dispatcher.ExecuteAsync(context); + var endPointTask = dispatcher.ExecuteAsync(context, options, app); await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout(); await endPointTask.OrTimeout(); @@ -637,21 +725,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var services = new ServiceCollection(); services.AddOptions(); - services.AddEndPoint(options => + services.AddEndPoint(); + services.AddAuthorization(o => { - options.AuthorizationPolicyNames.Add("test"); - options.AuthorizationPolicyNames.Add("secondPolicy"); - }); - services.AddAuthorization(options => - { - options.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); - options.AddPolicy("secondPolicy", policy => policy.RequireClaim(ClaimTypes.StreetAddress)); + o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); + o.AddPolicy("secondPolicy", policy => policy.RequireClaim(ClaimTypes.StreetAddress)); }); services.AddLogging(); - - context.RequestServices = services.BuildServiceProvider(); + var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; + context.RequestServices = sp; var values = new Dictionary(); values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); @@ -661,18 +745,25 @@ namespace Microsoft.AspNetCore.Sockets.Tests authFeature.Handler = new TestAuthenticationHandler(context); context.Features.Set(authFeature); + var builder = new SocketBuilder(sp); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.AuthorizationPolicyNames.Add("test"); + options.AuthorizationPolicyNames.Add("secondPolicy"); + // partialy "authorize" user context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); // would hang if EndPoint was running - await dispatcher.ExecuteAsync(context).OrTimeout(); + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode); // fully "authorize" user context.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") })); - var endPointTask = dispatcher.ExecuteAsync(context); + var endPointTask = dispatcher.ExecuteAsync(context, options, app); await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout(); await endPointTask.OrTimeout(); @@ -689,23 +780,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var services = new ServiceCollection(); services.AddOptions(); - services.AddEndPoint(options => + services.AddEndPoint(); + services.AddAuthorization(o => { - options.AuthorizationPolicyNames.Add("test"); - }); - services.AddAuthorization(options => - { - options.AddPolicy("test", policy => + o.AddPolicy("test", policy => { policy.RequireClaim(ClaimTypes.NameIdentifier); policy.AddAuthenticationSchemes("Default"); }); }); services.AddLogging(); - - context.RequestServices = services.BuildServiceProvider(); + var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; + context.RequestServices = sp; var values = new Dictionary(); values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); @@ -715,10 +803,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests authFeature.Handler = new TestAuthenticationHandler(context); context.Features.Set(authFeature); + var builder = new SocketBuilder(sp); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.AuthorizationPolicyNames.Add("test"); + // "authorize" user context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); - var endPointTask = dispatcher.ExecuteAsync(context); + var endPointTask = dispatcher.ExecuteAsync(context, options, app); await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout(); await endPointTask.OrTimeout(); @@ -736,23 +830,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var services = new ServiceCollection(); services.AddOptions(); - services.AddEndPoint(options => + services.AddEndPoint(); + services.AddAuthorization(o => { - options.AuthorizationPolicyNames.Add("test"); - }); - services.AddAuthorization(options => - { - options.AddPolicy("test", policy => + o.AddPolicy("test", policy => { policy.RequireClaim(ClaimTypes.NameIdentifier); policy.AddAuthenticationSchemes("Default"); }); }); services.AddLogging(); - - context.RequestServices = services.BuildServiceProvider(); + var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; + context.RequestServices = sp; var values = new Dictionary(); values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); @@ -762,11 +853,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests authFeature.Handler = new TestAuthenticationHandler(context, acceptScheme: false); context.Features.Set(authFeature); + var builder = new SocketBuilder(sp); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.AuthorizationPolicyNames.Add("test"); + // "authorize" user context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); // would block if EndPoint was executed - await dispatcher.ExecuteAsync(context).OrTimeout(); + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode); } @@ -829,20 +926,22 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.Response.Body = strm; var services = new ServiceCollection(); services.AddOptions(); - services.AddEndPoint(options => - { - options.Transports = supportedTransports; - }); - + services.AddEndPoint(); SetTransport(context, transportType); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; var values = new Dictionary(); values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); context.Request.Query = qs; - await dispatcher.ExecuteAsync(context); + + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.Transports = supportedTransports; + + await dispatcher.ExecuteAsync(context, options, app); Assert.Equal(status, context.Response.StatusCode); await strm.FlushAsync(); @@ -861,10 +960,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state, format); + var context = MakeRequest("/foo", state, format); context.Request.Method = "POST"; context.Request.ContentType = contentType; - var endPoint = context.RequestServices.GetRequiredService(); + + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); var buffer = contentType == BinaryContentType ? Convert.FromBase64String(encoded) : @@ -872,7 +976,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var messages = new List(); using (context.Request.Body = new MemoryStream(buffer, writable: false)) { - await dispatcher.ExecuteAsync(context).OrTimeout(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app).OrTimeout(); } while (state.Connection.Transport.Input.TryRead(out var message)) @@ -883,18 +987,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests return messages; } - private static DefaultHttpContext MakeRequest(string path, ConnectionState state, string format = null) where TEndPoint : EndPoint + private static DefaultHttpContext MakeRequest(string path, ConnectionState state, string format = null) { var context = new DefaultHttpContext(); - var services = new ServiceCollection(); - services.AddEndPoint(o => - { - // Make the close timeout less than the default for OrTimeout() test helper - o.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1); - }); - - services.AddOptions(); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = path; context.Request.Method = "GET"; var values = new Dictionary(); @@ -942,7 +1037,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public class NerverEndingEndPoint : EndPoint { - public override Task OnConnectedAsync(Connection connection) + public override Task OnConnectedAsync(ConnectionContext connection) { var tcs = new TaskCompletionSource(); return tcs.Task; @@ -951,7 +1046,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public class BlockingEndPoint : EndPoint { - public override Task OnConnectedAsync(Connection connection) + public override Task OnConnectedAsync(ConnectionContext connection) { connection.Transport.Input.WaitToReadAsync().Wait(); return Task.CompletedTask; @@ -960,7 +1055,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public class SynchronusExceptionEndPoint : EndPoint { - public override Task OnConnectedAsync(Connection connection) + public override Task OnConnectedAsync(ConnectionContext connection) { throw new InvalidOperationException(); } @@ -968,7 +1063,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public class ImmediatelyCompleteEndPoint : EndPoint { - public override Task OnConnectedAsync(Connection connection) + public override Task OnConnectedAsync(ConnectionContext connection) { return Task.CompletedTask; } @@ -976,7 +1071,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public class TestEndPoint : EndPoint { - public override async Task OnConnectedAsync(Connection connection) + public override async Task OnConnectedAsync(ConnectionContext connection) { while (await connection.Transport.Input.WaitToReadAsync()) {