diff --git a/SignalR.sln b/SignalR.sln index bdf0240d45..1c7fe2eabf 100644 --- a/SignalR.sln +++ b/SignalR.sln @@ -1,6 +1,6 @@ Microsoft Visual Studio Solution File, Format Version 12.00 # Visual Studio 15 -VisualStudioVersion = 15.0.26411.1 +VisualStudioVersion = 15.0.26510.0 MinimumVisualStudioVersion = 10.0.40219.1 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}" EndProject @@ -76,6 +76,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.Signal EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.SignalR.Common.Tests", "test\Microsoft.AspNetCore.SignalR.Common.Tests\Microsoft.AspNetCore.SignalR.Common.Tests.csproj", "{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.Sockets.Abstractions", "src\Microsoft.AspNetCore.Sockets.Abstractions\Microsoft.AspNetCore.Sockets.Abstractions.csproj", "{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -186,6 +188,10 @@ Global {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Debug|Any CPU.Build.0 = Debug|Any CPU {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.ActiveCfg = Release|Any CPU {75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.Build.0 = Release|Any CPU + {F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -218,5 +224,6 @@ Global {6CEC3DC2-5B01-45A8-8F0D-8531315DA90B} = {6A35B453-52EC-48AF-89CA-D4A69800F131} {96771B3F-4D18-41A7-A75B-FF38E76AAC89} = {6A35B453-52EC-48AF-89CA-D4A69800F131} {75E342F6-5445-4E7E-9143-6D9AE62C2B1E} = {6A35B453-52EC-48AF-89CA-D4A69800F131} + {F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8} = {DA69F624-5398-4884-87E4-B816698CDE65} EndGlobalSection EndGlobal diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs index 7b930e5437..2852c885ae 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs @@ -10,7 +10,7 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server { public class EchoEndPoint : EndPoint { - public async override Task OnConnectedAsync(Connection connection) + public async override Task OnConnectedAsync(ConnectionContext connection) { await connection.Transport.Output.WriteAsync(await connection.Transport.Input.ReadAsync()); } diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs index 494de9a17c..6c30114542 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/Startup.cs @@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server } app.UseFileServer(); - app.UseSockets(options => options.MapEndpoint("echo")); + app.UseSockets(options => options.MapEndPoint("echo")); app.UseSignalR(routes => { routes.MapHub("testhub"); diff --git a/samples/ChatSample/IUserTracker.cs b/samples/ChatSample/IUserTracker.cs index c55e22f8c5..ca46483f72 100644 --- a/samples/ChatSample/IUserTracker.cs +++ b/samples/ChatSample/IUserTracker.cs @@ -11,8 +11,8 @@ namespace ChatSample public interface IUserTracker { Task> UsersOnline(); - Task AddUser(Connection connection, UserDetails userDetails); - Task RemoveUser(Connection connection); + Task AddUser(ConnectionContext connection, UserDetails userDetails); + Task RemoveUser(ConnectionContext connection); event Action UsersJoined; event Action UsersLeft; diff --git a/samples/ChatSample/InMemoryUserTracker.cs b/samples/ChatSample/InMemoryUserTracker.cs index ce7877618c..27ce19d717 100644 --- a/samples/ChatSample/InMemoryUserTracker.cs +++ b/samples/ChatSample/InMemoryUserTracker.cs @@ -9,8 +9,8 @@ namespace ChatSample { public class InMemoryUserTracker : IUserTracker { - private readonly ConcurrentDictionary _usersOnline - = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _usersOnline + = new ConcurrentDictionary(); public event Action UsersJoined; public event Action UsersLeft; @@ -18,7 +18,7 @@ namespace ChatSample public Task> UsersOnline() => Task.FromResult(_usersOnline.Values.AsEnumerable()); - public Task AddUser(Connection connection, UserDetails userDetails) + public Task AddUser(ConnectionContext connection, UserDetails userDetails) { _usersOnline.TryAdd(connection, userDetails); UsersJoined(new[] { userDetails }); @@ -26,7 +26,7 @@ namespace ChatSample return Task.CompletedTask; } - public Task RemoveUser(Connection connection) + public Task RemoveUser(ConnectionContext connection) { if (_usersOnline.TryRemove(connection, out var userDetails)) { diff --git a/samples/ChatSample/PresenceHubLifetimeManager.cs b/samples/ChatSample/PresenceHubLifetimeManager.cs index c918b81e41..fcc168000d 100644 --- a/samples/ChatSample/PresenceHubLifetimeManager.cs +++ b/samples/ChatSample/PresenceHubLifetimeManager.cs @@ -57,14 +57,14 @@ namespace ChatSample _wrappedHubLifetimeManager = serviceProvider.GetRequiredService(); } - public override async Task OnConnectedAsync(Connection connection) + public override async Task OnConnectedAsync(ConnectionContext connection) { await _wrappedHubLifetimeManager.OnConnectedAsync(connection); _connections.Add(connection); await _userTracker.AddUser(connection, new UserDetails(connection.ConnectionId, connection.User.Identity.Name)); } - public override async Task OnDisconnectedAsync(Connection connection) + public override async Task OnDisconnectedAsync(ConnectionContext connection) { await _wrappedHubLifetimeManager.OnDisconnectedAsync(connection); _connections.Remove(connection); @@ -157,12 +157,12 @@ namespace ChatSample return _wrappedHubLifetimeManager.InvokeUserAsync(userId, methodName, args); } - public override Task AddGroupAsync(Connection connection, string groupName) + public override Task AddGroupAsync(ConnectionContext connection, string groupName) { return _wrappedHubLifetimeManager.AddGroupAsync(connection, groupName); } - public override Task RemoveGroupAsync(Connection connection, string groupName) + public override Task RemoveGroupAsync(ConnectionContext connection, string groupName) { return _wrappedHubLifetimeManager.RemoveGroupAsync(connection, groupName); } diff --git a/samples/ChatSample/RedisUserTracker.cs b/samples/ChatSample/RedisUserTracker.cs index 9083cae2fa..06753884ae 100644 --- a/samples/ChatSample/RedisUserTracker.cs +++ b/samples/ChatSample/RedisUserTracker.cs @@ -129,7 +129,7 @@ namespace ChatSample } } - public async Task AddUser(Connection connection, UserDetails userDetails) + public async Task AddUser(ConnectionContext connection, UserDetails userDetails) { var key = GetUserRedisKey(connection); var user = SerializeUser(connection); @@ -156,7 +156,7 @@ namespace ChatSample } } - public async Task RemoveUser(Connection connection) + public async Task RemoveUser(ConnectionContext connection) { await _userSyncSempaphore.WaitAsync(); try @@ -180,7 +180,7 @@ namespace ChatSample } } - private static string GetUserRedisKey(Connection connection) => $"user:{connection.ConnectionId}"; + private static string GetUserRedisKey(ConnectionContext connection) => $"user:{connection.ConnectionId}"; private static void Scan(object state) { @@ -319,7 +319,7 @@ namespace ChatSample } } - private static string SerializeUser(Connection connection) => + private static string SerializeUser(ConnectionContext connection) => $"{{ \"ConnectionID\": \"{connection.ConnectionId}\", \"Name\": \"{connection.User.Identity.Name}\" }}"; private static UserDetails DeserializerUser(string userJson) => diff --git a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs index 899a3fc625..1cbd68c282 100644 --- a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs +++ b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs @@ -20,13 +20,13 @@ namespace SocialWeather _formatterResolver = formatterResolver; } - public void OnConnectedAsync(Connection connection) + public void OnConnectedAsync(ConnectionContext connection) { connection.Metadata[ConnectionMetadataNames.Format] = "json"; _connectionList.Add(connection); } - public void OnDisconnectedAsync(Connection connection) + public void OnDisconnectedAsync(ConnectionContext connection) { _connectionList.Remove(connection); } @@ -64,7 +64,7 @@ namespace SocialWeather throw new NotImplementedException(); } - public void AddGroupAsync(Connection connection, string groupName) + public void AddGroupAsync(ConnectionContext connection, string groupName) { var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet()); lock (groups) @@ -73,7 +73,7 @@ namespace SocialWeather } } - public void RemoveGroupAsync(Connection connection, string groupName) + public void RemoveGroupAsync(ConnectionContext connection, string groupName) { var groups = connection.Metadata.Get>("groups"); if (groups != null) diff --git a/samples/SocialWeather/SocialWeatherEndPoint.cs b/samples/SocialWeather/SocialWeatherEndPoint.cs index a78141fa7e..2e9a0fb8ee 100644 --- a/samples/SocialWeather/SocialWeatherEndPoint.cs +++ b/samples/SocialWeather/SocialWeatherEndPoint.cs @@ -23,14 +23,14 @@ namespace SocialWeather _logger = logger; } - public async override Task OnConnectedAsync(Connection connection) + public async override Task OnConnectedAsync(ConnectionContext connection) { _lifetimeManager.OnConnectedAsync(connection); await ProcessRequests(connection); _lifetimeManager.OnDisconnectedAsync(connection); } - public async Task ProcessRequests(Connection connection) + public async Task ProcessRequests(ConnectionContext connection) { var formatter = _formatterResolver.GetFormatter( connection.Metadata.Get("formatType")); diff --git a/samples/SocialWeather/Startup.cs b/samples/SocialWeather/Startup.cs index a950ef629f..fa2d62bb43 100644 --- a/samples/SocialWeather/Startup.cs +++ b/samples/SocialWeather/Startup.cs @@ -32,7 +32,7 @@ namespace SocialWeather app.UseDeveloperExceptionPage(); } - app.UseSockets(o => { o.MapEndpoint("weather"); }); + app.UseSockets(o => { o.MapEndPoint("weather"); }); app.UseFileServer(); var formatterResolver = app.ApplicationServices.GetRequiredService(); diff --git a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs index 3f68fa8f7d..f6ed46097e 100644 --- a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs +++ b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs @@ -13,7 +13,7 @@ namespace SocketsSample.EndPoints { public ConnectionList Connections { get; } = new ConnectionList(); - public override async Task OnConnectedAsync(Connection connection) + public override async Task OnConnectedAsync(ConnectionContext connection) { Connections.Add(connection); diff --git a/samples/SocketsSample/Startup.cs b/samples/SocketsSample/Startup.cs index 034284f9b1..1acae62ef4 100644 --- a/samples/SocketsSample/Startup.cs +++ b/samples/SocketsSample/Startup.cs @@ -53,7 +53,7 @@ namespace SocketsSample app.UseSockets(routes => { - routes.MapEndpoint("chat"); + routes.MapEndPoint("chat"); }); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index cb847095a6..64e3aa6bbb 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -128,7 +128,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis await _bus.PublishAsync(channel, payload); } - public override Task OnConnectedAsync(Connection connection) + public override Task OnConnectedAsync(ConnectionContext connection) { var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet()); var connectionTask = Task.CompletedTask; @@ -173,7 +173,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis return Task.WhenAll(connectionTask, userTask); } - public override Task OnDisconnectedAsync(Connection connection) + public override Task OnDisconnectedAsync(ConnectionContext connection) { _connections.Remove(connection); @@ -204,7 +204,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis return Task.WhenAll(tasks); } - public override async Task AddGroupAsync(Connection connection, string groupName) + public override async Task AddGroupAsync(ConnectionContext connection, string groupName) { var groupChannel = typeof(THub).FullName + ".group." + groupName; @@ -255,7 +255,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis } } - public override async Task RemoveGroupAsync(Connection connection, string groupName) + public override async Task RemoveGroupAsync(ConnectionContext connection, string groupName) { var groupChannel = typeof(THub).FullName + ".group." + groupName; @@ -297,7 +297,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis _redisServerConnection.Dispose(); } - private async Task WriteAsync(Connection connection, HubMessage hubMessage) + private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage) { var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol); var data = await protocol.WriteToArrayAsync(hubMessage); diff --git a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs index c8b83fe4a2..5105d4ac6c 100644 --- a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs @@ -15,7 +15,7 @@ namespace Microsoft.AspNetCore.SignalR private long _nextInvocationId = 0; private readonly ConnectionList _connections = new ConnectionList(); - public override Task AddGroupAsync(Connection connection, string groupName) + public override Task AddGroupAsync(ConnectionContext connection, string groupName) { var groups = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet()); @@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.SignalR return Task.CompletedTask; } - public override Task RemoveGroupAsync(Connection connection, string groupName) + public override Task RemoveGroupAsync(ConnectionContext connection, string groupName) { var groups = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); @@ -49,7 +49,7 @@ namespace Microsoft.AspNetCore.SignalR return InvokeAllWhere(methodName, args, c => true); } - private Task InvokeAllWhere(string methodName, object[] args, Func include) + private Task InvokeAllWhere(string methodName, object[] args, Func include) { var tasks = new List(_connections.Count); var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); @@ -94,19 +94,19 @@ namespace Microsoft.AspNetCore.SignalR }); } - public override Task OnConnectedAsync(Connection connection) + public override Task OnConnectedAsync(ConnectionContext connection) { _connections.Add(connection); return Task.CompletedTask; } - public override Task OnDisconnectedAsync(Connection connection) + public override Task OnDisconnectedAsync(ConnectionContext connection) { _connections.Remove(connection); return Task.CompletedTask; } - private async Task WriteAsync(Connection connection, HubMessage hubMessage) + private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage) { var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol); var payload = await protocol.WriteToArrayAsync(hubMessage); diff --git a/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs b/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs index 50df41dc6a..f5e7b50cfa 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs @@ -8,12 +8,12 @@ namespace Microsoft.AspNetCore.SignalR { public class HubCallerContext { - public HubCallerContext(Connection connection) + public HubCallerContext(ConnectionContext connection) { Connection = connection; } - public Connection Connection { get; } + public ConnectionContext Connection { get; } public ClaimsPrincipal User => Connection.User; diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index 82af05346f..65c3f14e7f 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -22,10 +22,9 @@ namespace Microsoft.AspNetCore.SignalR public HubEndPoint(HubLifetimeManager lifetimeManager, IHubProtocolResolver protocolResolver, IHubContext hubContext, - IOptions>> endPointOptions, ILogger> logger, IServiceScopeFactory serviceScopeFactory) - : base(lifetimeManager, protocolResolver, hubContext, endPointOptions, logger, serviceScopeFactory) + : base(lifetimeManager, protocolResolver, hubContext, logger, serviceScopeFactory) { } } @@ -43,7 +42,6 @@ namespace Microsoft.AspNetCore.SignalR public HubEndPoint(HubLifetimeManager lifetimeManager, IHubProtocolResolver protocolResolver, IHubContext hubContext, - IOptions>> endPointOptions, ILogger> logger, IServiceScopeFactory serviceScopeFactory) { @@ -56,7 +54,7 @@ namespace Microsoft.AspNetCore.SignalR DiscoverHubMethods(); } - public override async Task OnConnectedAsync(Connection connection) + public override async Task OnConnectedAsync(ConnectionContext connection) { try { @@ -74,7 +72,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task RunHubAsync(Connection connection) + private async Task RunHubAsync(ConnectionContext connection) { await HubOnConnectedAsync(connection); @@ -92,7 +90,7 @@ namespace Microsoft.AspNetCore.SignalR await HubOnDisconnectedAsync(connection, null); } - private async Task HubOnConnectedAsync(Connection connection) + private async Task HubOnConnectedAsync(ConnectionContext connection) { try { @@ -118,7 +116,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task HubOnDisconnectedAsync(Connection connection, Exception exception) + private async Task HubOnDisconnectedAsync(ConnectionContext connection, Exception exception) { try { @@ -144,7 +142,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task DispatchMessagesAsync(Connection connection) + private async Task DispatchMessagesAsync(ConnectionContext connection) { // We use these for error handling. Since we dispatch multiple hub invocations // in parallel, we need a way to communicate failure back to the main processing loop. The @@ -190,7 +188,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task ProcessInvocation(Connection connection, + private async Task ProcessInvocation(ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage, CancellationTokenSource dispatcherCancellation, @@ -212,7 +210,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task Execute(Connection connection, IHubProtocol protocol, InvocationMessage invocationMessage) + private async Task Execute(ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage) { HubMethodDescriptor descriptor; if (!_methods.TryGetValue(invocationMessage.Target, out descriptor)) @@ -228,7 +226,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task SendMessageAsync(Connection connection, IHubProtocol protocol, HubMessage hubMessage) + private async Task SendMessageAsync(ConnectionContext connection, IHubProtocol protocol, HubMessage hubMessage) { var payload = await protocol.WriteToArrayAsync(hubMessage); var message = new Message(payload, protocol.MessageType, endOfMessage: true); @@ -246,7 +244,7 @@ namespace Microsoft.AspNetCore.SignalR throw new OperationCanceledException("Outbound channel was closed while trying to write hub message"); } - private async Task Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationMessage invocationMessage) + private async Task Invoke(HubMethodDescriptor descriptor, ConnectionContext connection, InvocationMessage invocationMessage) { var methodExecutor = descriptor.MethodExecutor; @@ -295,7 +293,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private void InitializeHub(THub hub, Connection connection) + private void InitializeHub(THub hub, ConnectionContext connection) { hub.Clients = _hubContext.Clients; hub.Context = new HubCallerContext(connection); diff --git a/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs index 51a2b4d4f7..2e6f34a01d 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs @@ -8,9 +8,9 @@ namespace Microsoft.AspNetCore.SignalR { public abstract class HubLifetimeManager { - public abstract Task OnConnectedAsync(Connection connection); + public abstract Task OnConnectedAsync(ConnectionContext connection); - public abstract Task OnDisconnectedAsync(Connection connection); + public abstract Task OnDisconnectedAsync(ConnectionContext connection); public abstract Task InvokeAllAsync(string methodName, object[] args); @@ -20,9 +20,9 @@ namespace Microsoft.AspNetCore.SignalR public abstract Task InvokeUserAsync(string userId, string methodName, object[] args); - public abstract Task AddGroupAsync(Connection connection, string groupName); + public abstract Task AddGroupAsync(ConnectionContext connection, string groupName); - public abstract Task RemoveGroupAsync(Connection connection, string groupName); + public abstract Task RemoveGroupAsync(ConnectionContext connection, string groupName); } } diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs index 2d06b8bcee..c4e8957522 100644 --- a/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs +++ b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs @@ -9,7 +9,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal { public class DefaultHubProtocolResolver : IHubProtocolResolver { - public IHubProtocol GetProtocol(Connection connection) + public IHubProtocol GetProtocol(ConnectionContext connection) { // TODO: Allow customization of this serializer! return new JsonHubProtocol(new JsonSerializer()); diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs index c9627d0d59..9d2c1f1bb8 100644 --- a/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs +++ b/src/Microsoft.AspNetCore.SignalR/Internal/IHubProtocolResolver.cs @@ -8,6 +8,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal { public interface IHubProtocolResolver { - IHubProtocol GetProtocol(Connection connection); + IHubProtocol GetProtocol(ConnectionContext connection); } } diff --git a/src/Microsoft.AspNetCore.SignalR/Proxies.cs b/src/Microsoft.AspNetCore.SignalR/Proxies.cs index dfed4a80ff..f0b8423580 100644 --- a/src/Microsoft.AspNetCore.SignalR/Proxies.cs +++ b/src/Microsoft.AspNetCore.SignalR/Proxies.cs @@ -75,10 +75,10 @@ namespace Microsoft.AspNetCore.SignalR public class GroupManager : IGroupManager { - private readonly Connection _connection; + private readonly ConnectionContext _connection; private readonly HubLifetimeManager _lifetimeManager; - public GroupManager(Connection connection, HubLifetimeManager lifetimeManager) + public GroupManager(ConnectionContext connection, HubLifetimeManager lifetimeManager) { _connection = connection; _lifetimeManager = lifetimeManager; diff --git a/src/Microsoft.AspNetCore.SignalR/SignalRAppBuilderExtensions.cs b/src/Microsoft.AspNetCore.SignalR/SignalRAppBuilderExtensions.cs index 1ded6ec809..00cca9cdbd 100644 --- a/src/Microsoft.AspNetCore.SignalR/SignalRAppBuilderExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR/SignalRAppBuilderExtensions.cs @@ -31,7 +31,7 @@ namespace Microsoft.AspNetCore.Builder public void MapHub(string path) where THub : Hub { - _routes.MapEndpoint>(path); + _routes.MapEndPoint>(path); } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs new file mode 100644 index 0000000000..7300980331 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs @@ -0,0 +1,33 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Security.Claims; +using System.Text; +using System.Threading; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Sockets +{ + public abstract class ConnectionContext : IDisposable + { + public abstract string ConnectionId { get; } + + public abstract IFeatureCollection Features { get; } + + public abstract ClaimsPrincipal User { get; set; } + + // REVIEW: Should this be changed to items + public abstract ConnectionMetadata Metadata { get; } + + // TEMPORARY + public abstract IChannelConnection Transport { get; set; } + + // TEMPORARY + public void Dispose() + { + Transport?.Dispose(); + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionMetadata.cs similarity index 100% rename from src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs rename to src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionMetadata.cs diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs new file mode 100644 index 0000000000..b353216d37 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/DefaultConnectionContext.cs @@ -0,0 +1,28 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Security.Claims; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Sockets +{ + public class DefaultConnectionContext : ConnectionContext + { + public DefaultConnectionContext(string id, IChannelConnection transport) + { + Transport = transport; + ConnectionId = id; + } + + public override string ConnectionId { get; } + + public override IFeatureCollection Features { get; } = new FeatureCollection(); + + public override ClaimsPrincipal User { get; set; } + + public override ConnectionMetadata Metadata { get; } = new ConnectionMetadata(); + + public override IChannelConnection Transport { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/EndPoint.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/EndPoint.cs similarity index 90% rename from src/Microsoft.AspNetCore.Sockets/EndPoint.cs rename to src/Microsoft.AspNetCore.Sockets.Abstractions/EndPoint.cs index a975839e34..9108e43d83 100644 --- a/src/Microsoft.AspNetCore.Sockets/EndPoint.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/EndPoint.cs @@ -14,8 +14,8 @@ namespace Microsoft.AspNetCore.Sockets /// /// Called when a new connection is accepted to the endpoint /// - /// The new + /// The new /// A that represents the connection lifetime. When the task completes, the connection is complete. - public abstract Task OnConnectedAsync(Connection connection); + public abstract Task OnConnectedAsync(ConnectionContext connection); } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ISocketBuilder.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ISocketBuilder.cs new file mode 100644 index 0000000000..f7d2f98513 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/ISocketBuilder.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.AspNetCore.Sockets +{ + public interface ISocketBuilder + { + IServiceProvider ApplicationServices { get; } + + ISocketBuilder Use(Func middleware); + + SocketDelegate Build(); + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj b/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj new file mode 100644 index 0000000000..ea17b5e2ea --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Microsoft.AspNetCore.Sockets.Abstractions.csproj @@ -0,0 +1,23 @@ + + + + + + Components for providing real-time bi-directional communication across the Web. + netcoreapp2.0 + $(NoWarn);CS1591 + true + aspnetcore;signalr + false + + + + + + + + + + + + diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilder.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilder.cs new file mode 100644 index 0000000000..8dfc21b848 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilder.cs @@ -0,0 +1,44 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Sockets +{ + public class SocketBuilder : ISocketBuilder + { + private readonly IList> _components = new List>(); + + public IServiceProvider ApplicationServices { get; } + + public SocketBuilder(IServiceProvider applicationServices) + { + ApplicationServices = applicationServices; + } + + public ISocketBuilder Use(Func middleware) + { + _components.Add(middleware); + return this; + } + + public SocketDelegate Build() + { + SocketDelegate app = features => + { + return Task.CompletedTask; + }; + + foreach (var component in _components.Reverse()) + { + app = component(app); + } + + return app; + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilderExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilderExtensions.cs new file mode 100644 index 0000000000..82635ec4da --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketBuilderExtensions.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Sockets +{ + public static class SocketBuilderExtensions + { + public static ISocketBuilder Use(this ISocketBuilder socketBuilder, Func, Task> middleware) + { + return socketBuilder.Use(next => + { + return context => + { + Func simpleNext = () => next(context); + return middleware(context, simpleNext); + }; + }); + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketDelegate.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketDelegate.cs new file mode 100644 index 0000000000..45af07cd77 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/SocketDelegate.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Sockets +{ + public delegate Task SocketDelegate(ConnectionContext connection); +} diff --git a/src/Microsoft.AspNetCore.Sockets/Connection.cs b/src/Microsoft.AspNetCore.Sockets/Connection.cs deleted file mode 100644 index 1508a95eb4..0000000000 --- a/src/Microsoft.AspNetCore.Sockets/Connection.cs +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Security.Claims; - -namespace Microsoft.AspNetCore.Sockets -{ - public class Connection : IDisposable - { - public string ConnectionId { get; } - - public ClaimsPrincipal User { get; set; } - public ConnectionMetadata Metadata { get; } = new ConnectionMetadata(); - - public IChannelConnection Transport { get; } - - public Connection(string id, IChannelConnection transport) - { - Transport = transport; - ConnectionId = id; - } - - public void Dispose() - { - Transport.Dispose(); - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs index 5f5efaadb8..323c0e442e 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs @@ -8,15 +8,15 @@ using System.Collections.Generic; namespace Microsoft.AspNetCore.Sockets { - public class ConnectionList : IReadOnlyCollection + public class ConnectionList : IReadOnlyCollection { - private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); + private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); - public Connection this[string connectionId] + public ConnectionContext this[string connectionId] { get { - Connection connection; + ConnectionContext connection; if (_connections.TryGetValue(connectionId, out connection)) { return connection; @@ -27,18 +27,18 @@ namespace Microsoft.AspNetCore.Sockets public int Count => _connections.Count; - public void Add(Connection connection) + public void Add(ConnectionContext connection) { _connections.TryAdd(connection.ConnectionId, connection); } - public void Remove(Connection connection) + public void Remove(ConnectionContext connection) { - Connection dummy; + ConnectionContext dummy; _connections.TryRemove(connection.ConnectionId, out dummy); } - public IEnumerator GetEnumerator() + public IEnumerator GetEnumerator() { foreach (var item in _connections) { diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index 20b3d1c866..201fb02c8c 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -57,7 +57,7 @@ namespace Microsoft.AspNetCore.Sockets var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); var state = new ConnectionState( - new Connection(id, applicationSide), + new DefaultConnectionContext(id, applicationSide), transportSide); _connections.TryAdd(id, state); diff --git a/src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs index c09c6118fd..dfe59a108c 100644 --- a/src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets/EndPointDependencyInjectionExtensions.cs @@ -14,15 +14,5 @@ namespace Microsoft.Extensions.DependencyInjection return services; } - - public static IServiceCollection AddEndPoint(this IServiceCollection services, - Action> setupAction) where TEndPoint : EndPoint - { - services.AddEndPoint(); - - services.Configure(setupAction); - - return services; - } } } diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index bd38c6c0ee..55ce652fb7 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -14,9 +14,7 @@ using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Sockets.Internal.Formatters; using Microsoft.AspNetCore.Sockets.Transports; using Microsoft.AspNetCore.WebSockets.Internal; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Sockets @@ -34,10 +32,8 @@ namespace Microsoft.AspNetCore.Sockets _logger = _loggerFactory.CreateLogger(); } - public async Task ExecuteAsync(HttpContext context) where TEndPoint : EndPoint + public async Task ExecuteAsync(HttpContext context, HttpSocketOptions options, SocketDelegate socketDelegate) { - var options = context.RequestServices.GetRequiredService>>().Value; - // TODO: Authorize attribute on EndPoint if (!await AuthorizeHelper.AuthorizeAsync(context, options.AuthorizationPolicyNames)) { return; @@ -56,10 +52,7 @@ namespace Microsoft.AspNetCore.Sockets else if (HttpMethods.IsGet(context.Request.Method)) { // GET /{path} - - // Get the end point mapped to this http connection - var endpoint = (EndPoint)context.RequestServices.GetRequiredService(); - await ExecuteEndpointAsync(context, endpoint, options); + await ExecuteEndpointAsync(context, socketDelegate, options); } else { @@ -67,7 +60,7 @@ namespace Microsoft.AspNetCore.Sockets } } - private async Task ExecuteEndpointAsync(HttpContext context, EndPoint endpoint, EndPointOptions options) where TEndPoint : EndPoint + private async Task ExecuteEndpointAsync(HttpContext context, SocketDelegate socketDelegate, HttpSocketOptions options) { var supportedTransports = options.Transports; @@ -94,7 +87,7 @@ namespace Microsoft.AspNetCore.Sockets // We only need to provide the Input channel since writing to the application is handled through /send. var sse = new ServerSentEventsTransport(state.Application.Input, _loggerFactory); - await DoPersistentConnection(endpoint, sse, context, state); + await DoPersistentConnection(socketDelegate, sse, context, state); } else if (context.Features.Get()?.IsWebSocketRequest == true) { @@ -114,7 +107,7 @@ namespace Microsoft.AspNetCore.Sockets var ws = new WebSocketsTransport(options.WebSockets, state.Application, _loggerFactory); - await DoPersistentConnection(endpoint, ws, context, state); + await DoPersistentConnection(socketDelegate, ws, context, state); } else { @@ -183,7 +176,7 @@ namespace Microsoft.AspNetCore.Sockets state.Connection.Metadata[ConnectionMetadataNames.Transport] = TransportType.LongPolling; - state.ApplicationTask = ExecuteApplication(endpoint, state.Connection); + state.ApplicationTask = ExecuteApplication(socketDelegate, state.Connection); } else { @@ -275,7 +268,7 @@ namespace Microsoft.AspNetCore.Sockets return state; } - private async Task DoPersistentConnection(EndPoint endpoint, + private async Task DoPersistentConnection(SocketDelegate socketDelegate, IHttpTransport transport, HttpContext context, ConnectionState state) @@ -310,7 +303,7 @@ namespace Microsoft.AspNetCore.Sockets state.RequestId = context.TraceIdentifier; // Call into the end point passing the connection - state.ApplicationTask = ExecuteApplication(endpoint, state.Connection); + state.ApplicationTask = ExecuteApplication(socketDelegate, state.Connection); // Start the transport state.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted); @@ -326,17 +319,17 @@ namespace Microsoft.AspNetCore.Sockets await _manager.DisposeAndRemoveAsync(state); } - private async Task ExecuteApplication(EndPoint endpoint, Connection connection) + private async Task ExecuteApplication(SocketDelegate socketDelegate, ConnectionContext connection) { // Jump onto the thread pool thread so blocking user code doesn't block the setup of the // connection and transport await AwaitableThreadPool.Yield(); // Running this in an async method turns sync exceptions into async ones - await endpoint.OnConnectedAsync(connection); + await socketDelegate(connection); } - private Task ProcessNegotiate(HttpContext context, EndPointOptions options) where TEndPoint : EndPoint + private Task ProcessNegotiate(HttpContext context, HttpSocketOptions options) { // Set the allowed headers for this resource context.Response.Headers.AppendCommaSeparatedValues("Allow", "GET", "POST", "OPTIONS"); diff --git a/src/Microsoft.AspNetCore.Sockets/HttpDispatcherAppBuilderExtensions.cs b/src/Microsoft.AspNetCore.Sockets/HttpDispatcherAppBuilderExtensions.cs index a9f1b87888..732b0050af 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpDispatcherAppBuilderExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpDispatcherAppBuilderExtensions.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; @@ -36,9 +35,23 @@ namespace Microsoft.AspNetCore.Builder _dispatcher = dispatcher; } - public void MapEndpoint(string path) where TEndPoint : EndPoint + public void MapSocket(string path, Action socketConfig) => + MapSocket(path, new HttpSocketOptions(), socketConfig); + + public void MapSocket(string path, HttpSocketOptions options, Action socketConfig) { - _routes.MapRoute(path, _dispatcher.ExecuteAsync); + var socketBuilder = new SocketBuilder(_routes.ServiceProvider); + socketConfig(socketBuilder); + var socket = socketBuilder.Build(); + _routes.MapRoute(path, c => _dispatcher.ExecuteAsync(c, options, socket)); + } + + public void MapEndPoint(string path) where TEndPoint : EndPoint + { + MapSocket(path, builder => + { + builder.UseEndPoint(); + }); } } } diff --git a/src/Microsoft.AspNetCore.Sockets/HttpSocketBuilderExtensions.cs b/src/Microsoft.AspNetCore.Sockets/HttpSocketBuilderExtensions.cs new file mode 100644 index 0000000000..9aa946ebe3 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/HttpSocketBuilderExtensions.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Sockets +{ + public static class HttpSocketBuilderExtensions + { + public static ISocketBuilder UseEndPoint(this ISocketBuilder socketBuilder) where TEndPoint : EndPoint + { + // This is a terminal middleware, so there's no need to use the 'next' parameter + return socketBuilder.Use((connection, _) => + { + var endpoint = socketBuilder.ApplicationServices.GetRequiredService(); + return endpoint.OnConnectedAsync(connection); + }); + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs b/src/Microsoft.AspNetCore.Sockets/HttpSocketOptions.cs similarity index 87% rename from src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs rename to src/Microsoft.AspNetCore.Sockets/HttpSocketOptions.cs index ec939ff762..5b676f648c 100644 --- a/src/Microsoft.AspNetCore.Sockets/EndPointOptions.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpSocketOptions.cs @@ -5,7 +5,7 @@ using System.Collections.Generic; namespace Microsoft.AspNetCore.Sockets { - public class EndPointOptions where TEndPoint : EndPoint + public class HttpSocketOptions { public IList AuthorizationPolicyNames { get; } = new List(); diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs index c73423b87f..7c4db1a773 100644 --- a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs +++ b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs @@ -14,7 +14,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal // on the same task private TaskCompletionSource _disposeTcs = new TaskCompletionSource(); - public Connection Connection { get; set; } + public ConnectionContext Connection { get; set; } public IChannelConnection Application { get; } public CancellationTokenSource Cancellation { get; set; } @@ -29,7 +29,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal public DateTime LastSeenUtc { get; set; } public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive; - public ConnectionState(Connection connection, IChannelConnection application) + public ConnectionState(ConnectionContext connection, IChannelConnection application) { Connection = connection; Application = application; diff --git a/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj b/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj index 9f37687e10..4cceabdecd 100644 --- a/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj +++ b/src/Microsoft.AspNetCore.Sockets/Microsoft.AspNetCore.Sockets.csproj @@ -12,6 +12,7 @@ + diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs index 8b733dd154..2575514352 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EchoEndPoint.cs @@ -8,7 +8,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { public class EchoEndPoint : EndPoint { - public async override Task OnConnectedAsync(Connection connection) + public async override Task OnConnectedAsync(ConnectionContext connection) { await connection.Transport.Output.WriteAsync(await connection.Transport.Input.ReadAsync()); } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 91b51ba380..bf8536a2e8 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -39,7 +39,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var mockLifetimeManager = new Mock>(); mockLifetimeManager - .Setup(m => m.OnConnectedAsync(It.IsAny())) + .Setup(m => m.OnConnectedAsync(It.IsAny())) .Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed.")); var mockHubActivator = new Mock>(); @@ -60,8 +60,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests client.Dispose(); - mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); - mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); // No hubs should be created since the connection is terminated mockHubActivator.Verify(m => m.Create(), Times.Never); mockHubActivator.Verify(m => m.Release(It.IsAny()), Times.Never); @@ -87,8 +87,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests var exception = await Assert.ThrowsAsync(async () => await endPointTask); Assert.Equal("Hub OnConnected failed.", exception.Message); - mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); - mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); } } @@ -111,8 +111,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests var exception = await Assert.ThrowsAsync(async () => await endPointTask); Assert.Equal("Hub OnDisconnected failed.", exception.Message); - mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); - mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny()), Times.Once); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/ServerFixture.cs b/test/Microsoft.AspNetCore.SignalR.Tests/ServerFixture.cs index 60dd69183d..5fa3d2d629 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/ServerFixture.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/ServerFixture.cs @@ -46,7 +46,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public void Configure(IApplicationBuilder app, IHostingEnvironment env) { - app.UseSockets(options => options.MapEndpoint("echo")); + app.UseSockets(options => options.MapEndPoint("echo")); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs index 1139ff5295..e9edcdc268 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs @@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests private IHubProtocol _protocol; private CancellationTokenSource _cts; - public Connection Connection; + public ConnectionContext Connection; public IChannelConnection Application { get; } public Task Connected => Connection.Metadata.Get>("ConnectedTask").Task; @@ -33,7 +33,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests Application = ChannelConnection.Create(input: applicationToTransport, output: transportToApplication); var transport = ChannelConnection.Create(input: transportToApplication, output: applicationToTransport); - Connection = new Connection(Guid.NewGuid().ToString(), transport); + Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), transport); Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) })); Connection.Metadata["ConnectedTask"] = new TaskCompletionSource(); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 4176b8f32a..9e635bf156 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -37,12 +37,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddEndPoint(); services.AddOptions(); - context.RequestServices = services.BuildServiceProvider(); var ms = new MemoryStream(); context.Request.Path = "/foo"; context.Request.Method = "OPTIONS"; context.Response.Body = ms; - await dispatcher.ExecuteAsync(context); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); var id = Encoding.UTF8.GetString(ms.ToArray()); @@ -68,7 +70,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddEndPoint(); services.AddOptions(); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; var values = new Dictionary(); @@ -77,7 +78,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.Request.Query = qs; SetTransport(context, transportType); - await dispatcher.ExecuteAsync(context); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode); await strm.FlushAsync(); @@ -100,7 +104,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddEndPoint(); services.AddOptions(); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "POST"; var values = new Dictionary(); @@ -108,7 +111,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests var qs = new QueryCollection(values); context.Request.Query = qs; - await dispatcher.ExecuteAsync(context); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode); await strm.FlushAsync(); @@ -130,13 +136,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddOptions(); services.AddEndPoint(); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; SetTransport(context, transportType); - await dispatcher.ExecuteAsync(context); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); await strm.FlushAsync(); @@ -156,11 +164,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddOptions(); services.AddEndPoint(); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "POST"; - await dispatcher.ExecuteAsync(context); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); await strm.FlushAsync(); @@ -180,14 +190,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddOptions(); services.AddEndPoint(); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "POST"; context.Request.QueryString = new QueryString($"?id={connectionState.Connection.ConnectionId}"); context.Request.ContentType = "text/plain"; context.Response.Body = strm; - await dispatcher.ExecuteAsync(context); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); await strm.FlushAsync(); @@ -237,10 +249,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); SetTransport(context, TransportType.ServerSentEvents); - await dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); @@ -256,11 +273,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests var state = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); SetTransport(context, TransportType.ServerSentEvents); - await dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); @@ -276,10 +297,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests var state = manager.CreateConnection(); var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + var context = MakeRequest("/foo", state); - var context = MakeRequest("/foo", state); - - await dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); @@ -296,9 +321,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); - await dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); @@ -315,10 +345,18 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); SetTransport(context, TransportType.WebSockets); - var task = dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1); + + var task = dispatcher.ExecuteAsync(context, options, app); await task.OrTimeout(); } @@ -333,15 +371,21 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context1 = MakeRequest("/foo", state); - var context2 = MakeRequest("/foo", state); + var context1 = MakeRequest("/foo", state); + var context2 = MakeRequest("/foo", state); SetTransport(context1, transportType); SetTransport(context2, transportType); - var request1 = dispatcher.ExecuteAsync(context1); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + var request1 = dispatcher.ExecuteAsync(context1, options, app); - await dispatcher.ExecuteAsync(context2); + await dispatcher.ExecuteAsync(context2, options, app); Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode); @@ -369,11 +413,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context1 = MakeRequest("/foo", state); - var context2 = MakeRequest("/foo", state); + var context1 = MakeRequest("/foo", state); + var context2 = MakeRequest("/foo", state); - var request1 = dispatcher.ExecuteAsync(context1); - var request2 = dispatcher.ExecuteAsync(context2); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + var request1 = dispatcher.ExecuteAsync(context1, options, app); + var request2 = dispatcher.ExecuteAsync(context2, options, app); await request1; @@ -398,10 +448,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); SetTransport(context, transportType); - await dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + await dispatcher.ExecuteAsync(context, options, app); + Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode); } @@ -414,9 +471,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); - var task = dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + var task = dispatcher.ExecuteAsync(context, options, app); var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -439,10 +502,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); SetTransport(context, TransportType.ServerSentEvents); - var task = dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + var task = dispatcher.ExecuteAsync(context, options, app); var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -465,9 +534,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state); + var context = MakeRequest("/foo", state); - var task = dispatcher.ExecuteAsync(context); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + var task = dispatcher.ExecuteAsync(context, options, app); var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -490,10 +565,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context1 = MakeRequest("/foo", state); - var task1 = dispatcher.ExecuteAsync(context1); - var context2 = MakeRequest("/foo", state); - var task2 = dispatcher.ExecuteAsync(context2); + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + + var context1 = MakeRequest("/foo", state); + var task1 = dispatcher.ExecuteAsync(context1, options, app); + var context2 = MakeRequest("/foo", state); + var task2 = dispatcher.ExecuteAsync(context2, options, app); // Task 1 should finish when request 2 arrives await task1.OrTimeout(); @@ -555,19 +637,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var services = new ServiceCollection(); services.AddOptions(); - services.AddEndPoint(options => + services.AddEndPoint(); + services.AddAuthorization(o => { - options.AuthorizationPolicyNames.Add("test"); - }); - services.AddAuthorization(options => - { - options.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); + o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); }); services.AddLogging(); - - context.RequestServices = services.BuildServiceProvider(); + var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; + context.RequestServices = sp; var values = new Dictionary(); values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); @@ -576,8 +655,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests authFeature.Handler = new TestAuthenticationHandler(context); context.Features.Set(authFeature); + var builder = new SocketBuilder(sp); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.AuthorizationPolicyNames.Add("test"); + // would hang if EndPoint was running - await dispatcher.ExecuteAsync(context).OrTimeout(); + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode); } @@ -591,22 +676,19 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var services = new ServiceCollection(); services.AddOptions(); - services.AddEndPoint(options => + services.AddEndPoint(); + services.AddAuthorization(o => { - options.AuthorizationPolicyNames.Add("test"); - }); - services.AddAuthorization(options => - { - options.AddPolicy("test", policy => + o.AddPolicy("test", policy => { policy.RequireClaim(ClaimTypes.NameIdentifier); }); }); services.AddLogging(); - - context.RequestServices = services.BuildServiceProvider(); + var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; + context.RequestServices = sp; var values = new Dictionary(); values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); @@ -616,10 +698,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests authFeature.Handler = new TestAuthenticationHandler(context); context.Features.Set(authFeature); + var builder = new SocketBuilder(sp); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.AuthorizationPolicyNames.Add("test"); + // "authorize" user context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); - var endPointTask = dispatcher.ExecuteAsync(context); + var endPointTask = dispatcher.ExecuteAsync(context, options, app); await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout(); await endPointTask.OrTimeout(); @@ -637,21 +725,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var services = new ServiceCollection(); services.AddOptions(); - services.AddEndPoint(options => + services.AddEndPoint(); + services.AddAuthorization(o => { - options.AuthorizationPolicyNames.Add("test"); - options.AuthorizationPolicyNames.Add("secondPolicy"); - }); - services.AddAuthorization(options => - { - options.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); - options.AddPolicy("secondPolicy", policy => policy.RequireClaim(ClaimTypes.StreetAddress)); + o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); + o.AddPolicy("secondPolicy", policy => policy.RequireClaim(ClaimTypes.StreetAddress)); }); services.AddLogging(); - - context.RequestServices = services.BuildServiceProvider(); + var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; + context.RequestServices = sp; var values = new Dictionary(); values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); @@ -661,18 +745,25 @@ namespace Microsoft.AspNetCore.Sockets.Tests authFeature.Handler = new TestAuthenticationHandler(context); context.Features.Set(authFeature); + var builder = new SocketBuilder(sp); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.AuthorizationPolicyNames.Add("test"); + options.AuthorizationPolicyNames.Add("secondPolicy"); + // partialy "authorize" user context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); // would hang if EndPoint was running - await dispatcher.ExecuteAsync(context).OrTimeout(); + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode); // fully "authorize" user context.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") })); - var endPointTask = dispatcher.ExecuteAsync(context); + var endPointTask = dispatcher.ExecuteAsync(context, options, app); await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout(); await endPointTask.OrTimeout(); @@ -689,23 +780,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var services = new ServiceCollection(); services.AddOptions(); - services.AddEndPoint(options => + services.AddEndPoint(); + services.AddAuthorization(o => { - options.AuthorizationPolicyNames.Add("test"); - }); - services.AddAuthorization(options => - { - options.AddPolicy("test", policy => + o.AddPolicy("test", policy => { policy.RequireClaim(ClaimTypes.NameIdentifier); policy.AddAuthenticationSchemes("Default"); }); }); services.AddLogging(); - - context.RequestServices = services.BuildServiceProvider(); + var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; + context.RequestServices = sp; var values = new Dictionary(); values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); @@ -715,10 +803,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests authFeature.Handler = new TestAuthenticationHandler(context); context.Features.Set(authFeature); + var builder = new SocketBuilder(sp); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.AuthorizationPolicyNames.Add("test"); + // "authorize" user context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); - var endPointTask = dispatcher.ExecuteAsync(context); + var endPointTask = dispatcher.ExecuteAsync(context, options, app); await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout(); await endPointTask.OrTimeout(); @@ -736,23 +830,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var services = new ServiceCollection(); services.AddOptions(); - services.AddEndPoint(options => + services.AddEndPoint(); + services.AddAuthorization(o => { - options.AuthorizationPolicyNames.Add("test"); - }); - services.AddAuthorization(options => - { - options.AddPolicy("test", policy => + o.AddPolicy("test", policy => { policy.RequireClaim(ClaimTypes.NameIdentifier); policy.AddAuthenticationSchemes("Default"); }); }); services.AddLogging(); - - context.RequestServices = services.BuildServiceProvider(); + var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; + context.RequestServices = sp; var values = new Dictionary(); values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); @@ -762,11 +853,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests authFeature.Handler = new TestAuthenticationHandler(context, acceptScheme: false); context.Features.Set(authFeature); + var builder = new SocketBuilder(sp); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.AuthorizationPolicyNames.Add("test"); + // "authorize" user context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); // would block if EndPoint was executed - await dispatcher.ExecuteAsync(context).OrTimeout(); + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode); } @@ -829,20 +926,22 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.Response.Body = strm; var services = new ServiceCollection(); services.AddOptions(); - services.AddEndPoint(options => - { - options.Transports = supportedTransports; - }); - + services.AddEndPoint(); SetTransport(context, transportType); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; var values = new Dictionary(); values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); context.Request.Query = qs; - await dispatcher.ExecuteAsync(context); + + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.Transports = supportedTransports; + + await dispatcher.ExecuteAsync(context, options, app); Assert.Equal(status, context.Response.StatusCode); await strm.FlushAsync(); @@ -861,10 +960,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = MakeRequest("/foo", state, format); + var context = MakeRequest("/foo", state, format); context.Request.Method = "POST"; context.Request.ContentType = contentType; - var endPoint = context.RequestServices.GetRequiredService(); + + var services = new ServiceCollection(); + services.AddEndPoint(); + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); var buffer = contentType == BinaryContentType ? Convert.FromBase64String(encoded) : @@ -872,7 +976,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var messages = new List(); using (context.Request.Body = new MemoryStream(buffer, writable: false)) { - await dispatcher.ExecuteAsync(context).OrTimeout(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app).OrTimeout(); } while (state.Connection.Transport.Input.TryRead(out var message)) @@ -883,18 +987,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests return messages; } - private static DefaultHttpContext MakeRequest(string path, ConnectionState state, string format = null) where TEndPoint : EndPoint + private static DefaultHttpContext MakeRequest(string path, ConnectionState state, string format = null) { var context = new DefaultHttpContext(); - var services = new ServiceCollection(); - services.AddEndPoint(o => - { - // Make the close timeout less than the default for OrTimeout() test helper - o.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1); - }); - - services.AddOptions(); - context.RequestServices = services.BuildServiceProvider(); context.Request.Path = path; context.Request.Method = "GET"; var values = new Dictionary(); @@ -942,7 +1037,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public class NerverEndingEndPoint : EndPoint { - public override Task OnConnectedAsync(Connection connection) + public override Task OnConnectedAsync(ConnectionContext connection) { var tcs = new TaskCompletionSource(); return tcs.Task; @@ -951,7 +1046,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public class BlockingEndPoint : EndPoint { - public override Task OnConnectedAsync(Connection connection) + public override Task OnConnectedAsync(ConnectionContext connection) { connection.Transport.Input.WaitToReadAsync().Wait(); return Task.CompletedTask; @@ -960,7 +1055,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public class SynchronusExceptionEndPoint : EndPoint { - public override Task OnConnectedAsync(Connection connection) + public override Task OnConnectedAsync(ConnectionContext connection) { throw new InvalidOperationException(); } @@ -968,7 +1063,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public class ImmediatelyCompleteEndPoint : EndPoint { - public override Task OnConnectedAsync(Connection connection) + public override Task OnConnectedAsync(ConnectionContext connection) { return Task.CompletedTask; } @@ -976,7 +1071,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests public class TestEndPoint : EndPoint { - public override async Task OnConnectedAsync(Connection connection) + public override async Task OnConnectedAsync(ConnectionContext connection) { while (await connection.Transport.Input.WaitToReadAsync()) {