Hiding Channels (#183)

Hiding Channels
This commit is contained in:
Pawel Kadluczka 2017-02-09 10:31:07 -08:00 committed by GitHub
parent b711128ec2
commit 0c8df245de
12 changed files with 306 additions and 190 deletions

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;

View File

@ -1,16 +1,6 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.Extensions.Logging;
namespace ClientSample
{
public class Program

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Net.Http;
using System.Text;
using System.Threading;
@ -62,9 +61,7 @@ namespace ClientSample
var line = Console.ReadLine();
logger.LogInformation("Sending: {0}", line);
await connection.Output.WriteAsync(new Message(
ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve(),
Format.Text));
await connection.SendAsync(Encoding.UTF8.GetBytes("Hello World"), Format.Text);
}
logger.LogInformation("Send loop terminated");
}
@ -74,18 +71,10 @@ namespace ClientSample
logger.LogInformation("Receive loop starting");
try
{
while (await connection.Input.WaitToReadAsync(cancellationToken))
var receiveData = new ReceiveData();
while (await connection.ReceiveAsync(receiveData, cancellationToken))
{
Message message;
if (!connection.Input.TryRead(out message))
{
continue;
}
using (message)
{
logger.LogInformation("Received: {0}", Encoding.UTF8.GetString(message.Payload.Buffer.ToArray()));
}
logger.LogInformation($"Received: {Encoding.UTF8.GetString(receiveData.Data)}");
}
}
catch (OperationCanceledException)

View File

