Initial changes to move to pipelines (#1424)

- Change the Sockets abstraction from Channel<byte[]> to pipelines.

#615
This commit is contained in:
David Fowler 2018-02-09 17:45:21 -08:00 committed by GitHub
parent f939a7cd53
commit 28439d1441
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 658 additions and 372 deletions

View File

@ -1,11 +1,10 @@
using System;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using BenchmarkDotNet.Attributes;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.Logging.Abstractions;
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
@ -26,12 +25,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
for (var i = 0; i < Connections; ++i)
{
var transportToApplication = Channel.CreateUnbounded<byte[]>(options);
var applicationToTransport = Channel.CreateUnbounded<byte[]>(options);
var application = ChannelConnection.Create<byte[]>(input: applicationToTransport, output: transportToApplication);
var transport = ChannelConnection.Create<byte[]>(input: transportToApplication, output: applicationToTransport);
var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), transport, application);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Application, pair.Transport);
_hubLifetimeManager.OnConnectedAsync(new HubConnectionContext(connection, Timeout.InfiniteTimeSpan, NullLoggerFactory.Instance)).Wait();
}

View File

@ -60,6 +60,7 @@
<NewtonsoftJsonPackageVersion>10.0.1</NewtonsoftJsonPackageVersion>
<StackExchangeRedisStrongNamePackageVersion>1.2.4</StackExchangeRedisStrongNamePackageVersion>
<SystemBuffersPackageVersion>4.5.0-preview2-26130-01</SystemBuffersPackageVersion>
<SystemBuffersPrimitivesPackageVersion>0.1.0-preview2-180130-1</SystemBuffersPrimitivesPackageVersion>
<SystemIOPipelinesPackageVersion>0.1.0-preview2-180130-1</SystemIOPipelinesPackageVersion>
<SystemMemoryPackageVersion>4.5.0-preview2-26130-01</SystemMemoryPackageVersion>
<SystemNumericsVectorsPackageVersion>4.5.0-preview2-26130-01</SystemNumericsVectorsPackageVersion>

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
@ -10,7 +11,20 @@ namespace FunctionalTests
{
public async override Task OnConnectedAsync(ConnectionContext connection)
{
await connection.Transport.Writer.WriteAsync(await connection.Transport.Reader.ReadAsync());
var result = await connection.Transport.Input.ReadAsync();
var buffer = result.Buffer;
try
{
if (!buffer.IsEmpty)
{
await connection.Transport.Output.WriteAsync(buffer.ToArray());
}
}
finally
{
connection.Transport.Input.AdvanceTo(result.Buffer.End);
}
}
}
}

View File

@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Features;
@ -50,7 +51,7 @@ namespace SocialWeather
var ms = new MemoryStream();
await formatter.WriteAsync(data, ms);
connection.Transport.Writer.TryWrite(ms.ToArray());
await connection.Transport.Output.WriteAsync(ms.ToArray());
}
}

View File

@ -34,15 +34,29 @@ namespace SocialWeather
var formatter = _formatterResolver.GetFormatter<WeatherReport>(
(string)connection.Metadata["format"]);
while (await connection.Transport.Reader.WaitToReadAsync())
while (true)
{
if (connection.Transport.Reader.TryRead(out var buffer))
var result = await connection.Transport.Input.ReadAsync();
var buffer = result.Buffer;
try
{
var stream = new MemoryStream();
await stream.WriteAsync(buffer, 0, buffer.Length);
stream.Position = 0;
var weatherReport = await formatter.ReadAsync(stream);
await _lifetimeManager.SendToAllAsync(weatherReport);
if (!buffer.IsEmpty)
{
var stream = new MemoryStream();
var data = buffer.ToArray();
await stream.WriteAsync(data, 0, data.Length);
stream.Position = 0;
var weatherReport = await formatter.ReadAsync(stream);
await _lifetimeManager.SendToAllAsync(weatherReport);
}
else if (result.IsCompleted)
{
break;
}
}
finally
{
connection.Transport.Input.AdvanceTo(buffer.End);
}
}
}

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
@ -20,14 +21,28 @@ namespace SocketsSample.EndPoints
try
{
while (await connection.Transport.Reader.WaitToReadAsync())
while (true)
{
if (connection.Transport.Reader.TryRead(out var buffer))
var result = await connection.Transport.Input.ReadAsync();
var buffer = result.Buffer;
try
{
// We can avoid the copy here but we'll deal with that later
var text = Encoding.UTF8.GetString(buffer);
text = $"{connection.ConnectionId}: {text}";
await Broadcast(Encoding.UTF8.GetBytes(text));
if (!buffer.IsEmpty)
{
// We can avoid the copy here but we'll deal with that later
var text = Encoding.UTF8.GetString(buffer.ToArray());
text = $"{connection.ConnectionId}: {text}";
await Broadcast(Encoding.UTF8.GetBytes(text));
}
else if (result.IsCompleted)
{
break;
}
}
finally
{
connection.Transport.Input.AdvanceTo(buffer.End);
}
}
}
@ -50,7 +65,7 @@ namespace SocketsSample.EndPoints
foreach (var c in Connections)
{
tasks.Add(c.Transport.Writer.WriteAsync(payload));
tasks.Add(c.Transport.Output.WriteAsync(payload));
}
return Task.WhenAll(tasks);

View File

@ -3,7 +3,7 @@
"windowsAuthentication": false,
"anonymousAuthentication": true,
"iisExpress": {
"applicationUrl": "http://localhost:57707/",
"applicationUrl": "http://localhost:59847/",
"sslPort": 0
}
},

View File

@ -1,6 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Buffers;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
@ -19,6 +21,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal
_dataEncoder = dataEncoder;
}
public bool ReadMessages(ReadOnlyBuffer<byte> buffer, IInvocationBinder binder, out IList<HubMessage> messages, out SequencePosition consumed, out SequencePosition examined)
{
// TODO: Fix this implementation to be incremental
consumed = buffer.End;
examined = consumed;
return ReadMessages(buffer.ToArray(), binder, out messages);
}
public bool ReadMessages(byte[] input, IInvocationBinder binder, out IList<HubMessage> messages)
{
var buffer = _dataEncoder.Decode(input);

View File

@ -2,6 +2,8 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections;
using System.IO;
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
using Newtonsoft.Json;
@ -53,5 +55,27 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
}
return true;
}
public static bool TryParseMessage(ReadOnlyBuffer<byte> buffer, out NegotiationMessage negotiationMessage, out SequencePosition consumed, out SequencePosition examined)
{
var separator = buffer.PositionOf(TextMessageFormatter.RecordSeparator);
if (separator == null)
{
// Haven't seen the entire negotiate message so bail
consumed = buffer.Start;
examined = buffer.End;
negotiationMessage = null;
return false;
}
else
{
consumed = buffer.GetPosition(separator.Value, 1);
examined = consumed;
}
var memory = buffer.IsSingleSegment ? buffer.First : buffer.ToArray();
return TryParseMessage(memory.Span, out negotiationMessage);
}
}
}

View File

@ -11,6 +11,8 @@
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
<PackageReference Include="System.Memory" Version="$(SystemMemoryPackageVersion)" />
<PackageReference Include="System.Buffers" Version="$(SystemBuffersPackageVersion)" />
<PackageReference Include="System.Buffers.Primitives" Version="$(SystemBuffersPrimitivesPackageVersion)" />
<PackageReference Include="System.Runtime.CompilerServices.Unsafe" Version="$(SystemRuntimeCompilerServicesUnsafePackageVersion)" />
<PackageReference Include="Microsoft.Extensions.Options" Version="$(MicrosoftExtensionsOptionsPackageVersion)" />
</ItemGroup>

View File

