Converting pipelines to channels

This commit is contained in:
moozzyk 2017-01-13 13:21:54 -08:00
parent 1119bcf1b3
commit c997ea8165
21 changed files with 533 additions and 226 deletions

View File

@ -28,11 +28,11 @@ namespace ClientSample
var logger = loggerFactory.CreateLogger<Program>();
using (var httpClient = new HttpClient(new LoggingMessageHandler(loggerFactory, new HttpClientHandler())))
using (var pipelineFactory = new PipelineFactory())
{
logger.LogInformation("Connecting to {0}", baseUrl);
var transport = new LongPollingTransport(httpClient, loggerFactory);
using (var connection = await HubConnection.ConnectAsync(new Uri(baseUrl), new JsonNetInvocationAdapter(), transport, httpClient, pipelineFactory, loggerFactory))
using (var connection = await HubConnection.ConnectAsync(new Uri(baseUrl),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{
logger.LogInformation("Connected to {0}", baseUrl);

View File

@ -7,6 +7,7 @@ using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.Extensions.Logging;
@ -27,11 +28,10 @@ namespace ClientSample
var logger = loggerFactory.CreateLogger<Program>();
using (var httpClient = new HttpClient(new LoggingMessageHandler(loggerFactory, new HttpClientHandler())))
using (var pipelineFactory = new PipelineFactory())
{
logger.LogInformation("Connecting to {0}", baseUrl);
var transport = new LongPollingTransport(httpClient, loggerFactory);
using (var connection = await Connection.ConnectAsync(new Uri(baseUrl), transport, httpClient, pipelineFactory, loggerFactory))
using (var connection = await Connection.ConnectAsync(new Uri(baseUrl), transport, httpClient, loggerFactory))
{
logger.LogInformation("Connected to {0}", baseUrl);
@ -44,8 +44,10 @@ namespace ClientSample
};
// Ready to start the loops
var receive = StartReceiving(loggerFactory.CreateLogger("ReceiveLoop"), connection, cts.Token);
var send = StartSending(loggerFactory.CreateLogger("SendLoop"), connection, cts.Token);
var receive =
StartReceiving(loggerFactory.CreateLogger("ReceiveLoop"), connection, cts.Token).ContinueWith(_ => cts.Cancel());
var send =
StartSending(loggerFactory.CreateLogger("SendLoop"), connection, cts.Token).ContinueWith(_ => cts.Cancel());
await Task.WhenAll(receive, send);
}
@ -60,7 +62,9 @@ namespace ClientSample
var line = Console.ReadLine();
logger.LogInformation("Sending: {0}", line);
await connection.Output.WriteAsync(Encoding.UTF8.GetBytes(line));
await connection.Output.WriteAsync(new Message(
ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve(),
Format.Text));
}
logger.LogInformation("Send loop terminated");
}
@ -68,30 +72,31 @@ namespace ClientSample
private static async Task StartReceiving(ILogger logger, Connection connection, CancellationToken cancellationToken)
{
logger.LogInformation("Receive loop starting");
using (cancellationToken.Register(() => connection.Input.Complete()))
try
{
while (!cancellationToken.IsCancellationRequested)
while (await connection.Input.WaitToReadAsync(cancellationToken))
{
var result = await connection.Input.ReadAsync();
var buffer = result.Buffer;
try
Message message;
if (!connection.Input.TryRead(out message))
{
if (!buffer.IsEmpty)
{
var message = Encoding.UTF8.GetString(buffer.ToArray());
logger.LogInformation("Received: {0}", message);
}
continue;
}
finally
using (message)
{
connection.Input.Advance(buffer.End);
}
if (result.IsCompleted)
{
break;
logger.LogInformation("Received: {0}", Encoding.UTF8.GetString(message.Payload.Buffer.ToArray()));
}
}
}
catch (OperationCanceledException)
{
logger.LogInformation("Connection is closing");
}
catch (Exception ex)
{
logger.LogError(0, ex, "Connection terminated due to an exception");
}
logger.LogInformation("Receive loop terminated");
}
}

View File

@ -1,13 +1,10 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Sockets;
namespace SocketsSample.EndPoints

View File

@ -11,6 +11,7 @@ using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.Extensions.Logging;
@ -19,7 +20,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
public class HubConnection : IDisposable
{
private readonly Task _reader;
private readonly Stream _stream;
private readonly ILogger _logger;
private readonly Connection _connection;
private readonly IInvocationAdapter _adapter;
@ -42,13 +42,11 @@ namespace Microsoft.AspNetCore.SignalR.Client
{
_binder = new HubBinder(this);
_connection = connection;
_stream = connection.GetStream();
_adapter = adapter;
_logger = logger;
_reader = ReceiveMessages(_readerCts.Token);
Completion = _connection.Output.Writing.ContinueWith(
t => Shutdown(t)).Unwrap();
Completion = _connection.Input.Completion.ContinueWith(t => Shutdown(t)).Unwrap();
}
// TODO: Client return values/tasks?
@ -98,9 +96,21 @@ namespace Microsoft.AspNetCore.SignalR.Client
_logger.LogTrace("Invocation #{0}: {1} {2}({3})", descriptor.Id, returnType.FullName, methodName, argsList);
}
// Write the invocation to the stream
var ms = new MemoryStream();
await _adapter.WriteMessageAsync(descriptor, ms, cancellationToken);
_logger.LogInformation("Sending Invocation #{0}", descriptor.Id);
await _adapter.WriteMessageAsync(descriptor, _stream, cancellationToken);
// TODO: Format.Text - who, where and when decides about the format of outgoing messages
var message = new Message(ReadableBuffer.Create(ms.ToArray()).Preserve(), Format.Text);
while (await _connection.Output.WaitToWriteAsync())
{
if (_connection.Output.TryWrite(message))
{
break;
}
}
_logger.LogInformation("Sending Invocation #{0} complete", descriptor.Id);
// Return the completion task. It will be completed by ReceiveMessages when the response is received.
@ -114,12 +124,12 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
// TODO: Clean up the API here. Negotiation of format would be better than providing an adapter instance. Similarly, we should not require a logger factory
public static Task<HubConnection> ConnectAsync(Uri url, IInvocationAdapter adapter, ITransport transport, PipelineFactory pipelineFactory, ILoggerFactory loggerFactory) => ConnectAsync(url, adapter, transport, new HttpClient(), pipelineFactory, loggerFactory);
public static Task<HubConnection> ConnectAsync(Uri url, IInvocationAdapter adapter, ITransport transport, ILoggerFactory loggerFactory) => ConnectAsync(url, adapter, transport, new HttpClient(), loggerFactory);
public static async Task<HubConnection> ConnectAsync(Uri url, IInvocationAdapter adapter, ITransport transport, HttpClient httpClient, PipelineFactory pipelineFactory, ILoggerFactory loggerFactory)
public static async Task<HubConnection> ConnectAsync(Uri url, IInvocationAdapter adapter, ITransport transport, HttpClient httpClient, ILoggerFactory loggerFactory)
{
// Connect the underlying connection
var connection = await Connection.ConnectAsync(url, transport, httpClient, pipelineFactory, loggerFactory);
var connection = await Connection.ConnectAsync(url, transport, httpClient, loggerFactory);
// Create the RPC connection wrapper
return new HubConnection(connection, adapter, loggerFactory.CreateLogger<HubConnection>());
@ -132,30 +142,38 @@ namespace Microsoft.AspNetCore.SignalR.Client
_logger.LogTrace("Beginning receive loop");
try
{
while (!cancellationToken.IsCancellationRequested)
while (await _connection.Input.WaitToReadAsync(cancellationToken))
{
// This is a little odd... we want to remove the InvocationRequest once and only once so we pull it out in the callback,
// and stash it here because we know the callback will have finished before the end of the await.
var message = await _adapter.ReadMessageAsync(_stream, _binder, cancellationToken);
Message incomingMessage;
while (_connection.Input.TryRead(out incomingMessage))
{
var invocationDescriptor = message as InvocationDescriptor;
if (invocationDescriptor != null)
{
DispatchInvocation(invocationDescriptor, cancellationToken);
}
else
{
var invocationResultDescriptor = message as InvocationResultDescriptor;
if (invocationResultDescriptor != null)
InvocationMessage message;
using (incomingMessage)
{
InvocationRequest irq;
lock (_pendingCallsLock)
message = await _adapter.ReadMessageAsync(
new MemoryStream(incomingMessage.Payload.Buffer.ToArray()), _binder, cancellationToken);
}
var invocationDescriptor = message as InvocationDescriptor;
if (invocationDescriptor != null)
{
DispatchInvocation(invocationDescriptor, cancellationToken);
}
else
{
var invocationResultDescriptor = message as InvocationResultDescriptor;
if (invocationResultDescriptor != null)
{
_connectionActive.Token.ThrowIfCancellationRequested();
irq = _pendingCalls[invocationResultDescriptor.Id];
_pendingCalls.Remove(invocationResultDescriptor.Id);
InvocationRequest irq;
lock (_pendingCallsLock)
{
_connectionActive.Token.ThrowIfCancellationRequested();
irq = _pendingCalls[invocationResultDescriptor.Id];
_pendingCalls.Remove(invocationResultDescriptor.Id);
}
DispatchInvocationResult(invocationResultDescriptor, irq, cancellationToken);
}
DispatchInvocationResult(invocationResultDescriptor, irq, cancellationToken);
}
}
}

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Options;

View File

@ -2,56 +2,45 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Net.Http;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Sockets.Client
{
public class Connection : IPipelineConnection
public class Connection : IChannelConnection<Message>
{
private IPipelineConnection _consumerPipe;
private IChannelConnection<Message> _transportChannel;
private ITransport _transport;
private readonly ILogger _logger;
public Uri Url { get; }
// TODO: Review. This is really only designed to be used from ConnectAsync
private Connection(Uri url, ITransport transport, IPipelineConnection consumerPipe, ILogger logger)
private Connection(Uri url, ITransport transport, IChannelConnection<Message> transportChannel, ILogger logger)
{
Url = url;
_logger = logger;
_transport = transport;
_consumerPipe = consumerPipe;
_consumerPipe.Output.Writing.ContinueWith(t =>
{
if (t.IsFaulted)
{
_consumerPipe.Input.Complete(t.Exception);
}
return t;
});
_transportChannel = transportChannel;
}
public IPipelineReader Input => _consumerPipe.Input;
public IPipelineWriter Output => _consumerPipe.Output;
public ReadableChannel<Message> Input => _transportChannel.Input;
public WritableChannel<Message> Output => _transportChannel.Output;
public void Dispose()
{
_consumerPipe.Dispose();
_transport.Dispose();
}
// TODO: More overloads. PipelineFactory should be optional but someone needs to dispose the pool, if we're OK with it being the GC, then this is easy.
public static Task<Connection> ConnectAsync(Uri url, ITransport transport, PipelineFactory pipelineFactory) => ConnectAsync(url, transport, new HttpClient(), pipelineFactory, NullLoggerFactory.Instance);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport, PipelineFactory pipelineFactory, ILoggerFactory loggerFactory) => ConnectAsync(url, transport, new HttpClient(), pipelineFactory, loggerFactory);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport, HttpClient httpClient, PipelineFactory pipelineFactory) => ConnectAsync(url, transport, httpClient, pipelineFactory, NullLoggerFactory.Instance);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport) => ConnectAsync(url, transport, new HttpClient(), NullLoggerFactory.Instance);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport, ILoggerFactory loggerFactory) => ConnectAsync(url, transport, new HttpClient(), loggerFactory);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport, HttpClient httpClient) => ConnectAsync(url, transport, httpClient, NullLoggerFactory.Instance);
public static async Task<Connection> ConnectAsync(Uri url, ITransport transport, HttpClient httpClient, PipelineFactory pipelineFactory, ILoggerFactory loggerFactory)
public static async Task<Connection> ConnectAsync(Uri url, ITransport transport, HttpClient httpClient, ILoggerFactory loggerFactory)
{
if (url == null)
{
@ -68,11 +57,6 @@ namespace Microsoft.AspNetCore.Sockets.Client
throw new ArgumentNullException(nameof(httpClient));
}
if (pipelineFactory == null)
{
throw new ArgumentNullException(nameof(pipelineFactory));
}
if (loggerFactory == null)
{
throw new ArgumentNullException(nameof(loggerFactory));
@ -97,12 +81,16 @@ namespace Microsoft.AspNetCore.Sockets.Client
var connectedUrl = Utils.AppendQueryString(url, "id=" + connectionId);
var pair = pipelineFactory.CreatePipelinePair();
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
// Start the transport, giving it one end of the pipeline
try
{
await transport.StartAsync(connectedUrl, pair.Item1);
await transport.StartAsync(connectedUrl, applicationSide);
}
catch (Exception ex)
{
@ -111,7 +99,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
}
// Create the connection, giving it the other end of the pipeline
return new Connection(url, transport, pair.Item2, logger);
return new Connection(url, transport, transportSide, logger);
}
}
}

