add "messaging" endpoints and transports

* Need a separate set of primitives to handle messaging
* Using Channels (not Pipelines!) to provide the data flow for messaging
* All transports are now "message" based transports
* Added an adaptor to convert message-based transports to serve
streaming endpoints
This commit is contained in:
Andrew Stanton-Nurse 2016-12-07 14:09:09 -08:00 committed by Andrew Stanton-Nurse
parent fa0219f75e
commit d281cb72ea
85 changed files with 1603 additions and 813 deletions

View File

@ -1,3 +1,4 @@
image: Visual Studio 2015
init:
- git config --global core.autocrlf true
branches:
@ -10,4 +11,4 @@ build_script:
- build.cmd --quiet verify
clone_depth: 1
test: off
deploy: off
deploy: off

View File

@ -6,4 +6,4 @@
"sdk": {
"version": "1.0.0-preview2-1-003180"
}
}
}

View File

@ -18,7 +18,7 @@ namespace ChatSample.Hubs
{
if (!Context.User.Identity.IsAuthenticated)
{
Context.Connection.Channel.Dispose();
Context.Connection.Transport.Dispose();
}
return Task.CompletedTask;

View File

@ -90,4 +90,4 @@
"prepublish": [ "bower install" ],
"postpublish": [ "dotnet publish-iis --publish-folder %publish:OutputPath% --framework %publish:FullTargetFramework%" ]
}
}
}

View File

@ -1,4 +1,6 @@

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using Microsoft.Extensions.DependencyInjection;

View File

@ -1,4 +1,7 @@
using System.IO;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
using System.Threading.Tasks;
namespace SocialWeather

View File

@ -1,4 +1,7 @@
using System.IO;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
using System.Threading.Tasks;
using Newtonsoft.Json;

View File

@ -1,6 +1,10 @@
using System;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
@ -9,19 +13,19 @@ namespace SocialWeather
public class PersistentConnectionLifeTimeManager
{
private readonly FormatterResolver _formatterResolver;
private readonly ConnectionList _connectionList = new ConnectionList();
private readonly ConnectionList<StreamingConnection> _connectionList = new ConnectionList<StreamingConnection>();
public PersistentConnectionLifeTimeManager(FormatterResolver formatterResolver)
{
_formatterResolver = formatterResolver;
}
public void OnConnectedAsync(Connection connection)
public void OnConnectedAsync(StreamingConnection connection)
{
_connectionList.Add(connection);
}
public void OnDisconnectedAsync(Connection connection)
public void OnDisconnectedAsync(StreamingConnection connection)
{
_connectionList.Remove(connection);
}
@ -31,7 +35,7 @@ namespace SocialWeather
foreach (var connection in _connectionList)
{
var formatter = _formatterResolver.GetFormatter<T>(connection.Metadata.Get<string>("formatType"));
await formatter.WriteAsync(data, connection.Channel.GetStream());
await formatter.WriteAsync(data, connection.Transport.GetStream());
}
}
@ -50,7 +54,7 @@ namespace SocialWeather
throw new NotImplementedException();
}
public void AddGroupAsync(Connection connection, string groupName)
public void AddGroupAsync(StreamingConnection connection, string groupName)
{
var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>());
lock (groups)
@ -59,7 +63,7 @@ namespace SocialWeather
}
}
public void RemoveGroupAsync(Connection connection, string groupName)
public void RemoveGroupAsync(StreamingConnection connection, string groupName)
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
if (groups != null)

View File

@ -1,4 +1,7 @@
using System;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Text;
using System.Threading.Tasks;

View File

@ -1,4 +1,7 @@
using System;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;

View File

@ -1,4 +1,7 @@
using System;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Threading.Tasks;
using Google.Protobuf;

View File

@ -1,11 +1,14 @@
using System.IO.Pipelines;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.Logging;
namespace SocialWeather
{
public class SocialWeatherEndPoint : EndPoint
public class SocialWeatherEndPoint : StreamingEndPoint
{
private readonly PersistentConnectionLifeTimeManager _lifetimeManager;
private readonly FormatterResolver _formatterResolver;
@ -19,16 +22,16 @@ namespace SocialWeather
_logger = logger;
}
public async override Task OnConnectedAsync(Connection connection)
public async override Task OnConnectedAsync(StreamingConnection connection)
{
_lifetimeManager.OnConnectedAsync(connection);
await ProcessRequests(connection);
_lifetimeManager.OnDisconnectedAsync(connection);
}
public async Task ProcessRequests(Connection connection)
public async Task ProcessRequests(StreamingConnection connection)
{
var stream = connection.Channel.GetStream();
var stream = connection.Transport.GetStream();
var formatter = _formatterResolver.GetFormatter<WeatherReport>(
connection.Metadata.Get<string>("formatType"));

View File

@ -1,4 +1,7 @@
using Microsoft.AspNetCore.Builder;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

View File

@ -1,4 +1,7 @@
using Newtonsoft.Json;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Newtonsoft.Json;
using Newtonsoft.Json.Converters;
namespace SocialWeather

View File

@ -39,4 +39,4 @@
"dotnet publish-iis --publish-folder %publish:OutputPath% --framework %publish:FullTargetFramework%"
]
}
}
}

View File

@ -1,19 +1,21 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
namespace SocketsSample
{
public class ChatEndPoint : EndPoint
public class ChatEndPoint : StreamingEndPoint
{
public ConnectionList Connections { get; } = new ConnectionList();
public ConnectionList<StreamingConnection> Connections { get; } = new ConnectionList<StreamingConnection>();
public override async Task OnConnectedAsync(Connection connection)
public override async Task OnConnectedAsync(StreamingConnection connection)
{
Connections.Add(connection);
@ -21,7 +23,7 @@ namespace SocketsSample
while (true)
{
var result = await connection.Channel.Input.ReadAsync();
var result = await connection.Transport.Input.ReadAsync();
var input = result.Buffer;
try
{
@ -35,7 +37,7 @@ namespace SocketsSample
}
finally
{
connection.Channel.Input.Advance(input.End);
connection.Transport.Input.Advance(input.End);
}
}
@ -55,7 +57,7 @@ namespace SocketsSample
foreach (var c in Connections)
{
tasks.Add(c.Channel.Output.WriteAsync(payload));
tasks.Add(c.Transport.Output.WriteAsync(payload));
}
return Task.WhenAll(tasks);

View File

@ -0,0 +1,68 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Sockets;
namespace SocketsSample.EndPoints
{
public class MessagesEndPoint : MessagingEndPoint
{
public ConnectionList<MessagingConnection> Connections { get; } = new ConnectionList<MessagingConnection>();
public override async Task OnConnectedAsync(MessagingConnection connection)
{
Connections.Add(connection);
await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})");
try
{
while (true)
{
using (var message = await connection.Transport.Input.ReadAsync())
{
// We can avoid the copy here but we'll deal with that later
await Broadcast(message.Payload.Buffer, message.MessageFormat, message.EndOfMessage);
}
}
}
catch (Exception ex) when (ex.GetType().IsNested && ex.GetType().DeclaringType == typeof(Channel))
{
// Gross that we have to catch this this way. See https://github.com/dotnet/corefxlab/issues/1068
}
finally
{
Connections.Remove(connection);
await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})");
}
}
private Task Broadcast(string text)
{
return Broadcast(ReadableBuffer.Create(Encoding.UTF8.GetBytes(text)), Format.Text, endOfMessage: true);
}
private Task Broadcast(ReadableBuffer payload, Format format, bool endOfMessage)
{
var tasks = new List<Task>(Connections.Count);
foreach (var c in Connections)
{
tasks.Add(c.Transport.Output.WriteAsync(new Message(
payload.Preserve(),
format,
endOfMessage)));
}
return Task.WhenAll(tasks);
}
}
}

View File

@ -5,6 +5,7 @@ using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using SocketsSample.EndPoints;
using SocketsSample.Hubs;
using SocketsSample.Protobuf;
@ -29,6 +30,7 @@ namespace SocketsSample
// .AddRedis();
services.AddSingleton<ChatEndPoint>();
services.AddSingleton<MessagesEndPoint>();
services.AddSingleton<ProtobufSerializer>();
}
@ -52,6 +54,7 @@ namespace SocketsSample
app.UseSockets(routes =>
{
routes.MapEndpoint<ChatEndPoint>("/chat");
routes.MapEndpoint<MessagesEndPoint>("/msgs");
});
}
}

View File

@ -1,10 +1,6 @@
{
"dependencies": {
"Microsoft.AspNetCore.SignalR.Redis": "1.0.0-*",
"Microsoft.NETCore.App": {
"version": "1.2.0-*",
"type": "platform"
},
"Newtonsoft.Json": "9.0.1",
"Microsoft.AspNetCore.Diagnostics": "1.2.0-*",
"Microsoft.AspNetCore.StaticFiles": "1.2.0-*",
@ -18,8 +14,15 @@
},
"frameworks": {
"netcoreapp1.1": {
"imports": "portable-net45+win8+wp8+wpa81"
}
"imports": "portable-net45+win8+wp8+wpa81",
"dependencies": {
"Microsoft.NETCore.App": {
"version": "1.2.0-*",
"type": "platform"
}
}
},
"net46": {}
},
"buildOptions": {
"emitEntryPoint": true,
@ -45,4 +48,4 @@
"dotnet publish-iis --publish-folder %publish:OutputPath% --framework %publish:FullTargetFramework%"
]
}
}
}

View File

@ -6,14 +6,21 @@
</head>
<body>
<h1>ASP.NET Sockets</h1>
<h2>Streaming</h2>
<ul>
<li><a href="sse.html">Server Sent Events</a></li>
<li><a href="polling.html">Long polling</a></li>
<li><a href="ws.html">Web Sockets</a></li>
<li><a href="sse.html#/chat">Server Sent Events</a></li>
<li><a href="polling.html#/chat">Long polling</a></li>
<li><a href="ws.html#/chat">Web Sockets</a></li>
</ul>
<h2>Messaging</h2>
<ul>
<li><a href="sse.html#/msgs">Server Sent Events</a></li>
<li><a href="polling.html#/msgs">Long polling</a></li>
<li><a href="ws.html#/msgs">Web Sockets</a></li>
</ul>
<h1>ASP.NET SignalR</h1>
<ul>
<li><a href="hubs.html">Hubs</a></li>
</ul>
</body>
</html>
</html>

View File