@ -5,6 +5,7 @@ using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Net;
using System.Runtime.ExceptionServices;
using System.Security.Claims;
@ -15,6 +16,7 @@ using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Features;
@ -36,7 +38,6 @@ namespace Microsoft.AspNetCore.SignalR
private Task _writingTask = Task.CompletedTask;
private long _lastSendTimestamp = Stopwatch.GetTimestamp();
private byte[] _pingMessage;
public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory)
{
@ -59,7 +60,7 @@ namespace Microsoft.AspNetCore.SignalR
public virtual HubProtocolReaderWriter ProtocolReaderWriter { get; set; }
public virtual ChannelReader<byte[]> Input => _connectionContext.Transport.Reader;
public virtual PipeReader Input => _connectionContext.Transport.Input;
public string UserIdentifier { get; private set; }
@ -125,37 +126,54 @@ namespace Microsoft.AspNetCore.SignalR
using (var cts = new CancellationTokenSource())
{
cts.CancelAfter(timeout);
while (await _connectionContext.Transport.Reader.WaitToReadAsync(cts.Token))
while (true)
{
while (_connectionContext.Transport.Reader.TryRead(out var buffer))
var result = await _connectionContext.Transport.Input.ReadAsync(cts.Token);
var buffer = result.Buffer;
var consumed = buffer.End;
var examined = buffer.End;
try
{
if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage))
if (!buffer.IsEmpty)
{
var protocol = protocolResolver.GetProtocol(negotiationMessage.Protocol, this);
if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage, out consumed, out examined))
{
var protocol = protocolResolver.GetProtocol(negotiationMessage.Protocol, this);
var transportCapabilities = Features.Get<IConnectionTransportFeature>()?.TransportCapabilities
?? throw new InvalidOperationException("Unable to read transport capabilities.");
var transportCapabilities = Features.Get<IConnectionTransportFeature>()?.TransportCapabilities
?? throw new InvalidOperationException("Unable to read transport capabilities.");
var dataEncoder = (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) == 0)
? (IDataEncoder)Base64Encoder
: PassThroughEncoder;
var dataEncoder = (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) == 0)
? (IDataEncoder)Base64Encoder
: PassThroughEncoder;
var transferModeFeature = Features.Get<ITransferModeFeature>() ??
throw new InvalidOperationException("Unable to read transfer mode.");
var transferModeFeature = Features.Get<ITransferModeFeature>() ??
throw new InvalidOperationException("Unable to read transfer mode.");
transferModeFeature.TransferMode =
(protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) != 0)
? TransferMode.Binary
: TransferMode.Text;
transferModeFeature.TransferMode =
(protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) != 0)
? TransferMode.Binary
: TransferMode.Text;
ProtocolReaderWriter = new HubProtocolReaderWriter(protocol, dataEncoder);
ProtocolReaderWriter = new HubProtocolReaderWriter(protocol, dataEncoder);
_logger.UsingHubProtocol(protocol.Name);
_logger.UsingHubProtocol(protocol.Name);
UserIdentifier = userIdProvider.GetUserId(this);
UserIdentifier = userIdProvider.GetUserId(this);
return true;
return true;
}
}
else if (result.IsCompleted)
{
break;
}
}
finally
{
_connectionContext.Transport.Input.AdvanceTo(consumed, examined);
}
}
}
@ -186,7 +204,6 @@ namespace Microsoft.AspNetCore.SignalR
if (Features.Get<IConnectionInherentKeepAliveFeature>() == null)
{
Debug.Assert(ProtocolReaderWriter != null, "Expected the ProtocolReaderWriter to be set before StartAsync is called");
_pingMessage = ProtocolReaderWriter.WriteMessage(PingMessage.Instance);
_connectionContext.Features.Get<IConnectionHeartbeatFeature>()?.OnHeartbeat(state => ((HubConnectionContext)state).KeepAliveTick(), this);
}
@ -197,14 +214,10 @@ namespace Microsoft.AspNetCore.SignalR
while (Output.Reader.TryRead(out var hubMessage))
{
var buffer = ProtocolReaderWriter.WriteMessage(hubMessage);
while (await _connectionContext.Transport.Writer.WaitToWriteAsync())
{
if (_connectionContext.Transport.Writer.TryWrite(buffer))
{
Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp());
break;
}
}
await _connectionContext.Transport.Output.WriteAsync(buffer);
Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp());
}
}
}
@ -221,7 +234,6 @@ namespace Microsoft.AspNetCore.SignalR
// If it is, we send a ping frame, if not, we no-op on this tick. This means that in the worst case, the
// true "ping rate" of the server could be (_hubOptions.KeepAliveInterval + HubEndPoint.KeepAliveTimerInterval),
// because if the interval elapses right after the last tick of this timer, it won't be detected until the next tick.
Debug.Assert(_pingMessage != null, "Expected the ping message to be prepared before the first heartbeat tick");
if (Stopwatch.GetTimestamp() - Interlocked.Read(ref _lastSendTimestamp) > _keepAliveDuration)
{
@ -229,7 +241,8 @@ namespace Microsoft.AspNetCore.SignalR
// If the transport channel is full, this will fail, but that's OK because
// adding a Ping message when the transport is full is unnecessary since the
// transport is still in the process of sending frames.
if (_connectionContext.Transport.Writer.TryWrite(_pingMessage))
if (Output.Writer.TryWrite(PingMessage.Instance))
{
_logger.SentPing();
}

View File

@ -171,58 +171,74 @@ namespace Microsoft.AspNetCore.SignalR
try
{
while (await connection.Input.WaitToReadAsync(connection.ConnectionAbortedToken))
while (true)
{
while (connection.Input.TryRead(out var buffer))
var result = await connection.Input.ReadAsync(connection.ConnectionAbortedToken);
var buffer = result.Buffer;
var consumed = buffer.End;
var examined = buffer.End;
try
{
if (connection.ProtocolReaderWriter.ReadMessages(buffer, this, out var hubMessages))
if (!buffer.IsEmpty)
{
foreach (var hubMessage in hubMessages)
if (connection.ProtocolReaderWriter.ReadMessages(buffer, this, out var hubMessages, out consumed, out examined))
{
switch (hubMessage)
foreach (var hubMessage in hubMessages)
{
case InvocationMessage invocationMessage:
_logger.ReceivedHubInvocation(invocationMessage);
switch (hubMessage)
{
case InvocationMessage invocationMessage:
_logger.ReceivedHubInvocation(invocationMessage);
// Don't wait on the result of execution, continue processing other
// incoming messages on this connection.
_ = ProcessInvocation(connection, invocationMessage, isStreamedInvocation: false);
break;
// Don't wait on the result of execution, continue processing other
// incoming messages on this connection.
_ = ProcessInvocation(connection, invocationMessage, isStreamedInvocation: false);
break;
case StreamInvocationMessage streamInvocationMessage:
_logger.ReceivedStreamHubInvocation(streamInvocationMessage);
case StreamInvocationMessage streamInvocationMessage:
_logger.ReceivedStreamHubInvocation(streamInvocationMessage);
// Don't wait on the result of execution, continue processing other
// incoming messages on this connection.
_ = ProcessInvocation(connection, streamInvocationMessage, isStreamedInvocation: true);
break;
// Don't wait on the result of execution, continue processing other
// incoming messages on this connection.
_ = ProcessInvocation(connection, streamInvocationMessage, isStreamedInvocation: true);
break;
case CancelInvocationMessage cancelInvocationMessage:
// Check if there is an associated active stream and cancel it if it exists.
// The cts will be removed when the streaming method completes executing
if (connection.ActiveRequestCancellationSources.TryGetValue(cancelInvocationMessage.InvocationId, out var cts))
{
_logger.CancelStream(cancelInvocationMessage.InvocationId);
cts.Cancel();
}
else
{
// Stream can be canceled on the server while client is canceling stream.
_logger.UnexpectedCancel();
}
break;
case CancelInvocationMessage cancelInvocationMessage:
// Check if there is an associated active stream and cancel it if it exists.
// The cts will be removed when the streaming method completes executing
if (connection.ActiveRequestCancellationSources.TryGetValue(cancelInvocationMessage.InvocationId, out var cts))
{
_logger.CancelStream(cancelInvocationMessage.InvocationId);
cts.Cancel();
}
else
{
// Stream can be canceled on the server while client is canceling stream.
_logger.UnexpectedCancel();
}
break;
case PingMessage _:
// We don't care about pings
break;
case PingMessage _:
// We don't care about pings
break;
// Other kind of message we weren't expecting
default:
_logger.UnsupportedMessageReceived(hubMessage.GetType().FullName);
throw new NotSupportedException($"Received unsupported message: {hubMessage}");
// Other kind of message we weren't expecting
default:
_logger.UnsupportedMessageReceived(hubMessage.GetType().FullName);
throw new NotSupportedException($"Received unsupported message: {hubMessage}");
}
}
}
}
else if (result.IsCompleted)
{
break;
}
}
finally
{
connection.Input.AdvanceTo(consumed, examined);
}
}
}

View File

@ -2,7 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.Threading.Channels;
using System.IO.Pipelines;
using Microsoft.AspNetCore.Http.Features;
namespace Microsoft.AspNetCore.Sockets
@ -15,6 +15,6 @@ namespace Microsoft.AspNetCore.Sockets
public abstract IDictionary<object, object> Metadata { get; set; }
public abstract Channel<byte[]> Transport { get; set; }
public abstract IDuplexPipe Transport { get; set; }
}
}