View File

@ -2,13 +2,12 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets.Client
{
public interface ITransport : IDisposable
{
Task StartAsync(Uri url, IPipelineConnection pipeline);
Task StartAsync(Uri url, IChannelConnection<Message> application);
}
}

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.IO.Pipelines;
using System.Net;
using System.Net.Http;
@ -20,12 +21,10 @@ namespace Microsoft.AspNetCore.Sockets.Client
private readonly HttpClient _httpClient;
private readonly ILogger _logger;
private readonly CancellationTokenSource _senderCts = new CancellationTokenSource();
private readonly CancellationTokenSource _pollCts = new CancellationTokenSource();
private IPipelineConnection _pipeline;
private IChannelConnection<Message> _application;
private Task _sender;
private Task _poller;
private readonly CancellationTokenSource _transportCts = new CancellationTokenSource();
public Task Running { get; private set; }
@ -37,26 +36,21 @@ namespace Microsoft.AspNetCore.Sockets.Client
public void Dispose()
{
_senderCts.Cancel();
_pollCts.Cancel();
_pipeline?.Dispose();
_transportCts.Cancel();
}
public Task StartAsync(Uri url, IPipelineConnection pipeline)
public Task StartAsync(Uri url, IChannelConnection<Message> application)
{
_pipeline = pipeline;
// Schedule shutdown of the poller when the output is closed
pipeline.Output.Writing.ContinueWith(_ =>
{
_pollCts.Cancel();
return TaskCache.CompletedTask;
});
_application = application;
// Start sending and polling
_poller = Poll(Utils.AppendPath(url, "poll"), _pollCts.Token);
_sender = SendMessages(Utils.AppendPath(url, "send"), _senderCts.Token);
Running = Task.WhenAll(_sender, _poller);
_poller = Poll(Utils.AppendPath(url, "poll"), _transportCts.Token);
_sender = SendMessages(Utils.AppendPath(url, "send"), _transportCts.Token);
Running = Task.WhenAll(_sender, _poller).ContinueWith(t => {
_application.Output.TryComplete(t.IsFaulted ? t.Exception.InnerException : null);
return t;
}).Unwrap();
return TaskCache.CompletedTask;
}
@ -80,64 +74,71 @@ namespace Microsoft.AspNetCore.Sockets.Client
}
else
{
// Write the data to the output
var buffer = _pipeline.Output.Alloc();
var stream = new WriteableBufferStream(buffer);
await response.Content.CopyToAsync(stream);
await buffer.FlushAsync();
var ms = new MemoryStream();
await response.Content.CopyToAsync(ms);
var message = new Message(ReadableBuffer.Create(ms.ToArray()).Preserve(), Format.Text);
while (await _application.Output.WaitToWriteAsync(cancellationToken))
{
if (_application.Output.TryWrite(message))
{
break;
}
}
}
}
// Polling complete
_pipeline.Output.Complete();
}
catch (OperationCanceledException)
{
// transport is being closed
}
catch (Exception ex)
{
// Shut down the output pipeline and log
_logger.LogError("Error while polling '{0}': {1}", pollUrl, ex);
_pipeline.Output.Complete(ex);
_pipeline.Input.Complete(ex);
throw;
}
finally
{
// Make sure the send loop is terminated
_transportCts.Cancel();
}
}
private async Task SendMessages(Uri sendUrl, CancellationToken cancellationToken)
{
using (cancellationToken.Register(() => _pipeline.Input.Complete()))
try
{
try
while (await _application.Input.WaitToReadAsync(cancellationToken))
{
while (!cancellationToken.IsCancellationRequested)
Message message;
while (!cancellationToken.IsCancellationRequested && _application.Input.TryRead(out message))
{
var result = await _pipeline.Input.ReadAsync();
var buffer = result.Buffer;
if (buffer.IsEmpty || result.IsCompleted)
using (message)
{
// No more data to send
break;
var request = new HttpRequestMessage(HttpMethod.Post, sendUrl);
request.Headers.UserAgent.Add(DefaultUserAgentHeader);
request.Content = new ReadableBufferContent(message.Payload.Buffer);
var response = await _httpClient.SendAsync(request);
response.EnsureSuccessStatusCode();
}
// Create a message to send
var message = new HttpRequestMessage(HttpMethod.Post, sendUrl);
message.Headers.UserAgent.Add(DefaultUserAgentHeader);
message.Content = new ReadableBufferContent(buffer);
// Send it
var response = await _httpClient.SendAsync(message);
response.EnsureSuccessStatusCode();
_pipeline.Input.Advance(buffer.End);
}
// Sending complete
_pipeline.Input.Complete();
}
catch (Exception ex)
{
// Shut down the input pipeline and log
_logger.LogError("Error while sending to '{0}': {1}", sendUrl, ex);
_pipeline.Input.Complete(ex);
}
}
catch (OperationCanceledException)
{
// transport is being closed
}
catch (Exception ex)
{
_logger.LogError("Error while sending to '{0}': {1}", sendUrl, ex);
throw;
}
finally
{
// Make sure the poll loop is terminated
_transportCts.Cancel();
}
}
}
}