@ -75,7 +75,8 @@
}
document.addEventListener('DOMContentLoaded', () => {
var sock = new socket('/chat');
var url = location.hash || '#/chat';
var sock = new socket(url.substring(1));
sock.onopen = function () {
console.log('Opened!');
@ -105,4 +106,4 @@
<ul id="messages"></ul>
</body>
</html>
</html>

View File

@ -74,7 +74,8 @@
document.addEventListener('DOMContentLoaded', () => {
var sock = new socket('/chat');
var url = location.hash || '#/chat';
var sock = new socket(url.substring(1));
sock.onopen = function () {
console.log('Opened!');
@ -105,4 +106,4 @@
<ul id="messages"></ul>
</body>
</html>
</html>

View File

@ -5,7 +5,8 @@
<title></title>
<script>
document.addEventListener('DOMContentLoaded', () => {
var ws = new WebSocket(`ws://${document.location.host}/chat/ws`);
var url = (location.hash || '#/chat').substring(1);
var ws = new WebSocket(`ws://${document.location.host}${url}/ws`);
ws.onopen = function () {
console.log('Opened!');
@ -41,4 +42,4 @@
</ul>
</body>
</html>
</html>

View File

@ -5,7 +5,7 @@
"dependencies": {
"Microsoft.NETCore.App": {
"version": "1.1.0-*",
"version": "1.2.0-*",
"type": "platform"
},
"System.Net.WebSockets.Client": "4.3.0-*"

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
@ -25,11 +26,18 @@ namespace Microsoft.AspNetCore.SignalR.Client
private readonly HubBinder _binder;
private readonly CancellationTokenSource _readerCts = new CancellationTokenSource();
private readonly ConcurrentDictionary<string, InvocationRequest> _pendingCalls = new ConcurrentDictionary<string, InvocationRequest>();
private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource();
// We need to ensure pending calls added after a connection failure don't hang. Right now the easiest thing to do is lock.
private readonly object _pendingCallsLock = new object();
private readonly Dictionary<string, InvocationRequest> _pendingCalls = new Dictionary<string, InvocationRequest>();
private readonly ConcurrentDictionary<string, InvocationHandler> _handlers = new ConcurrentDictionary<string, InvocationHandler>();
private int _nextId = 0;
public Task Completion { get; }
private HubConnection(Connection connection, IInvocationAdapter adapter, ILogger logger)
{
_binder = new HubBinder(this);
@ -39,8 +47,8 @@ namespace Microsoft.AspNetCore.SignalR.Client
_logger = logger;
_reader = ReceiveMessages(_readerCts.Token);
_connection.Output.Writing.ContinueWith(
t => CompletePendingCalls(t.IsFaulted ? t.Exception.InnerException : null));
Completion = _connection.Output.Writing.ContinueWith(
t => Shutdown(t)).Unwrap();
}
// TODO: Client return values/tasks?
@ -73,10 +81,15 @@ namespace Microsoft.AspNetCore.SignalR.Client
// I just want an excuse to use 'irq' as a variable name...
_logger.LogDebug("Registering Invocation ID '{0}' for tracking", descriptor.Id);
var irq = new InvocationRequest(cancellationToken, returnType);
var addedSuccessfully = _pendingCalls.TryAdd(descriptor.Id, irq);
// This should always be true since we monotonically increase ids.
Debug.Assert(addedSuccessfully, "Id already in use?");
lock (_pendingCallsLock)
{
if (_connectionActive.IsCancellationRequested)
{
throw new InvalidOperationException("Connection has been terminated");
}
_pendingCalls.Add(descriptor.Id, irq);
}
// Trace the invocation, but only if that logging level is enabled (because building the args list is a bit slow)
if (_logger.IsEnabled(LogLevel.Trace))
@ -117,45 +130,69 @@ namespace Microsoft.AspNetCore.SignalR.Client
await Task.Yield();
_logger.LogTrace("Beginning receive loop");
while (!cancellationToken.IsCancellationRequested)
try
{
// This is a little odd... we want to remove the InvocationRequest once and only once so we pull it out in the callback,
// and stash it here because we know the callback will have finished before the end of the await.
var message = await _adapter.ReadMessageAsync(_stream, _binder, cancellationToken);
while (!cancellationToken.IsCancellationRequested)
{
// This is a little odd... we want to remove the InvocationRequest once and only once so we pull it out in the callback,
// and stash it here because we know the callback will have finished before the end of the await.
var message = await _adapter.ReadMessageAsync(_stream, _binder, cancellationToken);
var invocationDescriptor = message as InvocationDescriptor;
if (invocationDescriptor != null)
{
DispatchInvocation(invocationDescriptor, cancellationToken);
}
else
{
var invocationResultDescriptor = message as InvocationResultDescriptor;
if (invocationResultDescriptor != null)
var invocationDescriptor = message as InvocationDescriptor;
if (invocationDescriptor != null)
{
DispatchInvocationResult(invocationResultDescriptor, cancellationToken);
DispatchInvocation(invocationDescriptor, cancellationToken);
}
else
{
var invocationResultDescriptor = message as InvocationResultDescriptor;
if (invocationResultDescriptor != null)
{
InvocationRequest irq;
lock (_pendingCallsLock)
{
_connectionActive.Token.ThrowIfCancellationRequested();
irq = _pendingCalls[invocationResultDescriptor.Id];
_pendingCalls.Remove(invocationResultDescriptor.Id);
}
DispatchInvocationResult(invocationResultDescriptor, irq, cancellationToken);
}
}
}
}
_logger.LogTrace("Ending receive loop");
finally
{
_logger.LogTrace("Ending receive loop");
}
}
private void CompletePendingCalls(Exception e)
private Task Shutdown(Task completion)
{
_logger.LogTrace("Completing pending calls");
foreach (var call in _pendingCalls.Values)
_logger.LogTrace("Shutting down connection");
if (completion.IsFaulted)
{
if (e == null)
{
call.Completion.TrySetCanceled();
}
else
{
call.Completion.TrySetException(e);
}
_logger.LogError("Connection is shutting down due to an error: {0}", completion.Exception.InnerException);
}
_pendingCalls.Clear();
lock (_pendingCallsLock)
{
_connectionActive.Cancel();
foreach (var call in _pendingCalls.Values)
{
if (!completion.IsFaulted)
{
call.Completion.TrySetCanceled();
}
else
{
call.Completion.TrySetException(completion.Exception.InnerException);
}
}
_pendingCalls.Clear();
}
// Return the completion anyway
return completion;
}
private void DispatchInvocation(InvocationDescriptor invocationDescriptor, CancellationToken cancellationToken)
@ -172,12 +209,8 @@ namespace Microsoft.AspNetCore.SignalR.Client
handler.Handler(invocationDescriptor.Arguments);
}
private void DispatchInvocationResult(InvocationResultDescriptor result, CancellationToken cancellationToken)
private void DispatchInvocationResult(InvocationResultDescriptor result, InvocationRequest irq, CancellationToken cancellationToken)
{
InvocationRequest irq;
var successfullyRemoved = _pendingCalls.TryRemove(result.Id, out irq);
Debug.Assert(successfullyRemoved, $"Invocation request {result.Id} was removed from the pending calls dictionary!");
_logger.LogInformation("Received Result for Invocation #{0}", result.Id);
if (cancellationToken.IsCancellationRequested)

View File

@ -28,7 +28,7 @@
},
"frameworks": {
"netstandard1.3": {},
"net451": {}
"netstandard1.3": {
}
}
}

View File

@ -31,7 +31,6 @@
},
"frameworks": {
"netstandard1.3": {},
"net451": {}
"netstandard1.3": {}
}
}

View File

@ -20,7 +20,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable
{
private readonly ConnectionList _connections = new ConnectionList();
private readonly ConnectionList<StreamingConnection> _connections = new ConnectionList<StreamingConnection>();
// TODO: Investigate "memory leak" entries never get removed
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>();
private readonly InvocationAdapterRegistry _registry;
@ -51,7 +51,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
foreach (var connection in _connections)
{
tasks.Add(connection.Channel.Output.WriteAsync((byte[])data));
tasks.Add(connection.Transport.Output.WriteAsync((byte[])data));
}
previousBroadcastTask = Task.WhenAll(tasks);
@ -116,7 +116,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
public override Task OnConnectedAsync(Connection connection)
public override Task OnConnectedAsync(StreamingConnection connection)
{
var redisSubscriptions = connection.Metadata.GetOrAdd("redis_subscriptions", _ => new HashSet<string>());
var connectionTask = TaskCache.CompletedTask;
@ -133,7 +133,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
await previousConnectionTask;
previousConnectionTask = connection.Channel.Output.WriteAsync((byte[])data);
previousConnectionTask = connection.Transport.Output.WriteAsync((byte[])data);
});
@ -149,14 +149,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
await previousUserTask;
previousUserTask = connection.Channel.Output.WriteAsync((byte[])data);
previousUserTask = connection.Transport.Output.WriteAsync((byte[])data);
});
}
return Task.WhenAll(connectionTask, userTask);
}
public override Task OnDisconnectedAsync(Connection connection)
public override Task OnDisconnectedAsync(StreamingConnection connection)
{
_connections.Remove(connection);
@ -186,7 +186,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return Task.WhenAll(tasks);
}
public override async Task AddGroupAsync(Connection connection, string groupName)
public override async Task AddGroupAsync(StreamingConnection connection, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
@ -220,9 +220,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis
await previousTask;
var tasks = new List<Task>(group.Connections.Count);
foreach (var groupConnection in group.Connections)
foreach (var groupConnection in group.Connections.Cast<StreamingConnection>())
{
tasks.Add(groupConnection.Channel.Output.WriteAsync((byte[])data));
tasks.Add(groupConnection.Transport.Output.WriteAsync((byte[])data));
}
previousTask = Task.WhenAll(tasks);
@ -234,7 +234,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
public override async Task RemoveGroupAsync(Connection connection, string groupName)
public override async Task RemoveGroupAsync(StreamingConnection connection, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
@ -300,7 +300,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
private class GroupData
{
public SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
public ConnectionList Connections = new ConnectionList();
public ConnectionList<StreamingConnection> Connections = new ConnectionList<StreamingConnection>();
}
}
}

View File

@ -32,6 +32,10 @@
},
"frameworks": {
"netstandard1.6": {},
"net451": {}
"net46": {
"dependencies": {
"System.Security.Claims": "4.0.1"
}
}
}
}
}

View File