View File

@ -0,0 +1,47 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace System.IO.Pipelines
{
public class DuplexPipe : IDuplexPipe
{
public DuplexPipe(PipeReader reader, PipeWriter writer)
{
Input = reader;
Output = writer;
}
public PipeReader Input { get; }
public PipeWriter Output { get; }
public void Dispose()
{
}
public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions)
{
var input = new Pipe(inputOptions);
var output = new Pipe(outputOptions);
var transportToApplication = new DuplexPipe(output.Reader, input.Writer);
var applicationToTransport = new DuplexPipe(input.Reader, output.Writer);
return new DuplexPipePair(applicationToTransport, transportToApplication);
}
// This class exists to work around issues with value tuple on .NET Framework
public struct DuplexPipePair
{
public IDuplexPipe Transport { get; private set; }
public IDuplexPipe Application { get; private set; }
public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application)
{
Transport = transport;
Application = application;
}
}
}
}

View File

@ -1,13 +1,13 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading.Channels;
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Sockets.Features
{
public interface IConnectionTransportFeature
{
Channel<byte[]> Transport { get; set; }
IDuplexPipe Transport { get; set; }
TransferMode TransportCapabilities { get; set; }
}

View File

@ -9,6 +9,9 @@
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Http.Features" Version="$(MicrosoftAspNetCoreHttpFeaturesPackageVersion)" />
<PackageReference Include="System.Threading.Channels" Version="$(SystemThreadingChannelsPackageVersion)" />
<PackageReference Include="System.IO.Pipelines" Version="$(SystemIOPipelinesPackageVersion)" />
<PackageReference Include="System.Memory" Version="$(SystemMemoryPackageVersion)" />
<PackageReference Include="System.Runtime.CompilerServices.Unsafe" Version="$(SystemRuntimeCompilerServicesUnsafePackageVersion)" />
<PackageReference Include="System.Threading.Tasks.Extensions" Version="$(SystemThreadingTasksExtensionsPackageVersion)" />
</ItemGroup>

View File

@ -4,6 +4,7 @@
using System;
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
@ -116,7 +117,7 @@ namespace Microsoft.AspNetCore.Sockets
connection.TransportCapabilities = TransferMode.Text;
// We only need to provide the Input channel since writing to the application is handled through /send.
var sse = new ServerSentEventsTransport(connection.Application.Reader, connection.ConnectionId, _loggerFactory);
var sse = new ServerSentEventsTransport(connection.Application.Input, connection.ConnectionId, _loggerFactory);
await DoPersistentConnection(socketDelegate, sse, context, connection);
}
@ -218,7 +219,7 @@ namespace Microsoft.AspNetCore.Sockets
context.Response.RegisterForDispose(timeoutSource);
context.Response.RegisterForDispose(tokenSource);
var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Reader, connection.ConnectionId, _loggerFactory);
var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Input, connection.ConnectionId, _loggerFactory);
// Start the transport
connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token);
@ -239,7 +240,7 @@ namespace Microsoft.AspNetCore.Sockets
if (resultTask == connection.ApplicationTask)
{
// Complete the transport (notifying it of the application error if there is one)
connection.Transport.Writer.TryComplete(connection.ApplicationTask.Exception);
connection.Transport.Output.Complete(connection.ApplicationTask.Exception);
// Wait for the transport to run
await connection.TransportTask;
@ -440,13 +441,7 @@ namespace Microsoft.AspNetCore.Sockets
}
_logger.ReceivedBytes(buffer.Length);
while (!connection.Application.Writer.TryWrite(buffer))
{
if (!await connection.Application.Writer.WaitToWriteAsync())
{
return;
}
}
await connection.Application.Output.WriteAsync(buffer);
}
private async Task<bool> EnsureConnectionStateAsync(DefaultConnectionContext connection, HttpContext context, TransportType transportType, TransportType supportedTransports, ConnectionLogScope logScope, HttpSocketOptions options)

View File