View File

@ -0,0 +1,25 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Sockets.Client
{
public struct Message : IDisposable
{
public Format MessageFormat { get; }
public PreservedBuffer Payload { get; }
public Message(PreservedBuffer payload, Format messageFormat)
{
MessageFormat = messageFormat;
Payload = payload;
}
public void Dispose()
{
Payload.Dispose();
}
}
}

View File

@ -4,18 +4,16 @@
<VisualStudioVersion Condition="'$(VisualStudioVersion)' == ''">14.0</VisualStudioVersion>
<VSToolsPath Condition="'$(VSToolsPath)' == ''">$(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion)</VSToolsPath>
</PropertyGroup>
<Import Project="$(VSToolsPath)\DotNet\Microsoft.DotNet.Props" Condition="'$(VSToolsPath)' != ''" />
<PropertyGroup Label="Globals">
<ProjectGuid>623fd372-36de-41a9-a564-f6040d570dbd</ProjectGuid>
<RootNamespace>Microsoft.AspNetCore.SignalR.Client</RootNamespace>
<RootNamespace>Microsoft.AspNetCore.Sockets.Client</RootNamespace>
<BaseIntermediateOutputPath Condition="'$(BaseIntermediateOutputPath)'=='' ">.\obj</BaseIntermediateOutputPath>
<OutputPath Condition="'$(OutputPath)'=='' ">.\bin\</OutputPath>
<TargetFrameworkVersion>v4.5.2</TargetFrameworkVersion>
</PropertyGroup>
<PropertyGroup>
<SchemaVersion>2.0</SchemaVersion>
</PropertyGroup>
<Import Project="$(VSToolsPath)\DotNet\Microsoft.DotNet.targets" Condition="'$(VSToolsPath)' != ''" />
</Project>
</Project>

