diff --git a/samples/ClientSample/HubSample.cs b/samples/ClientSample/HubSample.cs index 38115f782f..429c202b80 100644 --- a/samples/ClientSample/HubSample.cs +++ b/samples/ClientSample/HubSample.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO.Pipelines; using System.Net.Http; using System.Threading; using System.Threading.Tasks; diff --git a/samples/ClientSample/Program.cs b/samples/ClientSample/Program.cs index 328ed29c4b..951fd17f38 100644 --- a/samples/ClientSample/Program.cs +++ b/samples/ClientSample/Program.cs @@ -1,16 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.IO.Pipelines; -using System.Net.Http; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.AspNetCore.SignalR; -using Microsoft.AspNetCore.SignalR.Client; -using Microsoft.AspNetCore.Sockets.Client; -using Microsoft.Extensions.Logging; - namespace ClientSample { public class Program diff --git a/samples/ClientSample/RawSample.cs b/samples/ClientSample/RawSample.cs index fc402c63c3..8ba3487f4c 100644 --- a/samples/ClientSample/RawSample.cs +++ b/samples/ClientSample/RawSample.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO.Pipelines; using System.Net.Http; using System.Text; using System.Threading; @@ -62,9 +61,7 @@ namespace ClientSample var line = Console.ReadLine(); logger.LogInformation("Sending: {0}", line); - await connection.Output.WriteAsync(new Message( - ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve(), - Format.Text)); + await connection.SendAsync(Encoding.UTF8.GetBytes("Hello World"), Format.Text); } logger.LogInformation("Send loop terminated"); } @@ -74,18 +71,10 @@ namespace ClientSample logger.LogInformation("Receive loop starting"); try { - while (await connection.Input.WaitToReadAsync(cancellationToken)) + var receiveData = new ReceiveData(); + while (await connection.ReceiveAsync(receiveData, cancellationToken)) { - Message message; - if (!connection.Input.TryRead(out message)) - { - continue; - } - - using (message) - { - logger.LogInformation("Received: {0}", Encoding.UTF8.GetString(message.Payload.Buffer.ToArray())); - } + logger.LogInformation($"Received: {Encoding.UTF8.GetString(receiveData.Data)}"); } } catch (OperationCanceledException) diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs index 966a311eff..341bb08443 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs @@ -6,7 +6,6 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; -using System.IO.Pipelines; using System.Linq; using System.Net.Http; using System.Threading; @@ -36,8 +35,6 @@ namespace Microsoft.AspNetCore.SignalR.Client private int _nextId = 0; - public Task Completion { get; } - private HubConnection(Connection connection, IInvocationAdapter adapter, ILogger logger) { _binder = new HubBinder(this); @@ -46,7 +43,6 @@ namespace Microsoft.AspNetCore.SignalR.Client _logger = logger; _reader = ReceiveMessages(_readerCts.Token); - Completion = _connection.Input.Completion.ContinueWith(t => Shutdown(t)).Unwrap(); } // TODO: Client return values/tasks? @@ -102,14 +98,7 @@ namespace Microsoft.AspNetCore.SignalR.Client _logger.LogInformation("Sending Invocation #{0}", descriptor.Id); // TODO: Format.Text - who, where and when decides about the format of outgoing messages - var message = new Message(ReadableBuffer.Create(ms.ToArray()).Preserve(), Format.Text); - while (await _connection.Output.WaitToWriteAsync()) - { - if (_connection.Output.TryWrite(message)) - { - break; - } - } + await _connection.SendAsync(ms.ToArray(), Format.Text, cancellationToken); _logger.LogInformation("Sending Invocation #{0} complete", descriptor.Id); @@ -142,41 +131,35 @@ namespace Microsoft.AspNetCore.SignalR.Client _logger.LogTrace("Beginning receive loop"); try { - while (await _connection.Input.WaitToReadAsync(cancellationToken)) + ReceiveData receiveData = new ReceiveData(); + while (await _connection.ReceiveAsync(receiveData, cancellationToken)) { - Message incomingMessage; - while (_connection.Input.TryRead(out incomingMessage)) + var message + = await _adapter.ReadMessageAsync(new MemoryStream(receiveData.Data), _binder, cancellationToken); + + switch (message) { - - InvocationMessage message; - using (incomingMessage) - { - message = await _adapter.ReadMessageAsync( - new MemoryStream(incomingMessage.Payload.Buffer.ToArray()), _binder, cancellationToken); - } - - var invocationDescriptor = message as InvocationDescriptor; - if (invocationDescriptor != null) - { + case InvocationDescriptor invocationDescriptor: DispatchInvocation(invocationDescriptor, cancellationToken); - } - else - { - var invocationResultDescriptor = message as InvocationResultDescriptor; - if (invocationResultDescriptor != null) + break; + case InvocationResultDescriptor invocationResultDescriptor: + InvocationRequest irq; + lock (_pendingCallsLock) { - InvocationRequest irq; - lock (_pendingCallsLock) - { - _connectionActive.Token.ThrowIfCancellationRequested(); - irq = _pendingCalls[invocationResultDescriptor.Id]; - _pendingCalls.Remove(invocationResultDescriptor.Id); - } - DispatchInvocationResult(invocationResultDescriptor, irq, cancellationToken); + _connectionActive.Token.ThrowIfCancellationRequested(); + irq = _pendingCalls[invocationResultDescriptor.Id]; + _pendingCalls.Remove(invocationResultDescriptor.Id); } - } + DispatchInvocationResult(invocationResultDescriptor, irq, cancellationToken); + break; } } + Shutdown(); + } + catch (Exception ex) + { + Shutdown(ex); + throw; } finally { @@ -184,12 +167,12 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - private Task Shutdown(Task completion) + private void Shutdown(Exception ex = null) { _logger.LogTrace("Shutting down connection"); - if (completion.IsFaulted) + if (ex != null) { - _logger.LogError("Connection is shutting down due to an error: {0}", completion.Exception.InnerException); + _logger.LogError("Connection is shutting down due to an error: {0}", ex); } lock (_pendingCallsLock) @@ -197,27 +180,23 @@ namespace Microsoft.AspNetCore.SignalR.Client _connectionActive.Cancel(); foreach (var call in _pendingCalls.Values) { - if (!completion.IsFaulted) + if (ex != null) { call.Completion.TrySetCanceled(); } else { - call.Completion.TrySetException(completion.Exception.InnerException); + call.Completion.TrySetException(ex); } } _pendingCalls.Clear(); } - - // Return the completion anyway - return completion; } private void DispatchInvocation(InvocationDescriptor invocationDescriptor, CancellationToken cancellationToken) { // Find the handler - InvocationHandler handler; - if (!_handlers.TryGetValue(invocationDescriptor.Method, out handler)) + if (!_handlers.TryGetValue(invocationDescriptor.Method, out InvocationHandler handler)) { _logger.LogWarning("Failed to find handler for '{0}' method", invocationDescriptor.Method); } @@ -271,8 +250,7 @@ namespace Microsoft.AspNetCore.SignalR.Client public Type GetReturnType(string invocationId) { - InvocationRequest irq; - if (!_connection._pendingCalls.TryGetValue(invocationId, out irq)) + if (!_connection._pendingCalls.TryGetValue(invocationId, out InvocationRequest irq)) { _connection._logger.LogError("Unsolicited response received for invocation '{0}'", invocationId); return null; @@ -282,8 +260,7 @@ namespace Microsoft.AspNetCore.SignalR.Client public Type[] GetParameterTypes(string methodName) { - InvocationHandler handler; - if (!_connection._handlers.TryGetValue(methodName, out handler)) + if (!_connection._handlers.TryGetValue(methodName, out InvocationHandler handler)) { _connection._logger.LogWarning("Failed to find handler for '{0}' method", methodName); return Type.EmptyTypes; diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs index 41e5530e83..2fa57d2ac2 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs @@ -2,15 +2,17 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO.Pipelines; using System.Net.Http; +using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; -using Microsoft.Extensions.Logging; using Microsoft.AspNetCore.Sockets.Internal; +using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Sockets.Client { - public class Connection : IChannelConnection + public class Connection : IDisposable { private IChannelConnection _transportChannel; private ITransport _transport; @@ -18,7 +20,6 @@ namespace Microsoft.AspNetCore.Sockets.Client public Uri Url { get; } - // TODO: Review. This is really only designed to be used from ConnectAsync private Connection(Uri url, ITransport transport, IChannelConnection transportChannel, ILogger logger) { Url = url; @@ -28,17 +29,92 @@ namespace Microsoft.AspNetCore.Sockets.Client _transportChannel = transportChannel; } - public ReadableChannel Input => _transportChannel.Input; - public WritableChannel Output => _transportChannel.Output; + private ReadableChannel Input => _transportChannel.Input; + private WritableChannel Output => _transportChannel.Output; + + public Task ReceiveAsync(ReceiveData receiveData) + { + return ReceiveAsync(receiveData, CancellationToken.None); + } + + public async Task ReceiveAsync(ReceiveData receiveData, CancellationToken cancellationToken) + { + if (receiveData == null) + { + throw new ArgumentNullException(nameof(receiveData)); + } + + if (Input.Completion.IsCompleted) + { + throw new InvalidOperationException("Cannot receive messages when the connection is stopped."); + } + + try + { + while (await Input.WaitToReadAsync(cancellationToken)) + { + if (Input.TryRead(out Message message)) + { + using (message) + { + receiveData.Format = message.MessageFormat; + receiveData.Data = message.Payload.Buffer.ToArray(); + return true; + } + } + } + + await Input.Completion; + } + catch (OperationCanceledException) + { + // channel is being closed + } + catch (Exception ex) + { + Output.TryComplete(ex); + _logger.LogError("Error receiving message: {0}", ex); + throw; + } + + return false; + } + + public Task SendAsync(byte[] data, Format format) + { + return SendAsync(data, format, CancellationToken.None); + } + + public async Task SendAsync(byte[] data, Format format, CancellationToken cancellationToken) + { + var message = new Message(ReadableBuffer.Create(data).Preserve(), format); + + while (await Output.WaitToWriteAsync(cancellationToken)) + { + if (Output.TryWrite(message)) + { + return true; + } + } + + return false; + } + + public async Task StopAsync() + { + Output.TryComplete(); + await _transport.StopAsync(); + } public void Dispose() { + Output.TryComplete(); _transport.Dispose(); } - public static Task ConnectAsync(Uri url, ITransport transport) => ConnectAsync(url, transport, new HttpClient(), NullLoggerFactory.Instance); - public static Task ConnectAsync(Uri url, ITransport transport, ILoggerFactory loggerFactory) => ConnectAsync(url, transport, new HttpClient(), loggerFactory); - public static Task ConnectAsync(Uri url, ITransport transport, HttpClient httpClient) => ConnectAsync(url, transport, httpClient, NullLoggerFactory.Instance); + public static Task ConnectAsync(Uri url, ITransport transport) => ConnectAsync(url, transport, null, null); + public static Task ConnectAsync(Uri url, ITransport transport, ILoggerFactory loggerFactory) => ConnectAsync(url, transport, null, loggerFactory); + public static Task ConnectAsync(Uri url, ITransport transport, HttpClient httpClient) => ConnectAsync(url, transport, httpClient, null); public static async Task ConnectAsync(Uri url, ITransport transport, HttpClient httpClient, ILoggerFactory loggerFactory) { @@ -47,39 +123,16 @@ namespace Microsoft.AspNetCore.Sockets.Client throw new ArgumentNullException(nameof(url)); } + // TODO: Once we have websocket transport we would be able to use it as the default transport if (transport == null) { - throw new ArgumentNullException(nameof(transport)); - } - - if (httpClient == null) - { - throw new ArgumentNullException(nameof(httpClient)); - } - - if (loggerFactory == null) - { - throw new ArgumentNullException(nameof(loggerFactory)); + throw new ArgumentNullException(nameof(url)); } + loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; var logger = loggerFactory.CreateLogger(); - var negotiateUrl = Utils.AppendPath(url, "negotiate"); - string connectionId; - try - { - // Get a connection ID from the server - logger.LogDebug("Establishing Connection at: {0}", negotiateUrl); - connectionId = await httpClient.GetStringAsync(negotiateUrl); - logger.LogDebug("Connection Id: {0}", connectionId); - } - catch (Exception ex) - { - logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", negotiateUrl, ex); - throw; - } - - var connectedUrl = Utils.AppendQueryString(url, "id=" + connectionId); + var connectUrl = await GetConnectUrl(url, httpClient, logger); var applicationToTransport = Channel.CreateUnbounded(); var transportToApplication = Channel.CreateUnbounded(); @@ -90,7 +143,7 @@ namespace Microsoft.AspNetCore.Sockets.Client // Start the transport, giving it one end of the pipeline try { - await transport.StartAsync(connectedUrl, applicationSide); + await transport.StartAsync(connectUrl, applicationSide); } catch (Exception ex) { @@ -101,5 +154,41 @@ namespace Microsoft.AspNetCore.Sockets.Client // Create the connection, giving it the other end of the pipeline return new Connection(url, transport, transportSide, logger); } + + private static async Task GetConnectUrl(Uri url, HttpClient httpClient, ILogger logger) + { + var disposeHttpClient = httpClient == null; + httpClient = httpClient ?? new HttpClient(); + try + { + var connectionId = await GetConnectionId(url, httpClient, logger); + return Utils.AppendQueryString(url, "id=" + connectionId); + } + finally + { + if (disposeHttpClient) + { + httpClient.Dispose(); + } + } + } + + private static async Task GetConnectionId(Uri url, HttpClient httpClient, ILogger logger) + { + var negotiateUrl = Utils.AppendPath(url, "negotiate"); + try + { + // Get a connection ID from the server + logger.LogDebug("Establishing Connection at: {0}", negotiateUrl); + var connectionId = await httpClient.GetStringAsync(negotiateUrl); + logger.LogDebug("Connection Id: {0}", connectionId); + return connectionId; + } + catch (Exception ex) + { + logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", negotiateUrl, ex); + throw; + } + } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client/ITransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/ITransport.cs index a4a2ce8bec..604e6cc623 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/ITransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/ITransport.cs @@ -9,5 +9,6 @@ namespace Microsoft.AspNetCore.Sockets.Client public interface ITransport : IDisposable { Task StartAsync(Uri url, IChannelConnection application); + Task StopAsync(); } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs index 6a6af78ab3..a4ec417d20 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs @@ -34,11 +34,6 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger = loggerFactory.CreateLogger(); } - public void Dispose() - { - _transportCts.Cancel(); - } - public Task StartAsync(Uri url, IChannelConnection application) { _application = application; @@ -55,6 +50,17 @@ namespace Microsoft.AspNetCore.Sockets.Client return TaskCache.CompletedTask; } + public async Task StopAsync() + { + _transportCts.Cancel(); + await Running; + } + + public void Dispose() + { + _transportCts.Cancel(); + } + private async Task Poll(Uri pollUrl, CancellationToken cancellationToken) { try @@ -110,8 +116,7 @@ namespace Microsoft.AspNetCore.Sockets.Client { while (await _application.Input.WaitToReadAsync(cancellationToken)) { - Message message; - while (!cancellationToken.IsCancellationRequested && _application.Input.TryRead(out message)) + while (!cancellationToken.IsCancellationRequested && _application.Input.TryRead(out Message message)) { using (message) { diff --git a/src/Microsoft.AspNetCore.Sockets.Client/ReceiveData.cs b/src/Microsoft.AspNetCore.Sockets.Client/ReceiveData.cs new file mode 100644 index 0000000000..75668fe11d --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Client/ReceiveData.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Sockets.Client +{ + public class ReceiveData + { + public byte[] Data { get; set; } + + public Format Format { get; set; } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index cde8d074c0..748e5802ee 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -53,8 +53,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) { - EnsureConnectionEstablished(connection); - var result = await connection.Invoke("HelloWorld"); Assert.Equal("Hello World!", result); @@ -74,8 +72,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) { - EnsureConnectionEstablished(connection); - var result = await connection.Invoke("Echo", originalMessage); Assert.Equal(originalMessage, result); @@ -95,8 +91,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) { - EnsureConnectionEstablished(connection); - var result = await connection.Invoke("echo", originalMessage); Assert.Equal(originalMessage, result); @@ -122,8 +116,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests tcs.TrySetResult((string)a[0]); }); - EnsureConnectionEstablished(connection); - await connection.Invoke("CallEcho", originalMessage); var completed = await Task.WhenAny(Task.Delay(2000), tcs.Task); Assert.True(completed == tcs.Task, "Receive timed out!"); @@ -143,8 +135,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) { - EnsureConnectionEstablished(connection); - var ex = await Assert.ThrowsAnyAsync( async () => await connection.Invoke("!@#$%")); @@ -153,14 +143,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } - private static void EnsureConnectionEstablished(HubConnection connection) - { - if (connection.Completion.IsCompleted) - { - connection.Completion.GetAwaiter().GetResult(); - } - } - public void Dispose() { _testServer.Dispose(); diff --git a/test/Microsoft.AspNetCore.SignalR.Test.Server/Microsoft.AspNetCore.SignalR.Test.Server.csproj b/test/Microsoft.AspNetCore.SignalR.Test.Server/Microsoft.AspNetCore.SignalR.Test.Server.csproj index f161a75653..93768cdc63 100644 --- a/test/Microsoft.AspNetCore.SignalR.Test.Server/Microsoft.AspNetCore.SignalR.Test.Server.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Test.Server/Microsoft.AspNetCore.SignalR.Test.Server.csproj @@ -27,4 +27,4 @@ - + \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index 73c36285e1..4bb55a32c1 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -65,31 +65,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests var transport = new LongPollingTransport(httpClient, loggerFactory); using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo"), transport, httpClient, loggerFactory)) { - await connection.Output.WriteAsync(new Message( - ReadableBuffer.Create(Encoding.UTF8.GetBytes(message)).Preserve(), - Format.Text)); + await connection.SendAsync(Encoding.UTF8.GetBytes(message), Format.Text); - var received = await ReceiveMessage(connection).OrTimeout(); - Assert.Equal(message, received); + var receiveData = new ReceiveData(); + + Assert.True(await connection.ReceiveAsync(receiveData).OrTimeout()); + Assert.Equal(message, Encoding.UTF8.GetString(receiveData.Data)); } } } - - private static async Task ReceiveMessage(ClientConnection connection) - { - Message message; - while (await connection.Input.WaitToReadAsync()) - { - if (connection.Input.TryRead(out message)) - { - using (message) - { - return Encoding.UTF8.GetString(message.Payload.Buffer.ToArray()); - } - } - } - - return null; - } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs b/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs index 78c9b7629a..efc4e1b50e 100644 --- a/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO.Pipelines; using System.Net; using System.Net.Http; using System.Text; @@ -42,6 +41,30 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests } } + [Fact] + public async Task TransportIsStoppedWhenConnectionIsStopped() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + { + Assert.False(longPollingTransport.Running.IsCompleted); + + await connection.StopAsync(); + + Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running)); + } + } + [Fact] public async Task TransportIsClosedWhenConnectionIsDisposed() { @@ -87,11 +110,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) { - Assert.False(connection.Input.Completion.IsCompleted); - var data = new byte[] { 1, 1, 2, 3, 5, 8 }; - connection.Output.TryWrite( - new Message(ReadableBuffer.Create(data).Preserve(), Format.Binary)); + await connection.SendAsync(data, Format.Binary); Assert.Equal(sendTcs.Task, await Task.WhenAny(Task.Delay(1000), sendTcs.Task)); Assert.Equal(data, sendTcs.Task.Result); @@ -120,20 +140,14 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) { - Assert.False(connection.Input.Completion.IsCompleted); - - await connection.Input.WaitToReadAsync(); - Message message; - connection.Input.TryRead(out message); - using (message) - { - Assert.Equal("42", Encoding.UTF8.GetString(message.Payload.Buffer.ToArray(), 0, message.Payload.Buffer.Length)); - } + var receiveData = new ReceiveData(); + Assert.True(await connection.ReceiveAsync(receiveData)); + Assert.Equal("42", Encoding.UTF8.GetString(receiveData.Data)); } } [Fact] - public async Task CanCloseConnection() + public async Task CannotSendAfterConnectionIsStopped() { var mockHttpHandler = new Mock(); mockHttpHandler.Protected() @@ -148,20 +162,95 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) { - Assert.False(connection.Input.Completion.IsCompleted); - connection.Output.TryComplete(); + await connection.StopAsync(); + Assert.False(await connection.SendAsync(new byte[] { 1, 1, 3, 5, 8 }, Format.Binary)); + } + } - var whenAnyTask = Task.WhenAny(Task.Delay(1000), connection.Input.Completion); - - // The channel needs to be drained for the Completion task to be completed - Message message; - while (!whenAnyTask.IsCompleted) + [Fact] + public async Task CannotReceiveAfterConnectionIsStopped() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => { - connection.Input.TryRead(out message); - message.Dispose(); - } + await Task.Yield(); + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); - Assert.Equal(connection.Input.Completion, await whenAnyTask); + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + { + await connection.StopAsync(); + var exception = await Assert.ThrowsAsync( + async () => await connection.ReceiveAsync(new ReceiveData())); + + Assert.Equal("Cannot receive messages when the connection is stopped.", exception.Message); + } + } + + [Fact] + public async Task CannotSendAfterReceiveThrewException() + { + var allowPollTcs = new TaskCompletionSource(); + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + if (request.RequestUri.AbsolutePath.EndsWith("/poll")) + { + await allowPollTcs.Task; + return new HttpResponseMessage(HttpStatusCode.InternalServerError) { Content = new StringContent(string.Empty) }; + } + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + { + var receiveTask = connection.ReceiveAsync(new ReceiveData()); + allowPollTcs.TrySetResult(null); + await Assert.ThrowsAsync(async () => await receiveTask); + + Assert.False(await connection.SendAsync(new byte[] { 1, 1, 3, 5, 8 }, Format.Binary)); + } + } + + [Fact] + public async Task CannotReceiveAfterReceiveThrewException() + { + var allowPollTcs = new TaskCompletionSource(); + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + if (request.RequestUri.AbsolutePath.EndsWith("/poll")) + { + await allowPollTcs.Task; + return new HttpResponseMessage(HttpStatusCode.InternalServerError) { Content = new StringContent(string.Empty) }; + } + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + { + var receiveTask = connection.ReceiveAsync(new ReceiveData()); + allowPollTcs.TrySetResult(null); + await Assert.ThrowsAsync(async () => await receiveTask); + + var exception = await Assert.ThrowsAsync( + async () => await connection.ReceiveAsync(new ReceiveData())); + + Assert.Equal("Cannot receive messages when the connection is stopped.", exception.Message); } } }