Hiding Channels (#183)

Hiding Channels
This commit is contained in:
Pawel Kadluczka 2017-02-09 10:31:07 -08:00 committed by GitHub
parent b711128ec2
commit 0c8df245de
12 changed files with 306 additions and 190 deletions

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.IO.Pipelines;
using System.Net.Http; using System.Net.Http;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;

View File

@ -1,16 +1,6 @@
// Copyright (c) .NET Foundation. All rights reserved. // Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.Extensions.Logging;
namespace ClientSample namespace ClientSample
{ {
public class Program public class Program

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.IO.Pipelines;
using System.Net.Http; using System.Net.Http;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
@ -62,9 +61,7 @@ namespace ClientSample
var line = Console.ReadLine(); var line = Console.ReadLine();
logger.LogInformation("Sending: {0}", line); logger.LogInformation("Sending: {0}", line);
await connection.Output.WriteAsync(new Message( await connection.SendAsync(Encoding.UTF8.GetBytes("Hello World"), Format.Text);
ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve(),
Format.Text));
} }
logger.LogInformation("Send loop terminated"); logger.LogInformation("Send loop terminated");
} }
@ -74,18 +71,10 @@ namespace ClientSample
logger.LogInformation("Receive loop starting"); logger.LogInformation("Receive loop starting");
try try
{ {
while (await connection.Input.WaitToReadAsync(cancellationToken)) var receiveData = new ReceiveData();
while (await connection.ReceiveAsync(receiveData, cancellationToken))
{ {
Message message; logger.LogInformation($"Received: {Encoding.UTF8.GetString(receiveData.Data)}");
if (!connection.Input.TryRead(out message))
{
continue;
}
using (message)
{
logger.LogInformation("Received: {0}", Encoding.UTF8.GetString(message.Payload.Buffer.ToArray()));
}
} }
} }
catch (OperationCanceledException) catch (OperationCanceledException)

View File

