Progress towards splitting the layers (#473)
* Progress towards splitting the layers - This is based on the work anurse did in anurse/endpoint-middleware-spike to introduce a connection middleware pipeline that mimics much of our http pipeline. The intent is that this layer will be generic enough to build both SignalR and Kestrel on top of but we're not there yet. This change makes incremental progress towards splitting apart sockets and http so that we can add the tcp transport without breaking everything all at once. - Created Microsoft.AspNetCore.Sockets.Abstractions where the primitives for sockets live. That includes, ConnectionContext (formerly Connection), EndPoint, ISocketBuilder, SocketDelegate, etc. - ConnectionContext isn't in it's final form as yet, it still very closely mirrors the original Connection object we had so that tests continue to pass. - The HttpConnectionDispatcher doesn't know about EndPoint anymore, it just cares about invoking the SocketDelegate. - EndPointOptions has been removed as part of this change as it coupled http specific configuration to the end point type. There's a new HttpSocketOptions that needs to be passed into MapSocket calls. - Updated the tests to deal with the API changes.
This commit is contained in:
parent
e68a1b294f
commit
9d9a52119e
|
|
@ -1,6 +1,6 @@
|
|||
Microsoft Visual Studio Solution File, Format Version 12.00
|
||||
# Visual Studio 15
|
||||
VisualStudioVersion = 15.0.26411.1
|
||||
VisualStudioVersion = 15.0.26510.0
|
||||
MinimumVisualStudioVersion = 10.0.40219.1
|
||||
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}"
|
||||
EndProject
|
||||
|
|
@ -76,6 +76,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.Signal
|
|||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.SignalR.Common.Tests", "test\Microsoft.AspNetCore.SignalR.Common.Tests\Microsoft.AspNetCore.SignalR.Common.Tests.csproj", "{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}"
|
||||
EndProject
|
||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.Sockets.Abstractions", "src\Microsoft.AspNetCore.Sockets.Abstractions\Microsoft.AspNetCore.Sockets.Abstractions.csproj", "{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}"
|
||||
EndProject
|
||||
Global
|
||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution
|
||||
Debug|Any CPU = Debug|Any CPU
|
||||
|
|
@ -186,6 +188,10 @@ Global
|
|||
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
|
||||
{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Debug|Any CPU.Build.0 = Debug|Any CPU
|
||||
{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Release|Any CPU.ActiveCfg = Release|Any CPU
|
||||
{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Release|Any CPU.Build.0 = Release|Any CPU
|
||||
EndGlobalSection
|
||||
GlobalSection(SolutionProperties) = preSolution
|
||||
HideSolutionNode = FALSE
|
||||
|
|
@ -218,5 +224,6 @@ Global
|
|||
{6CEC3DC2-5B01-45A8-8F0D-8531315DA90B} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
|
||||
{96771B3F-4D18-41A7-A75B-FF38E76AAC89} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
|
||||
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
|
||||
{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8} = {DA69F624-5398-4884-87E4-B816698CDE65}
|
||||
EndGlobalSection
|
||||
EndGlobal
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server
|
|||
{
|
||||
public class EchoEndPoint : EndPoint
|
||||
{
|
||||
public async override Task OnConnectedAsync(Connection connection)
|
||||
public async override Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
await connection.Transport.Output.WriteAsync(await connection.Transport.Input.ReadAsync());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server
|
|||
}
|
||||
|
||||
app.UseFileServer();
|
||||
app.UseSockets(options => options.MapEndpoint<EchoEndPoint>("echo"));
|
||||
app.UseSockets(options => options.MapEndPoint<EchoEndPoint>("echo"));
|
||||
app.UseSignalR(routes =>
|
||||
{
|
||||
routes.MapHub<TestHub>("testhub");
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ namespace ChatSample
|
|||
public interface IUserTracker<out THub>
|
||||
{
|
||||
Task<IEnumerable<UserDetails>> UsersOnline();
|
||||
Task AddUser(Connection connection, UserDetails userDetails);
|
||||
Task RemoveUser(Connection connection);
|
||||
Task AddUser(ConnectionContext connection, UserDetails userDetails);
|
||||
Task RemoveUser(ConnectionContext connection);
|
||||
|
||||
event Action<UserDetails[]> UsersJoined;
|
||||
event Action<UserDetails[]> UsersLeft;
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ namespace ChatSample
|
|||
{
|
||||
public class InMemoryUserTracker<THub> : IUserTracker<THub>
|
||||
{
|
||||
private readonly ConcurrentDictionary<Connection, UserDetails> _usersOnline
|
||||
= new ConcurrentDictionary<Connection, UserDetails>();
|
||||
private readonly ConcurrentDictionary<ConnectionContext, UserDetails> _usersOnline
|
||||
= new ConcurrentDictionary<ConnectionContext, UserDetails>();
|
||||
|
||||
public event Action<UserDetails[]> UsersJoined;
|
||||
public event Action<UserDetails[]> UsersLeft;
|
||||
|
|
@ -18,7 +18,7 @@ namespace ChatSample
|
|||
public Task<IEnumerable<UserDetails>> UsersOnline()
|
||||
=> Task.FromResult(_usersOnline.Values.AsEnumerable());
|
||||
|
||||
public Task AddUser(Connection connection, UserDetails userDetails)
|
||||
public Task AddUser(ConnectionContext connection, UserDetails userDetails)
|
||||
{
|
||||
_usersOnline.TryAdd(connection, userDetails);
|
||||
UsersJoined(new[] { userDetails });
|
||||
|
|
@ -26,7 +26,7 @@ namespace ChatSample
|
|||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public Task RemoveUser(Connection connection)
|
||||
public Task RemoveUser(ConnectionContext connection)
|
||||
{
|
||||
if (_usersOnline.TryRemove(connection, out var userDetails))
|
||||
{
|
||||
|
|
|
|||
|
|
@ -57,14 +57,14 @@ namespace ChatSample
|
|||
_wrappedHubLifetimeManager = serviceProvider.GetRequiredService<THubLifetimeManager>();
|
||||
}
|
||||
|
||||
public override async Task OnConnectedAsync(Connection connection)
|
||||
public override async Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
await _wrappedHubLifetimeManager.OnConnectedAsync(connection);
|
||||
_connections.Add(connection);
|
||||
await _userTracker.AddUser(connection, new UserDetails(connection.ConnectionId, connection.User.Identity.Name));
|
||||
}
|
||||
|
||||
public override async Task OnDisconnectedAsync(Connection connection)
|
||||
public override async Task OnDisconnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
await _wrappedHubLifetimeManager.OnDisconnectedAsync(connection);
|
||||
_connections.Remove(connection);
|
||||
|
|
@ -157,12 +157,12 @@ namespace ChatSample
|
|||
return _wrappedHubLifetimeManager.InvokeUserAsync(userId, methodName, args);
|
||||
}
|
||||
|
||||
public override Task AddGroupAsync(Connection connection, string groupName)
|
||||
public override Task AddGroupAsync(ConnectionContext connection, string groupName)
|
||||
{
|
||||
return _wrappedHubLifetimeManager.AddGroupAsync(connection, groupName);
|
||||
}
|
||||
|
||||
public override Task RemoveGroupAsync(Connection connection, string groupName)
|
||||
public override Task RemoveGroupAsync(ConnectionContext connection, string groupName)
|
||||
{
|
||||
return _wrappedHubLifetimeManager.RemoveGroupAsync(connection, groupName);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ namespace ChatSample
|
|||
}
|
||||
}
|
||||
|
||||
public async Task AddUser(Connection connection, UserDetails userDetails)
|
||||
public async Task AddUser(ConnectionContext connection, UserDetails userDetails)
|
||||
{
|
||||
var key = GetUserRedisKey(connection);
|
||||
var user = SerializeUser(connection);
|
||||
|
|
@ -156,7 +156,7 @@ namespace ChatSample
|
|||
}
|
||||
}
|
||||
|
||||
public async Task RemoveUser(Connection connection)
|
||||
public async Task RemoveUser(ConnectionContext connection)
|
||||
{
|
||||
await _userSyncSempaphore.WaitAsync();
|
||||
try
|
||||
|
|
@ -180,7 +180,7 @@ namespace ChatSample
|
|||
}
|
||||
}
|
||||
|
||||
private static string GetUserRedisKey(Connection connection) => $"user:{connection.ConnectionId}";
|
||||
private static string GetUserRedisKey(ConnectionContext connection) => $"user:{connection.ConnectionId}";
|
||||
|
||||
private static void Scan(object state)
|
||||
{
|
||||
|
|
@ -319,7 +319,7 @@ namespace ChatSample
|
|||
}
|
||||
}
|
||||
|
||||
private static string SerializeUser(Connection connection) =>
|
||||
private static string SerializeUser(ConnectionContext connection) =>
|
||||
$"{{ \"ConnectionID\": \"{connection.ConnectionId}\", \"Name\": \"{connection.User.Identity.Name}\" }}";
|
||||
|
||||
private static UserDetails DeserializerUser(string userJson) =>
|
||||
|
|
|
|||
|
|
@ -20,13 +20,13 @@ namespace SocialWeather
|
|||
_formatterResolver = formatterResolver;
|
||||
}
|
||||
|
||||
public void OnConnectedAsync(Connection connection)
|
||||
public void OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
connection.Metadata[ConnectionMetadataNames.Format] = "json";
|
||||
_connectionList.Add(connection);
|
||||
}
|
||||
|
||||
public void OnDisconnectedAsync(Connection connection)
|
||||
public void OnDisconnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
_connectionList.Remove(connection);
|
||||
}
|
||||
|
|
@ -64,7 +64,7 @@ namespace SocialWeather
|
|||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public void AddGroupAsync(Connection connection, string groupName)
|
||||
public void AddGroupAsync(ConnectionContext connection, string groupName)
|
||||
{
|
||||
var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>());
|
||||
lock (groups)
|
||||
|
|
@ -73,7 +73,7 @@ namespace SocialWeather
|
|||
}
|
||||
}
|
||||
|
||||
public void RemoveGroupAsync(Connection connection, string groupName)
|
||||
public void RemoveGroupAsync(ConnectionContext connection, string groupName)
|
||||
{
|
||||
var groups = connection.Metadata.Get<HashSet<string>>("groups");
|
||||
if (groups != null)
|
||||
|
|
|
|||
|
|
@ -23,14 +23,14 @@ namespace SocialWeather
|
|||
_logger = logger;
|
||||
}
|
||||
|
||||
public async override Task OnConnectedAsync(Connection connection)
|
||||
public async override Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
_lifetimeManager.OnConnectedAsync(connection);
|
||||
await ProcessRequests(connection);
|
||||
_lifetimeManager.OnDisconnectedAsync(connection);
|
||||
}
|
||||
|
||||
public async Task ProcessRequests(Connection connection)
|
||||
public async Task ProcessRequests(ConnectionContext connection)
|
||||
{
|
||||
var formatter = _formatterResolver.GetFormatter<WeatherReport>(
|
||||
connection.Metadata.Get<string>("formatType"));
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ namespace SocialWeather
|
|||
app.UseDeveloperExceptionPage();
|
||||
}
|
||||
|
||||
app.UseSockets(o => { o.MapEndpoint<SocialWeatherEndPoint>("weather"); });
|
||||
app.UseSockets(o => { o.MapEndPoint<SocialWeatherEndPoint>("weather"); });
|
||||
app.UseFileServer();
|
||||
|
||||
var formatterResolver = app.ApplicationServices.GetRequiredService<FormatterResolver>();
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ namespace SocketsSample.EndPoints
|
|||
{
|
||||
public ConnectionList Connections { get; } = new ConnectionList();
|
||||
|
||||
public override async Task OnConnectedAsync(Connection connection)
|
||||
public override async Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
Connections.Add(connection);
|
||||
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ namespace SocketsSample
|
|||
|
||||
app.UseSockets(routes =>
|
||||
{
|
||||
routes.MapEndpoint<MessagesEndPoint>("chat");
|
||||
routes.MapEndPoint<MessagesEndPoint>("chat");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
await _bus.PublishAsync(channel, payload);
|
||||
}
|
||||
|
||||
public override Task OnConnectedAsync(Connection connection)
|
||||
public override Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet<string>());
|
||||
var connectionTask = Task.CompletedTask;
|
||||
|
|
@ -173,7 +173,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
return Task.WhenAll(connectionTask, userTask);
|
||||
}
|
||||
|
||||
public override Task OnDisconnectedAsync(Connection connection)
|
||||
public override Task OnDisconnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
_connections.Remove(connection);
|
||||
|
||||
|
|
@ -204,7 +204,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
return Task.WhenAll(tasks);
|
||||
}
|
||||
|
||||
public override async Task AddGroupAsync(Connection connection, string groupName)
|
||||
public override async Task AddGroupAsync(ConnectionContext connection, string groupName)
|
||||
{
|
||||
var groupChannel = typeof(THub).FullName + ".group." + groupName;
|
||||
|
||||
|
|
@ -255,7 +255,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
}
|
||||
}
|
||||
|
||||
public override async Task RemoveGroupAsync(Connection connection, string groupName)
|
||||
public override async Task RemoveGroupAsync(ConnectionContext connection, string groupName)
|
||||
{
|
||||
var groupChannel = typeof(THub).FullName + ".group." + groupName;
|
||||
|
||||
|
|
@ -297,7 +297,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
_redisServerConnection.Dispose();
|
||||
}
|
||||
|
||||
private async Task WriteAsync(Connection connection, HubMessage hubMessage)
|
||||
private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage)
|
||||
{
|
||||
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
|
||||
var data = await protocol.WriteToArrayAsync(hubMessage);
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
private long _nextInvocationId = 0;
|
||||
private readonly ConnectionList _connections = new ConnectionList();
|
||||
|
||||
public override Task AddGroupAsync(Connection connection, string groupName)
|
||||
public override Task AddGroupAsync(ConnectionContext connection, string groupName)
|
||||
{
|
||||
var groups = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet<string>());
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public override Task RemoveGroupAsync(Connection connection, string groupName)
|
||||
public override Task RemoveGroupAsync(ConnectionContext connection, string groupName)
|
||||
{
|
||||
var groups = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
return InvokeAllWhere(methodName, args, c => true);
|
||||
}
|
||||
|
||||
private Task InvokeAllWhere(string methodName, object[] args, Func<Connection, bool> include)
|
||||
private Task InvokeAllWhere(string methodName, object[] args, Func<ConnectionContext, bool> include)
|
||||
{
|
||||
var tasks = new List<Task>(_connections.Count);
|
||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||
|
|
@ -94,19 +94,19 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
});
|
||||
}
|
||||
|
||||
public override Task OnConnectedAsync(Connection connection)
|
||||
public override Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
_connections.Add(connection);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public override Task OnDisconnectedAsync(Connection connection)
|
||||
public override Task OnDisconnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
_connections.Remove(connection);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private async Task WriteAsync(Connection connection, HubMessage hubMessage)
|
||||
private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage)
|
||||
{
|
||||
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
|
||||
var payload = await protocol.WriteToArrayAsync(hubMessage);
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
{
|
||||
public class HubCallerContext
|
||||
{
|
||||
public HubCallerContext(Connection connection)
|
||||
public HubCallerContext(ConnectionContext connection)
|
||||
{
|
||||
Connection = connection;
|
||||
}
|
||||
|
||||
public Connection Connection { get; }
|
||||
public ConnectionContext Connection { get; }
|
||||
|
||||
public ClaimsPrincipal User => Connection.User;
|
||||
|
||||
|
|
|
|||
|
|
@ -22,10 +22,9 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
|
||||
IHubProtocolResolver protocolResolver,
|
||||
IHubContext<THub> hubContext,
|
||||
IOptions<EndPointOptions<HubEndPoint<THub, IClientProxy>>> endPointOptions,
|
||||
ILogger<HubEndPoint<THub>> logger,
|
||||
IServiceScopeFactory serviceScopeFactory)
|
||||
: base(lifetimeManager, protocolResolver, hubContext, endPointOptions, logger, serviceScopeFactory)
|
||||
: base(lifetimeManager, protocolResolver, hubContext, logger, serviceScopeFactory)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
|
@ -43,7 +42,6 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
|
||||
IHubProtocolResolver protocolResolver,
|
||||
IHubContext<THub, TClient> hubContext,
|
||||
IOptions<EndPointOptions<HubEndPoint<THub, TClient>>> endPointOptions,
|
||||
ILogger<HubEndPoint<THub, TClient>> logger,
|
||||
IServiceScopeFactory serviceScopeFactory)
|
||||
{
|
||||
|
|
@ -56,7 +54,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
DiscoverHubMethods();
|
||||
}
|
||||
|
||||
public override async Task OnConnectedAsync(Connection connection)
|
||||
public override async Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
try
|
||||
{
|
||||
|
|
@ -74,7 +72,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
}
|
||||
|
||||
private async Task RunHubAsync(Connection connection)
|
||||
private async Task RunHubAsync(ConnectionContext connection)
|
||||
{
|
||||
await HubOnConnectedAsync(connection);
|
||||
|
||||
|
|
@ -92,7 +90,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
await HubOnDisconnectedAsync(connection, null);
|
||||
}
|
||||
|
||||
private async Task HubOnConnectedAsync(Connection connection)
|
||||
private async Task HubOnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
try
|
||||
{
|
||||
|
|
@ -118,7 +116,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
}
|
||||
|
||||
private async Task HubOnDisconnectedAsync(Connection connection, Exception exception)
|
||||
private async Task HubOnDisconnectedAsync(ConnectionContext connection, Exception exception)
|
||||
{
|
||||
try
|
||||
{
|
||||
|
|
@ -144,7 +142,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
}
|
||||
|
||||
private async Task DispatchMessagesAsync(Connection connection)
|
||||
private async Task DispatchMessagesAsync(ConnectionContext connection)
|
||||
{
|
||||
// We use these for error handling. Since we dispatch multiple hub invocations
|
||||
// in parallel, we need a way to communicate failure back to the main processing loop. The
|
||||
|
|
@ -190,7 +188,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
}
|
||||
|
||||
private async Task ProcessInvocation(Connection connection,
|
||||
private async Task ProcessInvocation(ConnectionContext connection,
|
||||
IHubProtocol protocol,
|
||||
InvocationMessage invocationMessage,
|
||||
CancellationTokenSource dispatcherCancellation,
|
||||
|
|
@ -212,7 +210,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
}
|
||||
|
||||
private async Task Execute(Connection connection, IHubProtocol protocol, InvocationMessage invocationMessage)
|
||||
private async Task Execute(ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage)
|
||||
{
|
||||
HubMethodDescriptor descriptor;
|
||||
if (!_methods.TryGetValue(invocationMessage.Target, out descriptor))
|
||||
|
|
@ -228,7 +226,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
}
|
||||
|
||||
private async Task SendMessageAsync(Connection connection, IHubProtocol protocol, HubMessage hubMessage)
|
||||
private async Task SendMessageAsync(ConnectionContext connection, IHubProtocol protocol, HubMessage hubMessage)
|
||||
{
|
||||
var payload = await protocol.WriteToArrayAsync(hubMessage);
|
||||
var message = new Message(payload, protocol.MessageType, endOfMessage: true);
|
||||
|
|
@ -246,7 +244,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
throw new OperationCanceledException("Outbound channel was closed while trying to write hub message");
|
||||
}
|
||||
|
||||
private async Task<CompletionMessage> Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationMessage invocationMessage)
|
||||
private async Task<CompletionMessage> Invoke(HubMethodDescriptor descriptor, ConnectionContext connection, InvocationMessage invocationMessage)
|
||||
{
|
||||
var methodExecutor = descriptor.MethodExecutor;
|
||||
|
||||
|
|
@ -295,7 +293,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
}
|
||||
|
||||
private void InitializeHub(THub hub, Connection connection)
|
||||
private void InitializeHub(THub hub, ConnectionContext connection)
|
||||
{
|
||||
hub.Clients = _hubContext.Clients;
|
||||
hub.Context = new HubCallerContext(connection);
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
{
|
||||
public abstract class HubLifetimeManager<THub>
|
||||
{
|
||||
public abstract Task OnConnectedAsync(Connection connection);
|
||||
public abstract Task OnConnectedAsync(ConnectionContext connection);
|
||||
|
||||
public abstract Task OnDisconnectedAsync(Connection connection);
|
||||
public abstract Task OnDisconnectedAsync(ConnectionContext connection);
|
||||
|
||||
public abstract Task InvokeAllAsync(string methodName, object[] args);
|
||||
|
||||
|
|
@ -20,9 +20,9 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
public abstract Task InvokeUserAsync(string userId, string methodName, object[] args);
|
||||
|
||||
public abstract Task AddGroupAsync(Connection connection, string groupName);
|
||||
public abstract Task AddGroupAsync(ConnectionContext connection, string groupName);
|
||||
|
||||
public abstract Task RemoveGroupAsync(Connection connection, string groupName);
|
||||
public abstract Task RemoveGroupAsync(ConnectionContext connection, string groupName);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
{
|
||||
public class DefaultHubProtocolResolver : IHubProtocolResolver
|
||||
{
|
||||
public IHubProtocol GetProtocol(Connection connection)
|
||||
public IHubProtocol GetProtocol(ConnectionContext connection)
|
||||
{
|
||||
// TODO: Allow customization of this serializer!
|
||||
return new JsonHubProtocol(new JsonSerializer());
|
||||
|
|
|
|||
|
|
@ -8,6 +8,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
{
|
||||
public interface IHubProtocolResolver
|
||||
{
|
||||
IHubProtocol GetProtocol(Connection connection);
|
||||
IHubProtocol GetProtocol(ConnectionContext connection);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -75,10 +75,10 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
public class GroupManager<THub> : IGroupManager
|
||||
{
|
||||
private readonly Connection _connection;
|
||||
private readonly ConnectionContext _connection;
|
||||
private readonly HubLifetimeManager<THub> _lifetimeManager;
|
||||
|
||||
public GroupManager(Connection connection, HubLifetimeManager<THub> lifetimeManager)
|
||||
public GroupManager(ConnectionContext connection, HubLifetimeManager<THub> lifetimeManager)
|
||||
{
|
||||
_connection = connection;
|
||||
_lifetimeManager = lifetimeManager;
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ namespace Microsoft.AspNetCore.Builder
|
|||
|
||||
public void MapHub<THub>(string path) where THub : Hub<IClientProxy>
|
||||
{
|
||||
_routes.MapEndpoint<HubEndPoint<THub>>(path);
|
||||
_routes.MapEndPoint<HubEndPoint<THub>>(path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Security.Claims;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using Microsoft.AspNetCore.Http.Features;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public abstract class ConnectionContext : IDisposable
|
||||
{
|
||||
public abstract string ConnectionId { get; }
|
||||
|
||||
public abstract IFeatureCollection Features { get; }
|
||||
|
||||
public abstract ClaimsPrincipal User { get; set; }
|
||||
|
||||
// REVIEW: Should this be changed to items
|
||||
public abstract ConnectionMetadata Metadata { get; }
|
||||
|
||||
// TEMPORARY
|
||||
public abstract IChannelConnection<Message> Transport { get; set; }
|
||||
|
||||
// TEMPORARY
|
||||
public void Dispose()
|
||||
{
|
||||
Transport?.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Security.Claims;
|
||||
using Microsoft.AspNetCore.Http.Features;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public class DefaultConnectionContext : ConnectionContext
|
||||
{
|
||||
public DefaultConnectionContext(string id, IChannelConnection<Message> transport)
|
||||
{
|
||||
Transport = transport;
|
||||
ConnectionId = id;
|
||||
}
|
||||
|
||||
public override string ConnectionId { get; }
|
||||
|
||||
public override IFeatureCollection Features { get; } = new FeatureCollection();
|
||||
|
||||
public override ClaimsPrincipal User { get; set; }
|
||||
|
||||
public override ConnectionMetadata Metadata { get; } = new ConnectionMetadata();
|
||||
|
||||
public override IChannelConnection<Message> Transport { get; set; }
|
||||
}
|
||||
}
|
||||
|
|
@ -14,8 +14,8 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
/// <summary>
|
||||
/// Called when a new connection is accepted to the endpoint
|
||||
/// </summary>
|
||||
/// <param name="connection">The new <see cref="Connection"/></param>
|
||||
/// <param name="connection">The new <see cref="ConnectionContext"/></param>
|
||||
/// <returns>A <see cref="Task"/> that represents the connection lifetime. When the task completes, the connection is complete.</returns>
|
||||
public abstract Task OnConnectedAsync(Connection connection);
|
||||
public abstract Task OnConnectedAsync(ConnectionContext connection);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public interface ISocketBuilder
|
||||
{
|
||||
IServiceProvider ApplicationServices { get; }
|
||||
|
||||
ISocketBuilder Use(Func<SocketDelegate, SocketDelegate> middleware);
|
||||
|
||||
SocketDelegate Build();
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<Import Project="..\..\build\common.props" />
|
||||
|
||||
<PropertyGroup>
|
||||
<Description>Components for providing real-time bi-directional communication across the Web.</Description>
|
||||
<TargetFramework>netcoreapp2.0</TargetFramework>
|
||||
<NoWarn>$(NoWarn);CS1591</NoWarn>
|
||||
<GenerateDocumentationFile>true</GenerateDocumentationFile>
|
||||
<PackageTags>aspnetcore;signalr</PackageTags>
|
||||
<EnableApiCheck>false</EnableApiCheck>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.AspNetCore.Http.Features" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="System.Threading.Tasks.Channels" Version="$(CoreFxLabsVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common\Microsoft.AspNetCore.Sockets.Common.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public class SocketBuilder : ISocketBuilder
|
||||
{
|
||||
private readonly IList<Func<SocketDelegate, SocketDelegate>> _components = new List<Func<SocketDelegate, SocketDelegate>>();
|
||||
|
||||
public IServiceProvider ApplicationServices { get; }
|
||||
|
||||
public SocketBuilder(IServiceProvider applicationServices)
|
||||
{
|
||||
ApplicationServices = applicationServices;
|
||||
}
|
||||
|
||||
public ISocketBuilder Use(Func<SocketDelegate, SocketDelegate> middleware)
|
||||
{
|
||||
_components.Add(middleware);
|
||||
return this;
|
||||
}
|
||||
|
||||
public SocketDelegate Build()
|
||||
{
|
||||
SocketDelegate app = features =>
|
||||
{
|
||||
return Task.CompletedTask;
|
||||
};
|
||||
|
||||
foreach (var component in _components.Reverse())
|
||||
{
|
||||
app = component(app);
|
||||
}
|
||||
|
||||
return app;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public static class SocketBuilderExtensions
|
||||
{
|
||||
public static ISocketBuilder Use(this ISocketBuilder socketBuilder, Func<ConnectionContext, Func<Task>, Task> middleware)
|
||||
{
|
||||
return socketBuilder.Use(next =>
|
||||
{
|
||||
return context =>
|
||||
{
|
||||
Func<Task> simpleNext = () => next(context);
|
||||
return middleware(context, simpleNext);
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public delegate Task SocketDelegate(ConnectionContext connection);
|
||||
}
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Security.Claims;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public class Connection : IDisposable
|
||||
{
|
||||
public string ConnectionId { get; }
|
||||
|
||||
public ClaimsPrincipal User { get; set; }
|
||||
public ConnectionMetadata Metadata { get; } = new ConnectionMetadata();
|
||||
|
||||
public IChannelConnection<Message> Transport { get; }
|
||||
|
||||
public Connection(string id, IChannelConnection<Message> transport)
|
||||
{
|
||||
Transport = transport;
|
||||
ConnectionId = id;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
Transport.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -8,15 +8,15 @@ using System.Collections.Generic;
|
|||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public class ConnectionList : IReadOnlyCollection<Connection>
|
||||
public class ConnectionList : IReadOnlyCollection<ConnectionContext>
|
||||
{
|
||||
private readonly ConcurrentDictionary<string, Connection> _connections = new ConcurrentDictionary<string, Connection>();
|
||||
private readonly ConcurrentDictionary<string, ConnectionContext> _connections = new ConcurrentDictionary<string, ConnectionContext>();
|
||||
|
||||
public Connection this[string connectionId]
|
||||
public ConnectionContext this[string connectionId]
|
||||
{
|
||||
get
|
||||
{
|
||||
Connection connection;
|
||||
ConnectionContext connection;
|
||||
if (_connections.TryGetValue(connectionId, out connection))
|
||||
{
|
||||
return connection;
|
||||
|
|
@ -27,18 +27,18 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
public int Count => _connections.Count;
|
||||
|
||||
public void Add(Connection connection)
|
||||
public void Add(ConnectionContext connection)
|
||||
{
|
||||
_connections.TryAdd(connection.ConnectionId, connection);
|
||||
}
|
||||
|
||||
public void Remove(Connection connection)
|
||||
public void Remove(ConnectionContext connection)
|
||||
{
|
||||
Connection dummy;
|
||||
ConnectionContext dummy;
|
||||
_connections.TryRemove(connection.ConnectionId, out dummy);
|
||||
}
|
||||
|
||||
public IEnumerator<Connection> GetEnumerator()
|
||||
public IEnumerator<ConnectionContext> GetEnumerator()
|
||||
{
|
||||
foreach (var item in _connections)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
|
||||
|
||||
var state = new ConnectionState(
|
||||
new Connection(id, applicationSide),
|
||||
new DefaultConnectionContext(id, applicationSide),
|
||||
transportSide);
|
||||
|
||||
_connections.TryAdd(id, state);
|
||||
|
|
|
|||
|
|
@ -14,15 +14,5 @@ namespace Microsoft.Extensions.DependencyInjection
|
|||
|
||||
return services;
|
||||
}
|
||||
|
||||
public static IServiceCollection AddEndPoint<TEndPoint>(this IServiceCollection services,
|
||||
Action<EndPointOptions<TEndPoint>> setupAction) where TEndPoint : EndPoint
|
||||
{
|
||||
services.AddEndPoint<TEndPoint>();
|
||||
|
||||
services.Configure(setupAction);
|
||||
|
||||
return services;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,9 +14,7 @@ using Microsoft.AspNetCore.Sockets.Internal;
|
|||
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
|
||||
using Microsoft.AspNetCore.Sockets.Transports;
|
||||
using Microsoft.AspNetCore.WebSockets.Internal;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Microsoft.Extensions.Primitives;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
|
|
@ -34,10 +32,8 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
_logger = _loggerFactory.CreateLogger<HttpConnectionDispatcher>();
|
||||
}
|
||||
|
||||
public async Task ExecuteAsync<TEndPoint>(HttpContext context) where TEndPoint : EndPoint
|
||||
public async Task ExecuteAsync(HttpContext context, HttpSocketOptions options, SocketDelegate socketDelegate)
|
||||
{
|
||||
var options = context.RequestServices.GetRequiredService<IOptions<EndPointOptions<TEndPoint>>>().Value;
|
||||
// TODO: Authorize attribute on EndPoint
|
||||
if (!await AuthorizeHelper.AuthorizeAsync(context, options.AuthorizationPolicyNames))
|
||||
{
|
||||
return;
|
||||
|
|
@ -56,10 +52,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
else if (HttpMethods.IsGet(context.Request.Method))
|
||||
{
|
||||
// GET /{path}
|
||||
|
||||
// Get the end point mapped to this http connection
|
||||
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
|
||||
await ExecuteEndpointAsync(context, endpoint, options);
|
||||
await ExecuteEndpointAsync(context, socketDelegate, options);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
@ -67,7 +60,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
}
|
||||
}
|
||||
|
||||
private async Task ExecuteEndpointAsync<TEndPoint>(HttpContext context, EndPoint endpoint, EndPointOptions<TEndPoint> options) where TEndPoint : EndPoint
|
||||
private async Task ExecuteEndpointAsync(HttpContext context, SocketDelegate socketDelegate, HttpSocketOptions options)
|
||||
{
|
||||
var supportedTransports = options.Transports;
|
||||
|
||||
|
|
@ -94,7 +87,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
// We only need to provide the Input channel since writing to the application is handled through /send.
|
||||
var sse = new ServerSentEventsTransport(state.Application.Input, _loggerFactory);
|
||||
|
||||
await DoPersistentConnection(endpoint, sse, context, state);
|
||||
await DoPersistentConnection(socketDelegate, sse, context, state);
|
||||
}
|
||||
else if (context.Features.Get<IHttpWebSocketConnectionFeature>()?.IsWebSocketRequest == true)
|
||||
{
|
||||
|
|
@ -114,7 +107,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
var ws = new WebSocketsTransport(options.WebSockets, state.Application, _loggerFactory);
|
||||
|
||||
await DoPersistentConnection(endpoint, ws, context, state);
|
||||
await DoPersistentConnection(socketDelegate, ws, context, state);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
@ -183,7 +176,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
state.Connection.Metadata[ConnectionMetadataNames.Transport] = TransportType.LongPolling;
|
||||
|
||||
state.ApplicationTask = ExecuteApplication(endpoint, state.Connection);
|
||||
state.ApplicationTask = ExecuteApplication(socketDelegate, state.Connection);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
@ -275,7 +268,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
return state;
|
||||
}
|
||||
|
||||
private async Task DoPersistentConnection(EndPoint endpoint,
|
||||
private async Task DoPersistentConnection(SocketDelegate socketDelegate,
|
||||
IHttpTransport transport,
|
||||
HttpContext context,
|
||||
ConnectionState state)
|
||||
|
|
@ -310,7 +303,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
state.RequestId = context.TraceIdentifier;
|
||||
|
||||
// Call into the end point passing the connection
|
||||
state.ApplicationTask = ExecuteApplication(endpoint, state.Connection);
|
||||
state.ApplicationTask = ExecuteApplication(socketDelegate, state.Connection);
|
||||
|
||||
// Start the transport
|
||||
state.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted);
|
||||
|
|
@ -326,17 +319,17 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
await _manager.DisposeAndRemoveAsync(state);
|
||||
}
|
||||
|
||||
private async Task ExecuteApplication(EndPoint endpoint, Connection connection)
|
||||
private async Task ExecuteApplication(SocketDelegate socketDelegate, ConnectionContext connection)
|
||||
{
|
||||
// Jump onto the thread pool thread so blocking user code doesn't block the setup of the
|
||||
// connection and transport
|
||||
await AwaitableThreadPool.Yield();
|
||||
|
||||
// Running this in an async method turns sync exceptions into async ones
|
||||
await endpoint.OnConnectedAsync(connection);
|
||||
await socketDelegate(connection);
|
||||
}
|
||||
|
||||
private Task ProcessNegotiate<TEndPoint>(HttpContext context, EndPointOptions<TEndPoint> options) where TEndPoint : EndPoint
|
||||
private Task ProcessNegotiate(HttpContext context, HttpSocketOptions options)
|
||||
{
|
||||
// Set the allowed headers for this resource
|
||||
context.Response.Headers.AppendCommaSeparatedValues("Allow", "GET", "POST", "OPTIONS");
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Routing;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
|
@ -36,9 +35,23 @@ namespace Microsoft.AspNetCore.Builder
|
|||
_dispatcher = dispatcher;
|
||||
}
|
||||
|
||||
public void MapEndpoint<TEndPoint>(string path) where TEndPoint : EndPoint
|
||||
public void MapSocket(string path, Action<ISocketBuilder> socketConfig) =>
|
||||
MapSocket(path, new HttpSocketOptions(), socketConfig);
|
||||
|
||||
public void MapSocket(string path, HttpSocketOptions options, Action<ISocketBuilder> socketConfig)
|
||||
{
|
||||
_routes.MapRoute(path, _dispatcher.ExecuteAsync<TEndPoint>);
|
||||
var socketBuilder = new SocketBuilder(_routes.ServiceProvider);
|
||||
socketConfig(socketBuilder);
|
||||
var socket = socketBuilder.Build();
|
||||
_routes.MapRoute(path, c => _dispatcher.ExecuteAsync(c, options, socket));
|
||||
}
|
||||
|
||||
public void MapEndPoint<TEndPoint>(string path) where TEndPoint : EndPoint
|
||||
{
|
||||
MapSocket(path, builder =>
|
||||
{
|
||||
builder.UseEndPoint<TEndPoint>();
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public static class HttpSocketBuilderExtensions
|
||||
{
|
||||
public static ISocketBuilder UseEndPoint<TEndPoint>(this ISocketBuilder socketBuilder) where TEndPoint : EndPoint
|
||||
{
|
||||
// This is a terminal middleware, so there's no need to use the 'next' parameter
|
||||
return socketBuilder.Use((connection, _) =>
|
||||
{
|
||||
var endpoint = socketBuilder.ApplicationServices.GetRequiredService<TEndPoint>();
|
||||
return endpoint.OnConnectedAsync(connection);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -5,7 +5,7 @@ using System.Collections.Generic;
|
|||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public class EndPointOptions<TEndPoint> where TEndPoint : EndPoint
|
||||
public class HttpSocketOptions
|
||||
{
|
||||
public IList<string> AuthorizationPolicyNames { get; } = new List<string>();
|
||||
|
||||
|
|
@ -14,7 +14,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
|
|||
// on the same task
|
||||
private TaskCompletionSource<object> _disposeTcs = new TaskCompletionSource<object>();
|
||||
|
||||
public Connection Connection { get; set; }
|
||||
public ConnectionContext Connection { get; set; }
|
||||
public IChannelConnection<Message> Application { get; }
|
||||
|
||||
public CancellationTokenSource Cancellation { get; set; }
|
||||
|
|
@ -29,7 +29,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
|
|||
public DateTime LastSeenUtc { get; set; }
|
||||
public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive;
|
||||
|
||||
public ConnectionState(Connection connection, IChannelConnection<Message> application)
|
||||
public ConnectionState(ConnectionContext connection, IChannelConnection<Message> application)
|
||||
{
|
||||
Connection = connection;
|
||||
Application = application;
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Abstractions\Microsoft.AspNetCore.Sockets.Abstractions.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common\Microsoft.AspNetCore.Sockets.Common.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.WebSockets.Internal\Microsoft.AspNetCore.WebSockets.Internal.csproj" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Authorization" Version="$(AspNetCoreVersion)" />
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
public class EchoEndPoint : EndPoint
|
||||
{
|
||||
public async override Task OnConnectedAsync(Connection connection)
|
||||
public async override Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
await connection.Transport.Output.WriteAsync(await connection.Transport.Input.ReadAsync());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var mockLifetimeManager = new Mock<HubLifetimeManager<Hub>>();
|
||||
mockLifetimeManager
|
||||
.Setup(m => m.OnConnectedAsync(It.IsAny<Connection>()))
|
||||
.Setup(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()))
|
||||
.Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed."));
|
||||
var mockHubActivator = new Mock<IHubActivator<Hub, IClientProxy>>();
|
||||
|
||||
|
|
@ -60,8 +60,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
client.Dispose();
|
||||
|
||||
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
|
||||
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
|
||||
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
|
||||
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
|
||||
// No hubs should be created since the connection is terminated
|
||||
mockHubActivator.Verify(m => m.Create(), Times.Never);
|
||||
mockHubActivator.Verify(m => m.Release(It.IsAny<Hub>()), Times.Never);
|
||||
|
|
@ -87,8 +87,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
|
||||
Assert.Equal("Hub OnConnected failed.", exception.Message);
|
||||
|
||||
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
|
||||
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
|
||||
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
|
||||
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -111,8 +111,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
|
||||
Assert.Equal("Hub OnDisconnected failed.", exception.Message);
|
||||
|
||||
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
|
||||
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
|
||||
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
|
||||
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
public void Configure(IApplicationBuilder app, IHostingEnvironment env)
|
||||
{
|
||||
app.UseSockets(options => options.MapEndpoint<EchoEndPoint>("echo"));
|
||||
app.UseSockets(options => options.MapEndPoint<EchoEndPoint>("echo"));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
private IHubProtocol _protocol;
|
||||
private CancellationTokenSource _cts;
|
||||
|
||||
public Connection Connection;
|
||||
public ConnectionContext Connection;
|
||||
public IChannelConnection<Message> Application { get; }
|
||||
public Task Connected => Connection.Metadata.Get<TaskCompletionSource<bool>>("ConnectedTask").Task;
|
||||
|
||||
|
|
@ -33,7 +33,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
Application = ChannelConnection.Create<Message>(input: applicationToTransport, output: transportToApplication);
|
||||
var transport = ChannelConnection.Create<Message>(input: transportToApplication, output: applicationToTransport);
|
||||
|
||||
Connection = new Connection(Guid.NewGuid().ToString(), transport);
|
||||
Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), transport);
|
||||
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) }));
|
||||
Connection.Metadata["ConnectedTask"] = new TaskCompletionSource<bool>();
|
||||
|
||||
|
|
|
|||
|
|
@ -37,12 +37,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddOptions();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
var ms = new MemoryStream();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "OPTIONS";
|
||||
context.Response.Body = ms;
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>(context);
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
|
||||
|
||||
var id = Encoding.UTF8.GetString(ms.ToArray());
|
||||
|
||||
|
|
@ -68,7 +70,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddOptions();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "GET";
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
|
|
@ -77,7 +78,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
context.Request.Query = qs;
|
||||
SetTransport(context, transportType);
|
||||
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>(context);
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
|
||||
|
||||
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
|
||||
await strm.FlushAsync();
|
||||
|
|
@ -100,7 +104,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddOptions();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "POST";
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
|
|
@ -108,7 +111,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>(context);
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
|
||||
|
||||
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
|
||||
await strm.FlushAsync();
|
||||
|
|
@ -130,13 +136,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "GET";
|
||||
|
||||
SetTransport(context, transportType);
|
||||
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>(context);
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
|
||||
|
||||
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
|
||||
await strm.FlushAsync();
|
||||
|
|
@ -156,11 +164,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "POST";
|
||||
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>(context);
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
|
||||
|
||||
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
|
||||
await strm.FlushAsync();
|
||||
|
|
@ -180,14 +190,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "POST";
|
||||
context.Request.QueryString = new QueryString($"?id={connectionState.Connection.ConnectionId}");
|
||||
context.Request.ContentType = "text/plain";
|
||||
context.Response.Body = strm;
|
||||
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>(context);
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
|
||||
|
||||
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
|
||||
await strm.FlushAsync();
|
||||
|
|
@ -237,10 +249,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/foo", state);
|
||||
var context = MakeRequest("/foo", state);
|
||||
SetTransport(context, TransportType.ServerSentEvents);
|
||||
|
||||
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>(context);
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<ImmediatelyCompleteEndPoint>();
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<ImmediatelyCompleteEndPoint>();
|
||||
var app = builder.Build();
|
||||
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
|
||||
|
||||
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
|
||||
|
||||
|
|
@ -256,11 +273,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var state = manager.CreateConnection();
|
||||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<SynchronusExceptionEndPoint>("/foo", state);
|
||||
var context = MakeRequest("/foo", state);
|
||||
SetTransport(context, TransportType.ServerSentEvents);
|
||||
|
||||
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>(context);
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<SynchronusExceptionEndPoint>();
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<SynchronusExceptionEndPoint>();
|
||||
var app = builder.Build();
|
||||
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
|
||||
|
||||
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
|
||||
|
||||
|
|
@ -276,10 +297,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var state = manager.CreateConnection();
|
||||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
var context = MakeRequest("/foo", state);
|
||||
|
||||
var context = MakeRequest<SynchronusExceptionEndPoint>("/foo", state);
|
||||
|
||||
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>(context);
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<SynchronusExceptionEndPoint>();
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<SynchronusExceptionEndPoint>();
|
||||
var app = builder.Build();
|
||||
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
|
||||
|
||||
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
|
||||
|
||||
|
|
@ -296,9 +321,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/foo", state);
|
||||
var context = MakeRequest("/foo", state);
|
||||
|
||||
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>(context);
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<ImmediatelyCompleteEndPoint>();
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<ImmediatelyCompleteEndPoint>();
|
||||
var app = builder.Build();
|
||||
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
|
||||
|
||||
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
|
||||
|
||||
|
|
@ -315,10 +345,18 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/foo", state);
|
||||
var context = MakeRequest("/foo", state);
|
||||
SetTransport(context, TransportType.WebSockets);
|
||||
|
||||
var task = dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>(context);
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<ImmediatelyCompleteEndPoint>();
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<ImmediatelyCompleteEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1);
|
||||
|
||||
var task = dispatcher.ExecuteAsync(context, options, app);
|
||||
|
||||
await task.OrTimeout();
|
||||
}
|
||||
|
|
@ -333,15 +371,21 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context1 = MakeRequest<TestEndPoint>("/foo", state);
|
||||
var context2 = MakeRequest<TestEndPoint>("/foo", state);
|
||||
var context1 = MakeRequest("/foo", state);
|
||||
var context2 = MakeRequest("/foo", state);
|
||||
|
||||
SetTransport(context1, transportType);
|
||||
SetTransport(context2, transportType);
|
||||
|
||||
var request1 = dispatcher.ExecuteAsync<TestEndPoint>(context1);
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
var request1 = dispatcher.ExecuteAsync(context1, options, app);
|
||||
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>(context2);
|
||||
await dispatcher.ExecuteAsync(context2, options, app);
|
||||
|
||||
Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode);
|
||||
|
||||
|
|
@ -369,11 +413,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context1 = MakeRequest<TestEndPoint>("/foo", state);
|
||||
var context2 = MakeRequest<TestEndPoint>("/foo", state);
|
||||
var context1 = MakeRequest("/foo", state);
|
||||
var context2 = MakeRequest("/foo", state);
|
||||
|
||||
var request1 = dispatcher.ExecuteAsync<TestEndPoint>(context1);
|
||||
var request2 = dispatcher.ExecuteAsync<TestEndPoint>(context2);
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
var request1 = dispatcher.ExecuteAsync(context1, options, app);
|
||||
var request2 = dispatcher.ExecuteAsync(context2, options, app);
|
||||
|
||||
await request1;
|
||||
|
||||
|
|
@ -398,10 +448,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<TestEndPoint>("/foo", state);
|
||||
var context = MakeRequest("/foo", state);
|
||||
SetTransport(context, transportType);
|
||||
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>(context);
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
await dispatcher.ExecuteAsync(context, options, app);
|
||||
|
||||
|
||||
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
|
||||
}
|
||||
|
|
@ -414,9 +471,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<TestEndPoint>("/foo", state);
|
||||
var context = MakeRequest("/foo", state);
|
||||
|
||||
var task = dispatcher.ExecuteAsync<TestEndPoint>(context);
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
var task = dispatcher.ExecuteAsync(context, options, app);
|
||||
|
||||
var buffer = Encoding.UTF8.GetBytes("Hello World");
|
||||
|
||||
|
|
@ -439,10 +502,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<BlockingEndPoint>("/foo", state);
|
||||
var context = MakeRequest("/foo", state);
|
||||
SetTransport(context, TransportType.ServerSentEvents);
|
||||
|
||||
var task = dispatcher.ExecuteAsync<BlockingEndPoint>(context);
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<BlockingEndPoint>();
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<BlockingEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
var task = dispatcher.ExecuteAsync(context, options, app);
|
||||
|
||||
var buffer = Encoding.UTF8.GetBytes("Hello World");
|
||||
|
||||
|
|
@ -465,9 +534,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<BlockingEndPoint>("/foo", state);
|
||||
var context = MakeRequest("/foo", state);
|
||||
|
||||
var task = dispatcher.ExecuteAsync<BlockingEndPoint>(context);
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<BlockingEndPoint>();
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<BlockingEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
var task = dispatcher.ExecuteAsync(context, options, app);
|
||||
|
||||
var buffer = Encoding.UTF8.GetBytes("Hello World");
|
||||
|
||||
|
|
@ -490,10 +565,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context1 = MakeRequest<TestEndPoint>("/foo", state);
|
||||
var task1 = dispatcher.ExecuteAsync<TestEndPoint>(context1);
|
||||
var context2 = MakeRequest<TestEndPoint>("/foo", state);
|
||||
var task2 = dispatcher.ExecuteAsync<TestEndPoint>(context2);
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
|
||||
var context1 = MakeRequest("/foo", state);
|
||||
var task1 = dispatcher.ExecuteAsync(context1, options, app);
|
||||
var context2 = MakeRequest("/foo", state);
|
||||
var task2 = dispatcher.ExecuteAsync(context2, options, app);
|
||||
|
||||
// Task 1 should finish when request 2 arrives
|
||||
await task1.OrTimeout();
|
||||
|
|
@ -555,19 +637,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<TestEndPoint>(options =>
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddAuthorization(o =>
|
||||
{
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
});
|
||||
services.AddAuthorization(options =>
|
||||
{
|
||||
options.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
|
||||
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
|
||||
});
|
||||
services.AddLogging();
|
||||
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
var sp = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "GET";
|
||||
context.RequestServices = sp;
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
|
|
@ -576,8 +655,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
authFeature.Handler = new TestAuthenticationHandler(context);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
var builder = new SocketBuilder(sp);
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
|
||||
// would hang if EndPoint was running
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>(context).OrTimeout();
|
||||
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
|
||||
|
||||
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
|
||||
}
|
||||
|
|
@ -591,22 +676,19 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<TestEndPoint>(options =>
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddAuthorization(o =>
|
||||
{
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
});
|
||||
services.AddAuthorization(options =>
|
||||
{
|
||||
options.AddPolicy("test", policy =>
|
||||
o.AddPolicy("test", policy =>
|
||||
{
|
||||
policy.RequireClaim(ClaimTypes.NameIdentifier);
|
||||
});
|
||||
});
|
||||
services.AddLogging();
|
||||
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
var sp = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "GET";
|
||||
context.RequestServices = sp;
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
|
|
@ -616,10 +698,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
authFeature.Handler = new TestAuthenticationHandler(context);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
var builder = new SocketBuilder(sp);
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
|
||||
// "authorize" user
|
||||
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
|
||||
|
||||
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>(context);
|
||||
var endPointTask = dispatcher.ExecuteAsync(context, options, app);
|
||||
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
|
||||
|
||||
await endPointTask.OrTimeout();
|
||||
|
|
@ -637,21 +725,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<TestEndPoint>(options =>
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddAuthorization(o =>
|
||||
{
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
options.AuthorizationPolicyNames.Add("secondPolicy");
|
||||
});
|
||||
services.AddAuthorization(options =>
|
||||
{
|
||||
options.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
|
||||
options.AddPolicy("secondPolicy", policy => policy.RequireClaim(ClaimTypes.StreetAddress));
|
||||
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
|
||||
o.AddPolicy("secondPolicy", policy => policy.RequireClaim(ClaimTypes.StreetAddress));
|
||||
});
|
||||
services.AddLogging();
|
||||
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
var sp = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "GET";
|
||||
context.RequestServices = sp;
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
|
|
@ -661,18 +745,25 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
authFeature.Handler = new TestAuthenticationHandler(context);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
var builder = new SocketBuilder(sp);
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
options.AuthorizationPolicyNames.Add("secondPolicy");
|
||||
|
||||
// partialy "authorize" user
|
||||
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
|
||||
|
||||
// would hang if EndPoint was running
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>(context).OrTimeout();
|
||||
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
|
||||
|
||||
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
|
||||
|
||||
// fully "authorize" user
|
||||
context.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") }));
|
||||
|
||||
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>(context);
|
||||
var endPointTask = dispatcher.ExecuteAsync(context, options, app);
|
||||
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
|
||||
|
||||
await endPointTask.OrTimeout();
|
||||
|
|
@ -689,23 +780,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<TestEndPoint>(options =>
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddAuthorization(o =>
|
||||
{
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
});
|
||||
services.AddAuthorization(options =>
|
||||
{
|
||||
options.AddPolicy("test", policy =>
|
||||
o.AddPolicy("test", policy =>
|
||||
{
|
||||
policy.RequireClaim(ClaimTypes.NameIdentifier);
|
||||
policy.AddAuthenticationSchemes("Default");
|
||||
});
|
||||
});
|
||||
services.AddLogging();
|
||||
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
var sp = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "GET";
|
||||
context.RequestServices = sp;
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
|
|
@ -715,10 +803,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
authFeature.Handler = new TestAuthenticationHandler(context);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
var builder = new SocketBuilder(sp);
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
|
||||
// "authorize" user
|
||||
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
|
||||
|
||||
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>(context);
|
||||
var endPointTask = dispatcher.ExecuteAsync(context, options, app);
|
||||
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
|
||||
|
||||
await endPointTask.OrTimeout();
|
||||
|
|
@ -736,23 +830,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<TestEndPoint>(options =>
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddAuthorization(o =>
|
||||
{
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
});
|
||||
services.AddAuthorization(options =>
|
||||
{
|
||||
options.AddPolicy("test", policy =>
|
||||
o.AddPolicy("test", policy =>
|
||||
{
|
||||
policy.RequireClaim(ClaimTypes.NameIdentifier);
|
||||
policy.AddAuthenticationSchemes("Default");
|
||||
});
|
||||
});
|
||||
services.AddLogging();
|
||||
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
var sp = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "GET";
|
||||
context.RequestServices = sp;
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
|
|
@ -762,11 +853,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
authFeature.Handler = new TestAuthenticationHandler(context, acceptScheme: false);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
var builder = new SocketBuilder(sp);
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
|
||||
// "authorize" user
|
||||
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
|
||||
|
||||
// would block if EndPoint was executed
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>(context).OrTimeout();
|
||||
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
|
||||
|
||||
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
|
||||
}
|
||||
|
|
@ -829,20 +926,22 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
context.Response.Body = strm;
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<ImmediatelyCompleteEndPoint>(options =>
|
||||
{
|
||||
options.Transports = supportedTransports;
|
||||
});
|
||||
|
||||
services.AddEndPoint<ImmediatelyCompleteEndPoint>();
|
||||
SetTransport(context, transportType);
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "GET";
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>(context);
|
||||
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<ImmediatelyCompleteEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
options.Transports = supportedTransports;
|
||||
|
||||
await dispatcher.ExecuteAsync(context, options, app);
|
||||
Assert.Equal(status, context.Response.StatusCode);
|
||||
await strm.FlushAsync();
|
||||
|
||||
|
|
@ -861,10 +960,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<TestEndPoint>("/foo", state, format);
|
||||
var context = MakeRequest("/foo", state, format);
|
||||
context.Request.Method = "POST";
|
||||
context.Request.ContentType = contentType;
|
||||
var endPoint = context.RequestServices.GetRequiredService<TestEndPoint>();
|
||||
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
var builder = new SocketBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
|
||||
var buffer = contentType == BinaryContentType ?
|
||||
Convert.FromBase64String(encoded) :
|
||||
|
|
@ -872,7 +976,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var messages = new List<Message>();
|
||||
using (context.Request.Body = new MemoryStream(buffer, writable: false))
|
||||
{
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>(context).OrTimeout();
|
||||
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app).OrTimeout();
|
||||
}
|
||||
|
||||
while (state.Connection.Transport.Input.TryRead(out var message))
|
||||
|
|
@ -883,18 +987,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
return messages;
|
||||
}
|
||||
|
||||
private static DefaultHttpContext MakeRequest<TEndPoint>(string path, ConnectionState state, string format = null) where TEndPoint : EndPoint
|
||||
private static DefaultHttpContext MakeRequest(string path, ConnectionState state, string format = null)
|
||||
{
|
||||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddEndPoint<TEndPoint>(o =>
|
||||
{
|
||||
// Make the close timeout less than the default for OrTimeout() test helper
|
||||
o.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1);
|
||||
});
|
||||
|
||||
services.AddOptions();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = path;
|
||||
context.Request.Method = "GET";
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
|
|
@ -942,7 +1037,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
public class NerverEndingEndPoint : EndPoint
|
||||
{
|
||||
public override Task OnConnectedAsync(Connection connection)
|
||||
public override Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
var tcs = new TaskCompletionSource<object>();
|
||||
return tcs.Task;
|
||||
|
|
@ -951,7 +1046,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
public class BlockingEndPoint : EndPoint
|
||||
{
|
||||
public override Task OnConnectedAsync(Connection connection)
|
||||
public override Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
connection.Transport.Input.WaitToReadAsync().Wait();
|
||||
return Task.CompletedTask;
|
||||
|
|
@ -960,7 +1055,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
public class SynchronusExceptionEndPoint : EndPoint
|
||||
{
|
||||
public override Task OnConnectedAsync(Connection connection)
|
||||
public override Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
throw new InvalidOperationException();
|
||||
}
|
||||
|
|
@ -968,7 +1063,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
public class ImmediatelyCompleteEndPoint : EndPoint
|
||||
{
|
||||
public override Task OnConnectedAsync(Connection connection)
|
||||
public override Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
|
@ -976,7 +1071,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
|
||||
public class TestEndPoint : EndPoint
|
||||
{
|
||||
public override async Task OnConnectedAsync(Connection connection)
|
||||
public override async Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
while (await connection.Transport.Input.WaitToReadAsync())
|
||||
{
|
||||
|
|
|
|||
Loading…
Reference in New Issue