diff --git a/samples/ClientSample/HubSample.cs b/samples/ClientSample/HubSample.cs index 71e948678d..38115f782f 100644 --- a/samples/ClientSample/HubSample.cs +++ b/samples/ClientSample/HubSample.cs @@ -28,11 +28,11 @@ namespace ClientSample var logger = loggerFactory.CreateLogger(); using (var httpClient = new HttpClient(new LoggingMessageHandler(loggerFactory, new HttpClientHandler()))) - using (var pipelineFactory = new PipelineFactory()) { logger.LogInformation("Connecting to {0}", baseUrl); var transport = new LongPollingTransport(httpClient, loggerFactory); - using (var connection = await HubConnection.ConnectAsync(new Uri(baseUrl), new JsonNetInvocationAdapter(), transport, httpClient, pipelineFactory, loggerFactory)) + using (var connection = await HubConnection.ConnectAsync(new Uri(baseUrl), + new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) { logger.LogInformation("Connected to {0}", baseUrl); diff --git a/samples/ClientSample/RawSample.cs b/samples/ClientSample/RawSample.cs index 0e33084633..7ee6bec583 100644 --- a/samples/ClientSample/RawSample.cs +++ b/samples/ClientSample/RawSample.cs @@ -7,6 +7,7 @@ using System.Net.Http; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.Extensions.Logging; @@ -27,11 +28,10 @@ namespace ClientSample var logger = loggerFactory.CreateLogger(); using (var httpClient = new HttpClient(new LoggingMessageHandler(loggerFactory, new HttpClientHandler()))) - using (var pipelineFactory = new PipelineFactory()) { logger.LogInformation("Connecting to {0}", baseUrl); var transport = new LongPollingTransport(httpClient, loggerFactory); - using (var connection = await Connection.ConnectAsync(new Uri(baseUrl), transport, httpClient, pipelineFactory, loggerFactory)) + using (var connection = await Connection.ConnectAsync(new Uri(baseUrl), transport, httpClient, loggerFactory)) { logger.LogInformation("Connected to {0}", baseUrl); @@ -44,8 +44,10 @@ namespace ClientSample }; // Ready to start the loops - var receive = StartReceiving(loggerFactory.CreateLogger("ReceiveLoop"), connection, cts.Token); - var send = StartSending(loggerFactory.CreateLogger("SendLoop"), connection, cts.Token); + var receive = + StartReceiving(loggerFactory.CreateLogger("ReceiveLoop"), connection, cts.Token).ContinueWith(_ => cts.Cancel()); + var send = + StartSending(loggerFactory.CreateLogger("SendLoop"), connection, cts.Token).ContinueWith(_ => cts.Cancel()); await Task.WhenAll(receive, send); } @@ -60,7 +62,9 @@ namespace ClientSample var line = Console.ReadLine(); logger.LogInformation("Sending: {0}", line); - await connection.Output.WriteAsync(Encoding.UTF8.GetBytes(line)); + await connection.Output.WriteAsync(new Message( + ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve(), + Format.Text)); } logger.LogInformation("Send loop terminated"); } @@ -68,30 +72,31 @@ namespace ClientSample private static async Task StartReceiving(ILogger logger, Connection connection, CancellationToken cancellationToken) { logger.LogInformation("Receive loop starting"); - using (cancellationToken.Register(() => connection.Input.Complete())) + try { - while (!cancellationToken.IsCancellationRequested) + while (await connection.Input.WaitToReadAsync(cancellationToken)) { - var result = await connection.Input.ReadAsync(); - var buffer = result.Buffer; - try + Message message; + if (!connection.Input.TryRead(out message)) { - if (!buffer.IsEmpty) - { - var message = Encoding.UTF8.GetString(buffer.ToArray()); - logger.LogInformation("Received: {0}", message); - } + continue; } - finally + + using (message) { - connection.Input.Advance(buffer.End); - } - if (result.IsCompleted) - { - break; + logger.LogInformation("Received: {0}", Encoding.UTF8.GetString(message.Payload.Buffer.ToArray())); } } } + catch (OperationCanceledException) + { + logger.LogInformation("Connection is closing"); + } + catch (Exception ex) + { + logger.LogError(0, ex, "Connection terminated due to an exception"); + } + logger.LogInformation("Receive loop terminated"); } } diff --git a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs index 1ef2a6d811..e014769f69 100644 --- a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs +++ b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs @@ -1,13 +1,10 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; using System.Collections.Generic; using System.IO.Pipelines; -using System.Linq; using System.Text; using System.Threading.Tasks; -using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Sockets; namespace SocketsSample.EndPoints diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs index 16777390fe..966a311eff 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs @@ -11,6 +11,7 @@ using System.Linq; using System.Net.Http; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.Extensions.Logging; @@ -19,7 +20,6 @@ namespace Microsoft.AspNetCore.SignalR.Client public class HubConnection : IDisposable { private readonly Task _reader; - private readonly Stream _stream; private readonly ILogger _logger; private readonly Connection _connection; private readonly IInvocationAdapter _adapter; @@ -42,13 +42,11 @@ namespace Microsoft.AspNetCore.SignalR.Client { _binder = new HubBinder(this); _connection = connection; - _stream = connection.GetStream(); _adapter = adapter; _logger = logger; _reader = ReceiveMessages(_readerCts.Token); - Completion = _connection.Output.Writing.ContinueWith( - t => Shutdown(t)).Unwrap(); + Completion = _connection.Input.Completion.ContinueWith(t => Shutdown(t)).Unwrap(); } // TODO: Client return values/tasks? @@ -98,9 +96,21 @@ namespace Microsoft.AspNetCore.SignalR.Client _logger.LogTrace("Invocation #{0}: {1} {2}({3})", descriptor.Id, returnType.FullName, methodName, argsList); } - // Write the invocation to the stream + var ms = new MemoryStream(); + await _adapter.WriteMessageAsync(descriptor, ms, cancellationToken); + _logger.LogInformation("Sending Invocation #{0}", descriptor.Id); - await _adapter.WriteMessageAsync(descriptor, _stream, cancellationToken); + + // TODO: Format.Text - who, where and when decides about the format of outgoing messages + var message = new Message(ReadableBuffer.Create(ms.ToArray()).Preserve(), Format.Text); + while (await _connection.Output.WaitToWriteAsync()) + { + if (_connection.Output.TryWrite(message)) + { + break; + } + } + _logger.LogInformation("Sending Invocation #{0} complete", descriptor.Id); // Return the completion task. It will be completed by ReceiveMessages when the response is received. @@ -114,12 +124,12 @@ namespace Microsoft.AspNetCore.SignalR.Client } // TODO: Clean up the API here. Negotiation of format would be better than providing an adapter instance. Similarly, we should not require a logger factory - public static Task ConnectAsync(Uri url, IInvocationAdapter adapter, ITransport transport, PipelineFactory pipelineFactory, ILoggerFactory loggerFactory) => ConnectAsync(url, adapter, transport, new HttpClient(), pipelineFactory, loggerFactory); + public static Task ConnectAsync(Uri url, IInvocationAdapter adapter, ITransport transport, ILoggerFactory loggerFactory) => ConnectAsync(url, adapter, transport, new HttpClient(), loggerFactory); - public static async Task ConnectAsync(Uri url, IInvocationAdapter adapter, ITransport transport, HttpClient httpClient, PipelineFactory pipelineFactory, ILoggerFactory loggerFactory) + public static async Task ConnectAsync(Uri url, IInvocationAdapter adapter, ITransport transport, HttpClient httpClient, ILoggerFactory loggerFactory) { // Connect the underlying connection - var connection = await Connection.ConnectAsync(url, transport, httpClient, pipelineFactory, loggerFactory); + var connection = await Connection.ConnectAsync(url, transport, httpClient, loggerFactory); // Create the RPC connection wrapper return new HubConnection(connection, adapter, loggerFactory.CreateLogger()); @@ -132,30 +142,38 @@ namespace Microsoft.AspNetCore.SignalR.Client _logger.LogTrace("Beginning receive loop"); try { - while (!cancellationToken.IsCancellationRequested) + while (await _connection.Input.WaitToReadAsync(cancellationToken)) { - // This is a little odd... we want to remove the InvocationRequest once and only once so we pull it out in the callback, - // and stash it here because we know the callback will have finished before the end of the await. - var message = await _adapter.ReadMessageAsync(_stream, _binder, cancellationToken); + Message incomingMessage; + while (_connection.Input.TryRead(out incomingMessage)) + { - var invocationDescriptor = message as InvocationDescriptor; - if (invocationDescriptor != null) - { - DispatchInvocation(invocationDescriptor, cancellationToken); - } - else - { - var invocationResultDescriptor = message as InvocationResultDescriptor; - if (invocationResultDescriptor != null) + InvocationMessage message; + using (incomingMessage) { - InvocationRequest irq; - lock (_pendingCallsLock) + message = await _adapter.ReadMessageAsync( + new MemoryStream(incomingMessage.Payload.Buffer.ToArray()), _binder, cancellationToken); + } + + var invocationDescriptor = message as InvocationDescriptor; + if (invocationDescriptor != null) + { + DispatchInvocation(invocationDescriptor, cancellationToken); + } + else + { + var invocationResultDescriptor = message as InvocationResultDescriptor; + if (invocationResultDescriptor != null) { - _connectionActive.Token.ThrowIfCancellationRequested(); - irq = _pendingCalls[invocationResultDescriptor.Id]; - _pendingCalls.Remove(invocationResultDescriptor.Id); + InvocationRequest irq; + lock (_pendingCallsLock) + { + _connectionActive.Token.ThrowIfCancellationRequested(); + irq = _pendingCalls[invocationResultDescriptor.Id]; + _pendingCalls.Remove(invocationResultDescriptor.Id); + } + DispatchInvocationResult(invocationResultDescriptor, irq, cancellationToken); } - DispatchInvocationResult(invocationResultDescriptor, irq, cancellationToken); } } } diff --git a/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs index 0d5b20aee7..bceb79c04f 100644 --- a/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; using Microsoft.AspNetCore.SignalR; using Microsoft.Extensions.Options; diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs index cf6cbb40c6..12e2472381 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs @@ -2,56 +2,45 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO.Pipelines; using System.Net.Http; using System.Threading.Tasks; +using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Sockets.Client { - public class Connection : IPipelineConnection + public class Connection : IChannelConnection { - private IPipelineConnection _consumerPipe; + private IChannelConnection _transportChannel; private ITransport _transport; private readonly ILogger _logger; public Uri Url { get; } // TODO: Review. This is really only designed to be used from ConnectAsync - private Connection(Uri url, ITransport transport, IPipelineConnection consumerPipe, ILogger logger) + private Connection(Uri url, ITransport transport, IChannelConnection transportChannel, ILogger logger) { Url = url; _logger = logger; _transport = transport; - _consumerPipe = consumerPipe; - - _consumerPipe.Output.Writing.ContinueWith(t => - { - if (t.IsFaulted) - { - _consumerPipe.Input.Complete(t.Exception); - } - - return t; - }); + _transportChannel = transportChannel; } - public IPipelineReader Input => _consumerPipe.Input; - public IPipelineWriter Output => _consumerPipe.Output; + public ReadableChannel Input => _transportChannel.Input; + public WritableChannel Output => _transportChannel.Output; public void Dispose() { - _consumerPipe.Dispose(); _transport.Dispose(); } - // TODO: More overloads. PipelineFactory should be optional but someone needs to dispose the pool, if we're OK with it being the GC, then this is easy. - public static Task ConnectAsync(Uri url, ITransport transport, PipelineFactory pipelineFactory) => ConnectAsync(url, transport, new HttpClient(), pipelineFactory, NullLoggerFactory.Instance); - public static Task ConnectAsync(Uri url, ITransport transport, PipelineFactory pipelineFactory, ILoggerFactory loggerFactory) => ConnectAsync(url, transport, new HttpClient(), pipelineFactory, loggerFactory); - public static Task ConnectAsync(Uri url, ITransport transport, HttpClient httpClient, PipelineFactory pipelineFactory) => ConnectAsync(url, transport, httpClient, pipelineFactory, NullLoggerFactory.Instance); + public static Task ConnectAsync(Uri url, ITransport transport) => ConnectAsync(url, transport, new HttpClient(), NullLoggerFactory.Instance); + public static Task ConnectAsync(Uri url, ITransport transport, ILoggerFactory loggerFactory) => ConnectAsync(url, transport, new HttpClient(), loggerFactory); + public static Task ConnectAsync(Uri url, ITransport transport, HttpClient httpClient) => ConnectAsync(url, transport, httpClient, NullLoggerFactory.Instance); - public static async Task ConnectAsync(Uri url, ITransport transport, HttpClient httpClient, PipelineFactory pipelineFactory, ILoggerFactory loggerFactory) + public static async Task ConnectAsync(Uri url, ITransport transport, HttpClient httpClient, ILoggerFactory loggerFactory) { if (url == null) { @@ -68,11 +57,6 @@ namespace Microsoft.AspNetCore.Sockets.Client throw new ArgumentNullException(nameof(httpClient)); } - if (pipelineFactory == null) - { - throw new ArgumentNullException(nameof(pipelineFactory)); - } - if (loggerFactory == null) { throw new ArgumentNullException(nameof(loggerFactory)); @@ -97,12 +81,16 @@ namespace Microsoft.AspNetCore.Sockets.Client var connectedUrl = Utils.AppendQueryString(url, "id=" + connectionId); - var pair = pipelineFactory.CreatePipelinePair(); + var applicationToTransport = Channel.CreateUnbounded(); + var transportToApplication = Channel.CreateUnbounded(); + var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); + var transportSide = new ChannelConnection(applicationToTransport, transportToApplication); + // Start the transport, giving it one end of the pipeline try { - await transport.StartAsync(connectedUrl, pair.Item1); + await transport.StartAsync(connectedUrl, applicationSide); } catch (Exception ex) { @@ -111,7 +99,7 @@ namespace Microsoft.AspNetCore.Sockets.Client } // Create the connection, giving it the other end of the pipeline - return new Connection(url, transport, pair.Item2, logger); + return new Connection(url, transport, transportSide, logger); } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client/ITransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/ITransport.cs index 1424714e47..a4a2ce8bec 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/ITransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/ITransport.cs @@ -2,13 +2,12 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO.Pipelines; using System.Threading.Tasks; namespace Microsoft.AspNetCore.Sockets.Client { public interface ITransport : IDisposable { - Task StartAsync(Uri url, IPipelineConnection pipeline); + Task StartAsync(Uri url, IChannelConnection application); } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs index c4c45ca20f..6a6af78ab3 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO; using System.IO.Pipelines; using System.Net; using System.Net.Http; @@ -20,12 +21,10 @@ namespace Microsoft.AspNetCore.Sockets.Client private readonly HttpClient _httpClient; private readonly ILogger _logger; - private readonly CancellationTokenSource _senderCts = new CancellationTokenSource(); - private readonly CancellationTokenSource _pollCts = new CancellationTokenSource(); - - private IPipelineConnection _pipeline; + private IChannelConnection _application; private Task _sender; private Task _poller; + private readonly CancellationTokenSource _transportCts = new CancellationTokenSource(); public Task Running { get; private set; } @@ -37,26 +36,21 @@ namespace Microsoft.AspNetCore.Sockets.Client public void Dispose() { - _senderCts.Cancel(); - _pollCts.Cancel(); - _pipeline?.Dispose(); + _transportCts.Cancel(); } - public Task StartAsync(Uri url, IPipelineConnection pipeline) + public Task StartAsync(Uri url, IChannelConnection application) { - _pipeline = pipeline; - - // Schedule shutdown of the poller when the output is closed - pipeline.Output.Writing.ContinueWith(_ => - { - _pollCts.Cancel(); - return TaskCache.CompletedTask; - }); + _application = application; // Start sending and polling - _poller = Poll(Utils.AppendPath(url, "poll"), _pollCts.Token); - _sender = SendMessages(Utils.AppendPath(url, "send"), _senderCts.Token); - Running = Task.WhenAll(_sender, _poller); + _poller = Poll(Utils.AppendPath(url, "poll"), _transportCts.Token); + _sender = SendMessages(Utils.AppendPath(url, "send"), _transportCts.Token); + + Running = Task.WhenAll(_sender, _poller).ContinueWith(t => { + _application.Output.TryComplete(t.IsFaulted ? t.Exception.InnerException : null); + return t; + }).Unwrap(); return TaskCache.CompletedTask; } @@ -80,64 +74,71 @@ namespace Microsoft.AspNetCore.Sockets.Client } else { - // Write the data to the output - var buffer = _pipeline.Output.Alloc(); - var stream = new WriteableBufferStream(buffer); - await response.Content.CopyToAsync(stream); - await buffer.FlushAsync(); + var ms = new MemoryStream(); + await response.Content.CopyToAsync(ms); + var message = new Message(ReadableBuffer.Create(ms.ToArray()).Preserve(), Format.Text); + + while (await _application.Output.WaitToWriteAsync(cancellationToken)) + { + if (_application.Output.TryWrite(message)) + { + break; + } + } } } - - // Polling complete - _pipeline.Output.Complete(); + } + catch (OperationCanceledException) + { + // transport is being closed } catch (Exception ex) { - // Shut down the output pipeline and log _logger.LogError("Error while polling '{0}': {1}", pollUrl, ex); - _pipeline.Output.Complete(ex); - _pipeline.Input.Complete(ex); + throw; + } + finally + { + // Make sure the send loop is terminated + _transportCts.Cancel(); } } private async Task SendMessages(Uri sendUrl, CancellationToken cancellationToken) { - using (cancellationToken.Register(() => _pipeline.Input.Complete())) + try { - try + while (await _application.Input.WaitToReadAsync(cancellationToken)) { - while (!cancellationToken.IsCancellationRequested) + Message message; + while (!cancellationToken.IsCancellationRequested && _application.Input.TryRead(out message)) { - var result = await _pipeline.Input.ReadAsync(); - var buffer = result.Buffer; - if (buffer.IsEmpty || result.IsCompleted) + using (message) { - // No more data to send - break; + var request = new HttpRequestMessage(HttpMethod.Post, sendUrl); + request.Headers.UserAgent.Add(DefaultUserAgentHeader); + request.Content = new ReadableBufferContent(message.Payload.Buffer); + + var response = await _httpClient.SendAsync(request); + response.EnsureSuccessStatusCode(); } - - // Create a message to send - var message = new HttpRequestMessage(HttpMethod.Post, sendUrl); - message.Headers.UserAgent.Add(DefaultUserAgentHeader); - message.Content = new ReadableBufferContent(buffer); - - // Send it - var response = await _httpClient.SendAsync(message); - response.EnsureSuccessStatusCode(); - - _pipeline.Input.Advance(buffer.End); } - - // Sending complete - _pipeline.Input.Complete(); - } - catch (Exception ex) - { - // Shut down the input pipeline and log - _logger.LogError("Error while sending to '{0}': {1}", sendUrl, ex); - _pipeline.Input.Complete(ex); } } + catch (OperationCanceledException) + { + // transport is being closed + } + catch (Exception ex) + { + _logger.LogError("Error while sending to '{0}': {1}", sendUrl, ex); + throw; + } + finally + { + // Make sure the poll loop is terminated + _transportCts.Cancel(); + } } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Message.cs b/src/Microsoft.AspNetCore.Sockets.Client/Message.cs new file mode 100644 index 0000000000..0e45869608 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Client/Message.cs @@ -0,0 +1,25 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; + +namespace Microsoft.AspNetCore.Sockets.Client +{ + public struct Message : IDisposable + { + public Format MessageFormat { get; } + public PreservedBuffer Payload { get; } + + public Message(PreservedBuffer payload, Format messageFormat) + { + MessageFormat = messageFormat; + Payload = payload; + } + + public void Dispose() + { + Payload.Dispose(); + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.xproj b/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.xproj index be598871ad..611f959153 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.xproj +++ b/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.xproj @@ -4,18 +4,16 @@ 14.0 $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) - 623fd372-36de-41a9-a564-f6040d570dbd - Microsoft.AspNetCore.SignalR.Client + Microsoft.AspNetCore.Sockets.Client .\obj .\bin\ v4.5.2 - 2.0 - + \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Sockets.Client/PipelineConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client/PipelineConnection.cs deleted file mode 100644 index f0f2a05b5f..0000000000 --- a/src/Microsoft.AspNetCore.Sockets.Client/PipelineConnection.cs +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System.IO.Pipelines; - -namespace Microsoft.AspNetCore.Sockets.Client -{ - internal class PipelineConnection : IPipelineConnection - { - public IPipelineReader Input { get; } - public IPipelineWriter Output { get; } - - public PipelineConnection(IPipelineReader input, IPipelineWriter output) - { - Input = input; - Output = output; - } - - public void Dispose() - { - Input.Complete(); - Output.Complete(); - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets.Client/PipelineFactoryExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Client/PipelineFactoryExtensions.cs deleted file mode 100644 index 946db002b5..0000000000 --- a/src/Microsoft.AspNetCore.Sockets.Client/PipelineFactoryExtensions.cs +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.IO.Pipelines; - -namespace Microsoft.AspNetCore.Sockets.Client -{ - // TODO: Move to System.IO.Pipelines - public static class PipelineFactoryExtensions - { - // TODO: Use a named tuple? Though there aren't really good names for these ... client/server? left/right? - public static Tuple CreatePipelinePair(this PipelineFactory self) - { - // Create a pair of pipelines for "Server" and "Client" - var clientToServer = self.Create(); - var serverToClient = self.Create(); - - // "Server" reads from clientToServer and writes to serverToClient - var server = new PipelineConnection( - input: clientToServer, - output: serverToClient); - - // "Client" reads from serverToClient and writes to clientToServer - var client = new PipelineConnection( - input: serverToClient, - output: clientToServer); - - return Tuple.Create((IPipelineConnection)server, (IPipelineConnection)client); - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets.Client/BufferContent.cs b/src/Microsoft.AspNetCore.Sockets.Client/ReadableBufferContent.cs similarity index 100% rename from src/Microsoft.AspNetCore.Sockets.Client/BufferContent.cs rename to src/Microsoft.AspNetCore.Sockets.Client/ReadableBufferContent.cs diff --git a/src/Microsoft.AspNetCore.Sockets.Client/project.json b/src/Microsoft.AspNetCore.Sockets.Client/project.json index 678d6e13be..28e8122483 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/project.json +++ b/src/Microsoft.AspNetCore.Sockets.Client/project.json @@ -22,7 +22,8 @@ "compile": { "include": [ "../Microsoft.AspNetCore.Sockets/IChannelConnection.cs", - "../Microsoft.AspNetCore.Sockets/Internal/ChannelConnection.cs" + "../Microsoft.AspNetCore.Sockets/Internal/ChannelConnection.cs", + "../Microsoft.AspNetCore.Sockets/Format.cs" ] } }, diff --git a/src/Microsoft.AspNetCore.Sockets/SocketsDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.Sockets/SocketsDependencyInjectionExtensions.cs index be5d93aec1..5c0b250ab6 100644 --- a/src/Microsoft.AspNetCore.Sockets/SocketsDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets/SocketsDependencyInjectionExtensions.cs @@ -1,10 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Collections.Generic; -using System.IO.Pipelines; -using System.Linq; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection.Extensions; @@ -18,7 +14,6 @@ namespace Microsoft.Extensions.DependencyInjection services.AddRouting(); services.TryAddSingleton(); services.TryAddEnumerable(ServiceDescriptor.Singleton()); - services.TryAddSingleton(); services.TryAddSingleton(); return services; } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 40beb84601..71c798b038 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO.Pipelines; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; @@ -49,10 +48,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests var loggerFactory = CreateLogger(); using (var httpClient = _testServer.CreateClient()) - using (var pipelineFactory = new PipelineFactory()) { var transport = new LongPollingTransport(httpClient, loggerFactory); - using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), transport, httpClient, pipelineFactory, loggerFactory)) + using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), + new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) { EnsureConnectionEstablished(connection); @@ -70,10 +69,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests const string originalMessage = "SignalR"; using (var httpClient = _testServer.CreateClient()) - using (var pipelineFactory = new PipelineFactory()) { var transport = new LongPollingTransport(httpClient, loggerFactory); - using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), transport, httpClient, pipelineFactory, loggerFactory)) + using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), + new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) { EnsureConnectionEstablished(connection); @@ -91,10 +90,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests const string originalMessage = "SignalR"; using (var httpClient = _testServer.CreateClient()) - using (var pipelineFactory = new PipelineFactory()) { var transport = new LongPollingTransport(httpClient, loggerFactory); - using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), transport, httpClient, pipelineFactory, loggerFactory)) + using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), + new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) { var tcs = new TaskCompletionSource(); connection.On("Echo", new[] { typeof(string) }, a => @@ -118,10 +117,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests var loggerFactory = CreateLogger(); using (var httpClient = _testServer.CreateClient()) - using (var pipelineFactory = new PipelineFactory()) { var transport = new LongPollingTransport(httpClient, loggerFactory); - using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), transport, httpClient, pipelineFactory, loggerFactory)) + using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), + new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) { EnsureConnectionEstablished(connection); diff --git a/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs b/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs new file mode 100644 index 0000000000..78c9b7629a --- /dev/null +++ b/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs @@ -0,0 +1,168 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Moq; +using Moq.Protected; +using Xunit; + +namespace Microsoft.AspNetCore.Sockets.Client.Tests +{ + public class ConnectionTests + { + [Fact] + public async Task ConnectionReturnsUrlUsedToStartTheConnection() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + var connectionUrl = new Uri("http://fakeuri.org/"); + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + { + using (var connection = await Connection.ConnectAsync(connectionUrl, longPollingTransport, httpClient)) + { + Assert.Equal(connectionUrl, connection.Url); + } + + Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running)); + } + } + + [Fact] + public async Task TransportIsClosedWhenConnectionIsDisposed() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + { + using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + { + Assert.False(longPollingTransport.Running.IsCompleted); + } + + Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running)); + } + } + + [Fact] + public async Task CanSendData() + { + var sendTcs = new TaskCompletionSource(); + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + if (request.RequestUri.AbsolutePath.EndsWith("/send")) + { + sendTcs.SetResult(await request.Content.ReadAsByteArrayAsync()); + } + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + { + Assert.False(connection.Input.Completion.IsCompleted); + + var data = new byte[] { 1, 1, 2, 3, 5, 8 }; + connection.Output.TryWrite( + new Message(ReadableBuffer.Create(data).Preserve(), Format.Binary)); + + Assert.Equal(sendTcs.Task, await Task.WhenAny(Task.Delay(1000), sendTcs.Task)); + Assert.Equal(data, sendTcs.Task.Result); + } + } + + [Fact] + public async Task CanReceiveData() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + + var content = string.Empty; + if (request.RequestUri.AbsolutePath.EndsWith("/poll")) + { + content = "42"; + } + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(content) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + { + Assert.False(connection.Input.Completion.IsCompleted); + + await connection.Input.WaitToReadAsync(); + Message message; + connection.Input.TryRead(out message); + using (message) + { + Assert.Equal("42", Encoding.UTF8.GetString(message.Payload.Buffer.ToArray(), 0, message.Payload.Buffer.Length)); + } + } + } + + [Fact] + public async Task CanCloseConnection() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + { + Assert.False(connection.Input.Completion.IsCompleted); + connection.Output.TryComplete(); + + var whenAnyTask = Task.WhenAny(Task.Delay(1000), connection.Input.Completion); + + // The channel needs to be drained for the Completion task to be completed + Message message; + while (!whenAnyTask.IsCompleted) + { + connection.Input.TryRead(out message); + message.Dispose(); + } + + Assert.Equal(connection.Input.Completion, await whenAnyTask); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Sockets.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.Sockets.Client.Tests/LongPollingTransportTests.cs new file mode 100644 index 0000000000..8b067637b1 --- /dev/null +++ b/test/Microsoft.AspNetCore.Sockets.Client.Tests/LongPollingTransportTests.cs @@ -0,0 +1,169 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.Sockets.Internal; +using Microsoft.Extensions.Logging; +using Moq; +using Moq.Protected; +using Xunit; + +namespace Microsoft.AspNetCore.Sockets.Client.Tests +{ + public class LongPollingTransportTests + { + [Fact] + public async Task LongPollingTransportStopsPollAndSendLoopsWhenTransportDisposed() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + Task transportActiveTask; + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection); + + transportActiveTask = longPollingTransport.Running; + + Assert.False(transportActiveTask.IsCompleted); + } + + Assert.Equal(transportActiveTask, await Task.WhenAny(Task.Delay(1000), transportActiveTask)); + } + + [Fact] + public async Task LongPollingTransportStopsWhenPollReceives204() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return new HttpResponseMessage(HttpStatusCode.NoContent) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection); + + Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running)); + Assert.True(transportToConnection.In.Completion.IsCompleted); + } + } + + [Fact] + public async Task LongPollingTransportStopsWhenPollRequestFails() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return new HttpResponseMessage(HttpStatusCode.InternalServerError) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection); + + Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running)); + var exception = await Assert.ThrowsAsync(async () => await transportToConnection.In.Completion); + Assert.Contains(" 500 ", exception.Message); + } + } + + [Fact] + public async Task LongPollingTransportStopsWhenSendRequestFails() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + var statusCode = request.RequestUri.AbsolutePath.EndsWith("send") + ? HttpStatusCode.InternalServerError + : HttpStatusCode.OK; + return new HttpResponseMessage(statusCode) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection); + + await connectionToTransport.Out.WriteAsync(new Message()); + + Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running)); + + await Assert.ThrowsAsync(async () => await longPollingTransport.Running); + + // The channel needs to be drained for the Completion task to be completed + Message message; + while (transportToConnection.In.TryRead(out message)) + { + message.Dispose(); + } + + var exception = await Assert.ThrowsAsync(async () => await transportToConnection.In.Completion); + Assert.Contains(" 500 ", exception.Message); + } + } + + [Fact] + public async Task LongPollingTransportShutsDownWhenChannelIsClosed() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection); + + connectionToTransport.Out.Complete(); + + Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running)); + Assert.Equal(connectionToTransport.In.Completion, await Task.WhenAny(Task.Delay(1000), connectionToTransport.In.Completion)); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Sockets.Client.Tests/Microsoft.AspNetCore.Sockets.Client.Tests.xproj b/test/Microsoft.AspNetCore.Sockets.Client.Tests/Microsoft.AspNetCore.Sockets.Client.Tests.xproj index 0291384bd2..e89145f6e5 100644 --- a/test/Microsoft.AspNetCore.Sockets.Client.Tests/Microsoft.AspNetCore.Sockets.Client.Tests.xproj +++ b/test/Microsoft.AspNetCore.Sockets.Client.Tests/Microsoft.AspNetCore.Sockets.Client.Tests.xproj @@ -11,9 +11,11 @@ .\obj .\bin\ - 2.0 + + + \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Sockets.Client.Tests/project.json b/test/Microsoft.AspNetCore.Sockets.Client.Tests/project.json index 6a20cb4b1c..277a80267f 100644 --- a/test/Microsoft.AspNetCore.Sockets.Client.Tests/project.json +++ b/test/Microsoft.AspNetCore.Sockets.Client.Tests/project.json @@ -4,7 +4,10 @@ }, "dependencies": { + "Microsoft.AspNetCore.Sockets.Client": "1.0.0-*", + "Microsoft.Extensions.Logging": "1.2.0-*", "dotnet-test-xunit": "2.2.0-*", + "Moq": "4.6.36-*", "xunit": "2.2.0-*" }, "frameworks": { diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs index 035c81db65..1bcdfcaa4c 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs @@ -1,11 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Collections.Generic; using System.IO; using System.IO.Pipelines; -using System.Linq; using System.Text; using System.Threading.Tasks; using System.Threading.Tasks.Channels;