@ -6,7 +6,6 @@ using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
using System.IO.Pipelines;
using System.Linq; using System.Linq;
using System.Net.Http; using System.Net.Http;
using System.Threading; using System.Threading;
@ -36,8 +35,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
private int _nextId = 0; private int _nextId = 0;
public Task Completion { get; }
private HubConnection(Connection connection, IInvocationAdapter adapter, ILogger logger) private HubConnection(Connection connection, IInvocationAdapter adapter, ILogger logger)
{ {
_binder = new HubBinder(this); _binder = new HubBinder(this);
@ -46,7 +43,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
_logger = logger; _logger = logger;
_reader = ReceiveMessages(_readerCts.Token); _reader = ReceiveMessages(_readerCts.Token);
Completion = _connection.Input.Completion.ContinueWith(t => Shutdown(t)).Unwrap();
} }
// TODO: Client return values/tasks? // TODO: Client return values/tasks?
@ -102,14 +98,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
_logger.LogInformation("Sending Invocation #{0}", descriptor.Id); _logger.LogInformation("Sending Invocation #{0}", descriptor.Id);
// TODO: Format.Text - who, where and when decides about the format of outgoing messages // TODO: Format.Text - who, where and when decides about the format of outgoing messages
var message = new Message(ReadableBuffer.Create(ms.ToArray()).Preserve(), Format.Text); await _connection.SendAsync(ms.ToArray(), Format.Text, cancellationToken);
while (await _connection.Output.WaitToWriteAsync())
{
if (_connection.Output.TryWrite(message))
{
break;
}
}
_logger.LogInformation("Sending Invocation #{0} complete", descriptor.Id); _logger.LogInformation("Sending Invocation #{0} complete", descriptor.Id);
@ -142,41 +131,35 @@ namespace Microsoft.AspNetCore.SignalR.Client
_logger.LogTrace("Beginning receive loop"); _logger.LogTrace("Beginning receive loop");
try try
{ {
while (await _connection.Input.WaitToReadAsync(cancellationToken)) ReceiveData receiveData = new ReceiveData();
while (await _connection.ReceiveAsync(receiveData, cancellationToken))
{ {
Message incomingMessage; var message
while (_connection.Input.TryRead(out incomingMessage)) = await _adapter.ReadMessageAsync(new MemoryStream(receiveData.Data), _binder, cancellationToken);
switch (message)
{ {
case InvocationDescriptor invocationDescriptor:
InvocationMessage message;
using (incomingMessage)
{
message = await _adapter.ReadMessageAsync(
new MemoryStream(incomingMessage.Payload.Buffer.ToArray()), _binder, cancellationToken);
}
var invocationDescriptor = message as InvocationDescriptor;
if (invocationDescriptor != null)
{
DispatchInvocation(invocationDescriptor, cancellationToken); DispatchInvocation(invocationDescriptor, cancellationToken);
} break;
else case InvocationResultDescriptor invocationResultDescriptor:
{ InvocationRequest irq;
var invocationResultDescriptor = message as InvocationResultDescriptor; lock (_pendingCallsLock)
if (invocationResultDescriptor != null)
{ {
InvocationRequest irq; _connectionActive.Token.ThrowIfCancellationRequested();
lock (_pendingCallsLock) irq = _pendingCalls[invocationResultDescriptor.Id];
{ _pendingCalls.Remove(invocationResultDescriptor.Id);
_connectionActive.Token.ThrowIfCancellationRequested();
irq = _pendingCalls[invocationResultDescriptor.Id];
_pendingCalls.Remove(invocationResultDescriptor.Id);
}
DispatchInvocationResult(invocationResultDescriptor, irq, cancellationToken);
} }
} DispatchInvocationResult(invocationResultDescriptor, irq, cancellationToken);
break;
} }
} }
Shutdown();
}
catch (Exception ex)
{
Shutdown(ex);
throw;
} }
finally finally
{ {
@ -184,12 +167,12 @@ namespace Microsoft.AspNetCore.SignalR.Client
} }
} }
private Task Shutdown(Task completion) private void Shutdown(Exception ex = null)
{ {
_logger.LogTrace("Shutting down connection"); _logger.LogTrace("Shutting down connection");
if (completion.IsFaulted) if (ex != null)
{ {
_logger.LogError("Connection is shutting down due to an error: {0}", completion.Exception.InnerException); _logger.LogError("Connection is shutting down due to an error: {0}", ex);
} }
lock (_pendingCallsLock) lock (_pendingCallsLock)
@ -197,27 +180,23 @@ namespace Microsoft.AspNetCore.SignalR.Client
_connectionActive.Cancel(); _connectionActive.Cancel();
foreach (var call in _pendingCalls.Values) foreach (var call in _pendingCalls.Values)
{ {
if (!completion.IsFaulted) if (ex != null)
{ {
call.Completion.TrySetCanceled(); call.Completion.TrySetCanceled();
} }
else else
{ {
call.Completion.TrySetException(completion.Exception.InnerException); call.Completion.TrySetException(ex);
} }
} }
_pendingCalls.Clear(); _pendingCalls.Clear();
} }
// Return the completion anyway
return completion;
} }
private void DispatchInvocation(InvocationDescriptor invocationDescriptor, CancellationToken cancellationToken) private void DispatchInvocation(InvocationDescriptor invocationDescriptor, CancellationToken cancellationToken)
{ {
// Find the handler // Find the handler
InvocationHandler handler; if (!_handlers.TryGetValue(invocationDescriptor.Method, out InvocationHandler handler))
if (!_handlers.TryGetValue(invocationDescriptor.Method, out handler))
{ {
_logger.LogWarning("Failed to find handler for '{0}' method", invocationDescriptor.Method); _logger.LogWarning("Failed to find handler for '{0}' method", invocationDescriptor.Method);
} }
@ -271,8 +250,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
public Type GetReturnType(string invocationId) public Type GetReturnType(string invocationId)
{ {
InvocationRequest irq; if (!_connection._pendingCalls.TryGetValue(invocationId, out InvocationRequest irq))
if (!_connection._pendingCalls.TryGetValue(invocationId, out irq))
{ {
_connection._logger.LogError("Unsolicited response received for invocation '{0}'", invocationId); _connection._logger.LogError("Unsolicited response received for invocation '{0}'", invocationId);
return null; return null;
@ -282,8 +260,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
public Type[] GetParameterTypes(string methodName) public Type[] GetParameterTypes(string methodName)
{ {
InvocationHandler handler; if (!_connection._handlers.TryGetValue(methodName, out InvocationHandler handler))
if (!_connection._handlers.TryGetValue(methodName, out handler))
{ {
_connection._logger.LogWarning("Failed to find handler for '{0}' method", methodName); _connection._logger.LogWarning("Failed to find handler for '{0}' method", methodName);
return Type.EmptyTypes; return Type.EmptyTypes;

View File

@ -2,15 +2,17 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.IO.Pipelines;
using System.Net.Http; using System.Net.Http;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Threading.Tasks.Channels; using System.Threading.Tasks.Channels;
using Microsoft.Extensions.Logging;
using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Sockets.Client namespace Microsoft.AspNetCore.Sockets.Client
{ {
public class Connection : IChannelConnection<Message> public class Connection : IDisposable
{ {
private IChannelConnection<Message> _transportChannel; private IChannelConnection<Message> _transportChannel;
private ITransport _transport; private ITransport _transport;
@ -18,7 +20,6 @@ namespace Microsoft.AspNetCore.Sockets.Client
public Uri Url { get; } public Uri Url { get; }
// TODO: Review. This is really only designed to be used from ConnectAsync
private Connection(Uri url, ITransport transport, IChannelConnection<Message> transportChannel, ILogger logger) private Connection(Uri url, ITransport transport, IChannelConnection<Message> transportChannel, ILogger logger)
{ {
Url = url; Url = url;
@ -28,17 +29,92 @@ namespace Microsoft.AspNetCore.Sockets.Client
_transportChannel = transportChannel; _transportChannel = transportChannel;
} }
public ReadableChannel<Message> Input => _transportChannel.Input; private ReadableChannel<Message> Input => _transportChannel.Input;
public WritableChannel<Message> Output => _transportChannel.Output; private WritableChannel<Message> Output => _transportChannel.Output;
public Task<bool> ReceiveAsync(ReceiveData receiveData)
{
return ReceiveAsync(receiveData, CancellationToken.None);
}
public async Task<bool> ReceiveAsync(ReceiveData receiveData, CancellationToken cancellationToken)
{
if (receiveData == null)
{
throw new ArgumentNullException(nameof(receiveData));
}
if (Input.Completion.IsCompleted)
{
throw new InvalidOperationException("Cannot receive messages when the connection is stopped.");
}
try
{
while (await Input.WaitToReadAsync(cancellationToken))
{
if (Input.TryRead(out Message message))
{
using (message)
{
receiveData.Format = message.MessageFormat;
receiveData.Data = message.Payload.Buffer.ToArray();
return true;
}
}
}
await Input.Completion;
}
catch (OperationCanceledException)
{
// channel is being closed
}
catch (Exception ex)
{
Output.TryComplete(ex);
_logger.LogError("Error receiving message: {0}", ex);
throw;
}
return false;
}
public Task<bool> SendAsync(byte[] data, Format format)
{
return SendAsync(data, format, CancellationToken.None);
}
public async Task<bool> SendAsync(byte[] data, Format format, CancellationToken cancellationToken)
{
var message = new Message(ReadableBuffer.Create(data).Preserve(), format);
while (await Output.WaitToWriteAsync(cancellationToken))
{
if (Output.TryWrite(message))
{
return true;
}
}
return false;
}
public async Task StopAsync()
{
Output.TryComplete();
await _transport.StopAsync();
}
public void Dispose() public void Dispose()
{ {
Output.TryComplete();
_transport.Dispose(); _transport.Dispose();
} }
public static Task<Connection> ConnectAsync(Uri url, ITransport transport) => ConnectAsync(url, transport, new HttpClient(), NullLoggerFactory.Instance); public static Task<Connection> ConnectAsync(Uri url, ITransport transport) => ConnectAsync(url, transport, null, null);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport, ILoggerFactory loggerFactory) => ConnectAsync(url, transport, new HttpClient(), loggerFactory); public static Task<Connection> ConnectAsync(Uri url, ITransport transport, ILoggerFactory loggerFactory) => ConnectAsync(url, transport, null, loggerFactory);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport, HttpClient httpClient) => ConnectAsync(url, transport, httpClient, NullLoggerFactory.Instance); public static Task<Connection> ConnectAsync(Uri url, ITransport transport, HttpClient httpClient) => ConnectAsync(url, transport, httpClient, null);
public static async Task<Connection> ConnectAsync(Uri url, ITransport transport, HttpClient httpClient, ILoggerFactory loggerFactory) public static async Task<Connection> ConnectAsync(Uri url, ITransport transport, HttpClient httpClient, ILoggerFactory loggerFactory)
{ {
@ -47,39 +123,16 @@ namespace Microsoft.AspNetCore.Sockets.Client
throw new ArgumentNullException(nameof(url)); throw new ArgumentNullException(nameof(url));
} }
// TODO: Once we have websocket transport we would be able to use it as the default transport
if (transport == null) if (transport == null)
{ {
throw new ArgumentNullException(nameof(transport)); throw new ArgumentNullException(nameof(url));
}
if (httpClient == null)
{
throw new ArgumentNullException(nameof(httpClient));
}
if (loggerFactory == null)
{
throw new ArgumentNullException(nameof(loggerFactory));
} }
loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
var logger = loggerFactory.CreateLogger<Connection>(); var logger = loggerFactory.CreateLogger<Connection>();
var negotiateUrl = Utils.AppendPath(url, "negotiate");
string connectionId; var connectUrl = await GetConnectUrl(url, httpClient, logger);
try
{
// Get a connection ID from the server
logger.LogDebug("Establishing Connection at: {0}", negotiateUrl);
connectionId = await httpClient.GetStringAsync(negotiateUrl);
logger.LogDebug("Connection Id: {0}", connectionId);
}
catch (Exception ex)
{
logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", negotiateUrl, ex);
throw;
}
var connectedUrl = Utils.AppendQueryString(url, "id=" + connectionId);
var applicationToTransport = Channel.CreateUnbounded<Message>(); var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportToApplication = Channel.CreateUnbounded<Message>(); var transportToApplication = Channel.CreateUnbounded<Message>();
@ -90,7 +143,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
// Start the transport, giving it one end of the pipeline // Start the transport, giving it one end of the pipeline
try try
{ {
await transport.StartAsync(connectedUrl, applicationSide); await transport.StartAsync(connectUrl, applicationSide);
} }
catch (Exception ex) catch (Exception ex)
{ {
@ -101,5 +154,41 @@ namespace Microsoft.AspNetCore.Sockets.Client
// Create the connection, giving it the other end of the pipeline // Create the connection, giving it the other end of the pipeline
return new Connection(url, transport, transportSide, logger); return new Connection(url, transport, transportSide, logger);
} }
private static async Task<Uri> GetConnectUrl(Uri url, HttpClient httpClient, ILogger logger)
{
var disposeHttpClient = httpClient == null;
httpClient = httpClient ?? new HttpClient();
try
{
var connectionId = await GetConnectionId(url, httpClient, logger);
return Utils.AppendQueryString(url, "id=" + connectionId);
}
finally
{
if (disposeHttpClient)
{
httpClient.Dispose();
}
}
}
private static async Task<string> GetConnectionId(Uri url, HttpClient httpClient, ILogger logger)
{
var negotiateUrl = Utils.AppendPath(url, "negotiate");
try
{
// Get a connection ID from the server
logger.LogDebug("Establishing Connection at: {0}", negotiateUrl);
var connectionId = await httpClient.GetStringAsync(negotiateUrl);
logger.LogDebug("Connection Id: {0}", connectionId);
return connectionId;
}
catch (Exception ex)
{
logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", negotiateUrl, ex);
throw;
}
}
} }
} }

View File

@ -9,5 +9,6 @@ namespace Microsoft.AspNetCore.Sockets.Client
public interface ITransport : IDisposable public interface ITransport : IDisposable
{ {
Task StartAsync(Uri url, IChannelConnection<Message> application); Task StartAsync(Uri url, IChannelConnection<Message> application);
Task StopAsync();
} }
} }

View File

@ -34,11 +34,6 @@ namespace Microsoft.AspNetCore.Sockets.Client
_logger = loggerFactory.CreateLogger<LongPollingTransport>(); _logger = loggerFactory.CreateLogger<LongPollingTransport>();
} }
public void Dispose()
{
_transportCts.Cancel();
}
public Task StartAsync(Uri url, IChannelConnection<Message> application) public Task StartAsync(Uri url, IChannelConnection<Message> application)
{ {
_application = application; _application = application;
@ -55,6 +50,17 @@ namespace Microsoft.AspNetCore.Sockets.Client
return TaskCache.CompletedTask; return TaskCache.CompletedTask;
} }
public async Task StopAsync()
{
_transportCts.Cancel();
await Running;
}
public void Dispose()
{
_transportCts.Cancel();
}
private async Task Poll(Uri pollUrl, CancellationToken cancellationToken) private async Task Poll(Uri pollUrl, CancellationToken cancellationToken)
{ {
try try
@ -110,8 +116,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
{ {
while (await _application.Input.WaitToReadAsync(cancellationToken)) while (await _application.Input.WaitToReadAsync(cancellationToken))
{ {
Message message; while (!cancellationToken.IsCancellationRequested && _application.Input.TryRead(out Message message))
while (!cancellationToken.IsCancellationRequested && _application.Input.TryRead(out message))
{ {
using (message) using (message)
{ {

View File

@ -0,0 +1,12 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.Sockets.Client
{
public class ReceiveData
{
public byte[] Data { get; set; }
public Format Format { get; set; }
}
}

View File

@ -53,8 +53,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{ {
EnsureConnectionEstablished(connection);
var result = await connection.Invoke<string>("HelloWorld"); var result = await connection.Invoke<string>("HelloWorld");
Assert.Equal("Hello World!", result); Assert.Equal("Hello World!", result);
@ -74,8 +72,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{ {
EnsureConnectionEstablished(connection);
var result = await connection.Invoke<string>("Echo", originalMessage); var result = await connection.Invoke<string>("Echo", originalMessage);
Assert.Equal(originalMessage, result); Assert.Equal(originalMessage, result);
@ -95,8 +91,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{ {
EnsureConnectionEstablished(connection);
var result = await connection.Invoke<string>("echo", originalMessage); var result = await connection.Invoke<string>("echo", originalMessage);
Assert.Equal(originalMessage, result); Assert.Equal(originalMessage, result);
@ -122,8 +116,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
tcs.TrySetResult((string)a[0]); tcs.TrySetResult((string)a[0]);
}); });
EnsureConnectionEstablished(connection);
await connection.Invoke<Task>("CallEcho", originalMessage); await connection.Invoke<Task>("CallEcho", originalMessage);
var completed = await Task.WhenAny(Task.Delay(2000), tcs.Task); var completed = await Task.WhenAny(Task.Delay(2000), tcs.Task);
Assert.True(completed == tcs.Task, "Receive timed out!"); Assert.True(completed == tcs.Task, "Receive timed out!");
@ -143,8 +135,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{ {
EnsureConnectionEstablished(connection);
var ex = await Assert.ThrowsAnyAsync<Exception>( var ex = await Assert.ThrowsAnyAsync<Exception>(
async () => await connection.Invoke<object>("!@#$%")); async () => await connection.Invoke<object>("!@#$%"));
@ -153,14 +143,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
} }
} }
private static void EnsureConnectionEstablished(HubConnection connection)
{
if (connection.Completion.IsCompleted)
{
connection.Completion.GetAwaiter().GetResult();
}
}
public void Dispose() public void Dispose()
{ {
_testServer.Dispose(); _testServer.Dispose();

View File

@ -65,31 +65,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var transport = new LongPollingTransport(httpClient, loggerFactory); var transport = new LongPollingTransport(httpClient, loggerFactory);
using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo"), transport, httpClient, loggerFactory)) using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo"), transport, httpClient, loggerFactory))
{ {
await connection.Output.WriteAsync(new Message( await connection.SendAsync(Encoding.UTF8.GetBytes(message), Format.Text);
ReadableBuffer.Create(Encoding.UTF8.GetBytes(message)).Preserve(),
Format.Text));
var received = await ReceiveMessage(connection).OrTimeout(); var receiveData = new ReceiveData();
Assert.Equal(message, received);
Assert.True(await connection.ReceiveAsync(receiveData).OrTimeout());
Assert.Equal(message, Encoding.UTF8.GetString(receiveData.Data));
} }
} }
} }
private static async Task<string> ReceiveMessage(ClientConnection connection)
{
Message message;
while (await connection.Input.WaitToReadAsync())
{
if (connection.Input.TryRead(out message))
{
using (message)
{
return Encoding.UTF8.GetString(message.Payload.Buffer.ToArray());
}
}
}
return null;
}
} }
} }

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.IO.Pipelines;
using System.Net; using System.Net;
using System.Net.Http; using System.Net.Http;
using System.Text; using System.Text;
@ -42,6 +41,30 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
} }
} }
[Fact]
public async Task TransportIsStoppedWhenConnectionIsStopped()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
Assert.False(longPollingTransport.Running.IsCompleted);
await connection.StopAsync();
Assert.Equal(longPollingTransport.Running, await Task.WhenAny(Task.Delay(1000), longPollingTransport.Running));
}
}
[Fact] [Fact]
public async Task TransportIsClosedWhenConnectionIsDisposed() public async Task TransportIsClosedWhenConnectionIsDisposed()
{ {
@ -87,11 +110,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{ {
Assert.False(connection.Input.Completion.IsCompleted);
var data = new byte[] { 1, 1, 2, 3, 5, 8 }; var data = new byte[] { 1, 1, 2, 3, 5, 8 };
connection.Output.TryWrite( await connection.SendAsync(data, Format.Binary);
new Message(ReadableBuffer.Create(data).Preserve(), Format.Binary));
Assert.Equal(sendTcs.Task, await Task.WhenAny(Task.Delay(1000), sendTcs.Task)); Assert.Equal(sendTcs.Task, await Task.WhenAny(Task.Delay(1000), sendTcs.Task));
Assert.Equal(data, sendTcs.Task.Result); Assert.Equal(data, sendTcs.Task.Result);
@ -120,20 +140,14 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{ {
Assert.False(connection.Input.Completion.IsCompleted); var receiveData = new ReceiveData();
Assert.True(await connection.ReceiveAsync(receiveData));
await connection.Input.WaitToReadAsync(); Assert.Equal("42", Encoding.UTF8.GetString(receiveData.Data));
Message message;
connection.Input.TryRead(out message);
using (message)
{
Assert.Equal("42", Encoding.UTF8.GetString(message.Payload.Buffer.ToArray(), 0, message.Payload.Buffer.Length));
}
} }
} }
[Fact] [Fact]
public async Task CanCloseConnection() public async Task CannotSendAfterConnectionIsStopped()
{ {
var mockHttpHandler = new Mock<HttpMessageHandler>(); var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected() mockHttpHandler.Protected()
@ -148,20 +162,95 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{ {
Assert.False(connection.Input.Completion.IsCompleted); await connection.StopAsync();
connection.Output.TryComplete(); Assert.False(await connection.SendAsync(new byte[] { 1, 1, 3, 5, 8 }, Format.Binary));
}
}
var whenAnyTask = Task.WhenAny(Task.Delay(1000), connection.Input.Completion); [Fact]
public async Task CannotReceiveAfterConnectionIsStopped()
// The channel needs to be drained for the Completion task to be completed {
Message message; var mockHttpHandler = new Mock<HttpMessageHandler>();
while (!whenAnyTask.IsCompleted) mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{ {
connection.Input.TryRead(out message); await Task.Yield();
message.Dispose(); return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
} });
Assert.Equal(connection.Input.Completion, await whenAnyTask); using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
await connection.StopAsync();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(
async () => await connection.ReceiveAsync(new ReceiveData()));
Assert.Equal("Cannot receive messages when the connection is stopped.", exception.Message);
}
}
[Fact]
public async Task CannotSendAfterReceiveThrewException()
{
var allowPollTcs = new TaskCompletionSource<object>();
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
if (request.RequestUri.AbsolutePath.EndsWith("/poll"))
{
await allowPollTcs.Task;
return new HttpResponseMessage(HttpStatusCode.InternalServerError) { Content = new StringContent(string.Empty) };
}
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
var receiveTask = connection.ReceiveAsync(new ReceiveData());
allowPollTcs.TrySetResult(null);
await Assert.ThrowsAsync<HttpRequestException>(async () => await receiveTask);
Assert.False(await connection.SendAsync(new byte[] { 1, 1, 3, 5, 8 }, Format.Binary));
}
}
[Fact]
public async Task CannotReceiveAfterReceiveThrewException()
{
var allowPollTcs = new TaskCompletionSource<object>();
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
if (request.RequestUri.AbsolutePath.EndsWith("/poll"))
{
await allowPollTcs.Task;
return new HttpResponseMessage(HttpStatusCode.InternalServerError) { Content = new StringContent(string.Empty) };
}
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
{
var receiveTask = connection.ReceiveAsync(new ReceiveData());
allowPollTcs.TrySetResult(null);
await Assert.ThrowsAsync<HttpRequestException>(async () => await receiveTask);
var exception = await Assert.ThrowsAsync<InvalidOperationException>(
async () => await connection.ReceiveAsync(new ReceiveData()));
Assert.Equal("Cannot receive messages when the connection is stopped.", exception.Message);
} }
} }
} }