@ -12,7 +12,7 @@ namespace Microsoft.AspNetCore.SignalR
{
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub>
{
private readonly ConnectionList _connections = new ConnectionList();
private readonly ConnectionList<StreamingConnection> _connections = new ConnectionList<StreamingConnection>();
private readonly InvocationAdapterRegistry _registry;
public DefaultHubLifetimeManager(InvocationAdapterRegistry registry)
@ -20,7 +20,7 @@ namespace Microsoft.AspNetCore.SignalR
_registry = registry;
}
public override Task AddGroupAsync(Connection connection, string groupName)
public override Task AddGroupAsync(StreamingConnection connection, string groupName)
{
var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>());
@ -32,7 +32,7 @@ namespace Microsoft.AspNetCore.SignalR
return TaskCache.CompletedTask;
}
public override Task RemoveGroupAsync(Connection connection, string groupName)
public override Task RemoveGroupAsync(StreamingConnection connection, string groupName)
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
@ -49,7 +49,7 @@ namespace Microsoft.AspNetCore.SignalR
return InvokeAllWhere(methodName, args, c => true);
}
private Task InvokeAllWhere(string methodName, object[] args, Func<Connection, bool> include)
private Task InvokeAllWhere(string methodName, object[] args, Func<StreamingConnection, bool> include)
{
var tasks = new List<Task>(_connections.Count);
var message = new InvocationDescriptor
@ -68,7 +68,7 @@ namespace Microsoft.AspNetCore.SignalR
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
tasks.Add(invocationAdapter.WriteMessageAsync(message, connection.Channel.GetStream()));
tasks.Add(invocationAdapter.WriteMessageAsync(message, connection.Transport.GetStream()));
}
return Task.WhenAll(tasks);
@ -86,7 +86,7 @@ namespace Microsoft.AspNetCore.SignalR
Arguments = args
};
return invocationAdapter.WriteMessageAsync(message, connection.Channel.GetStream());
return invocationAdapter.WriteMessageAsync(message, connection.Transport.GetStream());
}
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
@ -106,13 +106,13 @@ namespace Microsoft.AspNetCore.SignalR
});
}
public override Task OnConnectedAsync(Connection connection)
public override Task OnConnectedAsync(StreamingConnection connection)
{
_connections.Add(connection);
return TaskCache.CompletedTask;
}
public override Task OnDisconnectedAsync(Connection connection)
public override Task OnDisconnectedAsync(StreamingConnection connection)
{
_connections.Remove(connection);
return TaskCache.CompletedTask;

View File

@ -8,12 +8,12 @@ namespace Microsoft.AspNetCore.SignalR
{
public class HubCallerContext
{
public HubCallerContext(Connection connection)
public HubCallerContext(StreamingConnection connection)
{
Connection = connection;
}
public Connection Connection { get; }
public StreamingConnection Connection { get; }
public ClaimsPrincipal User => Connection.User;

View File

@ -25,10 +25,10 @@ namespace Microsoft.AspNetCore.SignalR
}
}
public class HubEndPoint<THub, TClient> : EndPoint, IInvocationBinder where THub : Hub<TClient>
public class HubEndPoint<THub, TClient> : StreamingEndPoint, IInvocationBinder where THub : Hub<TClient>
{
private readonly Dictionary<string, Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>>> _callbacks
= new Dictionary<string, Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>>>(StringComparer.OrdinalIgnoreCase);
private readonly Dictionary<string, Func<StreamingConnection, InvocationDescriptor, Task<InvocationResultDescriptor>>> _callbacks
= new Dictionary<string, Func<StreamingConnection, InvocationDescriptor, Task<InvocationResultDescriptor>>>(StringComparer.OrdinalIgnoreCase);
private readonly Dictionary<string, Type[]> _paramTypes = new Dictionary<string, Type[]>();
private readonly HubLifetimeManager<THub> _lifetimeManager;
@ -52,7 +52,7 @@ namespace Microsoft.AspNetCore.SignalR
DiscoverHubMethods();
}
public override async Task OnConnectedAsync(Connection connection)
public override async Task OnConnectedAsync(StreamingConnection connection)
{
// TODO: Dispatch from the caller
await Task.Yield();
@ -68,7 +68,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task RunHubAsync(Connection connection)
private async Task RunHubAsync(StreamingConnection connection)
{
await HubOnConnectedAsync(connection);
@ -86,7 +86,7 @@ namespace Microsoft.AspNetCore.SignalR
await HubOnDisconnectedAsync(connection, null);
}
private async Task HubOnConnectedAsync(Connection connection)
private async Task HubOnConnectedAsync(StreamingConnection connection)
{
try
{
@ -108,11 +108,13 @@ namespace Microsoft.AspNetCore.SignalR
catch (Exception ex)
{
_logger.LogError(0, ex, "Error when invoking OnConnectedAsync on hub.");
connection.Transport.Input.Complete(ex);
connection.Transport.Output.Complete(ex);
throw;
}
}
private async Task HubOnDisconnectedAsync(Connection connection, Exception exception)
private async Task HubOnDisconnectedAsync(StreamingConnection connection, Exception exception)
{
try
{
@ -138,9 +140,9 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task DispatchMessagesAsync(Connection connection)
private async Task DispatchMessagesAsync(StreamingConnection connection)
{
var stream = connection.Channel.GetStream();
var stream = connection.Transport.GetStream();
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
while (true)
@ -160,7 +162,7 @@ namespace Microsoft.AspNetCore.SignalR
}
InvocationResultDescriptor result;
Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>> callback;
Func<StreamingConnection, InvocationDescriptor, Task<InvocationResultDescriptor>> callback;
if (_callbacks.TryGetValue(invocationDescriptor.Method, out callback))
{
result = await callback(connection, invocationDescriptor);
@ -181,7 +183,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private void InitializeHub(THub hub, Connection connection)
private void InitializeHub(THub hub, StreamingConnection connection)
{
hub.Clients = _hubContext.Clients;
hub.Context = new HubCallerContext(connection);

View File

@ -8,9 +8,9 @@ namespace Microsoft.AspNetCore.SignalR
{
public abstract class HubLifetimeManager<THub>
{
public abstract Task OnConnectedAsync(Connection connection);
public abstract Task OnConnectedAsync(StreamingConnection connection);
public abstract Task OnDisconnectedAsync(Connection connection);
public abstract Task OnDisconnectedAsync(StreamingConnection connection);
public abstract Task InvokeAllAsync(string methodName, object[] args);
@ -20,9 +20,9 @@ namespace Microsoft.AspNetCore.SignalR
public abstract Task InvokeUserAsync(string userId, string methodName, object[] args);
public abstract Task AddGroupAsync(Connection connection, string groupName);
public abstract Task AddGroupAsync(StreamingConnection connection, string groupName);
public abstract Task RemoveGroupAsync(Connection connection, string groupName);
public abstract Task RemoveGroupAsync(StreamingConnection connection, string groupName);
}
}

View File

@ -75,10 +75,10 @@ namespace Microsoft.AspNetCore.SignalR
public class GroupManager<THub> : IGroupManager
{
private readonly Connection _connection;
private readonly StreamingConnection _connection;
private readonly HubLifetimeManager<THub> _lifetimeManager;
public GroupManager(Connection connection, HubLifetimeManager<THub> lifetimeManager)
public GroupManager(StreamingConnection connection, HubLifetimeManager<THub> lifetimeManager)
{
_connection = connection;
_lifetimeManager = lifetimeManager;

View File

@ -32,7 +32,7 @@
"Newtonsoft.Json": "9.0.1"
},
"frameworks": {
"netstandard1.3": {},
"net451": {}
"netstandard1.3": {
}
}
}
}

View File

@ -54,8 +54,8 @@ namespace Microsoft.AspNetCore.Sockets.Client
});
// Start sending and polling
_sender = SendMessages(Utils.AppendPath(url, "send"), _senderCts.Token);
_poller = Poll(Utils.AppendPath(url, "poll"), _pollCts.Token);
_sender = SendMessages(Utils.AppendPath(url, "send"), _senderCts.Token);
Running = Task.WhenAll(_sender, _poller);
return TaskCache.CompletedTask;

View File

@ -32,7 +32,6 @@
},
"frameworks": {
"netstandard1.3": {},
"net451": {}
"netstandard1.3": {}
}
}

View File

@ -1,16 +1,26 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
using System;
using System.Security.Claims;
namespace Microsoft.AspNetCore.Sockets
{
public class Connection
public abstract class Connection : IDisposable
{
public string ConnectionId { get; set; }
public abstract ConnectionMode Mode { get; }
public string ConnectionId { get; }
public ClaimsPrincipal User { get; set; }
public IPipelineConnection Channel { get; set; }
public ConnectionMetadata Metadata { get; } = new ConnectionMetadata();
protected Connection(string id)
{
ConnectionId = id;
}
public virtual void Dispose()
{
}
}
}

View File

@ -8,15 +8,15 @@ using System.Collections.Generic;
namespace Microsoft.AspNetCore.Sockets
{
public class ConnectionList : IReadOnlyCollection<Connection>
public class ConnectionList<T> : IReadOnlyCollection<T> where T: Connection
{
private readonly ConcurrentDictionary<string, Connection> _connections = new ConcurrentDictionary<string, Connection>();
private readonly ConcurrentDictionary<string, T> _connections = new ConcurrentDictionary<string, T>();
public Connection this[string connectionId]
public T this[string connectionId]
{
get
{
Connection connection;
T connection;
if (_connections.TryGetValue(connectionId, out connection))
{
return connection;
@ -27,18 +27,18 @@ namespace Microsoft.AspNetCore.Sockets
public int Count => _connections.Count;
public void Add(Connection connection)
public void Add(T connection)
{
_connections.TryAdd(connection.ConnectionId, connection);
}
public void Remove(Connection connection)
public void Remove(T connection)
{
Connection dummy;
T dummy;
_connections.TryRemove(connection.ConnectionId, out dummy);
}
public IEnumerator<Connection> GetEnumerator()
public IEnumerator<T> GetEnumerator()
{
foreach (var item in _connections)
{

View File

@ -6,16 +6,20 @@ using System.Collections.Concurrent;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Sockets.Internal;
namespace Microsoft.AspNetCore.Sockets
{
public class ConnectionManager
{
private ConcurrentDictionary<string, ConnectionState> _connections = new ConcurrentDictionary<string, ConnectionState>();
private Timer _timer;
private readonly ConcurrentDictionary<string, ConnectionState> _connections = new ConcurrentDictionary<string, ConnectionState>();
private readonly Timer _timer;
private readonly PipelineFactory _pipelineFactory;
public ConnectionManager()
public ConnectionManager(PipelineFactory pipelineFactory)
{
_pipelineFactory = pipelineFactory;
_timer = new Timer(Scan, this, 0, 1000);
}
@ -24,39 +28,8 @@ namespace Microsoft.AspNetCore.Sockets
return _connections.TryGetValue(id, out state);
}
public ConnectionState ReserveConnection()
{
string id = MakeNewConnectionId();
// REVIEW: Should we create state for this?
var state = _connections.GetOrAdd(id, connectionId => new ConnectionState());
// Mark it as a reservation
state.Connection = new Connection
{
ConnectionId = id
};
return state;
}
public ConnectionState AddNewConnection(IPipelineConnection connection)
{
string id = MakeNewConnectionId();
var state = new ConnectionState
{
Connection = new Connection
{
Channel = connection,
ConnectionId = id
},
LastSeen = DateTimeOffset.UtcNow,
Active = true
};
_connections.TryAdd(id, state);
return state;
}
public ConnectionState CreateConnection(ConnectionMode mode) =>
mode == ConnectionMode.Streaming ? CreateStreamingConnection() : CreateMessagingConnection();
public void RemoveConnection(string id)
{
@ -82,7 +55,7 @@ namespace Microsoft.AspNetCore.Sockets
// Scan the registered connections looking for ones that have timed out
foreach (var c in _connections)
{
if (!c.Value.Active && (DateTimeOffset.UtcNow - c.Value.LastSeen).TotalSeconds > 5)
if (!c.Value.Active && (DateTimeOffset.UtcNow - c.Value.LastSeenUtc).TotalSeconds > 5)
{
ConnectionState s;
if (_connections.TryRemove(c.Key, out s))
@ -114,10 +87,46 @@ namespace Microsoft.AspNetCore.Sockets
}
else
{
s.Connection.Channel.Dispose();
s.Dispose();
}
}
}
}
private ConnectionState CreateMessagingConnection()
{
var id = MakeNewConnectionId();
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
var state = new MessagingConnectionState(
new MessagingConnection(id, applicationSide),
transportSide);
_connections.TryAdd(id, state);
return state;
}
private ConnectionState CreateStreamingConnection()
{
var id = MakeNewConnectionId();
var transportToApplication = _pipelineFactory.Create();
var applicationToTransport = _pipelineFactory.Create();
var transportSide = new PipelineConnection(applicationToTransport, transportToApplication);
var applicationSide = new PipelineConnection(transportToApplication, applicationToTransport);
var state = new StreamingConnectionState(
new StreamingConnection(id, applicationSide),
transportSide);
_connections.TryAdd(id, state);
return state;
}
}
}

View File

@ -0,0 +1,11 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.Sockets
{
public enum ConnectionMode
{
Streaming,
Messaging
}
}

View File

@ -1,17 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.Sockets
{
public class ConnectionState
{
public Connection Connection { get; set; }
// These are used for long polling mostly
public Action Close { get; set; }
public DateTimeOffset LastSeen { get; set; }
public bool Active { get; set; } = true;
}
}

View File

@ -8,8 +8,17 @@ namespace Microsoft.AspNetCore.Sockets
/// <summary>
/// Represents an end point that multiple connections connect to. For HTTP, endpoints are URLs, for non HTTP it can be a TCP listener (or similar)
/// </summary>
// REVIEW: This doesn't have any members any more... marker interface? Still even necessary?
public abstract class EndPoint
{
/// <summary>
/// Gets the connection mode supported by this endpoint.
/// </summary>
/// <remarks>
/// This maps directly to whichever of <see cref="MessagingEndPoint"/> or <see cref="StreamingEndPoint"/> the end point subclasses.
/// </remarks>
public abstract ConnectionMode Mode { get; }
/// <summary>
/// Called when a new connection is accepted to the endpoint
/// </summary>

View File

@ -1,11 +1,6 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets
{
public enum Format

View File

@ -1,34 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Sockets
{
public class HttpConnection : IPipelineConnection
{
public HttpConnection(PipelineFactory factory)
{
Input = factory.Create();
Output = factory.Create();
}
IPipelineReader IPipelineConnection.Input => Input;
IPipelineWriter IPipelineConnection.Output => Output;
public PipelineReaderWriter Input { get; }
public PipelineReaderWriter Output { get; }
public void Dispose()
{
Input.CompleteReader();
Input.CompleteWriter();
Output.CompleteReader();
Output.CompleteWriter();
}
}
}

View File

@ -2,10 +2,13 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.IO.Pipelines;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.AspNetCore.Sockets.Transports;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
@ -17,19 +20,24 @@ namespace Microsoft.AspNetCore.Sockets
private readonly ConnectionManager _manager;
private readonly PipelineFactory _pipelineFactory;
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
public HttpConnectionDispatcher(ConnectionManager manager, PipelineFactory factory, ILoggerFactory loggerFactory)
{
_manager = manager;
_pipelineFactory = factory;
_loggerFactory = loggerFactory;
_logger = _loggerFactory.CreateLogger<HttpConnectionDispatcher>();
}
public async Task ExecuteAsync<TEndPoint>(string path, HttpContext context) where TEndPoint : EndPoint
{
// Get the end point mapped to this http connection
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
if (context.Request.Path.StartsWithSegments(path + "/getid"))
{
await ProcessGetId(context);
await ProcessGetId(context, endpoint.Mode);
}
else if (context.Request.Path.StartsWithSegments(path + "/send"))
{
@ -37,163 +45,176 @@ namespace Microsoft.AspNetCore.Sockets
}
else
{
// Get the end point mapped to this http connection
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
var format =
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
? Format.Binary
: Format.Text;
// Server sent events transport
if (context.Request.Path.StartsWithSegments(path + "/sse"))
{
// Get the connection state for the current http context
var state = GetOrCreateConnection(context);
state.Connection.User = context.User;
state.Connection.Metadata["transport"] = "sse";
state.Connection.Metadata.Format = format;
// TODO: this is wrong. + how does the user add their own metadata based on HttpContext
var formatType = (string)context.Request.Query["formatType"];
state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType;
var sse = new ServerSentEvents(state.Connection);
await DoPersistentConnection(endpoint, sse, context, state.Connection);
_manager.RemoveConnection(state.Connection.ConnectionId);
}
else if (context.Request.Path.StartsWithSegments(path + "/ws"))
{
// Get the connection state for the current http context
var state = GetOrCreateConnection(context);
state.Connection.User = context.User;
state.Connection.Metadata["transport"] = "websockets";
state.Connection.Metadata.Format = format;
// TODO: this is wrong. + how does the user add their own metadata based on HttpContext
var formatType = (string)context.Request.Query["formatType"];
state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType;
var ws = new WebSockets(state.Connection, format, _loggerFactory);
await DoPersistentConnection(endpoint, ws, context, state.Connection);
_manager.RemoveConnection(state.Connection.ConnectionId);
}
else if (context.Request.Path.StartsWithSegments(path + "/poll"))
{
bool isNewConnection;
var state = GetOrCreateConnection(context, out isNewConnection);
// TODO: this is wrong. + how does the user add their own metadata based on HttpContext
var formatType = (string)context.Request.Query["formatType"];
state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType;
// Mark the connection as active
state.Active = true;
RegisterLongPollingDisconnect(context, state.Connection);
var longPolling = new LongPolling(state.Connection);
// Start the transport
var transportTask = longPolling.ProcessRequestAsync(context);
Task endpointTask = null;
// Raise OnConnected for new connections only since polls happen all the time
if (isNewConnection)
{
state.Connection.Metadata["transport"] = "poll";
state.Connection.Metadata.Format = format;
state.Connection.User = context.User;
// REVIEW: This is super gross, this all needs to be cleaned up...
state.Close = async () =>
{
try
{
await endpointTask;
}
catch
{
// possibly invoked on a ThreadPool thread
}
state.Connection.Channel.Dispose();
};
endpointTask = endpoint.OnConnectedAsync(state.Connection);
state.Connection.Metadata["endpoint"] = endpointTask;
}
else
{
// Get the endpoint task from connection state
endpointTask = state.Connection.Metadata.Get<Task>("endpoint");
}
var resultTask = await Task.WhenAny(endpointTask, transportTask);
if (resultTask == endpointTask)
{
// Notify the long polling transport to end
if (endpointTask.IsFaulted)
{
state.Connection.Channel.Input.Complete(endpointTask.Exception.InnerException);
state.Connection.Channel.Output.Complete(endpointTask.Exception.InnerException);
}
state.Connection.Channel.Dispose();
await transportTask;
}
// Mark the connection as inactive
state.LastSeen = DateTimeOffset.UtcNow;
state.Active = false;
}
await ExecuteEndpointAsync(path, context, endpoint);
}
}
private async Task ExecuteEndpointAsync(string path, HttpContext context, EndPoint endpoint)
{
var format =
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
? Format.Binary
: Format.Text;
var state = GetOrCreateConnection(context, endpoint.Mode);
// Adapt the connection to a message-based transport if necessary, since all the HTTP transports are message-based.
var application = GetMessagingChannel(state, format);
// Server sent events transport
if (context.Request.Path.StartsWithSegments(path + "/sse"))
{
InitializePersistentConnection(state, "sse", context, endpoint, format);
// We only need to provide the Input channel since writing to the application is handled through /send.
var sse = new ServerSentEventsTransport(application.Input, _loggerFactory);
await DoPersistentConnection(endpoint, sse, context, state);
_manager.RemoveConnection(state.Connection.ConnectionId);
}
else if (context.Request.Path.StartsWithSegments(path + "/ws"))
{
InitializePersistentConnection(state, "websockets", context, endpoint, format);
var ws = new WebSocketsTransport(application, _loggerFactory);
await DoPersistentConnection(endpoint, ws, context, state);
_manager.RemoveConnection(state.Connection.ConnectionId);
}
else if (context.Request.Path.StartsWithSegments(path + "/poll"))
{
// TODO: this is wrong. + how does the user add their own metadata based on HttpContext
var formatType = (string)context.Request.Query["formatType"];
state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType;
// Mark the connection as active
state.Active = true;
var longPolling = new LongPollingTransport(application.Input, _loggerFactory);
RegisterLongPollingDisconnect(context, longPolling);
// Start the transport
var transportTask = longPolling.ProcessRequestAsync(context);
// Raise OnConnected for new connections only since polls happen all the time
var endpointTask = state.Connection.Metadata.Get<Task>("endpoint");
if (endpointTask == null)
{
_logger.LogDebug("Establishing new Long Polling connection: {0}", state.Connection.ConnectionId);
// This will re-initialize formatType metadata, but meh...
InitializePersistentConnection(state, "poll", context, endpoint, format);
// REVIEW: This is super gross, this all needs to be cleaned up...
state.Close = async () =>
{
try
{
await endpointTask;
}
catch
{
// possibly invoked on a ThreadPool thread
}
state.Connection.Dispose();
};
endpointTask = endpoint.OnConnectedAsync(state.Connection);
state.Connection.Metadata["endpoint"] = endpointTask;
}
else
{
_logger.LogDebug("Resuming existing Long Polling connection: {0}", state.Connection.ConnectionId);
}
var resultTask = await Task.WhenAny(endpointTask, transportTask);
if (resultTask == endpointTask)
{
// Notify the long polling transport to end
if (endpointTask.IsFaulted)
{
state.TerminateTransport(endpointTask.Exception.InnerException);
}
state.Connection.Dispose();
await transportTask;
}
// Mark the connection as inactive
state.LastSeenUtc = DateTime.UtcNow;
state.Active = false;
}
}
private static IChannelConnection<Message> GetMessagingChannel(ConnectionState state, Format format)
{
if (state.Connection.Mode == ConnectionMode.Messaging)
{
return ((MessagingConnectionState)state).Application;
}
else
{
// We need to build an adapter
return new FramingChannel(((StreamingConnectionState)state).Application, format);
}
}
private ConnectionState InitializePersistentConnection(ConnectionState state, string transport, HttpContext context, EndPoint endpoint, Format format)
{
state.Connection.User = context.User;
state.Connection.Metadata["transport"] = transport;
state.Connection.Metadata.Format = format;
// TODO: this is wrong. + how does the user add their own metadata based on HttpContext
var formatType = (string)context.Request.Query["formatType"];
state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType;
return state;
}
private static async Task DoPersistentConnection(EndPoint endpoint,
IHttpTransport transport,
HttpContext context,
Connection connection)
ConnectionState state)
{
// Register this transport for disconnect
RegisterDisconnect(context, connection);
RegisterDisconnect(context, state);
// Start the transport
var transportTask = transport.ProcessRequestAsync(context);
// Call into the end point passing the connection
var endpointTask = endpoint.OnConnectedAsync(connection);
var endpointTask = endpoint.OnConnectedAsync(state.Connection);
// Wait for any of them to end
await Task.WhenAny(endpointTask, transportTask);
// Kill the channel
connection.Channel.Dispose();
state.Dispose();
// Wait for both
await Task.WhenAll(endpointTask, transportTask);
}
private static void RegisterLongPollingDisconnect(HttpContext context, Connection connection)
private static void RegisterLongPollingDisconnect(HttpContext context, LongPollingTransport transport)
{
// For long polling, we need to end the transport but not the overall connection so we write 0 bytes
context.RequestAborted.Register(state => ((HttpConnection)state).Output.WriteAsync(Span<byte>.Empty), connection.Channel);
context.RequestAborted.Register(state => ((LongPollingTransport)state).Cancel(), transport);
}
private static void RegisterDisconnect(HttpContext context, Connection connection)
private static void RegisterDisconnect(HttpContext context, ConnectionState connectionState)
{
// We just kill the output writing as a signal to the transport that it is done
context.RequestAborted.Register(state => ((HttpConnection)state).Output.CompleteWriter(), connection.Channel);
context.RequestAborted.Register(state => ((ConnectionState)state).Dispose(), connectionState);
}
private Task ProcessGetId(HttpContext context)
private Task ProcessGetId(HttpContext context, ConnectionMode mode)
{
// Reserve an id for this connection
var state = _manager.ReserveConnection();
// Establish the connection
var state = _manager.CreateConnection(mode);
// Get the bytes for the connection id
var connectionIdBuffer = Encoding.UTF8.GetBytes(state.Connection.ConnectionId);
@ -203,7 +224,7 @@ namespace Microsoft.AspNetCore.Sockets
return context.Response.Body.WriteAsync(connectionIdBuffer, 0, connectionIdBuffer.Length);
}
private Task ProcessSend(HttpContext context)
private async Task ProcessSend(HttpContext context)
{
var connectionId = context.Request.Query["id"];
if (StringValues.IsNullOrEmpty(connectionId))
@ -214,58 +235,54 @@ namespace Microsoft.AspNetCore.Sockets
ConnectionState state;
if (_manager.TryGetConnection(connectionId, out state))
{
// If we received an HTTP POST for the connection id and it's not an HttpChannel then fail.
// You can't write to a TCP channel directly from here.
var httpChannel = state.Connection.Channel as HttpConnection;
if (httpChannel == null)
if (state.Connection.Mode == ConnectionMode.Streaming)
{
throw new InvalidOperationException("No channel");
var streamingState = (StreamingConnectionState)state;
await context.Request.Body.CopyToAsync(streamingState.Application.Output);
}
else
{
// Collect the message and write it to the channel
// TODO: Need to use some kind of pooled memory here.
byte[] buffer;
using (var strm = new MemoryStream())
{
await context.Request.Body.CopyToAsync(strm);
await strm.FlushAsync();
buffer = strm.ToArray();
}
return context.Request.Body.CopyToAsync(httpChannel.Input);
}
throw new InvalidOperationException("Unknown connection id");
}
private ConnectionState GetOrCreateConnection(HttpContext context)
{
bool isNewConnection;
return GetOrCreateConnection(context, out isNewConnection);
}
private ConnectionState GetOrCreateConnection(HttpContext context, out bool isNewConnection)
{
var connectionId = context.Request.Query["id"];
ConnectionState connectionState;
isNewConnection = false;
// There's no connection id so this is a branch new connection
if (StringValues.IsNullOrEmpty(connectionId))
{
isNewConnection = true;
var channel = new HttpConnection(_pipelineFactory);
connectionState = _manager.AddNewConnection(channel);
var format =
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
? Format.Binary
: Format.Text;
var message = new Message(
ReadableBuffer.Create(buffer).Preserve(),
format,
endOfMessage: true);
await ((MessagingConnectionState)state).Application.Output.WriteAsync(message);
}
}
else
{
// REVIEW: Fail if not reserved? Reused an existing connection id?
throw new InvalidOperationException("Unknown connection id");
}
}
// There's a connection id
if (!_manager.TryGetConnection(connectionId, out connectionState))
{
throw new InvalidOperationException("Unknown connection id");
}
private ConnectionState GetOrCreateConnection(HttpContext context, ConnectionMode mode)
{
var connectionId = context.Request.Query["id"];
ConnectionState connectionState;
// Reserved connection, we need to provide a channel
if (connectionState.Connection.Channel == null)
{
isNewConnection = true;
connectionState.Connection.Channel = new HttpConnection(_pipelineFactory);
connectionState.Active = true;
connectionState.LastSeen = DateTimeOffset.UtcNow;
}
// There's no connection id so this is a brand new connection
if (StringValues.IsNullOrEmpty(connectionId))
{
connectionState = _manager.CreateConnection(mode);
}
else if (!_manager.TryGetConnection(connectionId, out connectionState))
{
throw new InvalidOperationException("Unknown connection id");
}
return connectionState;

View File

@ -0,0 +1,16 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks.Channels;
namespace Microsoft.AspNetCore.Sockets
{
// REVIEW: These should probably move to Channels. Why not use IChannel? Because I think it's better to be clear that this is providing
// access to two separate channels, the read end for one and the write end for the other.
public interface IChannelConnection<T> : IDisposable
{
IReadableChannel<T> Input { get; }
IWritableChannel<T> Output { get; }
}
}

View File

@ -0,0 +1,30 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public class ChannelConnection<T> : IChannelConnection<T>
{
public IReadableChannel<T> Input { get; }
public IWritableChannel<T> Output { get; }
public ChannelConnection(IReadableChannel<T> input, IWritableChannel<T> output)
{
Input = input;
Output = output;
}
public void Dispose()
{
Output.Complete();
(Input as IDisposable)?.Dispose();
(Output as IDisposable)?.Dispose();
}
}
}

View File

@ -0,0 +1,28 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public abstract class ConnectionState : IDisposable
{
public Connection Connection { get; set; }
public ConnectionMode Mode => Connection.Mode;
// These are used for long polling mostly
public Action Close { get; set; }
public DateTime LastSeenUtc { get; set; }
public bool Active { get; set; } = true;
protected ConnectionState(Connection connection)
{
Connection = connection;
LastSeenUtc = DateTime.UtcNow;
}
public abstract void Dispose();
public abstract void TerminateTransport(Exception innerException);
}
}

View File

@ -0,0 +1,119 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
namespace Microsoft.AspNetCore.Sockets.Internal
{
/// <summary>
/// Creates a <see cref="IChannelConnection{Message}"/> out of a <see cref="IPipelineConnection"/> by framing data
/// read out of the Pipeline and flattening out frames to write them to the Pipeline when received.
/// </summary>
public class FramingChannel : IChannelConnection<Message>, IReadableChannel<Message>, IWritableChannel<Message>
{
private readonly IPipelineConnection _connection;
private readonly TaskCompletionSource<object> _tcs = new TaskCompletionSource<object>();
private readonly Format _format;
Task IReadableChannel<Message>.Completion => _tcs.Task;
public IReadableChannel<Message> Input => this;
public IWritableChannel<Message> Output => this;
public FramingChannel(IPipelineConnection connection, Format format)
{
_connection = connection;
_format = format;
}
ValueTask<Message> IReadableChannel<Message>.ReadAsync(CancellationToken cancellationToken)
{
var awaiter = _connection.Input.ReadAsync();
if (awaiter.IsCompleted)
{
return new ValueTask<Message>(ReadSync(awaiter.GetResult(), cancellationToken));
}
else
{
return new ValueTask<Message>(AwaitReadAsync(awaiter, cancellationToken));
}
}
bool IReadableChannel<Message>.TryRead(out Message item)
{
// We need to think about how we do this. There's no way to check if there is data available in a Pipeline... though maybe there should be
// We could ReadAsync and check IsCompleted, but then we'd also need to stash that Awaitable for later since we can't call ReadAsync a second time...
// CancelPendingReads could help here.
item = default(Message);
return false;
}
Task<bool> IReadableChannel<Message>.WaitToReadAsync(CancellationToken cancellationToken)
{
// See above for TryRead. Same problems here.
throw new NotSupportedException();
}
Task IWritableChannel<Message>.WriteAsync(Message item, CancellationToken cancellationToken)
{
// Just dump the message on to the pipeline
var buffer = _connection.Output.Alloc();
buffer.Append(item.Payload.Buffer);
return buffer.FlushAsync();
}
Task<bool> IWritableChannel<Message>.WaitToWriteAsync(CancellationToken cancellationToken)
{
// We need to think about how we do this. We don't have a wait to synchronously check for back-pressure in the Pipeline.
throw new NotSupportedException();
}
bool IWritableChannel<Message>.TryWrite(Message item)
{
// We need to think about how we do this. We don't have a wait to synchronously check for back-pressure in the Pipeline.
return false;
}
bool IWritableChannel<Message>.TryComplete(Exception error)
{
_connection.Output.Complete(error);
return true;
}
private async Task<Message> AwaitReadAsync(ReadableBufferAwaitable awaiter, CancellationToken cancellationToken)
{
// Just await and then call ReadSync
var result = await awaiter;
return ReadSync(result, cancellationToken);
}
private Message ReadSync(ReadResult result, CancellationToken cancellationToken)
{
var buffer = result.Buffer;
// Preserve the buffer and advance the pipeline past it
var preserved = buffer.Preserve();
_connection.Input.Advance(buffer.End);
var msg = new Message(preserved, _format, endOfMessage: true);
if (result.IsCompleted)
{
// Complete the task
_tcs.TrySetResult(null);
}
return msg;
}
public void Dispose()
{
_tcs.TrySetResult(null);
_connection.Dispose();
}
}
}

View File

@ -0,0 +1,29 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public class MessagingConnectionState : ConnectionState
{
public new MessagingConnection Connection => (MessagingConnection)base.Connection;
public IChannelConnection<Message> Application { get; }
public MessagingConnectionState(MessagingConnection connection, IChannelConnection<Message> application) : base(connection)
{
Application = application;
}
public override void Dispose()
{
Connection.Dispose();
Application.Dispose();
}
public override void TerminateTransport(Exception innerException)
{
Connection.Transport.Output.TryComplete(innerException);
}
}
}

View File

@ -0,0 +1,25 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public class PipelineConnection : IPipelineConnection
{
public IPipelineReader Input { get; }
public IPipelineWriter Output { get; }
public PipelineConnection(IPipelineReader input, IPipelineWriter output)
{
Input = input;
Output = output;
}
public void Dispose()
{
Input.Complete();
Output.Complete();
}
}
}

View File

@ -0,0 +1,31 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Sockets.Internal
{
public class StreamingConnectionState : ConnectionState
{
public new StreamingConnection Connection => (StreamingConnection)base.Connection;
public IPipelineConnection Application { get; }
public StreamingConnectionState(StreamingConnection connection, IPipelineConnection application) : base(connection)
{
Application = application;
}
public override void Dispose()
{
Connection.Dispose();
Application.Dispose();
}
public override void TerminateTransport(Exception innerException)
{
Connection.Transport.Output.Complete(innerException);
Connection.Transport.Input.Complete(innerException);
}
}
}

View File

@ -1,48 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
namespace Microsoft.AspNetCore.Sockets
{
public class LongPolling : IHttpTransport
{
private readonly HttpConnection _channel;
private readonly Connection _connection;
public LongPolling(Connection connection)
{
_connection = connection;
_channel = (HttpConnection)connection.Channel;
}
public async Task ProcessRequestAsync(HttpContext context)
{
var result = await _channel.Output.ReadAsync();
var buffer = result.Buffer;
if (buffer.IsEmpty && result.IsCompleted)
{
// Client should stop if it receives a 204
context.Response.StatusCode = 204;
return;
}
if (!buffer.IsEmpty)
{
try
{
context.Response.ContentLength = buffer.Length;
await buffer.CopyToAsync(context.Response.Body);
}
finally
{
_channel.Output.Advance(buffer.End);
}
}
}
}
}

View File

@ -0,0 +1,27 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Sockets
{
public struct Message : IDisposable
{
public bool EndOfMessage { get; }
public Format MessageFormat { get; }
public PreservedBuffer Payload { get; }
public Message(PreservedBuffer payload, Format messageFormat, bool endOfMessage)
{
MessageFormat = messageFormat;
EndOfMessage = endOfMessage;
Payload = payload;
}
public void Dispose()
{
Payload.Dispose();
}
}
}

View File

@ -0,0 +1,23 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.Sockets
{
public class MessagingConnection : Connection
{
public override ConnectionMode Mode => ConnectionMode.Messaging;
public IChannelConnection<Message> Transport { get; }
public MessagingConnection(string id, IChannelConnection<Message> transport) : base(id)
{
Transport = transport;
}
public override void Dispose()
{
Transport.Dispose();
}
}
}

View File

@ -0,0 +1,29 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets
{
public abstract class MessagingEndPoint : EndPoint
{
public override ConnectionMode Mode => ConnectionMode.Messaging;
public override Task OnConnectedAsync(Connection connection)
{
if (connection.Mode != Mode)
{
throw new InvalidOperationException($"Connection mode does not match endpoint mode. Connection mode is '{connection.Mode}', endpoint mode is '{Mode}'");
}
return OnConnectedAsync((MessagingConnection)connection);
}
/// <summary>
/// Called when a new connection is accepted to the endpoint
/// </summary>
/// <param name="connection">The new <see cref="MessagingConnection"/></param>
/// <returns>A <see cref="Task"/> that represents the connection lifetime. When the task completes, the connection is complete.</returns>
public abstract Task OnConnectedAsync(MessagingConnection connection);
}
}

View File

@ -1,65 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
namespace Microsoft.AspNetCore.Sockets
{
public class ServerSentEvents : IHttpTransport
{
private readonly HttpConnection _channel;
private readonly Connection _connection;
public ServerSentEvents(Connection connection)
{
_connection = connection;
_channel = (HttpConnection)connection.Channel;
}
public async Task ProcessRequestAsync(HttpContext context)
{
context.Response.ContentType = "text/event-stream";
context.Response.Headers["Cache-Control"] = "no-cache";
context.Response.Headers["Content-Encoding"] = "identity";
await context.Response.Body.FlushAsync();
while (true)
{
var result = await _channel.Output.ReadAsync();
var buffer = result.Buffer;
if (buffer.IsEmpty && result.IsCompleted)
{
break;
}
await Send(context, buffer);
_channel.Output.Advance(buffer.End);
}
}
private async Task Send(HttpContext context, ReadableBuffer data)
{
// TODO: Pooled buffers
// 8 = 6(data: ) + 2 (\n\n)
var buffer = new byte[8 + data.Length];
var at = 0;
buffer[at++] = (byte)'d';
buffer[at++] = (byte)'a';
buffer[at++] = (byte)'t';
buffer[at++] = (byte)'a';
buffer[at++] = (byte)':';
buffer[at++] = (byte)' ';
data.CopyTo(new Span<byte>(buffer, at, data.Length));
at += data.Length;
buffer[at++] = (byte)'\n';
buffer[at++] = (byte)'\n';
await context.Response.Body.WriteAsync(buffer, 0, at);
await context.Response.Body.FlushAsync();
}
}
}

View File

@ -0,0 +1,24 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Sockets
{
public class StreamingConnection : Connection
{
public override ConnectionMode Mode => ConnectionMode.Streaming;
public IPipelineConnection Transport { get; set; }
public StreamingConnection(string id, IPipelineConnection transport) : base(id)
{
Transport = transport;
}
public override void Dispose()
{
Transport.Dispose();
}
}
}

View File

@ -0,0 +1,29 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets
{
public abstract class StreamingEndPoint : EndPoint
{
public override ConnectionMode Mode => ConnectionMode.Streaming;
public override Task OnConnectedAsync(Connection connection)
{
if(connection.Mode != Mode)
{
throw new InvalidOperationException($"Connection mode does not match endpoint mode. Connection mode is '{connection.Mode}', endpoint mode is '{Mode}'");
}
return OnConnectedAsync((StreamingConnection)connection);
}
/// <summary>
/// Called when a new connection is accepted to the endpoint
/// </summary>
/// <param name="connection">The new <see cref="StreamingConnection"/></param>
/// <returns>A <see cref="Task"/> that represents the connection lifetime. When the task completes, the connection is complete.</returns>
public abstract Task OnConnectedAsync(StreamingConnection connection);
}
}

View File

@ -1,13 +1,10 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
namespace Microsoft.AspNetCore.Sockets
namespace Microsoft.AspNetCore.Sockets.Transports
{
public interface IHttpTransport
{

View File

@ -0,0 +1,67 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Sockets.Transports
{
public class LongPollingTransport : IHttpTransport
{
private readonly IReadableChannel<Message> _connection;
private CancellationTokenSource _cancellationSource = new CancellationTokenSource();
private readonly ILogger _logger;
public LongPollingTransport(IReadableChannel<Message> connection, ILoggerFactory loggerFactory)
{
_connection = connection;
_logger = loggerFactory.CreateLogger<LongPollingTransport>();
}
public async Task ProcessRequestAsync(HttpContext context)
{
if (_connection.Completion.IsCompleted)
{
// Client should stop if it receives a 204
_logger.LogInformation("Terminating Long Polling connection by sending 204 response.");
context.Response.StatusCode = 204;
return;
}
try
{
using (var message = await _connection.ReadAsync(_cancellationSource.Token))
{
_logger.LogDebug("Writing {0} byte message to response", message.Payload.Buffer.Length);
context.Response.ContentLength = message.Payload.Buffer.Length;
await message.Payload.Buffer.CopyToAsync(context.Response.Body);
}
}
catch (Exception ex) when (ex.GetType().IsNested && ex.GetType().DeclaringType == typeof(Channel))
{
// The Channel was closed, while we were waiting to read. That's fine, just means we're done.
// Gross that we have to catch this this way. See https://github.com/dotnet/corefxlab/issues/1068
}
catch (OperationCanceledException)
{
// Suppress the exception
_logger.LogDebug("Client disconnected from Long Polling endpoint.");
}
catch (Exception ex)
{
_logger.LogError("Error reading next message from Application: {0}", ex);
throw;
}
}
public void Cancel()
{
_cancellationSource.Cancel();
}
}
}

View File

@ -0,0 +1,68 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Sockets.Transports
{
public class ServerSentEventsTransport : IHttpTransport
{
private readonly IReadableChannel<Message> _application;
private readonly ILogger _logger;
public ServerSentEventsTransport(IReadableChannel<Message> application, ILoggerFactory loggerFactory)
{
_application = application;
_logger = loggerFactory.CreateLogger<ServerSentEventsTransport>();
}
public async Task ProcessRequestAsync(HttpContext context)
{
context.Response.ContentType = "text/event-stream";
context.Response.Headers["Cache-Control"] = "no-cache";
context.Response.Headers["Content-Encoding"] = "identity";
await context.Response.Body.FlushAsync();
try
{
while (true)
{
using (var message = await _application.ReadAsync())
{
await Send(context, message);
}
}
}
catch (Exception ex) when (ex.GetType().IsNested && ex.GetType().DeclaringType == typeof(Channel))
{
// Gross that we have to catch this this way. See https://github.com/dotnet/corefxlab/issues/1068
}
}
private async Task Send(HttpContext context, Message message)
{
// TODO: Pooled buffers
// 8 = 6(data: ) + 2 (\n\n)
_logger.LogDebug("Sending {0} byte message to Server-Sent Events client", message.Payload.Buffer.Length);
var buffer = new byte[8 + message.Payload.Buffer.Length];
var at = 0;
buffer[at++] = (byte)'d';
buffer[at++] = (byte)'a';
buffer[at++] = (byte)'t';
buffer[at++] = (byte)'a';
buffer[at++] = (byte)':';
buffer[at++] = (byte)' ';
message.Payload.Buffer.CopyTo(new Span<byte>(buffer, at, message.Payload.Buffer.Length));
at += message.Payload.Buffer.Length;
buffer[at++] = (byte)'\n';
buffer[at++] = (byte)'\n';
await context.Response.Body.WriteAsync(buffer, 0, at);
await context.Response.Body.FlushAsync();
}
}
}

View File

@ -3,26 +3,28 @@
using System;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.WebSockets.Internal;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.WebSockets.Internal;
namespace Microsoft.AspNetCore.Sockets
namespace Microsoft.AspNetCore.Sockets.Transports
{
public class WebSockets : IHttpTransport
public class WebSocketsTransport : IHttpTransport
{
private static readonly TimeSpan _closeTimeout = TimeSpan.FromSeconds(5);
private static readonly WebSocketAcceptContext EmptyContext = new WebSocketAcceptContext();
private readonly HttpConnection _channel;
private readonly WebSocketOpcode _opcode;
private readonly ILogger _logger;
private WebSocketOpcode _lastOpcode = WebSocketOpcode.Continuation;
private bool _lastFrameIncomplete = false;
public WebSockets(Connection connection, Format format, ILoggerFactory loggerFactory)
private readonly ILogger _logger;
private readonly IChannelConnection<Message> _connection;
public WebSocketsTransport(IChannelConnection<Message> connection, ILoggerFactory loggerFactory)
{
if (connection == null)
{
@ -33,9 +35,8 @@ namespace Microsoft.AspNetCore.Sockets
throw new ArgumentNullException(nameof(loggerFactory));
}
_channel = (HttpConnection)connection.Channel;
_opcode = format == Format.Binary ? WebSocketOpcode.Binary : WebSocketOpcode.Text;
_logger = loggerFactory.CreateLogger<WebSockets>();
_connection = connection;
_logger = loggerFactory.CreateLogger<WebSocketsTransport>();
}
public async Task ProcessRequestAsync(HttpContext context)
@ -59,7 +60,7 @@ namespace Microsoft.AspNetCore.Sockets
public async Task ProcessSocketAsync(IWebSocketConnection socket)
{
// Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync.
var receiving = socket.ExecuteAsync((frame, state) => ((WebSockets)state).HandleFrame(frame), this);
var receiving = socket.ExecuteAsync((frame, state) => ((WebSocketsTransport)state).HandleFrame(frame), this);
var sending = StartSending(socket);
// Wait for something to shut down.
@ -83,7 +84,7 @@ namespace Microsoft.AspNetCore.Sockets
// Shutting down because we received a close frame from the client.
// Complete the input writer so that the application knows there won't be any more input.
_logger.LogDebug("Client closed connection with status code '{0}' ({1}). Signaling end-of-input to application", receiving.Result.Status, receiving.Result.Description);
_channel.Input.CompleteWriter();
_connection.Output.TryComplete();
// Wait for the application to finish sending.
_logger.LogDebug("Waiting for the application to finish sending data");
@ -100,12 +101,15 @@ namespace Microsoft.AspNetCore.Sockets
_logger.LogDebug(!failed ? "Application finished sending. Sending close frame." : "Application failed during sending. Sending InternalServerError close frame");
await socket.CloseAsync(!failed ? WebSocketCloseStatus.NormalClosure : WebSocketCloseStatus.InternalServerError);
// Now trigger the exception from the application, if there was one.
sending.GetAwaiter().GetResult();
_logger.LogDebug("Waiting for the client to close the socket");
// Wait for the client to close.
// TODO: Consider timing out here and cancelling the receive loop.
await receiving;
_channel.Input.CompleteWriter();
_connection.Output.TryComplete();
}
}
@ -119,14 +123,22 @@ namespace Microsoft.AspNetCore.Sockets
LogFrame("Receiving", frame);
// Allocate space from the input channel
var outputBuffer = _channel.Input.Alloc();
// Determine the effective opcode based on the continuation.
var effectiveOpcode = frame.Opcode;
if (frame.Opcode == WebSocketOpcode.Continuation)
{
effectiveOpcode = _lastOpcode;
}
else
{
_lastOpcode = frame.Opcode;
}
// Append this buffer to the input channel
_logger.LogDebug($"Appending {frame.Payload.Length} bytes to Connection channel");
outputBuffer.Append(frame.Payload);
// Create a Message for the frame
var message = new Message(frame.Payload.Preserve(), effectiveOpcode == WebSocketOpcode.Binary ? Format.Binary : Format.Text, frame.EndOfMessage);
return outputBuffer.FlushAsync();
// Write the message to the channel
return _connection.Output.WriteAsync(message);
}
private void LogFrame(string action, WebSocketFrame frame)
@ -140,43 +152,43 @@ namespace Microsoft.AspNetCore.Sockets
private async Task StartSending(IWebSocketConnection ws)
{
try
while (!_connection.Input.Completion.IsCompleted)
{
while (true)
// Get a frame from the application
try
{
var result = await _channel.Output.ReadAsync();
var buffer = result.Buffer;
try
using (var message = await _connection.Input.ReadAsync())
{
if (buffer.IsEmpty && result.IsCompleted)
if (message.Payload.Buffer.Length > 0)
{
break;
}
try
{
var opcode = message.MessageFormat == Format.Binary ?
WebSocketOpcode.Binary :
WebSocketOpcode.Text;
// Send the buffer in a frame
var frame = new WebSocketFrame(
endOfMessage: true,
opcode: _opcode,
payload: buffer);
LogFrame("Sending", frame);
await ws.SendAsync(frame);
}
catch (Exception ex)
{
_logger.LogError("Error writing frame to output: {0}", ex);
break;
}
finally
{
_channel.Output.Advance(buffer.End);
var frame = new WebSocketFrame(
endOfMessage: message.EndOfMessage,
opcode: _lastFrameIncomplete ? WebSocketOpcode.Continuation : opcode,
payload: message.Payload.Buffer);
_lastFrameIncomplete = !message.EndOfMessage;
LogFrame("Sending", frame);
await ws.SendAsync(frame);
}
catch (Exception ex)
{
_logger.LogError("Error writing frame to output: {0}", ex);
break;
}
}
}
}
}
finally
{
// No longer reading from the channel
_channel.Output.CompleteReader();
catch (Exception ex) when (ex.GetType().IsNested && ex.GetType().DeclaringType == typeof(Channel))
{
// Gross that we have to catch this this way. See https://github.com/dotnet/corefxlab/issues/1068
}
}
}
}

View File

@ -21,6 +21,11 @@
},
"dependencies": {
"System.IO.Pipelines": "0.1.0-*",
"System.Threading.Tasks.Channels": "0.1.0-*",
"System.Security.Claims": "4.4.0-*",
"System.Reflection.TypeExtensions": "4.4.0-*",
"Microsoft.AspNetCore.Hosting.Abstractions": "1.2.0-*",
"Microsoft.AspNetCore.Routing": "1.2.0-*",
"Microsoft.AspNetCore.WebSockets.Internal": "0.1.0-*",
@ -31,7 +36,6 @@
"NETStandard.Library": "1.6.2-*"
},
"frameworks": {
"netstandard1.3": {},
"net451": {}
"netstandard1.3": {}
}
}
}

View File

@ -28,7 +28,6 @@
"NETStandard.Library": "1.6.2-*"
},
"frameworks": {
"net451": {},
"netstandard1.3": {}
}
}