@ -6,7 +6,6 @@ using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Net.Http;
using System.Threading;
@ -36,8 +35,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
private int _nextId = 0;
public Task Completion { get; }
private HubConnection(Connection connection, IInvocationAdapter adapter, ILogger logger)
{
_binder = new HubBinder(this);
@ -46,7 +43,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
_logger = logger;
_reader = ReceiveMessages(_readerCts.Token);
Completion = _connection.Input.Completion.ContinueWith(t => Shutdown(t)).Unwrap();
}
// TODO: Client return values/tasks?
@ -102,14 +98,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
_logger.LogInformation("Sending Invocation #{0}", descriptor.Id);
// TODO: Format.Text - who, where and when decides about the format of outgoing messages
var message = new Message(ReadableBuffer.Create(ms.ToArray()).Preserve(), Format.Text);
while (await _connection.Output.WaitToWriteAsync())
{
if (_connection.Output.TryWrite(message))
{
break;
}
}
await _connection.SendAsync(ms.ToArray(), Format.Text, cancellationToken);
_logger.LogInformation("Sending Invocation #{0} complete", descriptor.Id);
@ -142,41 +131,35 @@ namespace Microsoft.AspNetCore.SignalR.Client
_logger.LogTrace("Beginning receive loop");
try
{
while (await _connection.Input.WaitToReadAsync(cancellationToken))
ReceiveData receiveData = new ReceiveData();
while (await _connection.ReceiveAsync(receiveData, cancellationToken))
{
Message incomingMessage;
while (_connection.Input.TryRead(out incomingMessage))
var message
= await _adapter.ReadMessageAsync(new MemoryStream(receiveData.Data), _binder, cancellationToken);
switch (message)
{
InvocationMessage message;
using (incomingMessage)
{
message = await _adapter.ReadMessageAsync(
new MemoryStream(incomingMessage.Payload.Buffer.ToArray()), _binder, cancellationToken);
}
var invocationDescriptor = message as InvocationDescriptor;
if (invocationDescriptor != null)
{
case InvocationDescriptor invocationDescriptor:
DispatchInvocation(invocationDescriptor, cancellationToken);
}
else
{
var invocationResultDescriptor = message as InvocationResultDescriptor;
if (invocationResultDescriptor != null)
break;
case InvocationResultDescriptor invocationResultDescriptor:
InvocationRequest irq;
lock (_pendingCallsLock)
{
InvocationRequest irq;
lock (_pendingCallsLock)
{
_connectionActive.Token.ThrowIfCancellationRequested();
irq = _pendingCalls[invocationResultDescriptor.Id];
_pendingCalls.Remove(invocationResultDescriptor.Id);
}
DispatchInvocationResult(invocationResultDescriptor, irq, cancellationToken);
_connectionActive.Token.ThrowIfCancellationRequested();
irq = _pendingCalls[invocationResultDescriptor.Id];
_pendingCalls.Remove(invocationResultDescriptor.Id);
}
}
DispatchInvocationResult(invocationResultDescriptor, irq, cancellationToken);
break;
}
}
Shutdown();
}
catch (Exception ex)
{
Shutdown(ex);
throw;
}
finally
{
@ -184,12 +167,12 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
}
private Task Shutdown(Task completion)
private void Shutdown(Exception ex = null)
{
_logger.LogTrace("Shutting down connection");
if (completion.IsFaulted)
if (ex != null)
{
_logger.LogError("Connection is shutting down due to an error: {0}", completion.Exception.InnerException);
_logger.LogError("Connection is shutting down due to an error: {0}", ex);
}
lock (_pendingCallsLock)
@ -197,27 +180,23 @@ namespace Microsoft.AspNetCore.SignalR.Client
_connectionActive.Cancel();
foreach (var call in _pendingCalls.Values)
{
if (!completion.IsFaulted)
if (ex != null)
{
call.Completion.TrySetCanceled();
}
else
{
call.Completion.TrySetException(completion.Exception.InnerException);
call.Completion.TrySetException(ex);
}
}
_pendingCalls.Clear();
}
// Return the completion anyway
return completion;
}
private void DispatchInvocation(InvocationDescriptor invocationDescriptor, CancellationToken cancellationToken)
{
// Find the handler
InvocationHandler handler;
if (!_handlers.TryGetValue(invocationDescriptor.Method, out handler))
if (!_handlers.TryGetValue(invocationDescriptor.Method, out InvocationHandler handler))
{
_logger.LogWarning("Failed to find handler for '{0}' method", invocationDescriptor.Method);
}
@ -271,8 +250,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
public Type GetReturnType(string invocationId)
{
InvocationRequest irq;
if (!_connection._pendingCalls.TryGetValue(invocationId, out irq))
if (!_connection._pendingCalls.TryGetValue(invocationId, out InvocationRequest irq))
{
_connection._logger.LogError("Unsolicited response received for invocation '{0}'", invocationId);
return null;
@ -282,8 +260,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
public Type[] GetParameterTypes(string methodName)
{
InvocationHandler handler;
if (!_connection._handlers.TryGetValue(methodName, out handler))
if (!_connection._handlers.TryGetValue(methodName, out InvocationHandler handler))
{
_connection._logger.LogWarning("Failed to find handler for '{0}' method", methodName);
return Type.EmptyTypes;

View File

@ -2,15 +2,17 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.Extensions.Logging;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Sockets.Client
{
public class Connection : IChannelConnection<Message>
public class Connection : IDisposable
{
private IChannelConnection<Message> _transportChannel;
private ITransport _transport;
@ -18,7 +20,6 @@ namespace Microsoft.AspNetCore.Sockets.Client
public Uri Url { get; }
// TODO: Review. This is really only designed to be used from ConnectAsync
private Connection(Uri url, ITransport transport, IChannelConnection<Message> transportChannel, ILogger logger)
{
Url = url;
@ -28,17 +29,92 @@ namespace Microsoft.AspNetCore.Sockets.Client
_transportChannel = transportChannel;
}
public ReadableChannel<Message> Input => _transportChannel.Input;
public WritableChannel<Message> Output => _transportChannel.Output;
private ReadableChannel<Message> Input => _transportChannel.Input;
private WritableChannel<Message> Output => _transportChannel.Output;
public Task<bool> ReceiveAsync(ReceiveData receiveData)
{
return ReceiveAsync(receiveData, CancellationToken.None);
}
public async Task<bool> ReceiveAsync(ReceiveData receiveData, CancellationToken cancellationToken)
{
if (receiveData == null)
{
throw new ArgumentNullException(nameof(receiveData));
}
if (Input.Completion.IsCompleted)
{
throw new InvalidOperationException("Cannot receive messages when the connection is stopped.");
}
try
{
while (await Input.WaitToReadAsync(cancellationToken))
{
if (Input.TryRead(out Message message))
{
using (message)
{
receiveData.Format = message.MessageFormat;
receiveData.Data = message.Payload.Buffer.ToArray();
return true;
}
}
}
await Input.Completion;
}
catch (OperationCanceledException)
{
// channel is being closed
}
catch (Exception ex)
{
Output.TryComplete(ex);
_logger.LogError("Error receiving message: {0}", ex);
throw;
}
return false;
}
public Task<bool> SendAsync(byte[] data, Format format)
{
return SendAsync(data, format, CancellationToken.None);
}
public async Task<bool> SendAsync(byte[] data, Format format, CancellationToken cancellationToken)
{
var message = new Message(ReadableBuffer.Create(data).Preserve(), format);
while (await Output.WaitToWriteAsync(cancellationToken))
{
if (Output.TryWrite(message))
{
return true;
}
}
return false;
}
public async Task StopAsync()
{
Output.TryComplete();
await _transport.StopAsync();
}
public void Dispose()
{
Output.TryComplete();
_transport.Dispose();
}
public static Task<Connection> ConnectAsync(Uri url, ITransport transport) => ConnectAsync(url, transport, new HttpClient(), NullLoggerFactory.Instance);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport, ILoggerFactory loggerFactory) => ConnectAsync(url, transport, new HttpClient(), loggerFactory);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport, HttpClient httpClient) => ConnectAsync(url, transport, httpClient, NullLoggerFactory.Instance);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport) => ConnectAsync(url, transport, null, null);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport, ILoggerFactory loggerFactory) => ConnectAsync(url, transport, null, loggerFactory);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport, HttpClient httpClient) => ConnectAsync(url, transport, httpClient, null);
public static async Task<Connection> ConnectAsync(Uri url, ITransport transport, HttpClient httpClient, ILoggerFactory loggerFactory)
{
@ -47,39 +123,16 @@ namespace Microsoft.AspNetCore.Sockets.Client
throw new ArgumentNullException(nameof(url));
}
// TODO: Once we have websocket transport we would be able to use it as the default transport
if (transport == null)
{
throw new ArgumentNullException(nameof(transport));
}
if (httpClient == null)
{
throw new ArgumentNullException(nameof(httpClient));
}
if (loggerFactory == null)
{
throw new ArgumentNullException(nameof(loggerFactory));
throw new ArgumentNullException(nameof(url));
}
loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
var logger = loggerFactory.CreateLogger<Connection>();
var negotiateUrl = Utils.AppendPath(url, "negotiate");
string connectionId;
try
{
// Get a connection ID from the server
logger.LogDebug("Establishing Connection at: {0}", negotiateUrl);
connectionId = await httpClient.GetStringAsync(negotiateUrl);
logger.LogDebug("Connection Id: {0}", connectionId);
}
catch (Exception ex)
{
logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", negotiateUrl, ex);
throw;
}
var connectedUrl = Utils.AppendQueryString(url, "id=" + connectionId);
var connectUrl = await GetConnectUrl(url, httpClient, logger);
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportToApplication = Channel.CreateUnbounded<Message>();
@ -90,7 +143,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
// Start the transport, giving it one end of the pipeline
try
{
await transport.StartAsync(connectedUrl, applicationSide);
await transport.StartAsync(connectUrl, applicationSide);
}
catch (Exception ex)
{
@ -101,5 +154,41 @@ namespace Microsoft.AspNetCore.Sockets.Client
// Create the connection, giving it the other end of the pipeline
return new Connection(url, transport, transportSide, logger);
}
private static async Task<Uri> GetConnectUrl(Uri url, HttpClient httpClient, ILogger logger)
{
var disposeHttpClient = httpClient == null;
httpClient = httpClient ?? new HttpClient();
try
{
var connectionId = await GetConnectionId(url, httpClient, logger);
return Utils.AppendQueryString(url, "id=" + connectionId);
}
finally
{
if (disposeHttpClient)
{
httpClient.Dispose();
}
}
}
private static async Task<string> GetConnectionId(Uri url, HttpClient httpClient, ILogger logger)
{
var negotiateUrl = Utils.AppendPath(url, "negotiate");
try
{
// Get a connection ID from the server
logger.LogDebug("Establishing Connection at: {0}", negotiateUrl);
var connectionId = await httpClient.GetStringAsync(negotiateUrl);
logger.LogDebug("Connection Id: {0}", connectionId);
return connectionId;
}
catch (Exception ex)
{
logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", negotiateUrl, ex);
throw;
}
}
}
}