View File

@ -1,25 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Sockets.Client
{
internal class PipelineConnection : IPipelineConnection
{
public IPipelineReader Input { get; }
public IPipelineWriter Output { get; }
public PipelineConnection(IPipelineReader input, IPipelineWriter output)
{
Input = input;
Output = output;
}
public void Dispose()
{
Input.Complete();
Output.Complete();
}
}
}

View File

@ -1,32 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Sockets.Client
{
// TODO: Move to System.IO.Pipelines
public static class PipelineFactoryExtensions
{
// TODO: Use a named tuple? Though there aren't really good names for these ... client/server? left/right?
public static Tuple<IPipelineConnection, IPipelineConnection> CreatePipelinePair(this PipelineFactory self)
{
// Create a pair of pipelines for "Server" and "Client"
var clientToServer = self.Create();
var serverToClient = self.Create();
// "Server" reads from clientToServer and writes to serverToClient
var server = new PipelineConnection(
input: clientToServer,
output: serverToClient);
// "Client" reads from serverToClient and writes to clientToServer
var client = new PipelineConnection(
input: serverToClient,
output: clientToServer);
return Tuple.Create((IPipelineConnection)server, (IPipelineConnection)client);
}
}
}

View File

@ -22,7 +22,8 @@
"compile": {
"include": [
"../Microsoft.AspNetCore.Sockets/IChannelConnection.cs",
"../Microsoft.AspNetCore.Sockets/Internal/ChannelConnection.cs"
"../Microsoft.AspNetCore.Sockets/Internal/ChannelConnection.cs",
"../Microsoft.AspNetCore.Sockets/Format.cs"
]
}
},