@ -16,8 +16,8 @@ namespace Microsoft.AspNetCore.Sockets.Internal
private static readonly Action<ILogger, Exception> _pollTimedOut =
LoggerMessage.Define(LogLevel.Information, new EventId(2, nameof(PollTimedOut)), "Poll request timed out. Sending 200 response to connection.");
private static readonly Action<ILogger, int, Exception> _longPollingWritingMessage =
LoggerMessage.Define<int>(LogLevel.Debug, new EventId(3, nameof(LongPollingWritingMessage)), "Writing a {count} byte message to connection.");
private static readonly Action<ILogger, long, Exception> _longPollingWritingMessage =
LoggerMessage.Define<long>(LogLevel.Debug, new EventId(3, nameof(LongPollingWritingMessage)), "Writing a {count} byte message to connection.");
private static readonly Action<ILogger, Exception> _longPollingDisconnected =
LoggerMessage.Define(LogLevel.Debug, new EventId(4, nameof(LongPollingDisconnected)), "Client disconnected from Long Polling endpoint for connection.");
@ -41,8 +41,8 @@ namespace Microsoft.AspNetCore.Sockets.Internal
private static readonly Action<ILogger, Exception> _resumingConnection =
LoggerMessage.Define(LogLevel.Debug, new EventId(5, nameof(ResumingConnection)), "Resuming existing connection.");
private static readonly Action<ILogger, int, Exception> _receivedBytes =
LoggerMessage.Define<int>(LogLevel.Debug, new EventId(6, nameof(ReceivedBytes)), "Received {count} bytes.");
private static readonly Action<ILogger, long, Exception> _receivedBytes =
LoggerMessage.Define<long>(LogLevel.Debug, new EventId(6, nameof(ReceivedBytes)), "Received {count} bytes.");
private static readonly Action<ILogger, TransportType, Exception> _transportNotSupported =
LoggerMessage.Define<TransportType>(LogLevel.Debug, new EventId(7, nameof(TransportNotSupported)), "{transportType} transport not supported by this endpoint type.");
@ -87,8 +87,8 @@ namespace Microsoft.AspNetCore.Sockets.Internal
private static readonly Action<ILogger, int, Exception> _messageToApplication =
LoggerMessage.Define<int>(LogLevel.Debug, new EventId(10, nameof(MessageToApplication)), "Passing message to application. Payload size: {size}.");
private static readonly Action<ILogger, int, Exception> _sendPayload =
LoggerMessage.Define<int>(LogLevel.Debug, new EventId(11, nameof(SendPayload)), "Sending payload: {size} bytes.");
private static readonly Action<ILogger, long, Exception> _sendPayload =
LoggerMessage.Define<long>(LogLevel.Debug, new EventId(11, nameof(SendPayload)), "Sending payload: {size} bytes.");
private static readonly Action<ILogger, Exception> _errorWritingFrame =
LoggerMessage.Define(LogLevel.Error, new EventId(12, nameof(ErrorWritingFrame)), "Error writing frame.");
@ -97,8 +97,8 @@ namespace Microsoft.AspNetCore.Sockets.Internal
LoggerMessage.Define(LogLevel.Trace, new EventId(13, nameof(SendFailed)), "Socket failed to send.");
// Category: ServerSentEventsTransport
private static readonly Action<ILogger, int, Exception> _sseWritingMessage =
LoggerMessage.Define<int>(LogLevel.Debug, new EventId(1, nameof(SSEWritingMessage)), "Writing a {count} byte message.");
private static readonly Action<ILogger, long, Exception> _sseWritingMessage =
LoggerMessage.Define<long>(LogLevel.Debug, new EventId(1, nameof(SSEWritingMessage)), "Writing a {count} byte message.");
public static void LongPolling204(this ILogger logger)
{
@ -110,7 +110,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
_pollTimedOut(logger, null);
}
public static void LongPollingWritingMessage(this ILogger logger, int count)
public static void LongPollingWritingMessage(this ILogger logger, long count)
{
_longPollingWritingMessage(logger, count, null);
}
@ -150,7 +150,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
_resumingConnection(logger, null);
}
public static void ReceivedBytes(this ILogger logger, int count)
public static void ReceivedBytes(this ILogger logger, long count)
{
_receivedBytes(logger, count, null);
}
@ -225,7 +225,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
_messageToApplication(logger, size, null);
}
public static void SendPayload(this ILogger logger, int size)
public static void SendPayload(this ILogger logger, long size)
{
_sendPayload(logger, size, null);
}
@ -240,7 +240,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
_sendFailed(logger, ex);
}
public static void SSEWritingMessage(this ILogger logger, int count)
public static void SSEWritingMessage(this ILogger logger, long count)
{
_sseWritingMessage(logger, count, null);
}

View File