View File

@ -9,5 +9,6 @@ namespace Microsoft.AspNetCore.Sockets.Client
public interface ITransport : IDisposable
{
Task StartAsync(Uri url, IChannelConnection<Message> application);
Task StopAsync();
}
}

View File

@ -34,11 +34,6 @@ namespace Microsoft.AspNetCore.Sockets.Client
_logger = loggerFactory.CreateLogger<LongPollingTransport>();
}
public void Dispose()
{
_transportCts.Cancel();
}
public Task StartAsync(Uri url, IChannelConnection<Message> application)
{
_application = application;
@ -55,6 +50,17 @@ namespace Microsoft.AspNetCore.Sockets.Client
return TaskCache.CompletedTask;
}
public async Task StopAsync()
{
_transportCts.Cancel();
await Running;
}
public void Dispose()
{
_transportCts.Cancel();
}
private async Task Poll(Uri pollUrl, CancellationToken cancellationToken)
{
try
@ -110,8 +116,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
while (await _application.Input.WaitToReadAsync(cancellationToken))
{
Message message;
while (!cancellationToken.IsCancellationRequested && _application.Input.TryRead(out message))
while (!cancellationToken.IsCancellationRequested && _application.Input.TryRead(out Message message))
{
using (message)
{

View File

@ -0,0 +1,12 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.Sockets.Client
{
public class ReceiveData
{
public byte[] Data { get; set; }
public Format Format { get; set; }
}
}

View File

@ -53,8 +53,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{
EnsureConnectionEstablished(connection);
var result = await connection.Invoke<string>("HelloWorld");
Assert.Equal("Hello World!", result);
@ -74,8 +72,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{
EnsureConnectionEstablished(connection);
var result = await connection.Invoke<string>("Echo", originalMessage);
Assert.Equal(originalMessage, result);
@ -95,8 +91,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{
EnsureConnectionEstablished(connection);
var result = await connection.Invoke<string>("echo", originalMessage);
Assert.Equal(originalMessage, result);
@ -122,8 +116,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
tcs.TrySetResult((string)a[0]);
});
EnsureConnectionEstablished(connection);
await connection.Invoke<Task>("CallEcho", originalMessage);
var completed = await Task.WhenAny(Task.Delay(2000), tcs.Task);
Assert.True(completed == tcs.Task, "Receive timed out!");
@ -143,8 +135,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{
EnsureConnectionEstablished(connection);
var ex = await Assert.ThrowsAnyAsync<Exception>(
async () => await connection.Invoke<object>("!@#$%"));
@ -153,14 +143,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
}
private static void EnsureConnectionEstablished(HubConnection connection)
{
if (connection.Completion.IsCompleted)
{
connection.Completion.GetAwaiter().GetResult();
}
}
public void Dispose()
{
_testServer.Dispose();

View File

@ -27,4 +27,4 @@
<Exec Command="npm run gulp -- --gulpfile $(MSBuildProjectDirectory)/../../src/Microsoft.AspNetCore.SignalR.Client.TS/gulpfile.js bundle-client --bundleOutDir $(MSBuildProjectDirectory)/wwwroot/lib/signalr-client/" />
</Target>
</Project>
</Project>

View File

@ -65,31 +65,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var transport = new LongPollingTransport(httpClient, loggerFactory);
using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo"), transport, httpClient, loggerFactory))
{
await connection.Output.WriteAsync(new Message(
ReadableBuffer.Create(Encoding.UTF8.GetBytes(message)).Preserve(),
Format.Text));
await connection.SendAsync(Encoding.UTF8.GetBytes(message), Format.Text);
var received = await ReceiveMessage(connection).OrTimeout();
Assert.Equal(message, received);
var receiveData = new ReceiveData();
Assert.True(await connection.ReceiveAsync(receiveData).OrTimeout());
Assert.Equal(message, Encoding.UTF8.GetString(receiveData.Data));
}
}
}
private static async Task<string> ReceiveMessage(ClientConnection connection)
{
Message message;
while (await connection.Input.WaitToReadAsync())
{
if (connection.Input.TryRead(out message))
{
using (message)
{
return Encoding.UTF8.GetString(message.Payload.Buffer.ToArray());
}
}
}
return null;
}
}
}

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Net;
using System.Net.Http;
using System.Text;
@ -42,6 +41,30 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
}
}
[Fact]
public async Task TransportIsStoppedWhenConnectionIsStopped()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
Assert.False(longPollingTransport.Running.IsCompleted);
await connection.StopAsync();
Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running));
}
}
[Fact]
public async Task TransportIsClosedWhenConnectionIsDisposed()
{
@ -87,11 +110,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
Assert.False(connection.Input.Completion.IsCompleted);
var data = new byte[] { 1, 1, 2, 3, 5, 8 };
connection.Output.TryWrite(
new Message(ReadableBuffer.Create(data).Preserve(), Format.Binary));
await connection.SendAsync(data, Format.Binary);
Assert.Equal(sendTcs.Task, await Task.WhenAny(Task.Delay(1000), sendTcs.Task));
Assert.Equal(data, sendTcs.Task.Result);
@ -120,20 +140,14 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
Assert.False(connection.Input.Completion.IsCompleted);
await connection.Input.WaitToReadAsync();
Message message;
connection.Input.TryRead(out message);
using (message)
{
Assert.Equal("42", Encoding.UTF8.GetString(message.Payload.Buffer.ToArray(), 0, message.Payload.Buffer.Length));
}
var receiveData = new ReceiveData();
Assert.True(await connection.ReceiveAsync(receiveData));
Assert.Equal("42", Encoding.UTF8.GetString(receiveData.Data));
}
}
[Fact]
public async Task CanCloseConnection()
public async Task CannotSendAfterConnectionIsStopped()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
@ -148,20 +162,95 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
Assert.False(connection.Input.Completion.IsCompleted);
connection.Output.TryComplete();
await connection.StopAsync();
Assert.False(await connection.SendAsync(new byte[] { 1, 1, 3, 5, 8 }, Format.Binary));
}
}
var whenAnyTask = Task.WhenAny(Task.Delay(1000), connection.Input.Completion);
// The channel needs to be drained for the Completion task to be completed
Message message;
while (!whenAnyTask.IsCompleted)
[Fact]
public async Task CannotReceiveAfterConnectionIsStopped()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
connection.Input.TryRead(out message);
message.Dispose();
}
await Task.Yield();
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
Assert.Equal(connection.Input.Completion, await whenAnyTask);
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
await connection.StopAsync();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(
async () => await connection.ReceiveAsync(new ReceiveData()));
Assert.Equal("Cannot receive messages when the connection is stopped.", exception.Message);
}
}
[Fact]
public async Task CannotSendAfterReceiveThrewException()
{
var allowPollTcs = new TaskCompletionSource<object>();
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
if (request.RequestUri.AbsolutePath.EndsWith("/poll"))
{
await allowPollTcs.Task;
return new HttpResponseMessage(HttpStatusCode.InternalServerError) { Content = new StringContent(string.Empty) };
}
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
var receiveTask = connection.ReceiveAsync(new ReceiveData());
allowPollTcs.TrySetResult(null);
await Assert.ThrowsAsync<HttpRequestException>(async () => await receiveTask);
Assert.False(await connection.SendAsync(new byte[] { 1, 1, 3, 5, 8 }, Format.Binary));
}
}
[Fact]
public async Task CannotReceiveAfterReceiveThrewException()
{
var allowPollTcs = new TaskCompletionSource<object>();
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
if (request.RequestUri.AbsolutePath.EndsWith("/poll"))
{
await allowPollTcs.Task;
return new HttpResponseMessage(HttpStatusCode.InternalServerError) { Content = new StringContent(string.Empty) };
}
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
var receiveTask = connection.ReceiveAsync(new ReceiveData());
allowPollTcs.TrySetResult(null);
await Assert.ThrowsAsync<HttpRequestException>(async () => await receiveTask);
var exception = await Assert.ThrowsAsync<InvalidOperationException>(
async () => await connection.ReceiveAsync(new ReceiveData()));
Assert.Equal("Cannot receive messages when the connection is stopped.", exception.Message);
}
}
}