View File

@ -30,7 +30,6 @@
"NETStandard.Library": "1.6.2-*"
},
"frameworks": {
"netstandard1.3": {},
"net451": {}
"netstandard1.3": {}
}
}

View File

@ -17,6 +17,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
public class HubConnectionTests : IDisposable
{
private readonly TestServer _testServer;
private static readonly bool _verbose = string.Equals(Environment.GetEnvironmentVariable("SIGNALR_TEST_VERBOSE"), "1");
public HubConnectionTests()
{
@ -25,6 +26,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
{
services.AddSignalR();
})
.ConfigureLogging(loggerFactory =>
{
if (_verbose)
{
loggerFactory.AddConsole();
}
})
.Configure(app =>
{
app.UseSignalR(routes =>
@ -38,7 +46,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
[Fact]
public async Task CheckFixedMessage()
{
var loggerFactory = new LoggerFactory();
var loggerFactory = CreateLogger();
using (var httpClient = _testServer.CreateClient())
using (var pipelineFactory = new PipelineFactory())
@ -48,6 +56,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
{
//TODO: Get rid of this. This is to prevent "No channel" failures due to sends occuring before the first poll.
await Task.Delay(500);
EnsureConnectionEstablished(connection);
var result = await connection.Invoke<string>("HelloWorld");
Assert.Equal("Hello World!", result);
@ -58,7 +68,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
[Fact]
public async Task CanSendAndReceiveMessage()
{
var loggerFactory = new LoggerFactory();
var loggerFactory = CreateLogger();
const string originalMessage = "SignalR";
using (var httpClient = _testServer.CreateClient())
@ -69,6 +79,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
{
//TODO: Get rid of this. This is to prevent "No channel" failures due to sends occuring before the first poll.
await Task.Delay(500);
EnsureConnectionEstablished(connection);
var result = await connection.Invoke<string>("Echo", originalMessage);
Assert.Equal(originalMessage, result);
@ -79,7 +91,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
[Fact]
public async Task CanInvokeClientMethodFromServer()
{
var loggerFactory = new LoggerFactory();
var loggerFactory = CreateLogger();
const string originalMessage = "SignalR";
using (var httpClient = _testServer.CreateClient())
@ -96,6 +108,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
//TODO: Get rid of this. This is to prevent "No channel" failures due to sends occuring before the first poll.
await Task.Delay(500);
EnsureConnectionEstablished(connection);
await connection.Invoke<Task>("CallEcho", originalMessage);
var completed = await Task.WhenAny(Task.Delay(2000), tcs.Task);
Assert.True(completed == tcs.Task, "Receive timed out!");
@ -107,7 +121,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
[Fact]
public async Task ServerClosesConnectionIfHubMethodCannotBeResolved()
{
var loggerFactory = new LoggerFactory();
var loggerFactory = CreateLogger();
using (var httpClient = _testServer.CreateClient())
using (var pipelineFactory = new PipelineFactory())
@ -118,6 +132,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
//TODO: Get rid of this. This is to prevent "No channel" failures due to sends occuring before the first poll.
await Task.Delay(500);
EnsureConnectionEstablished(connection);
var ex = await Assert.ThrowsAnyAsync<InvalidOperationException>(
async () => await connection.Invoke<Task>("!@#$%"));
@ -126,11 +142,27 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
}
private static void EnsureConnectionEstablished(HubConnection connection)
{
if (connection.Completion.IsCompleted)
{
connection.Completion.GetAwaiter().GetResult();
}
}
public void Dispose()
{
_testServer.Dispose();
}
private static LoggerFactory CreateLogger()
{
var loggerFactory = new LoggerFactory();
loggerFactory.AddConsole(_verbose ? LogLevel.Trace : LogLevel.Error);
return loggerFactory;
}
public class TestHub : Hub
{
public string HelloWorld()

View File

@ -1,18 +1,20 @@
{
{
"buildOptions": {
"warningsAsErrors": true
},
"dependencies": {
"dotnet-test-xunit": "2.2.0-*",
"Microsoft.AspNetCore.Http": "1.2.0-*",
"Microsoft.AspNetCore.Sockets": "0.1.0-*",
"Microsoft.AspNetCore.Diagnostics": "1.2.0-*",
"Microsoft.AspNetCore.Hosting": "1.2.0-*",
"Microsoft.AspNetCore.Http": "1.2.0-*",
"Microsoft.AspNetCore.Server.Kestrel": "1.2.0-*",
"Microsoft.AspNetCore.SignalR": "1.0.0-*",
"Microsoft.AspNetCore.SignalR.Client": "1.0.0-*",
"xunit": "2.2.0-*",
"Microsoft.AspNetCore.TestHost": "1.2.0-*"
"Microsoft.AspNetCore.Sockets": "0.1.0-*",
"Microsoft.AspNetCore.TestHost": "1.2.0-*",
"Microsoft.Extensions.Logging.Console": "1.2.0-*",
"xunit": "2.2.0-*"
},
"frameworks": {
"netcoreapp1.1": {
@ -23,7 +25,7 @@
}
}
},
"net451": {}
"net46": {}
},
"testRunner": "xunit"
}

View File

@ -8,11 +8,11 @@ using Microsoft.AspNetCore.Sockets;
namespace Microsoft.AspNetCore.SignalR.Test.Server
{
public class EchoEndPoint : EndPoint
public class EchoEndPoint : StreamingEndPoint
{
public async override Task OnConnectedAsync(Connection connection)
public async override Task OnConnectedAsync(StreamingConnection connection)
{
await connection.Channel.Input.CopyToAsync(connection.Channel.Output);
await connection.Transport.Input.CopyToAsync(connection.Transport.Output);
}
}
}

View File

@ -1,9 +1,5 @@
{
"dependencies": {
"Microsoft.NETCore.App": {
"version": "1.2.0-*",
"type": "platform"
},
"Microsoft.AspNetCore.Diagnostics": "1.2.0-*",
"Microsoft.AspNetCore.Server.IISIntegration": "1.2.0-*",
"Microsoft.AspNetCore.Server.Kestrel": "1.2.0-*",
@ -15,7 +11,15 @@
"Microsoft.AspNetCore.Server.IISIntegration.Tools": "1.0.0-preview2-final"
},
"frameworks": {
"netcoreapp1.1": {}
"netcoreapp1.1": {
"dependencies": {
"Microsoft.NETCore.App": {
"version": "1.2.0-*",
"type": "platform"
}
}
},
"net46": { }
},
"buildOptions": {
"emitEntryPoint": true,
@ -33,9 +37,11 @@
]
},
"scripts": {
"precompile": [ "npm install",
"precompile": [
"npm install",
"npm run gulp -- --gulpfile %project:Directory%/gulpfile.js copy-jasmine",
"npm run gulp -- --gulpfile %project:Directory%/../../src/Microsoft.AspNetCore.SignalR.Client.TS/gulpfile.js bundle-client --bundleOutDir %project:Directory%/wwwroot/lib/signalr-client/" ],
"npm run gulp -- --gulpfile %project:Directory%/../../src/Microsoft.AspNetCore.SignalR.Client.TS/gulpfile.js bundle-client --bundleOutDir %project:Directory%/wwwroot/lib/signalr-client/"
],
"postpublish": [ "dotnet publish-iis --publish-folder %publish:OutputPath% --framework %publish:FullTargetFramework%" ]
}
}
}

View File

@ -7,6 +7,7 @@ using System.Security.Claims;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.DependencyInjection;
using Moq;
using Xunit;
@ -26,10 +27,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
await connectionWrapper.HttpConnection.Input.ReadingStarted;
await connectionWrapper.ApplicationStartedReading;
// kill the connection
connectionWrapper.Connection.Channel.Dispose();
connectionWrapper.ConnectionState.Dispose();
await endPointTask;
@ -55,13 +56,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
await connectionWrapper.HttpConnection.Input.ReadingStarted;
await connectionWrapper.ApplicationStartedReading;
var buffer = connectionWrapper.HttpConnection.Input.Alloc();
var buffer = connectionWrapper.ConnectionState.Application.Output.Alloc();
buffer.Write(Encoding.UTF8.GetBytes("0xdeadbeef"));
await buffer.FlushAsync();
connectionWrapper.Connection.Channel.Dispose();
connectionWrapper.Dispose();
// InvalidCastException because the payload is not a JObject
// which is expected by the formatter
@ -76,7 +77,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var mockLifetimeManager = new Mock<HubLifetimeManager<Hub>>();
mockLifetimeManager
.Setup(m => m.OnConnectedAsync(It.IsAny<Connection>()))
.Setup(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()))
.Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed."));
var mockHubActivator = new Mock<IHubActivator<Hub, IClientProxy>>();
@ -95,10 +96,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
async () => await endPoint.OnConnectedAsync(connectionWrapper.Connection));
Assert.Equal("Lifetime manager OnConnectedAsync failed.", exception.Message);
connectionWrapper.Connection.Channel.Dispose();
connectionWrapper.ConnectionState.Dispose();
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
// No hubs should be created since the connection is terminated
mockHubActivator.Verify(m => m.Create(), Times.Never);
mockHubActivator.Verify(m => m.Release(It.IsAny<Hub>()), Times.Never);
@ -119,13 +120,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests
using (var connectionWrapper = new ConnectionWrapper())
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
connectionWrapper.Connection.Channel.Dispose();
connectionWrapper.ConnectionState.Dispose();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
Assert.Equal("Hub OnConnected failed.", exception.Message);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
}
}
@ -143,13 +144,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests
using (var connectionWrapper = new ConnectionWrapper())
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
connectionWrapper.Connection.Channel.Dispose();
connectionWrapper.ConnectionState.Dispose();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask);
Assert.Equal("Hub OnDisconnected failed.", exception.Message);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<StreamingConnection>()), Times.Once);
}
}
@ -222,27 +223,28 @@ namespace Microsoft.AspNetCore.SignalR.Tests
private class ConnectionWrapper : IDisposable
{
private PipelineFactory _factory;
private HttpConnection _httpConnection;
public Connection Connection;
public HttpConnection HttpConnection => (HttpConnection)Connection.Channel;
public StreamingConnectionState ConnectionState;
public StreamingConnection Connection => ConnectionState.Connection;
// Still kinda gross...
public Task ApplicationStartedReading => ((PipelineReaderWriter)Connection.Transport.Input).ReadingStarted;
public ConnectionWrapper(string format = "json")
{
_factory = new PipelineFactory();
_httpConnection = new HttpConnection(_factory);
var connectionManager = new ConnectionManager();
var connectionManager = new ConnectionManager(_factory);
Connection = connectionManager.AddNewConnection(_httpConnection).Connection;
Connection.Metadata["formatType"] = format;
Connection.User = new ClaimsPrincipal(new ClaimsIdentity());
ConnectionState = (StreamingConnectionState)connectionManager.CreateConnection(ConnectionMode.Streaming);
ConnectionState.Connection.Metadata["formatType"] = format;
ConnectionState.Connection.User = new ClaimsPrincipal(new ClaimsIdentity());
}
public void Dispose()
{
Connection.Channel.Dispose();
_httpConnection.Dispose();
ConnectionState.Dispose();
_factory.Dispose();
}
}