View File

@ -1,10 +1,6 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Linq;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection.Extensions;
@ -18,7 +14,6 @@ namespace Microsoft.Extensions.DependencyInjection
services.AddRouting();
services.TryAddSingleton<ConnectionManager>();
services.TryAddEnumerable(ServiceDescriptor.Singleton<IHostedService, SocketsApplicationLifetimeService>());
services.TryAddSingleton<PipelineFactory>();
services.TryAddSingleton<HttpConnectionDispatcher>();
return services;
}

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
@ -49,10 +48,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
var loggerFactory = CreateLogger();
using (var httpClient = _testServer.CreateClient())
using (var pipelineFactory = new PipelineFactory())
{
var transport = new LongPollingTransport(httpClient, loggerFactory);
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), transport, httpClient, pipelineFactory, loggerFactory))
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{
EnsureConnectionEstablished(connection);
@ -70,10 +69,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
const string originalMessage = "SignalR";
using (var httpClient = _testServer.CreateClient())
using (var pipelineFactory = new PipelineFactory())
{
var transport = new LongPollingTransport(httpClient, loggerFactory);
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), transport, httpClient, pipelineFactory, loggerFactory))
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{
EnsureConnectionEstablished(connection);
@ -91,10 +90,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
const string originalMessage = "SignalR";
using (var httpClient = _testServer.CreateClient())
using (var pipelineFactory = new PipelineFactory())
{
var transport = new LongPollingTransport(httpClient, loggerFactory);
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), transport, httpClient, pipelineFactory, loggerFactory))
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{
var tcs = new TaskCompletionSource<string>();
connection.On("Echo", new[] { typeof(string) }, a =>
@ -118,10 +117,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
var loggerFactory = CreateLogger();
using (var httpClient = _testServer.CreateClient())
using (var pipelineFactory = new PipelineFactory())
{
var transport = new LongPollingTransport(httpClient, loggerFactory);
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), transport, httpClient, pipelineFactory, loggerFactory))
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{
EnsureConnectionEstablished(connection);

View File

@ -0,0 +1,168 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Net;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Moq;
using Moq.Protected;
using Xunit;
namespace Microsoft.AspNetCore.Sockets.Client.Tests
{
public class ConnectionTests
{
[Fact]
public async Task ConnectionReturnsUrlUsedToStartTheConnection()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
var connectionUrl = new Uri("http://fakeuri.org/");
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
{
using (var connection = await Connection.ConnectAsync(connectionUrl, longPollingTransport, httpClient))
{
Assert.Equal(connectionUrl, connection.Url);
}
Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running));
}
}
[Fact]
public async Task TransportIsClosedWhenConnectionIsDisposed()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
{
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
Assert.False(longPollingTransport.Running.IsCompleted);
}
Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running));
}
}
[Fact]
public async Task CanSendData()
{
var sendTcs = new TaskCompletionSource<byte[]>();
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
if (request.RequestUri.AbsolutePath.EndsWith("/send"))
{
sendTcs.SetResult(await request.Content.ReadAsByteArrayAsync());
}
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
Assert.False(connection.Input.Completion.IsCompleted);
var data = new byte[] { 1, 1, 2, 3, 5, 8 };
connection.Output.TryWrite(
new Message(ReadableBuffer.Create(data).Preserve(), Format.Binary));
Assert.Equal(sendTcs.Task, await Task.WhenAny(Task.Delay(1000), sendTcs.Task));
Assert.Equal(data, sendTcs.Task.Result);
}
}
[Fact]
public async Task CanReceiveData()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
var content = string.Empty;
if (request.RequestUri.AbsolutePath.EndsWith("/poll"))
{
content = "42";
}
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(content) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
Assert.False(connection.Input.Completion.IsCompleted);
await connection.Input.WaitToReadAsync();
Message message;
connection.Input.TryRead(out message);
using (message)
{
Assert.Equal("42", Encoding.UTF8.GetString(message.Payload.Buffer.ToArray(), 0, message.Payload.Buffer.Length));
}
}
}
[Fact]
public async Task CanCloseConnection()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
Assert.False(connection.Input.Completion.IsCompleted);
connection.Output.TryComplete();
var whenAnyTask = Task.WhenAny(Task.Delay(1000), connection.Input.Completion);
// The channel needs to be drained for the Completion task to be completed
Message message;
while (!whenAnyTask.IsCompleted)
{
connection.Input.TryRead(out message);
message.Dispose();
}
Assert.Equal(connection.Input.Completion, await whenAnyTask);
}
}
}
}