@ -2,24 +2,24 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Sockets.Features;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Sockets.Internal.Transports
{
public class LongPollingTransport : IHttpTransport
{
private readonly ChannelReader<byte[]> _application;
private readonly PipeReader _application;
private readonly ILogger _logger;
private readonly CancellationToken _timeoutToken;
private readonly string _connectionId;
public LongPollingTransport(CancellationToken timeoutToken, ChannelReader<byte[]> application, string connectionId, ILoggerFactory loggerFactory)
public LongPollingTransport(CancellationToken timeoutToken, PipeReader application, string connectionId, ILoggerFactory loggerFactory)
{
_timeoutToken = timeoutToken;
_application = application;
@ -31,33 +31,38 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
{
try
{
if (!await _application.WaitToReadAsync(token))
var result = await _application.ReadAsync(token);
var buffer = result.Buffer;
if (buffer.IsEmpty && result.IsCompleted)
{
await _application.Completion;
_logger.LongPolling204();
context.Response.ContentType = "text/plain";
context.Response.StatusCode = StatusCodes.Status204NoContent;
return;
}
var contentLength = 0;
var buffers = new List<byte[]>();
// We're intentionally not checking cancellation here because we need to drain messages we've got so far,
// but it's too late to emit the 204 required by being cancelled.
while (_application.TryRead(out var buffer))
{
contentLength += buffer.Length;
buffers.Add(buffer);
_logger.LongPollingWritingMessage(buffer.Length);
}
_logger.LongPollingWritingMessage(buffer.Length);
context.Response.ContentLength = contentLength;
context.Response.ContentLength = buffer.Length;
context.Response.ContentType = "application/octet-stream";
foreach (var buffer in buffers)
try
{
await context.Response.Body.WriteAsync(buffer, 0, buffer.Length);
foreach (var segment in buffer)
{
var isArray = MemoryMarshal.TryGetArray(segment, out var arraySegment);
// We're using the managed memory pool which is backed by managed buffers
Debug.Assert(isArray);
await context.Response.Body.WriteAsync(arraySegment.Array, 0, arraySegment.Count);
}
}
finally
{
_application.AdvanceTo(buffer.End);
}
}
catch (OperationCanceledException)

View File

@ -3,9 +3,9 @@
using System;
using System.IO;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Channels;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
@ -15,11 +15,11 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
{
public class ServerSentEventsTransport : IHttpTransport
{
private readonly ChannelReader<byte[]> _application;
private readonly PipeReader _application;
private readonly string _connectionId;
private readonly ILogger _logger;
public ServerSentEventsTransport(ChannelReader<byte[]> application, string connectionId, ILoggerFactory loggerFactory)
public ServerSentEventsTransport(PipeReader application, string connectionId, ILoggerFactory loggerFactory)
{
_application = application;
_connectionId = connectionId;
@ -44,21 +44,32 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
try
{
while (await _application.WaitToReadAsync(token))
while (true)
{
var ms = new MemoryStream();
while (_application.TryRead(out var buffer))
var result = await _application.ReadAsync(token);
var buffer = result.Buffer;
try
{
_logger.SSEWritingMessage(buffer.Length);
ServerSentEventsMessageFormatter.WriteMessage(buffer, ms);
if (!buffer.IsEmpty)
{
var ms = new MemoryStream();
_logger.SSEWritingMessage(buffer.Length);
// Don't create a copy using ToArray every time
ServerSentEventsMessageFormatter.WriteMessage(buffer.ToArray(), ms);
ms.Seek(0, SeekOrigin.Begin);
await ms.CopyToAsync(context.Response.Body);
}
else if (result.IsCompleted)
{
break;
}
}
finally
{
_application.AdvanceTo(buffer.End);
}
ms.Seek(0, SeekOrigin.Begin);
await ms.CopyToAsync(context.Response.Body);
}
await _application.Completion;
}
catch (OperationCanceledException)
{

View File

@ -4,9 +4,9 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
@ -17,10 +17,10 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
{
private readonly WebSocketOptions _options;
private readonly ILogger _logger;
private readonly Channel<byte[]> _application;
private readonly IDuplexPipe _application;
private readonly DefaultConnectionContext _connection;
public WebSocketsTransport(WebSocketOptions options, Channel<byte[]> application, DefaultConnectionContext connection, ILoggerFactory loggerFactory)
public WebSocketsTransport(WebSocketOptions options, IDuplexPipe application, DefaultConnectionContext connection, ILoggerFactory loggerFactory)
{
if (options == null)
{
@ -86,9 +86,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
_logger.WaitingForClose();
}
// We're done writing
_application.Writer.TryComplete();
await socket.CloseOutputAsync(failed ? WebSocketCloseStatus.InternalServerError : WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
var resultTask = await Task.WhenAny(task, Task.Delay(_options.CloseTimeout));
@ -110,20 +107,17 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
private async Task<WebSocketReceiveResult> StartReceiving(WebSocket socket)
{
// REVIEW: This code was copied from the client, it's highly unoptimized at the moment (especially
// for server logic)
var incomingMessage = new List<ArraySegment<byte>>();
while (true)
try
{
const int bufferSize = 4096;
var totalBytes = 0;
WebSocketReceiveResult receiveResult;
do
while (true)
{
var buffer = new ArraySegment<byte>(new byte[bufferSize]);
var memory = _application.Output.GetMemory();
// REVIEW: Use new Memory<byte> websocket APIs on .NET Core 2.1
memory.TryGetArray(out var arraySegment);
// Exceptions are handled above where the send and receive tasks are being run.
receiveResult = await socket.ReceiveAsync(buffer, CancellationToken.None);
var receiveResult = await socket.ReceiveAsync(arraySegment, CancellationToken.None);
if (receiveResult.MessageType == WebSocketMessageType.Close)
{
return receiveResult;
@ -131,54 +125,33 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
_logger.MessageReceived(receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage);
var truncBuffer = new ArraySegment<byte>(buffer.Array, 0, receiveResult.Count);
incomingMessage.Add(truncBuffer);
totalBytes += receiveResult.Count;
} while (!receiveResult.EndOfMessage);
_application.Output.Advance(receiveResult.Count);
// Making sure the message type is either text or binary
Debug.Assert((receiveResult.MessageType == WebSocketMessageType.Binary || receiveResult.MessageType == WebSocketMessageType.Text), "Unexpected message type");
// TODO: Check received message type against the _options.WebSocketMessageType
byte[] messageBuffer = null;
if (incomingMessage.Count > 1)
{
messageBuffer = new byte[totalBytes];
var offset = 0;
for (var i = 0; i < incomingMessage.Count; i++)
if (receiveResult.EndOfMessage)
{
Buffer.BlockCopy(incomingMessage[i].Array, 0, messageBuffer, offset, incomingMessage[i].Count);
offset += incomingMessage[i].Count;
}
}
else
{
messageBuffer = new byte[incomingMessage[0].Count];
Buffer.BlockCopy(incomingMessage[0].Array, incomingMessage[0].Offset, messageBuffer, 0, incomingMessage[0].Count);
}
_logger.MessageToApplication(messageBuffer.Length);
while (await _application.Writer.WaitToWriteAsync())
{
if (_application.Writer.TryWrite(messageBuffer))
{
incomingMessage.Clear();
break;
await _application.Output.FlushAsync();
}
}
}
finally
{
// We're done writing
_application.Output.Complete();
}
}
private async Task StartSending(WebSocket ws)
{
while (await _application.Reader.WaitToReadAsync())
while (true)
{
var result = await _application.Input.ReadAsync();
var buffer = result.Buffer;
// Get a frame from the application
while (_application.Reader.TryRead(out var buffer))
try
{
if (buffer.Length > 0)
if (!buffer.IsEmpty)
{
try
{
@ -190,7 +163,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
if (WebSocketCanSend(ws))
{
await ws.SendAsync(new ArraySegment<byte>(buffer), webSocketMessageType, endOfMessage: true, cancellationToken: CancellationToken.None);
await ws.SendAsync(new ArraySegment<byte>(buffer.ToArray()), webSocketMessageType, endOfMessage: true, cancellationToken: CancellationToken.None);
}
}
catch (WebSocketException socketException) when (!WebSocketCanSend(ws))
@ -205,6 +178,14 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports
break;
}
}
else if (result.IsCompleted)
{
break;
}
}
finally
{
_application.Input.AdvanceTo(buffer.End);
}
}
}

View File

@ -6,6 +6,7 @@ using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Channels;
@ -63,13 +64,9 @@ namespace Microsoft.AspNetCore.Sockets
_logger.CreatedNewConnection(id);
var connectionTimer = SocketEventSource.Log.ConnectionStart(id);
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var transportSide = ChannelConnection.Create<byte[]>(applicationToTransport, transportToApplication);
var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport);
var connection = new DefaultConnectionContext(id, applicationSide, transportSide);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext(id, pair.Application, pair.Transport);
connection.ConnectionTimer = connectionTimer;
_connections.TryAdd(id, connection);

View File

@ -3,9 +3,9 @@
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Security.Claims;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Sockets.Features;
@ -28,7 +28,7 @@ namespace Microsoft.AspNetCore.Sockets
private TaskCompletionSource<object> _disposeTcs = new TaskCompletionSource<object>();
internal ValueStopwatch ConnectionTimer { get; set; }
public DefaultConnectionContext(string id, Channel<byte[]> transport, Channel<byte[]> application)
public DefaultConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application)
{
Transport = transport;
Application = application;
@ -65,9 +65,9 @@ namespace Microsoft.AspNetCore.Sockets
public override IDictionary<object, object> Metadata { get; set; } = new ConnectionMetadata();
public Channel<byte[]> Application { get; }
public IDuplexPipe Application { get; }
public override Channel<byte[]> Transport { get; set; }
public override IDuplexPipe Transport { get; set; }
public TransferMode TransportCapabilities { get; set; }
@ -111,21 +111,21 @@ namespace Microsoft.AspNetCore.Sockets
// If the application task is faulted, propagate the error to the transport
if (ApplicationTask?.IsFaulted == true)
{
Transport.Writer.TryComplete(ApplicationTask.Exception.InnerException);
Transport.Output.Complete(ApplicationTask.Exception.InnerException);
}
else
{
Transport.Writer.TryComplete();
Transport.Output.Complete();
}
// If the transport task is faulted, propagate the error to the application
if (TransportTask?.IsFaulted == true)
{
Application.Writer.TryComplete(TransportTask.Exception.InnerException);
Application.Output.Complete(TransportTask.Exception.InnerException);
}
else
{
Application.Writer.TryComplete();
Application.Output.Complete();
}
var applicationTask = ApplicationTask ?? Task.CompletedTask;
@ -139,7 +139,18 @@ namespace Microsoft.AspNetCore.Sockets
Lock.Release();
}
await disposeTask;
try
{
await disposeTask;
}
finally
{
// REVIEW: Should we move this to the read loops?
// Complete the reading side of the pipes
Application.Input.Complete();
Transport.Input.Complete();
}
}
private async Task WaitOnTasks(Task applicationTask, Task transportTask)

View File

@ -2,18 +2,17 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Channels;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Internal;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Tests
{
@ -23,22 +22,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests
private readonly HubProtocolReaderWriter _protocolReaderWriter;
private readonly IInvocationBinder _invocationBinder;
private CancellationTokenSource _cts;
private ChannelConnection<byte[]> _transport;
private Queue<HubMessage> _messages = new Queue<HubMessage>();
public DefaultConnectionContext Connection { get; }
public Channel<byte[]> Application { get; }
public Task Connected => ((TaskCompletionSource<bool>)Connection.Metadata["ConnectedTask"]).Task;
public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, bool addClaimId = false)
{
var options = new UnboundedChannelOptions { AllowSynchronousContinuations = synchronousCallbacks };
var transportToApplication = Channel.CreateUnbounded<byte[]>(options);
var applicationToTransport = Channel.CreateUnbounded<byte[]>(options);
Application = ChannelConnection.Create<byte[]>(input: applicationToTransport, output: transportToApplication);
_transport = ChannelConnection.Create<byte[]>(input: transportToApplication, output: applicationToTransport);
Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), _transport, Application);
var options = new PipeOptions(readerScheduler: synchronousCallbacks ? PipeScheduler.Inline : null);
var pair = DuplexPipe.CreateConnectionPair(options, options);
Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Transport, pair.Application);
var claimValue = Interlocked.Increment(ref _id).ToString();
var claims = new List<Claim> { new Claim(ClaimTypes.Name, claimValue) };
@ -59,7 +52,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
using (var memoryStream = new MemoryStream())
{
NegotiationProtocol.WriteMessage(new NegotiationMessage(protocol.Name), memoryStream);
Application.Writer.TryWrite(memoryStream.ToArray());
Connection.Application.Output.WriteAsync(memoryStream.ToArray());
}
}
@ -151,7 +144,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public async Task<string> SendHubMessageAsync(HubMessage message)
{
var payload = _protocolReaderWriter.WriteMessage(message);
await Application.Writer.WriteAsync(payload);
await Connection.Application.Output.WriteAsync(payload);
return message is HubInvocationMessage hubMessage ? hubMessage.InvocationId : null;
}
@ -163,9 +156,24 @@ namespace Microsoft.AspNetCore.SignalR.Tests
if (message == null)
{
if (!await Application.Reader.WaitToReadAsync())
var result = await Connection.Application.Input.ReadAsync();
var buffer = result.Buffer;
try
{
return null;
if (!buffer.IsEmpty)
{
continue;
}
if (result.IsCompleted)
{
return null;
}
}
finally
{
Connection.Application.Input.AdvanceTo(buffer.Start);
}
}
else
@ -177,18 +185,45 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public HubMessage TryRead()
{
if (Application.Reader.TryRead(out var buffer) &&
_protocolReaderWriter.ReadMessages(buffer, _invocationBinder, out var messages))
if (_messages.Count > 0)
{
return messages[0];
return _messages.Dequeue();
}
if (!Connection.Application.Input.TryRead(out var result))
{
return null;
}
var buffer = result.Buffer;
var consumed = buffer.End;
var examined = consumed;
try
{
if (_protocolReaderWriter.ReadMessages(result.Buffer, _invocationBinder, out var messages, out consumed, out examined))
{
foreach (var m in messages)
{
_messages.Enqueue(m);
}
return _messages.Dequeue();
}
}
finally
{
Connection.Application.Input.AdvanceTo(consumed, examined);
}
return null;
}
public void Dispose()
{
_cts.Cancel();
_transport.Dispose();
Connection.Application.Output.Complete();
}
private static string GetInvocationId()

View File

@ -1,8 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.Sockets;
namespace Microsoft.AspNetCore.SignalR.Tests
@ -11,7 +11,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
public async override Task OnConnectedAsync(ConnectionContext connection)
{
await connection.Transport.Writer.WriteAsync(await connection.Transport.Reader.ReadAsync());
var result = await connection.Transport.Input.ReadAsync();
var buffer = result.Buffer;
try
{
if (!buffer.IsEmpty)
{
await connection.Transport.Output.WriteAsync(buffer.ToArray());
}
}
finally
{
connection.Transport.Input.AdvanceTo(buffer.End);
}
}
}
}

View File

@ -92,6 +92,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Assert.Equal(bytes, buffer.Array.AsSpan().Slice(0, result.Count).ToArray());
logger.LogInformation("Waiting for close");
result = await ws.ReceiveAsync(buffer, CancellationToken.None).OrTimeout();
Assert.Equal(WebSocketMessageType.Close, result.MessageType);
logger.LogInformation("Closing socket");
await ws.CloseAsync(WebSocketCloseStatus.Empty, "", CancellationToken.None).OrTimeout();
logger.LogInformation("Closed socket");

View File

@ -234,13 +234,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await waitForSubscribe.Task.OrTimeout();
observable.OnNext(1);
await client.SendHubMessageAsync(new CancelInvocationMessage(invocationId)).OrTimeout();
await waitForDispose.Task.OrTimeout();
Assert.Equal(1L, ((StreamItemMessage)await client.ReadAsync().OrTimeout()).Item);
var message = await client.ReadAsync().OrTimeout();
Assert.IsType<CompletionMessage>(message);
client.Dispose();
@ -257,7 +257,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
using (var client = new TestClient())
{
// TestClient automatically writes negotiate, for this test we want to assume negotiate never gets sent
client.Connection.Transport.Reader.TryRead(out var item);
client.Connection.Transport.Input.TryRead(out var item);
client.Connection.Transport.Input.AdvanceTo(item.Buffer.End);
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
@ -283,7 +284,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
using (var client = new TestClient())
{
// TestClient automatically writes negotiate, for this test we want to assume negotiate never gets sent
client.Connection.Transport.Reader.TryRead(out var item);
client.Connection.Transport.Input.TryRead(out var item);
client.Connection.Transport.Input.AdvanceTo(item.Buffer.End);
await endPoint.OnConnectedAsync(client.Connection).OrTimeout();
}
@ -567,7 +569,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
client.Dispose();
// Nothing should have been written
Assert.False(client.Application.Reader.TryRead(out var buffer));
Assert.False(client.Connection.Application.Input.TryRead(out var buffer));
await endPointTask.OrTimeout();
}
@ -1720,6 +1722,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await endPointLifetime.OrTimeout();
client.Connection.Transport.Output.Complete();
// We shouldn't have any ping messages
HubMessage message;
var counter = 0;
@ -1760,6 +1764,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await endPointLifetime.OrTimeout();
client.Connection.Transport.Output.Complete();
// We should have all pings
HubMessage message;
var counter = 0;

View File

@ -81,12 +81,29 @@ namespace Microsoft.AspNetCore.Sockets.Tests
connection.ApplicationTask = Task.Run(async () =>
{
Assert.False(await connection.Transport.Reader.WaitToReadAsync());
var result = await connection.Transport.Input.ReadAsync();
try
{
Assert.True(result.IsCompleted);
}
finally
{
connection.Transport.Input.AdvanceTo(result.Buffer.End);
}
});
connection.TransportTask = Task.Run(async () =>
{
Assert.False(await connection.Application.Reader.WaitToReadAsync());
var result = await connection.Application.Input.ReadAsync();
try
{
Assert.True(result.IsCompleted);
}
finally
{
connection.Application.Input.AdvanceTo(result.Buffer.End);
}
});
connectionManager.CloseConnections();
@ -188,15 +205,22 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
var appLifetime = new TestApplicationLifetime();
var connectionManager = CreateConnectionManager(appLifetime);
var tcs = new TaskCompletionSource<object>();
appLifetime.Start();
var connection = connectionManager.CreateConnection();
connection.Application.Output.OnReaderCompleted((error, state) =>
{
tcs.TrySetResult(null);
},
null);
appLifetime.StopApplication();
// Connection should be disposed so this should complete immediately
Assert.False(await connection.Application.Writer.WaitToWriteAsync().OrTimeout());
await tcs.Task.OrTimeout();
}
private static ConnectionManager CreateConnectionManager(IApplicationLifetime lifetime = null)

View File

@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Net.WebSockets;
using System.Security.Claims;
using System.Text;
@ -199,6 +200,52 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
}
[Theory]
[InlineData(TransportType.LongPolling)]
[InlineData(TransportType.ServerSentEvents)]
public async Task PostSendsToConnection(TransportType transportType)
{
using (StartLog(out var loggerFactory, LogLevel.Debug))
{
var manager = CreateConnectionManager(loggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory);
var connection = manager.CreateConnection();
connection.Metadata[ConnectionMetadataNames.Transport] = transportType;
using (var requestBody = new MemoryStream())
using (var responseBody = new MemoryStream())
{
var bytes = Encoding.UTF8.GetBytes("Hello World");
requestBody.Write(bytes, 0, bytes.Length);
requestBody.Seek(0, SeekOrigin.Begin);
var context = new DefaultHttpContext();
context.Request.Body = requestBody;
context.Response.Body = responseBody;
var services = new ServiceCollection();
services.AddEndPoint<TestEndPoint>();
services.AddOptions();
context.Request.Path = "/foo";
context.Request.Method = "POST";
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new SocketBuilder(services.BuildServiceProvider());
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
Assert.True(connection.Transport.Input.TryRead(out var result));
Assert.Equal("Hello World", Encoding.UTF8.GetString(result.Buffer.ToArray()));
connection.Transport.Input.AdvanceTo(result.Buffer.End);
}
}
}
[Theory]
[InlineData(TransportType.ServerSentEvents)]
[InlineData(TransportType.LongPolling)]
@ -570,7 +617,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var buffer = Encoding.UTF8.GetBytes("Hello World");
// Write to the transport so the poll yields
await connection.Transport.Writer.WriteAsync(buffer);
await connection.Transport.Output.WriteAsync(buffer);
await task;
@ -605,7 +652,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var buffer = Encoding.UTF8.GetBytes("Hello World");
// Write to the application
await connection.Application.Writer.WriteAsync(buffer);
await connection.Application.Output.WriteAsync(buffer);
await task;
@ -638,7 +685,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var buffer = Encoding.UTF8.GetBytes("Hello World");
// Write to the application
await connection.Application.Writer.WriteAsync(buffer);
await connection.Application.Output.WriteAsync(buffer);
await task;
@ -674,7 +721,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
await task1.OrTimeout();
// Send a message from the app to complete Task 2
await connection.Transport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello, World"));
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World"));
await task2.OrTimeout();
@ -855,7 +902,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var endPointTask = dispatcher.ExecuteAsync(context, options, app);
await connection.Transport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout();
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout();
await endPointTask.OrTimeout();
@ -936,7 +983,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}));
var endPointTask = dispatcher.ExecuteAsync(context, options, app);
await connection.Transport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout();
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout();
await endPointTask.OrTimeout();
@ -993,7 +1040,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var endPointTask = dispatcher.ExecuteAsync(context, options, app);
await connection.Transport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout();
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).OrTimeout();
await endPointTask.OrTimeout();
@ -1229,7 +1276,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
public override Task OnConnectedAsync(ConnectionContext connection)
{
connection.Transport.Reader.WaitToReadAsync().Wait();
var waitHandle = new ManualResetEventSlim();
var awaiter = connection.Transport.Input.ReadAsync();
awaiter.OnCompleted(waitHandle.Set);
waitHandle.Wait();
var result = awaiter.GetResult();
connection.Transport.Input.AdvanceTo(result.Buffer.End);
return Task.CompletedTask;
}
}
@ -1254,8 +1308,21 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
public override async Task OnConnectedAsync(ConnectionContext connection)
{
while (await connection.Transport.Reader.WaitToReadAsync())
while (true)
{
var result = await connection.Transport.Input.ReadAsync();
try
{
if (result.IsCompleted)
{
break;
}
}
finally
{
connection.Transport.Input.AdvanceTo(result.Buffer.End);
}
}
}
}