View File

@ -22,7 +22,11 @@
}
}
},
"net451": { }
"net46": {
"dependencies": {
"System.Security.Claims": "4.0.1"
}
}
},
"testRunner": "xunit"
}

View File

@ -1,4 +1,7 @@
using System.Reflection;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

View File

@ -16,7 +16,7 @@
}
}
},
"net451": {}
"net46": {}
},
"testRunner": "xunit"
}

View File

@ -3,6 +3,7 @@
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets.Internal;
using Xunit;
namespace Microsoft.AspNetCore.Sockets.Tests
@ -10,49 +11,56 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public class ConnectionManagerTests
{
[Fact]
public void ReservedConnectionsHaveConnectionId()
public void NewConnectionsHaveConnectionId()
{
var connectionManager = new ConnectionManager();
var state = connectionManager.ReserveConnection();
using (var factory = new PipelineFactory())
{
var connectionManager = new ConnectionManager(factory);
var state = connectionManager.CreateConnection(ConnectionMode.Streaming);
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.True(state.Active);
Assert.Null(state.Close);
Assert.Null(state.Connection.Channel);
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.True(state.Active);
Assert.Null(state.Close);
Assert.NotNull(((StreamingConnectionState)state).Connection.Transport);
}
}
[Fact]
public void ReservedConnectionsCanBeRetrieved()
public void NewConnectionsCanBeRetrieved()
{
var connectionManager = new ConnectionManager();
var state = connectionManager.ReserveConnection();
using (var factory = new PipelineFactory())
{
var connectionManager = new ConnectionManager(factory);
var state = connectionManager.CreateConnection(ConnectionMode.Streaming);
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
ConnectionState newState;
Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
Assert.Same(newState, state);
ConnectionState newState;
Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
Assert.Same(newState, state);
}
}
[Fact]
public void AddNewConnection()
{
using (var factory = new PipelineFactory())
using (var connection = new HttpConnection(factory))
{
var connectionManager = new ConnectionManager();
var state = connectionManager.AddNewConnection(connection);
var connectionManager = new ConnectionManager(factory);
var state = connectionManager.CreateConnection(ConnectionMode.Streaming);
var transport = ((StreamingConnectionState)state).Connection.Transport;
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.NotNull(state.Connection.Channel);
Assert.NotNull(transport);
ConnectionState newState;
Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
Assert.Same(newState, state);
Assert.Same(connection, newState.Connection.Channel);
Assert.Same(transport, ((StreamingConnectionState)newState).Connection.Transport);
}
}
@ -60,19 +68,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public void RemoveConnection()
{
using (var factory = new PipelineFactory())
using (var connection = new HttpConnection(factory))
{
var connectionManager = new ConnectionManager();
var state = connectionManager.AddNewConnection(connection);
var connectionManager = new ConnectionManager(factory);
var state = connectionManager.CreateConnection(ConnectionMode.Streaming);
var transport = ((StreamingConnectionState)state).Connection.Transport;
Assert.NotNull(state.Connection);
Assert.NotNull(state.Connection.ConnectionId);
Assert.NotNull(state.Connection.Channel);
Assert.NotNull(transport);
ConnectionState newState;
Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
Assert.Same(newState, state);
Assert.Same(connection, newState.Connection.Channel);
Assert.Same(transport, ((StreamingConnectionState)newState).Connection.Transport);
connectionManager.RemoveConnection(state.Connection.ConnectionId);
Assert.False(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState));
@ -83,14 +92,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public async Task CloseConnectionsEndsAllPendingConnections()
{
using (var factory = new PipelineFactory())
using (var connection = new HttpConnection(factory))
{
var connectionManager = new ConnectionManager();
var state = connectionManager.AddNewConnection(connection);
var connectionManager = new ConnectionManager(factory);
var state = (StreamingConnectionState)connectionManager.CreateConnection(ConnectionMode.Streaming);
var task = Task.Run(async () =>
{
var result = await connection.Input.ReadAsync();
var result = await state.Connection.Transport.Input.ReadAsync();
Assert.True(result.IsCompleted);
});

View File

@ -5,13 +5,12 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Hosting.Internal;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Internal;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
using Xunit;
@ -23,11 +22,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task GetIdReservesConnectionIdAndReturnsIt()
{
var manager = new ConnectionManager();
using (var factory = new PipelineFactory())
{
var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory: null);
var manager = new ConnectionManager(factory);
var dispatcher = new HttpConnectionDispatcher(manager, factory, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
var ms = new MemoryStream();
context.Request.Path = "/getid";
context.Response.Body = ms;
@ -41,36 +43,40 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
}
[Fact]
public async Task SendingToReservedConnectionsThatHaveNotConnectedThrows()
{
var manager = new ConnectionManager();
var state = manager.ReserveConnection();
// REVIEW: No longer relevant since we establish the connection right away.
//[Fact]
//public async Task SendingToReservedConnectionsThatHaveNotConnectedThrows()
//{
// using (var factory = new PipelineFactory())
// {
// var manager = new ConnectionManager(factory);
// var state = manager.ReserveConnection();
using (var factory = new PipelineFactory())
{
var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory: null);
var context = new DefaultHttpContext();
context.Request.Path = "/send";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
{
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
});
}
}
// var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory: null);
// var context = new DefaultHttpContext();
// context.Request.Path = "/send";
// var values = new Dictionary<string, StringValues>();
// values["id"] = state.Connection.ConnectionId;
// var qs = new QueryCollection(values);
// context.Request.Query = qs;
// await Assert.ThrowsAsync<InvalidOperationException>(async () =>
// {
// await dispatcher.ExecuteAsync<TestEndPoint>("", context);
// });
// }
//}
[Fact]
public async Task SendingToUnknownConnectionIdThrows()
{
var manager = new ConnectionManager();
using (var factory = new PipelineFactory())
{
var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory: null);
var manager = new ConnectionManager(factory);
var dispatcher = new HttpConnectionDispatcher(manager, factory, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/send";
var values = new Dictionary<string, StringValues>();
values["id"] = "unknown";
@ -86,11 +92,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task SendingWithoutConnectionIdThrows()
{
var manager = new ConnectionManager();
using (var factory = new PipelineFactory())
{
var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory: null);
var manager = new ConnectionManager(factory);
var dispatcher = new HttpConnectionDispatcher(manager, factory, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/send";
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
{
@ -100,9 +109,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
}
public class TestEndPoint : EndPoint
public class TestEndPoint : StreamingEndPoint
{
public override Task OnConnectedAsync(Connection connection)
public override Task OnConnectedAsync(StreamingConnection connection)
{
throw new NotImplementedException();
}

View File

@ -8,7 +8,10 @@ using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Sockets.Transports;
using Microsoft.Extensions.Logging;
using Xunit;
namespace Microsoft.AspNetCore.Sockets.Tests
@ -18,45 +21,37 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task Set204StatusCodeWhenChannelComplete()
{
using (var factory = new PipelineFactory())
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var channel = new HttpConnection(factory);
connection.Channel = channel;
var context = new DefaultHttpContext();
var poll = new LongPolling(connection);
var channel = Channel.Create<Message>();
var context = new DefaultHttpContext();
var poll = new LongPollingTransport(channel, new LoggerFactory());
channel.Output.CompleteWriter();
Assert.True(channel.TryComplete());
await poll.ProcessRequestAsync(context);
await poll.ProcessRequestAsync(context);
Assert.Equal(204, context.Response.StatusCode);
}
Assert.Equal(204, context.Response.StatusCode);
}
[Fact]
public async Task NoFramingAddedWhenDataSent()
public async Task FrameSentAsSingleResponse()
{
using (var factory = new PipelineFactory())
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var channel = new HttpConnection(factory);
connection.Channel = channel;
var context = new DefaultHttpContext();
var ms = new MemoryStream();
context.Response.Body = ms;
var poll = new LongPolling(connection);
var channel = Channel.Create<Message>();
var context = new DefaultHttpContext();
var poll = new LongPollingTransport(channel, new LoggerFactory());
var ms = new MemoryStream();
context.Response.Body = ms;
await channel.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello World"));
await channel.WriteAsync(new Message(
ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve(),
Format.Text,
endOfMessage: true));
channel.Output.CompleteWriter();
Assert.True(channel.TryComplete());
await poll.ProcessRequestAsync(context);
await poll.ProcessRequestAsync(context);
Assert.Equal("Hello World", Encoding.UTF8.GetString(ms.ToArray()));
}
Assert.Equal(200, context.Response.StatusCode);
Assert.Equal("Hello World", Encoding.UTF8.GetString(ms.ToArray()));
}
}
}

View File

@ -8,7 +8,10 @@ using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Sockets.Transports;
using Microsoft.Extensions.Logging;
using Xunit;
namespace Microsoft.AspNetCore.Sockets.Tests
@ -18,47 +21,38 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task SSESetsContentType()
{
using (var factory = new PipelineFactory())
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var httpConnection = new HttpConnection(factory);
connection.Channel = httpConnection;
var sse = new ServerSentEvents(connection);
var context = new DefaultHttpContext();
var channel = Channel.Create<Message>();
var context = new DefaultHttpContext();
var sse = new ServerSentEventsTransport(channel, new LoggerFactory());
httpConnection.Output.CompleteWriter();
Assert.True(channel.TryComplete());
await sse.ProcessRequestAsync(context);
await sse.ProcessRequestAsync(context);
Assert.Equal("text/event-stream", context.Response.ContentType);
Assert.Equal("no-cache", context.Response.Headers["Cache-Control"]);
}
Assert.Equal("text/event-stream", context.Response.ContentType);
Assert.Equal("no-cache", context.Response.Headers["Cache-Control"]);
}
[Fact]
public async Task SSEAddsAppropriateFraming()
{
using (var factory = new PipelineFactory())
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var httpConnection = new HttpConnection(factory);
connection.Channel = httpConnection;
var sse = new ServerSentEvents(connection);
var context = new DefaultHttpContext();
var ms = new MemoryStream();
context.Response.Body = ms;
var channel = Channel.Create<Message>();
var context = new DefaultHttpContext();
var sse = new ServerSentEventsTransport(channel, new LoggerFactory());
var ms = new MemoryStream();
context.Response.Body = ms;
await httpConnection.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello World"));
await channel.WriteAsync(new Message(
ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve(),
Format.Text,
endOfMessage: true));
httpConnection.Output.CompleteWriter();
Assert.True(channel.TryComplete());
await sse.ProcessRequestAsync(context);
await sse.ProcessRequestAsync(context);
var expected = "data: Hello World\n\n";
Assert.Equal(expected, Encoding.UTF8.GetString(ms.ToArray()));
}
var expected = "data: Hello World\n\n";
Assert.Equal(expected, Encoding.UTF8.GetString(ms.ToArray()));
}
}
}