View File

@ -0,0 +1,169 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.Logging;
using Moq;
using Moq.Protected;
using Xunit;
namespace Microsoft.AspNetCore.Sockets.Client.Tests
{
public class LongPollingTransportTests
{
[Fact]
public async Task LongPollingTransportStopsPollAndSendLoopsWhenTransportDisposed()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
Task transportActiveTask;
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
{
var connectionToTransport = Channel.CreateUnbounded<Message>();
var transportToConnection = Channel.CreateUnbounded<Message>();
var channelConnection = new ChannelConnection<Message>(connectionToTransport, transportToConnection);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection);
transportActiveTask = longPollingTransport.Running;
Assert.False(transportActiveTask.IsCompleted);
}
Assert.Equal(transportActiveTask, await Task.WhenAny(Task.Delay(1000), transportActiveTask));
}
[Fact]
public async Task LongPollingTransportStopsWhenPollReceives204()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
return new HttpResponseMessage(HttpStatusCode.NoContent) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
{
var connectionToTransport = Channel.CreateUnbounded<Message>();
var transportToConnection = Channel.CreateUnbounded<Message>();
var channelConnection = new ChannelConnection<Message>(connectionToTransport, transportToConnection);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection);
Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running));
Assert.True(transportToConnection.In.Completion.IsCompleted);
}
}
[Fact]
public async Task LongPollingTransportStopsWhenPollRequestFails()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
return new HttpResponseMessage(HttpStatusCode.InternalServerError) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
{
var connectionToTransport = Channel.CreateUnbounded<Message>();
var transportToConnection = Channel.CreateUnbounded<Message>();
var channelConnection = new ChannelConnection<Message>(connectionToTransport, transportToConnection);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection);
Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running));
var exception = await Assert.ThrowsAsync<HttpRequestException>(async () => await transportToConnection.In.Completion);
Assert.Contains(" 500 ", exception.Message);
}
}
[Fact]
public async Task LongPollingTransportStopsWhenSendRequestFails()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
var statusCode = request.RequestUri.AbsolutePath.EndsWith("send")
? HttpStatusCode.InternalServerError
: HttpStatusCode.OK;
return new HttpResponseMessage(statusCode) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
{
var connectionToTransport = Channel.CreateUnbounded<Message>();
var transportToConnection = Channel.CreateUnbounded<Message>();
var channelConnection = new ChannelConnection<Message>(connectionToTransport, transportToConnection);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection);
await connectionToTransport.Out.WriteAsync(new Message());
Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running));
await Assert.ThrowsAsync<HttpRequestException>(async () => await longPollingTransport.Running);
// The channel needs to be drained for the Completion task to be completed
Message message;
while (transportToConnection.In.TryRead(out message))
{
message.Dispose();
}
var exception = await Assert.ThrowsAsync<HttpRequestException>(async () => await transportToConnection.In.Completion);
Assert.Contains(" 500 ", exception.Message);
}
}
[Fact]
public async Task LongPollingTransportShutsDownWhenChannelIsClosed()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
{
var connectionToTransport = Channel.CreateUnbounded<Message>();
var transportToConnection = Channel.CreateUnbounded<Message>();
var channelConnection = new ChannelConnection<Message>(connectionToTransport, transportToConnection);
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection);
connectionToTransport.Out.Complete();
Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running));
Assert.Equal(connectionToTransport.In.Completion, await Task.WhenAny(Task.Delay(1000), connectionToTransport.In.Completion));
}
}
}
}

View File

@ -11,9 +11,11 @@
<BaseIntermediateOutputPath Condition="'$(BaseIntermediateOutputPath)'=='' ">.\obj</BaseIntermediateOutputPath>
<OutputPath Condition="'$(OutputPath)'=='' ">.\bin\</OutputPath>
</PropertyGroup>
<PropertyGroup>
<SchemaVersion>2.0</SchemaVersion>
</PropertyGroup>
<ItemGroup>
<Service Include="{82a7f48d-3b50-4b1e-b82e-3ada8210c358}" />
</ItemGroup>
<Import Project="$(VSToolsPath)\DotNet\Microsoft.DotNet.targets" Condition="'$(VSToolsPath)' != ''" />
</Project>

View File

@ -4,7 +4,10 @@
},
"dependencies": {
"Microsoft.AspNetCore.Sockets.Client": "1.0.0-*",
"Microsoft.Extensions.Logging": "1.2.0-*",
"dotnet-test-xunit": "2.2.0-*",
"Moq": "4.6.36-*",
"xunit": "2.2.0-*"
},
"frameworks": {

View File

@ -1,11 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;