View File

@ -3,12 +3,11 @@
using System;
using System.IO;
using System.IO.Pipelines;
using System.Text;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Sockets.Features;
using Microsoft.AspNetCore.Sockets.Internal.Transports;
using Microsoft.Extensions.Logging;
using Xunit;
@ -20,14 +19,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task Set204StatusCodeWhenChannelComplete()
{
var toApplication = Channel.CreateUnbounded<byte[]>();
var toTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
var context = new DefaultHttpContext();
var connection = new DefaultConnectionContext("foo", toTransport, toApplication);
var poll = new LongPollingTransport(CancellationToken.None, toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var poll = new LongPollingTransport(CancellationToken.None, connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory());
Assert.True(toTransport.Writer.TryComplete());
connection.Transport.Output.Complete();
await poll.ProcessRequestAsync(context, context.RequestAborted).OrTimeout();
@ -37,13 +35,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task Set200StatusCodeWhenTimeoutTokenFires()
{
var toApplication = Channel.CreateUnbounded<byte[]>();
var toTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
var context = new DefaultHttpContext();
var connection = new DefaultConnectionContext("foo", toTransport, toApplication);
var timeoutToken = new CancellationToken(true);
var poll = new LongPollingTransport(timeoutToken, toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var poll = new LongPollingTransport(timeoutToken, connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory());
using (var cts = CancellationTokenSource.CreateLinkedTokenSource(timeoutToken, context.RequestAborted))
{
@ -57,18 +54,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task FrameSentAsSingleResponse()
{
var toApplication = Channel.CreateUnbounded<byte[]>();
var toTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
var context = new DefaultHttpContext();
var connection = new DefaultConnectionContext("foo", toTransport, toApplication);
var poll = new LongPollingTransport(CancellationToken.None, toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var poll = new LongPollingTransport(CancellationToken.None, connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var ms = new MemoryStream();
context.Response.Body = ms;
await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello World"));
Assert.True(toTransport.Writer.TryComplete());
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello World"));
connection.Transport.Output.Complete();
await poll.ProcessRequestAsync(context, context.RequestAborted).OrTimeout();
@ -79,20 +74,19 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task MultipleFramesSentAsSingleResponse()
{
var toApplication = Channel.CreateUnbounded<byte[]>();
var toTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
var context = new DefaultHttpContext();
var connection = new DefaultConnectionContext("foo", toTransport, toApplication);
var poll = new LongPollingTransport(CancellationToken.None, toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var poll = new LongPollingTransport(CancellationToken.None, connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var ms = new MemoryStream();
context.Response.Body = ms;
await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello"));
await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes(" "));
await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes("World"));
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello"));
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes(" "));
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("World"));
Assert.True(toTransport.Writer.TryComplete());
connection.Transport.Output.Complete();
await poll.ProcessRequestAsync(context, context.RequestAborted).OrTimeout();

View File

@ -92,9 +92,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
public override async Task OnConnectedAsync(ConnectionContext connection)
{
while (!await connection.Transport.Reader.WaitToReadAsync())
while (true)
{
var result = await connection.Transport.Input.ReadAsync();
if (result.IsCompleted)
{
break;
}
// Consume nothing
connection.Transport.Input.AdvanceTo(result.Buffer.Start);
}
}
}

View File

@ -2,8 +2,9 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
using System.IO.Pipelines;
using System.Text;
using System.Threading.Channels;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
@ -18,14 +19,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task SSESetsContentType()
{
var toApplication = Channel.CreateUnbounded<byte[]>();
var toTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
var context = new DefaultHttpContext();
var connection = new DefaultConnectionContext("foo", toTransport, toApplication);
var sse = new ServerSentEventsTransport(toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var sse = new ServerSentEventsTransport(connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory());
Assert.True(toTransport.Writer.TryComplete());
connection.Transport.Output.Complete();
await sse.ProcessRequestAsync(context, context.RequestAborted);
@ -36,16 +36,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task SSETurnsResponseBufferingOff()
{
var toApplication = Channel.CreateUnbounded<byte[]>();
var toTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
var context = new DefaultHttpContext();
var connection = new DefaultConnectionContext("foo", toTransport, toApplication);
var feature = new HttpBufferingFeature();
context.Features.Set<IHttpBufferingFeature>(feature);
var sse = new ServerSentEventsTransport(toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var sse = new ServerSentEventsTransport(connection.Application.Input, connectionId: connection.ConnectionId, loggerFactory: new LoggerFactory());
Assert.True(toTransport.Writer.TryComplete());
connection.Transport.Output.Complete();
await sse.ProcessRequestAsync(context, context.RequestAborted);
@ -55,25 +54,22 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[Fact]
public async Task SSEWritesMessages()
{
var toApplication = Channel.CreateUnbounded<byte[]>();
var toTransport = Channel.CreateUnbounded<byte[]>(new UnboundedChannelOptions
{
AllowSynchronousContinuations = true
});
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, new PipeOptions(readerScheduler: PipeScheduler.Inline));
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
var context = new DefaultHttpContext();
var connection = new DefaultConnectionContext("foo", toTransport, toApplication);
var ms = new MemoryStream();
context.Response.Body = ms;
var sse = new ServerSentEventsTransport(toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var sse = new ServerSentEventsTransport(connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var task = sse.ProcessRequestAsync(context, context.RequestAborted);
await toTransport.Writer.WriteAsync(Encoding.ASCII.GetBytes("Hello"));
await connection.Transport.Output.WriteAsync(Encoding.ASCII.GetBytes("Hello"));
Assert.Equal(":\r\ndata: Hello\r\n\r\n", Encoding.ASCII.GetString(ms.ToArray()));
toTransport.Writer.TryComplete();
connection.Transport.Output.Complete();
await task.OrTimeout();
}
@ -84,18 +80,17 @@ namespace Microsoft.AspNetCore.Sockets.Tests
[InlineData("Hello\r\nWorld", ":\r\ndata: Hello\r\ndata: World\r\n\r\n")]
public async Task SSEAddsAppropriateFraming(string message, string expected)
{
var toApplication = Channel.CreateUnbounded<byte[]>();
var toTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
var context = new DefaultHttpContext();
var connection = new DefaultConnectionContext("foo", toTransport, toApplication);
var sse = new ServerSentEventsTransport(toTransport.Reader, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var sse = new ServerSentEventsTransport(connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var ms = new MemoryStream();
context.Response.Body = ms;
await toTransport.Writer.WriteAsync(Encoding.UTF8.GetBytes(message));
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes(message));
Assert.True(toTransport.Writer.TryComplete());
connection.Transport.Output.Complete();
await sse.ProcessRequestAsync(context, context.RequestAborted);

View File

@ -2,12 +2,11 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.AspNetCore.Sockets.Internal.Transports;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
@ -30,15 +29,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
using (StartLog(out var loggerFactory, LogLevel.Debug))
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
using (var transportSide = ChannelConnection.Create<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory);
var ws = new WebSocketsTransport(new WebSocketOptions(), connection.Application, connectionContext, loggerFactory);
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
@ -54,10 +51,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests
cancellationToken: CancellationToken.None);
await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
var buffer = await applicationSide.Reader.ReadAsync();
Assert.Equal("Hello", Encoding.UTF8.GetString(buffer));
var result = await connection.Transport.Input.ReadAsync();
var buffer = result.Buffer;
Assert.Equal("Hello", Encoding.UTF8.GetString(buffer.ToArray()));
connection.Transport.Input.AdvanceTo(buffer.End);
Assert.True(applicationSide.Writer.TryComplete());
connection.Transport.Output.Complete();
// The transport should finish now
await transport;
@ -77,15 +76,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
using (StartLog(out var loggerFactory, LogLevel.Debug))
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
using (var transportSide = ChannelConnection.Create<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var connectionContext = new DefaultConnectionContext(string.Empty, null, null) { TransferMode = transferMode };
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory);
var ws = new WebSocketsTransport(new WebSocketOptions(), connection.Application, connectionContext, loggerFactory);
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
@ -94,8 +91,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var client = feature.Client.ExecuteAndCaptureFramesAsync();
// Write to the output channel, and then complete it
await applicationSide.Writer.WriteAsync(Encoding.UTF8.GetBytes("Hello"));
Assert.True(applicationSide.Writer.TryComplete());
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello"));
connection.Transport.Output.Complete();
// The client should finish now, as should the server
var clientSummary = await client;
@ -115,24 +112,23 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
using (StartLog(out var loggerFactory, LogLevel.Debug))
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
using (var transportSide = ChannelConnection.Create<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
async Task CompleteApplicationAfterTransportCompletes()
{
// Wait until the transport completes so that we can end the application
await applicationSide.Reader.WaitToReadAsync();
var result = await connection.Transport.Input.ReadAsync();
connection.Transport.Input.AdvanceTo(result.Buffer.End);
// Complete the application so that the connection unwinds without aborting
applicationSide.Writer.TryComplete();
connection.Transport.Output.Complete();
}
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory);
var ws = new WebSocketsTransport(new WebSocketOptions(), connection.Application, connectionContext, loggerFactory);
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
@ -150,8 +146,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
// Wait for the transport
await Assert.ThrowsAsync<WebSocketException>(() => transport).OrTimeout();
var summary = await client.OrTimeout();
Assert.Equal(WebSocketCloseStatus.InternalServerError, summary.CloseResult.CloseStatus);
await client.OrTimeout();
}
}
}
@ -161,15 +156,13 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
using (StartLog(out var loggerFactory, LogLevel.Debug))
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
using (var transportSide = ChannelConnection.Create<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionContext, loggerFactory);
var ws = new WebSocketsTransport(new WebSocketOptions(), connection.Application, connectionContext, loggerFactory);
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
@ -178,7 +171,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var client = feature.Client.ExecuteAndCaptureFramesAsync();
// Fail in the app
Assert.True(applicationSide.Writer.TryComplete(new InvalidOperationException("Catastrophic failure.")));
connection.Transport.Output.Complete(new InvalidOperationException("Catastrophic failure."));
var clientSummary = await client.OrTimeout();
Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.CloseStatus);
@ -196,11 +189,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
using (StartLog(out var loggerFactory, LogLevel.Debug))
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
using (var transportSide = ChannelConnection.Create<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var options = new WebSocketOptions()
@ -209,14 +200,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
};
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory);
var ws = new WebSocketsTransport(options, connection.Application, connectionContext, loggerFactory);
var serverSocket = await feature.AcceptAsync();
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(serverSocket);
// End the app
applicationSide.Dispose();
connection.Transport.Output.Complete();
await transport.OrTimeout(TimeSpan.FromSeconds(10));
@ -233,11 +224,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
using (StartLog(out var loggerFactory, LogLevel.Debug))
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
using (var transportSide = ChannelConnection.Create<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var options = new WebSocketOptions
@ -246,7 +235,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
};
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory);
var ws = new WebSocketsTransport(options, connection.Application, connectionContext, loggerFactory);
var serverSocket = await feature.AcceptAsync();
// Give the server socket to the transport and run it
@ -256,7 +245,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var client = feature.Client.ExecuteAndCaptureFramesAsync();
// fail the client to server channel
applicationToTransport.Writer.TryComplete(new Exception());
connection.Transport.Output.Complete(new Exception());
await Assert.ThrowsAsync<Exception>(() => transport).OrTimeout();
@ -270,11 +259,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
using (StartLog(out var loggerFactory, LogLevel.Debug))
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
using (var transportSide = ChannelConnection.Create<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var options = new WebSocketOptions
@ -284,7 +271,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
};
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory);
var ws = new WebSocketsTransport(options, connection.Application, connectionContext, loggerFactory);
var serverSocket = await feature.AcceptAsync();
// Give the server socket to the transport and run it
@ -294,7 +281,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var client = feature.Client.ExecuteAndCaptureFramesAsync();
// close the client to server channel
applicationToTransport.Writer.TryComplete();
connection.Transport.Output.Complete();
_ = await client.OrTimeout();
@ -312,11 +299,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
using (StartLog(out var loggerFactory, LogLevel.Debug))
{
var transportToApplication = Channel.CreateUnbounded<byte[]>();
var applicationToTransport = Channel.CreateUnbounded<byte[]>();
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
using (var transportSide = ChannelConnection.Create<byte[]>(applicationToTransport, transportToApplication))
using (var applicationSide = ChannelConnection.Create<byte[]>(transportToApplication, applicationToTransport))
using (var feature = new TestWebSocketConnectionFeature())
{
var options = new WebSocketOptions
@ -325,7 +310,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
CloseTimeout = TimeSpan.FromSeconds(20)
};
var connectionContext = new DefaultConnectionContext(string.Empty, null, null);
var ws = new WebSocketsTransport(options, transportSide, connectionContext, loggerFactory);
var ws = new WebSocketsTransport(options, connection.Application, connectionContext, loggerFactory);
var serverSocket = await feature.AcceptAsync();
// Give the server socket to the transport and run it
@ -337,7 +322,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
await feature.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).OrTimeout();
// close the client to server channel
applicationToTransport.Writer.TryComplete();
connection.Transport.Output.Complete();
_ = await client.OrTimeout();