View File

@ -5,6 +5,9 @@ using System;
using System.IO.Pipelines;
using System.Text;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.AspNetCore.Sockets.Transports;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.WebSockets.Internal;
using Microsoft.Extensions.WebSockets.Internal.Tests;
@ -14,17 +17,21 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
public class WebSocketsTests
{
[Fact]
public async Task ReceivedFramesAreWrittenToPipeline()
[Theory]
[InlineData(Format.Text, WebSocketOpcode.Text)]
[InlineData(Format.Binary, WebSocketOpcode.Binary)]
public async Task ReceivedFramesAreWrittenToChannel(Format format, WebSocketOpcode opcode)
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipelineFactory())
using (var pair = WebSocketPair.Create(factory))
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var httpConnection = new HttpConnection(factory);
connection.Channel = httpConnection;
var ws = new WebSockets(connection, Format.Text, new LoggerFactory());
var ws = new WebSocketsTransport(transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
@ -35,13 +42,18 @@ namespace Microsoft.AspNetCore.Sockets.Tests
// Send a frame, then close
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Text,
opcode: opcode,
payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello"))));
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
// Capture everything out of the input channel and then complete the writer (to do our end of the close)
var buffer = (await connection.Channel.Input.ReadToEndAsync()).ToArray();
httpConnection.Output.CompleteWriter();
using (var message = await applicationSide.Input.ReadAsync())
{
Assert.True(message.EndOfMessage);
Assert.Equal(format, message.MessageFormat);
Assert.Equal("Hello", Encoding.UTF8.GetString(message.Payload.Buffer.ToArray()));
}
Assert.True(applicationSide.Output.TryComplete());
// The transport should finish now
await transport;
@ -49,8 +61,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
// The connection should close after this, which means the client will get a close frame.
var clientSummary = await client;
// Read from the connection pipeline
Assert.Equal("Hello", Encoding.UTF8.GetString(buffer));
Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.Status);
}
}
@ -58,16 +68,125 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Theory]
[InlineData(Format.Text, WebSocketOpcode.Text)]
[InlineData(Format.Binary, WebSocketOpcode.Binary)]
public async Task DataWrittenToOutputPipelineAreSentAsFrames(Format format, WebSocketOpcode expectedOpcode)
public async Task MultiFrameMessagesArePropagatedToTheChannel(Format format, WebSocketOpcode opcode)
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipelineFactory())
using (var pair = WebSocketPair.Create(factory))
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var httpConnection = new HttpConnection(factory);
connection.Channel = httpConnection;
var ws = new WebSockets(connection, format, new LoggerFactory());
var ws = new WebSocketsTransport(transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
// Send a frame, then close
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: false,
opcode: opcode,
payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello"))));
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Continuation,
payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("World"))));
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
using (var message1 = await applicationSide.Input.ReadAsync())
{
Assert.False(message1.EndOfMessage);
Assert.Equal(format, message1.MessageFormat);
Assert.Equal("Hello", Encoding.UTF8.GetString(message1.Payload.Buffer.ToArray()));
}
using (var message2 = await applicationSide.Input.ReadAsync())
{
Assert.True(message2.EndOfMessage);
Assert.Equal(format, message2.MessageFormat);
Assert.Equal("World", Encoding.UTF8.GetString(message2.Payload.Buffer.ToArray()));
}
Assert.True(applicationSide.Output.TryComplete());
// The transport should finish now
await transport;
// The connection should close after this, which means the client will get a close frame.
var clientSummary = await client;
Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.Status);
}
}
[Theory]
[InlineData(Format.Text, WebSocketOpcode.Text)]
[InlineData(Format.Binary, WebSocketOpcode.Binary)]
public async Task IncompleteMessagesAreWrittenAsMultiFrameWebSocketMessages(Format format, WebSocketOpcode opcode)
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipelineFactory())
using (var pair = WebSocketPair.Create(factory))
{
var ws = new WebSocketsTransport(transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
// Write multi-frame message to the output channel, and then complete it
await applicationSide.Output.WriteAsync(new Message(
ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello")).Preserve(),
format,
endOfMessage: false));
await applicationSide.Output.WriteAsync(new Message(
ReadableBuffer.Create(Encoding.UTF8.GetBytes("World")).Preserve(),
format,
endOfMessage: true));
Assert.True(applicationSide.Output.TryComplete());
// The client should finish now, as should the server
var clientSummary = await client;
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
await transport;
Assert.Equal(2, clientSummary.Received.Count);
Assert.False(clientSummary.Received[0].EndOfMessage);
Assert.Equal(opcode, clientSummary.Received[0].Opcode);
Assert.Equal("Hello", Encoding.UTF8.GetString(clientSummary.Received[0].Payload.ToArray()));
Assert.True(clientSummary.Received[1].EndOfMessage);
Assert.Equal(WebSocketOpcode.Continuation, clientSummary.Received[1].Opcode);
Assert.Equal("World", Encoding.UTF8.GetString(clientSummary.Received[1].Payload.ToArray()));
}
}
[Theory]
[InlineData(Format.Text, WebSocketOpcode.Text)]
[InlineData(Format.Binary, WebSocketOpcode.Binary)]
public async Task DataWrittenToOutputPipelineAreSentAsFrames(Format format, WebSocketOpcode opcode)
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipelineFactory())
using (var pair = WebSocketPair.Create(factory))
{
var ws = new WebSocketsTransport(transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
@ -76,8 +195,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
// Write to the output channel, and then complete it
await httpConnection.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello"));
httpConnection.Output.CompleteWriter();
await applicationSide.Output.WriteAsync(new Message(
ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello")).Preserve(),
format,
endOfMessage: true));
Assert.True(applicationSide.Output.TryComplete());
// The client should finish now, as should the server
var clientSummary = await client;
@ -86,22 +208,26 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal(1, clientSummary.Received.Count);
Assert.True(clientSummary.Received[0].EndOfMessage);
Assert.Equal(expectedOpcode, clientSummary.Received[0].Opcode);
Assert.Equal(opcode, clientSummary.Received[0].Opcode);
Assert.Equal("Hello", Encoding.UTF8.GetString(clientSummary.Received[0].Payload.ToArray()));
}
}
[Fact]
public async Task FrameReceivedAfterServerCloseSent()
[Theory]
[InlineData(Format.Text, WebSocketOpcode.Text)]
[InlineData(Format.Binary, WebSocketOpcode.Binary)]
public async Task FrameReceivedAfterServerCloseSent(Format format, WebSocketOpcode opcode)
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipelineFactory())
using (var pair = WebSocketPair.Create(factory))
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var httpConnection = new HttpConnection(factory);
connection.Channel = httpConnection;
var ws = new WebSockets(connection, Format.Binary, new LoggerFactory());
var ws = new WebSocketsTransport(transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
@ -110,19 +236,23 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
// Close the output and wait for the close frame
httpConnection.Output.CompleteWriter();
Assert.True(applicationSide.Output.TryComplete());
await client;
// Send another frame. Then close
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Text,
opcode: opcode,
payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello"))));
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
// Read that frame from the input
var result = (await httpConnection.Input.ReadToEndAsync()).ToArray();
Assert.Equal("Hello", Encoding.UTF8.GetString(result));
using (var message = await applicationSide.Input.ReadAsync())
{
Assert.True(message.EndOfMessage);
Assert.Equal(format, message.MessageFormat);
Assert.Equal("Hello", Encoding.UTF8.GetString(message.Payload.Buffer.ToArray()));
}
await transport;
}
@ -131,14 +261,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task TransportFailsWhenClientDisconnectsAbnormally()
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipelineFactory())
using (var pair = WebSocketPair.Create(factory))
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var httpConnection = new HttpConnection(factory);
connection.Channel = httpConnection;
var ws = new WebSockets(connection, Format.Binary, new LoggerFactory());
var ws = new WebSocketsTransport(transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
@ -157,14 +289,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task ClientReceivesInternalServerErrorWhenTheApplicationFails()
{
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipelineFactory())
using (var pair = WebSocketPair.Create(factory))
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var httpConnection = new HttpConnection(factory);
connection.Channel = httpConnection;
var ws = new WebSockets(connection, Format.Binary, new LoggerFactory());
var ws = new WebSocketsTransport(transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
@ -173,7 +307,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
// Fail in the app
httpConnection.Output.CompleteWriter(new InvalidOperationException());
Assert.True(applicationSide.Output.TryComplete(new InvalidOperationException()));
var clientSummary = await client;
Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.Status);

View File

@ -24,7 +24,7 @@
}
}
},
"net451": {}
"net46": {}
},
"testRunner": "xunit"
}
}

