Progress towards splitting the layers (#473)

* Progress towards splitting the layers
- This is based on the work anurse did in anurse/endpoint-middleware-spike to
introduce a connection middleware pipeline that mimics much of our http
pipeline. The intent is that this layer will be generic enough to build both
SignalR and Kestrel on top of but we're not there yet. This change makes incremental
progress towards splitting apart sockets and http so that we can add the tcp transport
without breaking everything all at once.
- Created Microsoft.AspNetCore.Sockets.Abstractions where the primitives for
sockets live. That includes, ConnectionContext (formerly Connection), EndPoint,
ISocketBuilder, SocketDelegate, etc.
- ConnectionContext isn't in it's final form as yet, it still very closely mirrors
the original Connection object we had so that tests continue to pass.
- The HttpConnectionDispatcher doesn't know about EndPoint anymore, it just cares
about invoking the SocketDelegate.
- EndPointOptions has been removed as part of this change as it coupled http specific configuration
to the end point type. There's a new HttpSocketOptions that needs to be passed into MapSocket calls.
- Updated the tests to deal with the API changes.
This commit is contained in:
David Fowler 2017-05-23 02:43:32 -07:00 committed by GitHub
parent e68a1b294f
commit 9d9a52119e
45 changed files with 535 additions and 263 deletions

View File

@ -1,6 +1,6 @@
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15
VisualStudioVersion = 15.0.26411.1
VisualStudioVersion = 15.0.26510.0
MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}"
EndProject
@ -76,6 +76,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.Signal
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.SignalR.Common.Tests", "test\Microsoft.AspNetCore.SignalR.Common.Tests\Microsoft.AspNetCore.SignalR.Common.Tests.csproj", "{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.Sockets.Abstractions", "src\Microsoft.AspNetCore.Sockets.Abstractions\Microsoft.AspNetCore.Sockets.Abstractions.csproj", "{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@ -186,6 +188,10 @@ Global
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E}.Release|Any CPU.Build.0 = Release|Any CPU
{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Debug|Any CPU.Build.0 = Debug|Any CPU
{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Release|Any CPU.ActiveCfg = Release|Any CPU
{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@ -218,5 +224,6 @@ Global
{6CEC3DC2-5B01-45A8-8F0D-8531315DA90B} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
{96771B3F-4D18-41A7-A75B-FF38E76AAC89} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
{75E342F6-5445-4E7E-9143-6D9AE62C2B1E} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
{F2E4FBD6-9AEA-4A82-BAC9-3FAACA677DF8} = {DA69F624-5398-4884-87E4-B816698CDE65}
EndGlobalSection
EndGlobal

View File

@ -10,7 +10,7 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server
{
public class EchoEndPoint : EndPoint
{
public async override Task OnConnectedAsync(Connection connection)
public async override Task OnConnectedAsync(ConnectionContext connection)
{
await connection.Transport.Output.WriteAsync(await connection.Transport.Input.ReadAsync());
}

View File

@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server
}
app.UseFileServer();
app.UseSockets(options => options.MapEndpoint<EchoEndPoint>("echo"));
app.UseSockets(options => options.MapEndPoint<EchoEndPoint>("echo"));
app.UseSignalR(routes =>
{
routes.MapHub<TestHub>("testhub");

View File

@ -11,8 +11,8 @@ namespace ChatSample
public interface IUserTracker<out THub>
{
Task<IEnumerable<UserDetails>> UsersOnline();
Task AddUser(Connection connection, UserDetails userDetails);
Task RemoveUser(Connection connection);
Task AddUser(ConnectionContext connection, UserDetails userDetails);
Task RemoveUser(ConnectionContext connection);
event Action<UserDetails[]> UsersJoined;
event Action<UserDetails[]> UsersLeft;

View File

@ -9,8 +9,8 @@ namespace ChatSample
{
public class InMemoryUserTracker<THub> : IUserTracker<THub>
{
private readonly ConcurrentDictionary<Connection, UserDetails> _usersOnline
= new ConcurrentDictionary<Connection, UserDetails>();
private readonly ConcurrentDictionary<ConnectionContext, UserDetails> _usersOnline
= new ConcurrentDictionary<ConnectionContext, UserDetails>();
public event Action<UserDetails[]> UsersJoined;
public event Action<UserDetails[]> UsersLeft;
@ -18,7 +18,7 @@ namespace ChatSample
public Task<IEnumerable<UserDetails>> UsersOnline()
=> Task.FromResult(_usersOnline.Values.AsEnumerable());
public Task AddUser(Connection connection, UserDetails userDetails)
public Task AddUser(ConnectionContext connection, UserDetails userDetails)
{
_usersOnline.TryAdd(connection, userDetails);
UsersJoined(new[] { userDetails });
@ -26,7 +26,7 @@ namespace ChatSample
return Task.CompletedTask;
}
public Task RemoveUser(Connection connection)
public Task RemoveUser(ConnectionContext connection)
{
if (_usersOnline.TryRemove(connection, out var userDetails))
{

View File

@ -57,14 +57,14 @@ namespace ChatSample
_wrappedHubLifetimeManager = serviceProvider.GetRequiredService<THubLifetimeManager>();
}
public override async Task OnConnectedAsync(Connection connection)
public override async Task OnConnectedAsync(ConnectionContext connection)
{
await _wrappedHubLifetimeManager.OnConnectedAsync(connection);
_connections.Add(connection);
await _userTracker.AddUser(connection, new UserDetails(connection.ConnectionId, connection.User.Identity.Name));
}
public override async Task OnDisconnectedAsync(Connection connection)
public override async Task OnDisconnectedAsync(ConnectionContext connection)
{
await _wrappedHubLifetimeManager.OnDisconnectedAsync(connection);
_connections.Remove(connection);
@ -157,12 +157,12 @@ namespace ChatSample
return _wrappedHubLifetimeManager.InvokeUserAsync(userId, methodName, args);
}
public override Task AddGroupAsync(Connection connection, string groupName)
public override Task AddGroupAsync(ConnectionContext connection, string groupName)
{
return _wrappedHubLifetimeManager.AddGroupAsync(connection, groupName);
}
public override Task RemoveGroupAsync(Connection connection, string groupName)
public override Task RemoveGroupAsync(ConnectionContext connection, string groupName)
{
return _wrappedHubLifetimeManager.RemoveGroupAsync(connection, groupName);
}

View File

@ -129,7 +129,7 @@ namespace ChatSample
}
}
public async Task AddUser(Connection connection, UserDetails userDetails)
public async Task AddUser(ConnectionContext connection, UserDetails userDetails)
{
var key = GetUserRedisKey(connection);
var user = SerializeUser(connection);
@ -156,7 +156,7 @@ namespace ChatSample
}
}
public async Task RemoveUser(Connection connection)
public async Task RemoveUser(ConnectionContext connection)
{
await _userSyncSempaphore.WaitAsync();
try
@ -180,7 +180,7 @@ namespace ChatSample
}
}
private static string GetUserRedisKey(Connection connection) => $"user:{connection.ConnectionId}";
private static string GetUserRedisKey(ConnectionContext connection) => $"user:{connection.ConnectionId}";
private static void Scan(object state)
{
@ -319,7 +319,7 @@ namespace ChatSample
}
}
private static string SerializeUser(Connection connection) =>
private static string SerializeUser(ConnectionContext connection) =>
$"{{ \"ConnectionID\": \"{connection.ConnectionId}\", \"Name\": \"{connection.User.Identity.Name}\" }}";
private static UserDetails DeserializerUser(string userJson) =>

View File

@ -20,13 +20,13 @@ namespace SocialWeather
_formatterResolver = formatterResolver;
}
public void OnConnectedAsync(Connection connection)
public void OnConnectedAsync(ConnectionContext connection)
{
connection.Metadata[ConnectionMetadataNames.Format] = "json";
_connectionList.Add(connection);
}
public void OnDisconnectedAsync(Connection connection)
public void OnDisconnectedAsync(ConnectionContext connection)
{
_connectionList.Remove(connection);
}
@ -64,7 +64,7 @@ namespace SocialWeather
throw new NotImplementedException();
}
public void AddGroupAsync(Connection connection, string groupName)
public void AddGroupAsync(ConnectionContext connection, string groupName)
{
var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>());
lock (groups)
@ -73,7 +73,7 @@ namespace SocialWeather
}
}
public void RemoveGroupAsync(Connection connection, string groupName)
public void RemoveGroupAsync(ConnectionContext connection, string groupName)
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
if (groups != null)

View File

@ -23,14 +23,14 @@ namespace SocialWeather
_logger = logger;
}
public async override Task OnConnectedAsync(Connection connection)
public async override Task OnConnectedAsync(ConnectionContext connection)
{
_lifetimeManager.OnConnectedAsync(connection);
await ProcessRequests(connection);
_lifetimeManager.OnDisconnectedAsync(connection);
}
public async Task ProcessRequests(Connection connection)
public async Task ProcessRequests(ConnectionContext connection)
{
var formatter = _formatterResolver.GetFormatter<WeatherReport>(
connection.Metadata.Get<string>("formatType"));

View File

@ -32,7 +32,7 @@ namespace SocialWeather
app.UseDeveloperExceptionPage();
}
app.UseSockets(o => { o.MapEndpoint<SocialWeatherEndPoint>("weather"); });
app.UseSockets(o => { o.MapEndPoint<SocialWeatherEndPoint>("weather"); });
app.UseFileServer();
var formatterResolver = app.ApplicationServices.GetRequiredService<FormatterResolver>();

View File

@ -13,7 +13,7 @@ namespace SocketsSample.EndPoints
{
public ConnectionList Connections { get; } = new ConnectionList();
public override async Task OnConnectedAsync(Connection connection)
public override async Task OnConnectedAsync(ConnectionContext connection)
{
Connections.Add(connection);

View File

@ -53,7 +53,7 @@ namespace SocketsSample
app.UseSockets(routes =>
{
routes.MapEndpoint<MessagesEndPoint>("chat");
routes.MapEndPoint<MessagesEndPoint>("chat");
});
}
}

View File

@ -128,7 +128,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
await _bus.PublishAsync(channel, payload);
}
public override Task OnConnectedAsync(Connection connection)
public override Task OnConnectedAsync(ConnectionContext connection)
{
var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet<string>());
var connectionTask = Task.CompletedTask;
@ -173,7 +173,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return Task.WhenAll(connectionTask, userTask);
}
public override Task OnDisconnectedAsync(Connection connection)
public override Task OnDisconnectedAsync(ConnectionContext connection)
{
_connections.Remove(connection);
@ -204,7 +204,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return Task.WhenAll(tasks);
}
public override async Task AddGroupAsync(Connection connection, string groupName)
public override async Task AddGroupAsync(ConnectionContext connection, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
@ -255,7 +255,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
public override async Task RemoveGroupAsync(Connection connection, string groupName)
public override async Task RemoveGroupAsync(ConnectionContext connection, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
@ -297,7 +297,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
_redisServerConnection.Dispose();
}
private async Task WriteAsync(Connection connection, HubMessage hubMessage)
private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage)
{
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
var data = await protocol.WriteToArrayAsync(hubMessage);

View File

@ -15,7 +15,7 @@ namespace Microsoft.AspNetCore.SignalR
private long _nextInvocationId = 0;
private readonly ConnectionList _connections = new ConnectionList();
public override Task AddGroupAsync(Connection connection, string groupName)
public override Task AddGroupAsync(ConnectionContext connection, string groupName)
{
var groups = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet<string>());
@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.SignalR
return Task.CompletedTask;
}
public override Task RemoveGroupAsync(Connection connection, string groupName)
public override Task RemoveGroupAsync(ConnectionContext connection, string groupName)
{
var groups = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
@ -49,7 +49,7 @@ namespace Microsoft.AspNetCore.SignalR
return InvokeAllWhere(methodName, args, c => true);
}
private Task InvokeAllWhere(string methodName, object[] args, Func<Connection, bool> include)
private Task InvokeAllWhere(string methodName, object[] args, Func<ConnectionContext, bool> include)
{
var tasks = new List<Task>(_connections.Count);
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
@ -94,19 +94,19 @@ namespace Microsoft.AspNetCore.SignalR
});
}
public override Task OnConnectedAsync(Connection connection)
public override Task OnConnectedAsync(ConnectionContext connection)
{
_connections.Add(connection);
return Task.CompletedTask;
}
public override Task OnDisconnectedAsync(Connection connection)
public override Task OnDisconnectedAsync(ConnectionContext connection)
{
_connections.Remove(connection);
return Task.CompletedTask;
}
private async Task WriteAsync(Connection connection, HubMessage hubMessage)
private async Task WriteAsync(ConnectionContext connection, HubMessage hubMessage)
{
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
var payload = await protocol.WriteToArrayAsync(hubMessage);

View File

@ -8,12 +8,12 @@ namespace Microsoft.AspNetCore.SignalR
{
public class HubCallerContext
{
public HubCallerContext(Connection connection)
public HubCallerContext(ConnectionContext connection)
{
Connection = connection;
}
public Connection Connection { get; }
public ConnectionContext Connection { get; }
public ClaimsPrincipal User => Connection.User;

View File

@ -22,10 +22,9 @@ namespace Microsoft.AspNetCore.SignalR
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
IHubProtocolResolver protocolResolver,
IHubContext<THub> hubContext,
IOptions<EndPointOptions<HubEndPoint<THub, IClientProxy>>> endPointOptions,
ILogger<HubEndPoint<THub>> logger,
IServiceScopeFactory serviceScopeFactory)
: base(lifetimeManager, protocolResolver, hubContext, endPointOptions, logger, serviceScopeFactory)
: base(lifetimeManager, protocolResolver, hubContext, logger, serviceScopeFactory)
{
}
}
@ -43,7 +42,6 @@ namespace Microsoft.AspNetCore.SignalR
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
IHubProtocolResolver protocolResolver,
IHubContext<THub, TClient> hubContext,
IOptions<EndPointOptions<HubEndPoint<THub, TClient>>> endPointOptions,
ILogger<HubEndPoint<THub, TClient>> logger,
IServiceScopeFactory serviceScopeFactory)
{
@ -56,7 +54,7 @@ namespace Microsoft.AspNetCore.SignalR
DiscoverHubMethods();
}
public override async Task OnConnectedAsync(Connection connection)
public override async Task OnConnectedAsync(ConnectionContext connection)
{
try
{
@ -74,7 +72,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task RunHubAsync(Connection connection)
private async Task RunHubAsync(ConnectionContext connection)
{
await HubOnConnectedAsync(connection);
@ -92,7 +90,7 @@ namespace Microsoft.AspNetCore.SignalR
await HubOnDisconnectedAsync(connection, null);
}
private async Task HubOnConnectedAsync(Connection connection)
private async Task HubOnConnectedAsync(ConnectionContext connection)
{
try
{
@ -118,7 +116,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task HubOnDisconnectedAsync(Connection connection, Exception exception)
private async Task HubOnDisconnectedAsync(ConnectionContext connection, Exception exception)
{
try
{
@ -144,7 +142,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task DispatchMessagesAsync(Connection connection)
private async Task DispatchMessagesAsync(ConnectionContext connection)
{
// We use these for error handling. Since we dispatch multiple hub invocations
// in parallel, we need a way to communicate failure back to the main processing loop. The
@ -190,7 +188,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task ProcessInvocation(Connection connection,
private async Task ProcessInvocation(ConnectionContext connection,
IHubProtocol protocol,
InvocationMessage invocationMessage,
CancellationTokenSource dispatcherCancellation,
@ -212,7 +210,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task Execute(Connection connection, IHubProtocol protocol, InvocationMessage invocationMessage)
private async Task Execute(ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage)
{
HubMethodDescriptor descriptor;
if (!_methods.TryGetValue(invocationMessage.Target, out descriptor))
@ -228,7 +226,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task SendMessageAsync(Connection connection, IHubProtocol protocol, HubMessage hubMessage)
private async Task SendMessageAsync(ConnectionContext connection, IHubProtocol protocol, HubMessage hubMessage)
{
var payload = await protocol.WriteToArrayAsync(hubMessage);
var message = new Message(payload, protocol.MessageType, endOfMessage: true);
@ -246,7 +244,7 @@ namespace Microsoft.AspNetCore.SignalR
throw new OperationCanceledException("Outbound channel was closed while trying to write hub message");
}
private async Task<CompletionMessage> Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationMessage invocationMessage)
private async Task<CompletionMessage> Invoke(HubMethodDescriptor descriptor, ConnectionContext connection, InvocationMessage invocationMessage)
{
var methodExecutor = descriptor.MethodExecutor;
@ -295,7 +293,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private void InitializeHub(THub hub, Connection connection)
private void InitializeHub(THub hub, ConnectionContext connection)
{
hub.Clients = _hubContext.Clients;
hub.Context = new HubCallerContext(connection);

View File

@ -8,9 +8,9 @@ namespace Microsoft.AspNetCore.SignalR
{
public abstract class HubLifetimeManager<THub>
{
public abstract Task OnConnectedAsync(Connection connection);
public abstract Task OnConnectedAsync(ConnectionContext connection);
public abstract Task OnDisconnectedAsync(Connection connection);
public abstract Task OnDisconnectedAsync(ConnectionContext connection);
public abstract Task InvokeAllAsync(string methodName, object[] args);
@ -20,9 +20,9 @@ namespace Microsoft.AspNetCore.SignalR
public abstract Task InvokeUserAsync(string userId, string methodName, object[] args);
public abstract Task AddGroupAsync(Connection connection, string groupName);
public abstract Task AddGroupAsync(ConnectionContext connection, string groupName);
public abstract Task RemoveGroupAsync(Connection connection, string groupName);
public abstract Task RemoveGroupAsync(ConnectionContext connection, string groupName);
}
}

View File

@ -9,7 +9,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
public class DefaultHubProtocolResolver : IHubProtocolResolver
{
public IHubProtocol GetProtocol(Connection connection)
public IHubProtocol GetProtocol(ConnectionContext connection)
{
// TODO: Allow customization of this serializer!
return new JsonHubProtocol(new JsonSerializer());

View File

@ -8,6 +8,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
public interface IHubProtocolResolver
{
IHubProtocol GetProtocol(Connection connection);
IHubProtocol GetProtocol(ConnectionContext connection);
}
}

View File

@ -75,10 +75,10 @@ namespace Microsoft.AspNetCore.SignalR
public class GroupManager<THub> : IGroupManager
{
private readonly Connection _connection;
private readonly ConnectionContext _connection;
private readonly HubLifetimeManager<THub> _lifetimeManager;
public GroupManager(Connection connection, HubLifetimeManager<THub> lifetimeManager)
public GroupManager(ConnectionContext connection, HubLifetimeManager<THub> lifetimeManager)
{
_connection = connection;
_lifetimeManager = lifetimeManager;

View File

@ -31,7 +31,7 @@ namespace Microsoft.AspNetCore.Builder
public void MapHub<THub>(string path) where THub : Hub<IClientProxy>
{
_routes.MapEndpoint<HubEndPoint<THub>>(path);
_routes.MapEndPoint<HubEndPoint<THub>>(path);
}
}
}

View File

@ -0,0 +1,33 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Security.Claims;
using System.Text;
using System.Threading;
using Microsoft.AspNetCore.Http.Features;
namespace Microsoft.AspNetCore.Sockets
{
public abstract class ConnectionContext : IDisposable
{
public abstract string ConnectionId { get; }
public abstract IFeatureCollection Features { get; }
public abstract ClaimsPrincipal User { get; set; }
// REVIEW: Should this be changed to items
public abstract ConnectionMetadata Metadata { get; }
// TEMPORARY
public abstract IChannelConnection<Message> Transport { get; set; }
// TEMPORARY
public void Dispose()
{
Transport?.Dispose();
}
}
}

View File

@ -0,0 +1,28 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Security.Claims;
using Microsoft.AspNetCore.Http.Features;
namespace Microsoft.AspNetCore.Sockets
{
public class DefaultConnectionContext : ConnectionContext
{
public DefaultConnectionContext(string id, IChannelConnection<Message> transport)
{
Transport = transport;
ConnectionId = id;
}
public override string ConnectionId { get; }
public override IFeatureCollection Features { get; } = new FeatureCollection();
public override ClaimsPrincipal User { get; set; }
public override ConnectionMetadata Metadata { get; } = new ConnectionMetadata();
public override IChannelConnection<Message> Transport { get; set; }
}
}

View File

@ -14,8 +14,8 @@ namespace Microsoft.AspNetCore.Sockets
/// <summary>
/// Called when a new connection is accepted to the endpoint
/// </summary>
/// <param name="connection">The new <see cref="Connection"/></param>
/// <param name="connection">The new <see cref="ConnectionContext"/></param>
/// <returns>A <see cref="Task"/> that represents the connection lifetime. When the task completes, the connection is complete.</returns>
public abstract Task OnConnectedAsync(Connection connection);
public abstract Task OnConnectedAsync(ConnectionContext connection);
}
}

View File

@ -0,0 +1,18 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
namespace Microsoft.AspNetCore.Sockets
{
public interface ISocketBuilder
{
IServiceProvider ApplicationServices { get; }
ISocketBuilder Use(Func<SocketDelegate, SocketDelegate> middleware);
SocketDelegate Build();
}
}

View File

@ -0,0 +1,23 @@
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="..\..\build\common.props" />
<PropertyGroup>
<Description>Components for providing real-time bi-directional communication across the Web.</Description>
<TargetFramework>netcoreapp2.0</TargetFramework>
<NoWarn>$(NoWarn);CS1591</NoWarn>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<PackageTags>aspnetcore;signalr</PackageTags>
<EnableApiCheck>false</EnableApiCheck>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Http.Features" Version="$(AspNetCoreVersion)" />
<PackageReference Include="System.Threading.Tasks.Channels" Version="$(CoreFxLabsVersion)" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common\Microsoft.AspNetCore.Sockets.Common.csproj" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,44 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets
{
public class SocketBuilder : ISocketBuilder
{
private readonly IList<Func<SocketDelegate, SocketDelegate>> _components = new List<Func<SocketDelegate, SocketDelegate>>();
public IServiceProvider ApplicationServices { get; }
public SocketBuilder(IServiceProvider applicationServices)
{
ApplicationServices = applicationServices;
}
public ISocketBuilder Use(Func<SocketDelegate, SocketDelegate> middleware)
{
_components.Add(middleware);
return this;
}
public SocketDelegate Build()
{
SocketDelegate app = features =>
{
return Task.CompletedTask;
};
foreach (var component in _components.Reverse())
{
app = component(app);
}
return app;
}
}
}

View File

@ -0,0 +1,23 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets
{
public static class SocketBuilderExtensions
{
public static ISocketBuilder Use(this ISocketBuilder socketBuilder, Func<ConnectionContext, Func<Task>, Task> middleware)
{
return socketBuilder.Use(next =>
{
return context =>
{
Func<Task> simpleNext = () => next(context);
return middleware(context, simpleNext);
};
});
}
}
}

View File

@ -0,0 +1,12 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets
{
public delegate Task SocketDelegate(ConnectionContext connection);
}

View File

@ -1,29 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Security.Claims;
namespace Microsoft.AspNetCore.Sockets
{
public class Connection : IDisposable
{
public string ConnectionId { get; }
public ClaimsPrincipal User { get; set; }
public ConnectionMetadata Metadata { get; } = new ConnectionMetadata();
public IChannelConnection<Message> Transport { get; }
public Connection(string id, IChannelConnection<Message> transport)
{
Transport = transport;
ConnectionId = id;
}
public void Dispose()
{
Transport.Dispose();
}
}
}

View File

@ -8,15 +8,15 @@ using System.Collections.Generic;
namespace Microsoft.AspNetCore.Sockets
{
public class ConnectionList : IReadOnlyCollection<Connection>
public class ConnectionList : IReadOnlyCollection<ConnectionContext>
{
private readonly ConcurrentDictionary<string, Connection> _connections = new ConcurrentDictionary<string, Connection>();
private readonly ConcurrentDictionary<string, ConnectionContext> _connections = new ConcurrentDictionary<string, ConnectionContext>();
public Connection this[string connectionId]
public ConnectionContext this[string connectionId]
{
get
{
Connection connection;
ConnectionContext connection;
if (_connections.TryGetValue(connectionId, out connection))
{
return connection;
@ -27,18 +27,18 @@ namespace Microsoft.AspNetCore.Sockets
public int Count => _connections.Count;
public void Add(Connection connection)
public void Add(ConnectionContext connection)
{
_connections.TryAdd(connection.ConnectionId, connection);
}
public void Remove(Connection connection)
public void Remove(ConnectionContext connection)
{
Connection dummy;
ConnectionContext dummy;
_connections.TryRemove(connection.ConnectionId, out dummy);
}
public IEnumerator<Connection> GetEnumerator()
public IEnumerator<ConnectionContext> GetEnumerator()
{
foreach (var item in _connections)
{

View File

@ -57,7 +57,7 @@ namespace Microsoft.AspNetCore.Sockets
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
var state = new ConnectionState(
new Connection(id, applicationSide),
new DefaultConnectionContext(id, applicationSide),
transportSide);
_connections.TryAdd(id, state);

View File

@ -14,15 +14,5 @@ namespace Microsoft.Extensions.DependencyInjection
return services;
}
public static IServiceCollection AddEndPoint<TEndPoint>(this IServiceCollection services,
Action<EndPointOptions<TEndPoint>> setupAction) where TEndPoint : EndPoint
{
services.AddEndPoint<TEndPoint>();
services.Configure(setupAction);
return services;
}
}
}

View File

@ -14,9 +14,7 @@ using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
using Microsoft.AspNetCore.Sockets.Transports;
using Microsoft.AspNetCore.WebSockets.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
namespace Microsoft.AspNetCore.Sockets
@ -34,10 +32,8 @@ namespace Microsoft.AspNetCore.Sockets
_logger = _loggerFactory.CreateLogger<HttpConnectionDispatcher>();
}
public async Task ExecuteAsync<TEndPoint>(HttpContext context) where TEndPoint : EndPoint
public async Task ExecuteAsync(HttpContext context, HttpSocketOptions options, SocketDelegate socketDelegate)
{
var options = context.RequestServices.GetRequiredService<IOptions<EndPointOptions<TEndPoint>>>().Value;
// TODO: Authorize attribute on EndPoint
if (!await AuthorizeHelper.AuthorizeAsync(context, options.AuthorizationPolicyNames))
{
return;
@ -56,10 +52,7 @@ namespace Microsoft.AspNetCore.Sockets
else if (HttpMethods.IsGet(context.Request.Method))
{
// GET /{path}
// Get the end point mapped to this http connection
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
await ExecuteEndpointAsync(context, endpoint, options);
await ExecuteEndpointAsync(context, socketDelegate, options);
}
else
{
@ -67,7 +60,7 @@ namespace Microsoft.AspNetCore.Sockets
}
}
private async Task ExecuteEndpointAsync<TEndPoint>(HttpContext context, EndPoint endpoint, EndPointOptions<TEndPoint> options) where TEndPoint : EndPoint
private async Task ExecuteEndpointAsync(HttpContext context, SocketDelegate socketDelegate, HttpSocketOptions options)
{
var supportedTransports = options.Transports;
@ -94,7 +87,7 @@ namespace Microsoft.AspNetCore.Sockets
// We only need to provide the Input channel since writing to the application is handled through /send.
var sse = new ServerSentEventsTransport(state.Application.Input, _loggerFactory);
await DoPersistentConnection(endpoint, sse, context, state);
await DoPersistentConnection(socketDelegate, sse, context, state);
}
else if (context.Features.Get<IHttpWebSocketConnectionFeature>()?.IsWebSocketRequest == true)
{
@ -114,7 +107,7 @@ namespace Microsoft.AspNetCore.Sockets
var ws = new WebSocketsTransport(options.WebSockets, state.Application, _loggerFactory);
await DoPersistentConnection(endpoint, ws, context, state);
await DoPersistentConnection(socketDelegate, ws, context, state);
}
else
{
@ -183,7 +176,7 @@ namespace Microsoft.AspNetCore.Sockets
state.Connection.Metadata[ConnectionMetadataNames.Transport] = TransportType.LongPolling;
state.ApplicationTask = ExecuteApplication(endpoint, state.Connection);
state.ApplicationTask = ExecuteApplication(socketDelegate, state.Connection);
}
else
{
@ -275,7 +268,7 @@ namespace Microsoft.AspNetCore.Sockets
return state;
}
private async Task DoPersistentConnection(EndPoint endpoint,
private async Task DoPersistentConnection(SocketDelegate socketDelegate,
IHttpTransport transport,
HttpContext context,
ConnectionState state)
@ -310,7 +303,7 @@ namespace Microsoft.AspNetCore.Sockets
state.RequestId = context.TraceIdentifier;
// Call into the end point passing the connection
state.ApplicationTask = ExecuteApplication(endpoint, state.Connection);
state.ApplicationTask = ExecuteApplication(socketDelegate, state.Connection);
// Start the transport
state.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted);
@ -326,17 +319,17 @@ namespace Microsoft.AspNetCore.Sockets
await _manager.DisposeAndRemoveAsync(state);
}
private async Task ExecuteApplication(EndPoint endpoint, Connection connection)
private async Task ExecuteApplication(SocketDelegate socketDelegate, ConnectionContext connection)
{
// Jump onto the thread pool thread so blocking user code doesn't block the setup of the
// connection and transport
await AwaitableThreadPool.Yield();
// Running this in an async method turns sync exceptions into async ones
await endpoint.OnConnectedAsync(connection);
await socketDelegate(connection);
}
private Task ProcessNegotiate<TEndPoint>(HttpContext context, EndPointOptions<TEndPoint> options) where TEndPoint : EndPoint
private Task ProcessNegotiate(HttpContext context, HttpSocketOptions options)
{
// Set the allowed headers for this resource
context.Response.Headers.AppendCommaSeparatedValues("Allow", "GET", "POST", "OPTIONS");

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
@ -36,9 +35,23 @@ namespace Microsoft.AspNetCore.Builder
_dispatcher = dispatcher;
}
public void MapEndpoint<TEndPoint>(string path) where TEndPoint : EndPoint
public void MapSocket(string path, Action<ISocketBuilder> socketConfig) =>
MapSocket(path, new HttpSocketOptions(), socketConfig);
public void MapSocket(string path, HttpSocketOptions options, Action<ISocketBuilder> socketConfig)
{
_routes.MapRoute(path, _dispatcher.ExecuteAsync<TEndPoint>);
var socketBuilder = new SocketBuilder(_routes.ServiceProvider);
socketConfig(socketBuilder);
var socket = socketBuilder.Build();
_routes.MapRoute(path, c => _dispatcher.ExecuteAsync(c, options, socket));
}
public void MapEndPoint<TEndPoint>(string path) where TEndPoint : EndPoint
{
MapSocket(path, builder =>
{
builder.UseEndPoint<TEndPoint>();
});
}
}
}

View File

@ -0,0 +1,23 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.Sockets
{
public static class HttpSocketBuilderExtensions
{
public static ISocketBuilder UseEndPoint<TEndPoint>(this ISocketBuilder socketBuilder) where TEndPoint : EndPoint
{
// This is a terminal middleware, so there's no need to use the 'next' parameter
return socketBuilder.Use((connection, _) =>
{
var endpoint = socketBuilder.ApplicationServices.GetRequiredService<TEndPoint>();
return endpoint.OnConnectedAsync(connection);
});
}
}
}

View File

@ -5,7 +5,7 @@ using System.Collections.Generic;
namespace Microsoft.AspNetCore.Sockets
{
public class EndPointOptions<TEndPoint> where TEndPoint : EndPoint
public class HttpSocketOptions
{
public IList<string> AuthorizationPolicyNames { get; } = new List<string>();

View File

@ -14,7 +14,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
// on the same task
private TaskCompletionSource<object> _disposeTcs = new TaskCompletionSource<object>();
public Connection Connection { get; set; }
public ConnectionContext Connection { get; set; }
public IChannelConnection<Message> Application { get; }
public CancellationTokenSource Cancellation { get; set; }
@ -29,7 +29,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
public DateTime LastSeenUtc { get; set; }
public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive;
public ConnectionState(Connection connection, IChannelConnection<Message> application)
public ConnectionState(ConnectionContext connection, IChannelConnection<Message> application)
{
Connection = connection;
Application = application;

View File

@ -12,6 +12,7 @@
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Abstractions\Microsoft.AspNetCore.Sockets.Abstractions.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common\Microsoft.AspNetCore.Sockets.Common.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.WebSockets.Internal\Microsoft.AspNetCore.WebSockets.Internal.csproj" />
<PackageReference Include="Microsoft.AspNetCore.Authorization" Version="$(AspNetCoreVersion)" />

View File

@ -8,7 +8,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
public class EchoEndPoint : EndPoint
{
public async override Task OnConnectedAsync(Connection connection)
public async override Task OnConnectedAsync(ConnectionContext connection)
{
await connection.Transport.Output.WriteAsync(await connection.Transport.Input.ReadAsync());
}

View File

@ -39,7 +39,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var mockLifetimeManager = new Mock<HubLifetimeManager<Hub>>();
mockLifetimeManager
.Setup(m => m.OnConnectedAsync(It.IsAny<Connection>()))
.Setup(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()))
.Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed."));
var mockHubActivator = new Mock<IHubActivator<Hub, IClientProxy>>();
@ -60,8 +60,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
client.Dispose();
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
// No hubs should be created since the connection is terminated
mockHubActivator.Verify(m => m.Create(), Times.Never);
mockHubActivator.Verify(m => m.Release(It.IsAny<Hub>()), Times.Never);
@ -87,8 +87,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
Assert.Equal("Hub OnConnected failed.", exception.Message);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
}
}
@ -111,8 +111,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
Assert.Equal("Hub OnDisconnected failed.", exception.Message);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<ConnectionContext>()), Times.Once);
}
}

View File

@ -46,7 +46,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public void Configure(IApplicationBuilder app, IHostingEnvironment env)
{
app.UseSockets(options => options.MapEndpoint<EchoEndPoint>("echo"));
app.UseSockets(options => options.MapEndPoint<EchoEndPoint>("echo"));
}
}

View File

@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
private IHubProtocol _protocol;
private CancellationTokenSource _cts;
public Connection Connection;
public ConnectionContext Connection;
public IChannelConnection<Message> Application { get; }
public Task Connected => Connection.Metadata.Get<TaskCompletionSource<bool>>("ConnectedTask").Task;
@ -33,7 +33,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Application = ChannelConnection.Create<Message>(input: applicationToTransport, output: transportToApplication);
var transport = ChannelConnection.Create<Message>(input: transportToApplication, output: applicationToTransport);
Connection = new Connection(Guid.NewGuid().ToString(), transport);
Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), transport);
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) }));
Connection.Metadata["ConnectedTask"] = new TaskCompletionSource<bool>();

View File

@ -37,12 +37,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
services.AddOptions();
context.RequestServices = services.BuildServiceProvider();
var ms = new MemoryStream();
context.Request.Path = "/foo";
context.Request.Method = "OPTIONS";
context.Response.Body = ms;
await dispatcher.ExecuteAsync<TestEndPoint>(context);
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
var id = Encoding.UTF8.GetString(ms.ToArray());
@ -68,7 +70,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
services.AddOptions();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
@ -77,7 +78,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Request.Query = qs;
SetTransport(context, transportType);
await dispatcher.ExecuteAsync<TestEndPoint>(context);
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
await strm.FlushAsync();
@ -100,7 +104,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
services.AddOptions();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "POST";
var values = new Dictionary<string, StringValues>();
@ -108,7 +111,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var qs = new QueryCollection(values);
context.Request.Query = qs;
await dispatcher.ExecuteAsync<TestEndPoint>(context);
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
await strm.FlushAsync();
@ -130,13 +136,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
SetTransport(context, transportType);
await dispatcher.ExecuteAsync<TestEndPoint>(context);
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await strm.FlushAsync();
@ -156,11 +164,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "POST";
await dispatcher.ExecuteAsync<TestEndPoint>(context);
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await strm.FlushAsync();
@ -180,14 +190,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "POST";
context.Request.QueryString = new QueryString($"?id={connectionState.Connection.ConnectionId}");
context.Request.ContentType = "text/plain";
context.Response.Body = strm;
await dispatcher.ExecuteAsync<TestEndPoint>(context);
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await strm.FlushAsync();
@ -237,10 +249,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/foo", state);
var context = MakeRequest("/foo", state);
SetTransport(context, TransportType.ServerSentEvents);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>(context);
var services = new ServiceCollection();
services.AddEndPoint<ImmediatelyCompleteEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<ImmediatelyCompleteEndPoint>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
@ -256,11 +273,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<SynchronusExceptionEndPoint>("/foo", state);
var context = MakeRequest("/foo", state);
SetTransport(context, TransportType.ServerSentEvents);
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>(context);
var services = new ServiceCollection();
services.AddEndPoint<SynchronusExceptionEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<SynchronusExceptionEndPoint>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
@ -276,10 +297,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest("/foo", state);
var context = MakeRequest<SynchronusExceptionEndPoint>("/foo", state);
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>(context);
var services = new ServiceCollection();
services.AddEndPoint<SynchronusExceptionEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<SynchronusExceptionEndPoint>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
@ -296,9 +321,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/foo", state);
var context = MakeRequest("/foo", state);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>(context);
var services = new ServiceCollection();
services.AddEndPoint<ImmediatelyCompleteEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<ImmediatelyCompleteEndPoint>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
@ -315,10 +345,18 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/foo", state);
var context = MakeRequest("/foo", state);
SetTransport(context, TransportType.WebSockets);
var task = dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>(context);
var services = new ServiceCollection();
services.AddEndPoint<ImmediatelyCompleteEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<ImmediatelyCompleteEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1);
var task = dispatcher.ExecuteAsync(context, options, app);
await task.OrTimeout();
}
@ -333,15 +371,21 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context1 = MakeRequest<TestEndPoint>("/foo", state);
var context2 = MakeRequest<TestEndPoint>("/foo", state);
var context1 = MakeRequest("/foo", state);
var context2 = MakeRequest("/foo", state);
SetTransport(context1, transportType);
SetTransport(context2, transportType);
var request1 = dispatcher.ExecuteAsync<TestEndPoint>(context1);
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
var request1 = dispatcher.ExecuteAsync(context1, options, app);
await dispatcher.ExecuteAsync<TestEndPoint>(context2);
await dispatcher.ExecuteAsync(context2, options, app);
Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode);
@ -369,11 +413,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context1 = MakeRequest<TestEndPoint>("/foo", state);
var context2 = MakeRequest<TestEndPoint>("/foo", state);
var context1 = MakeRequest("/foo", state);
var context2 = MakeRequest("/foo", state);
var request1 = dispatcher.ExecuteAsync<TestEndPoint>(context1);
var request2 = dispatcher.ExecuteAsync<TestEndPoint>(context2);
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
var request1 = dispatcher.ExecuteAsync(context1, options, app);
var request2 = dispatcher.ExecuteAsync(context2, options, app);
await request1;
@ -398,10 +448,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<TestEndPoint>("/foo", state);
var context = MakeRequest("/foo", state);
SetTransport(context, transportType);
await dispatcher.ExecuteAsync<TestEndPoint>(context);
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
await dispatcher.ExecuteAsync(context, options, app);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
}
@ -414,9 +471,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<TestEndPoint>("/foo", state);
var context = MakeRequest("/foo", state);
var task = dispatcher.ExecuteAsync<TestEndPoint>(context);
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
var task = dispatcher.ExecuteAsync(context, options, app);
var buffer = Encoding.UTF8.GetBytes("Hello World");
@ -439,10 +502,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<BlockingEndPoint>("/foo", state);
var context = MakeRequest("/foo", state);
SetTransport(context, TransportType.ServerSentEvents);
var task = dispatcher.ExecuteAsync<BlockingEndPoint>(context);
var services = new ServiceCollection();
services.AddEndPoint<BlockingEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<BlockingEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
var task = dispatcher.ExecuteAsync(context, options, app);
var buffer = Encoding.UTF8.GetBytes("Hello World");
@ -465,9 +534,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<BlockingEndPoint>("/foo", state);
var context = MakeRequest("/foo", state);
var task = dispatcher.ExecuteAsync<BlockingEndPoint>(context);
var services = new ServiceCollection();
services.AddEndPoint<BlockingEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<BlockingEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
var task = dispatcher.ExecuteAsync(context, options, app);
var buffer = Encoding.UTF8.GetBytes("Hello World");
@ -490,10 +565,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context1 = MakeRequest<TestEndPoint>("/foo", state);
var task1 = dispatcher.ExecuteAsync<TestEndPoint>(context1);
var context2 = MakeRequest<TestEndPoint>("/foo", state);
var task2 = dispatcher.ExecuteAsync<TestEndPoint>(context2);
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
var context1 = MakeRequest("/foo", state);
var task1 = dispatcher.ExecuteAsync(context1, options, app);
var context2 = MakeRequest("/foo", state);
var task2 = dispatcher.ExecuteAsync(context2, options, app);
// Task 1 should finish when request 2 arrives
await task1.OrTimeout();
@ -555,19 +637,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>(options =>
services.AddEndPoint<TestEndPoint>();
services.AddAuthorization(o =>
{
options.AuthorizationPolicyNames.Add("test");
});
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
});
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
@ -576,8 +655,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
authFeature.Handler = new TestAuthenticationHandler(context);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
var builder = new SocketBuilder(sp);
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
options.AuthorizationPolicyNames.Add("test");
// would hang if EndPoint was running
await dispatcher.ExecuteAsync<TestEndPoint>(context).OrTimeout();
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
}
@ -591,22 +676,19 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>(options =>
services.AddEndPoint<TestEndPoint>();
services.AddAuthorization(o =>
{
options.AuthorizationPolicyNames.Add("test");
});
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy =>
o.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
});
});
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
@ -616,10 +698,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
authFeature.Handler = new TestAuthenticationHandler(context);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
var builder = new SocketBuilder(sp);
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
options.AuthorizationPolicyNames.Add("test");
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>(context);
var endPointTask = dispatcher.ExecuteAsync(context, options, app);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
@ -637,21 +725,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>(options =>
services.AddEndPoint<TestEndPoint>();
services.AddAuthorization(o =>
{
options.AuthorizationPolicyNames.Add("test");
options.AuthorizationPolicyNames.Add("secondPolicy");
});
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
options.AddPolicy("secondPolicy", policy => policy.RequireClaim(ClaimTypes.StreetAddress));
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
o.AddPolicy("secondPolicy", policy => policy.RequireClaim(ClaimTypes.StreetAddress));
});
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
@ -661,18 +745,25 @@ namespace Microsoft.AspNetCore.Sockets.Tests
authFeature.Handler = new TestAuthenticationHandler(context);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
var builder = new SocketBuilder(sp);
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
options.AuthorizationPolicyNames.Add("test");
options.AuthorizationPolicyNames.Add("secondPolicy");
// partialy "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// would hang if EndPoint was running
await dispatcher.ExecuteAsync<TestEndPoint>(context).OrTimeout();
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
// fully "authorize" user
context.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") }));
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>(context);
var endPointTask = dispatcher.ExecuteAsync(context, options, app);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
@ -689,23 +780,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>(options =>
services.AddEndPoint<TestEndPoint>();
services.AddAuthorization(o =>
{
options.AuthorizationPolicyNames.Add("test");
});
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy =>
o.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
policy.AddAuthenticationSchemes("Default");
});
});
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
@ -715,10 +803,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
authFeature.Handler = new TestAuthenticationHandler(context);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
var builder = new SocketBuilder(sp);
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
options.AuthorizationPolicyNames.Add("test");
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>(context);
var endPointTask = dispatcher.ExecuteAsync(context, options, app);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
@ -736,23 +830,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>(options =>
services.AddEndPoint<TestEndPoint>();
services.AddAuthorization(o =>
{
options.AuthorizationPolicyNames.Add("test");
});
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy =>
o.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
policy.AddAuthenticationSchemes("Default");
});
});
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
@ -762,11 +853,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests
authFeature.Handler = new TestAuthenticationHandler(context, acceptScheme: false);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
var builder = new SocketBuilder(sp);
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
options.AuthorizationPolicyNames.Add("test");
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// would block if EndPoint was executed
await dispatcher.ExecuteAsync<TestEndPoint>(context).OrTimeout();
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
}
@ -829,20 +926,22 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Response.Body = strm;
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<ImmediatelyCompleteEndPoint>(options =>
{
options.Transports = supportedTransports;
});
services.AddEndPoint<ImmediatelyCompleteEndPoint>();
SetTransport(context, transportType);
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>(context);
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<ImmediatelyCompleteEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
options.Transports = supportedTransports;
await dispatcher.ExecuteAsync(context, options, app);
Assert.Equal(status, context.Response.StatusCode);
await strm.FlushAsync();
@ -861,10 +960,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = MakeRequest<TestEndPoint>("/foo", state, format);
var context = MakeRequest("/foo", state, format);
context.Request.Method = "POST";
context.Request.ContentType = contentType;
var endPoint = context.RequestServices.GetRequiredService<TestEndPoint>();
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
var buffer = contentType == BinaryContentType ?
Convert.FromBase64String(encoded) :
@ -872,7 +976,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var messages = new List<Message>();
using (context.Request.Body = new MemoryStream(buffer, writable: false))
{
await dispatcher.ExecuteAsync<TestEndPoint>(context).OrTimeout();
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app).OrTimeout();
}
while (state.Connection.Transport.Input.TryRead(out var message))
@ -883,18 +987,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
return messages;
}
private static DefaultHttpContext MakeRequest<TEndPoint>(string path, ConnectionState state, string format = null) where TEndPoint : EndPoint
private static DefaultHttpContext MakeRequest(string path, ConnectionState state, string format = null)
{
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddEndPoint<TEndPoint>(o =>
{
// Make the close timeout less than the default for OrTimeout() test helper
o.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1);
});
services.AddOptions();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = path;
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
@ -942,7 +1037,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public class NerverEndingEndPoint : EndPoint
{
public override Task OnConnectedAsync(Connection connection)
public override Task OnConnectedAsync(ConnectionContext connection)
{
var tcs = new TaskCompletionSource<object>();
return tcs.Task;
@ -951,7 +1046,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public class BlockingEndPoint : EndPoint
{
public override Task OnConnectedAsync(Connection connection)
public override Task OnConnectedAsync(ConnectionContext connection)
{
connection.Transport.Input.WaitToReadAsync().Wait();
return Task.CompletedTask;
@ -960,7 +1055,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public class SynchronusExceptionEndPoint : EndPoint
{
public override Task OnConnectedAsync(Connection connection)
public override Task OnConnectedAsync(ConnectionContext connection)
{
throw new InvalidOperationException();
}
@ -968,7 +1063,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public class ImmediatelyCompleteEndPoint : EndPoint
{
public override Task OnConnectedAsync(Connection connection)
public override Task OnConnectedAsync(ConnectionContext connection)
{
return Task.CompletedTask;
}
@ -976,7 +1071,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public class TestEndPoint : EndPoint
{
public override async Task OnConnectedAsync(Connection connection)
public override async Task OnConnectedAsync(ConnectionContext connection)
{
while (await connection.Transport.Input.WaitToReadAsync())
{