View File

@ -108,7 +108,7 @@ namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn
var result = deployer.Deploy();
result.HostShutdownToken.ThrowIfCancellationRequested();
#if NET451
#if NET46
System.Net.ServicePointManager.ServerCertificateValidationCallback = (_, __, ___, ____) => true;
var client = new HttpClient();
#else

View File

@ -4,7 +4,7 @@
},
"dependencies": {
"dotnet-test-xunit": "2.2.0-*",
"Microsoft.AspNetCore.Server.IntegrationTesting": "0.2.0-*",
"Microsoft.AspNetCore.Server.IntegrationTesting": "0.3.0-*",
"Microsoft.AspNetCore.Testing": "1.2.0-*",
"Microsoft.Extensions.Logging": "1.2.0-*",
"Microsoft.Extensions.Logging.Console": "1.2.0-*",
@ -22,6 +22,6 @@
}
}
},
"net451": {}
"net46": {}
}
}
}

View File

@ -21,6 +21,6 @@
}
}
},
"net451": {}
"net46": {}
}
}
}

View File

@ -7,17 +7,21 @@
"Microsoft.AspNetCore.WebSockets.Internal": "0.1.0-*",
"Microsoft.Extensions.Configuration": "1.2.0-*",
"Microsoft.Extensions.Configuration.CommandLine": "1.2.0-*",
"Microsoft.Extensions.Logging.Console": "1.2.0-*",
"Microsoft.NETCore.App": {
"version": "1.2.0-*",
"type": "platform"
}
"Microsoft.Extensions.Logging.Console": "1.2.0-*"
},
"tools": {
"Microsoft.AspNetCore.Server.IISIntegration.Tools": "1.0.0-*"
},
"frameworks": {
"netcoreapp1.1": {}
"netcoreapp1.1": {
"dependencies": {
"Microsoft.NETCore.App": {
"version": "1.2.0-*",
"type": "platform"
}
}
},
"net46": { }
},
"buildOptions": {
"warningsAsErrors": true
@ -38,4 +42,4 @@
"dotnet publish-iis --publish-folder %publish:OutputPath% --framework %publish:FullTargetFramework%"
]
}
}
}