Converting static ConnectAsync to instance StartAsync

This commit is contained in:
moozzyk 2017-02-08 12:46:20 -08:00
parent 70d97dd7b8
commit 3ba95b98af
6 changed files with 152 additions and 155 deletions

View File

@ -30,8 +30,10 @@ namespace ClientSample
{
logger.LogInformation("Connecting to {0}", baseUrl);
var transport = new LongPollingTransport(httpClient, loggerFactory);
using (var connection = await Connection.ConnectAsync(new Uri(baseUrl), transport, httpClient, loggerFactory))
using (var connection = new Connection(new Uri(baseUrl), loggerFactory))
{
await connection.StartAsync(transport, httpClient);
logger.LogInformation("Connected to {0}", baseUrl);
var cts = new CancellationTokenSource();
@ -49,6 +51,7 @@ namespace ClientSample
StartSending(loggerFactory.CreateLogger("SendLoop"), connection, cts.Token).ContinueWith(_ => cts.Cancel());
await Task.WhenAll(receive, send);
await connection.StopAsync();
}
}
}

View File

@ -118,9 +118,9 @@ namespace Microsoft.AspNetCore.SignalR.Client
public static async Task<HubConnection> ConnectAsync(Uri url, IInvocationAdapter adapter, ITransport transport, HttpClient httpClient, ILoggerFactory loggerFactory)
{
// Connect the underlying connection
var connection = await Connection.ConnectAsync(url, transport, httpClient, loggerFactory);
var connection = new Connection(url, loggerFactory);
await connection.StartAsync(transport, httpClient);
// Create the RPC connection wrapper
return new HubConnection(connection, adapter, loggerFactory.CreateLogger<HubConnection>());
}

View File

@ -14,24 +14,93 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
public class Connection : IDisposable
{
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
private IChannelConnection<Message> _transportChannel;
private ITransport _transport;
private readonly ILogger _logger;
public Uri Url { get; }
private Connection(Uri url, ITransport transport, IChannelConnection<Message> transportChannel, ILogger logger)
{
Url = url;
_logger = logger;
_transport = transport;
_transportChannel = transportChannel;
}
private ReadableChannel<Message> Input => _transportChannel.Input;
private WritableChannel<Message> Output => _transportChannel.Output;
public Uri Url { get; }
public Connection(Uri url)
: this(url, null)
{ }
public Connection(Uri url, ILoggerFactory loggerFactory)
{
Url = url ?? throw new ArgumentNullException(nameof(url));
_loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
_logger = _loggerFactory.CreateLogger<Connection>();
}
public Task StartAsync(Uri url, ITransport transport) => StartAsync((ITransport)null, null);
public Task StartAsync(HttpClient httpClient) => StartAsync(null, httpClient);
public Task StartAsync(ITransport transport) => StartAsync(transport, null);
public async Task StartAsync(ITransport transport, HttpClient httpClient)
{
// TODO: make transport optional
_transport = transport ?? throw new ArgumentNullException(nameof(transport));
var connectUrl = await GetConnectUrl(Url, httpClient, _logger);
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
_transportChannel = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
// Start the transport, giving it one end of the pipeline
try
{
await transport.StartAsync(connectUrl, applicationSide);
}
catch (Exception ex)
{
_logger.LogError("Failed to start connection. Error starting transport '{0}': {1}", transport.GetType().Name, ex);
throw;
}
}
private static async Task<Uri> GetConnectUrl(Uri url, HttpClient httpClient, ILogger logger)
{
var disposeHttpClient = httpClient == null;
httpClient = httpClient ?? new HttpClient();
try
{
var connectionId = await GetConnectionId(url, httpClient, logger);
return Utils.AppendQueryString(url, "id=" + connectionId);
}
finally
{
if (disposeHttpClient)
{
httpClient.Dispose();
}
}
}
private static async Task<string> GetConnectionId(Uri url, HttpClient httpClient, ILogger logger)
{
var negotiateUrl = Utils.AppendPath(url, "negotiate");
try
{
// Get a connection ID from the server
logger.LogDebug("Establishing Connection at: {0}", negotiateUrl);
var connectionId = await httpClient.GetStringAsync(negotiateUrl);
logger.LogDebug("Connection Id: {0}", connectionId);
return connectionId;
}
catch (Exception ex)
{
logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", negotiateUrl, ex);
throw;
}
}
public Task<bool> ReceiveAsync(ReceiveData receiveData)
{
return ReceiveAsync(receiveData, CancellationToken.None);
@ -102,104 +171,27 @@ namespace Microsoft.AspNetCore.Sockets.Client
public async Task StopAsync()
{
Output.TryComplete();
await _transport.StopAsync();
await DrainMessages();
if (_transportChannel != null)
{
Output.TryComplete();
}
if (_transport != null)
{
await _transport.StopAsync();
}
}
public void Dispose()
{
Output.TryComplete();
_transport.Dispose();
}
private async Task DrainMessages()
{
while (await Input.WaitToReadAsync())
if (_transportChannel != null)
{
if (Input.TryRead(out Message message))
{
message.Dispose();
}
}
}
public static Task<Connection> ConnectAsync(Uri url, ITransport transport) => ConnectAsync(url, transport, null, null);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport, ILoggerFactory loggerFactory) => ConnectAsync(url, transport, null, loggerFactory);
public static Task<Connection> ConnectAsync(Uri url, ITransport transport, HttpClient httpClient) => ConnectAsync(url, transport, httpClient, null);
public static async Task<Connection> ConnectAsync(Uri url, ITransport transport, HttpClient httpClient, ILoggerFactory loggerFactory)
{
if (url == null)
{
throw new ArgumentNullException(nameof(url));
Output.TryComplete();
}
// TODO: Once we have websocket transport we would be able to use it as the default transport
if (transport == null)
if (_transport != null)
{
throw new ArgumentNullException(nameof(url));
}
loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
var logger = loggerFactory.CreateLogger<Connection>();
var connectUrl = await GetConnectUrl(url, httpClient, logger);
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
// Start the transport, giving it one end of the pipeline
try
{
await transport.StartAsync(connectUrl, applicationSide);
}
catch (Exception ex)
{
logger.LogError("Failed to start connection. Error starting transport '{0}': {1}", transport.GetType().Name, ex);
throw;
}
// Create the connection, giving it the other end of the pipeline
return new Connection(url, transport, transportSide, logger);
}
private static async Task<Uri> GetConnectUrl(Uri url, HttpClient httpClient, ILogger logger)
{
var disposeHttpClient = httpClient == null;
httpClient = httpClient ?? new HttpClient();
try
{
var connectionId = await GetConnectionId(url, httpClient, logger);
return Utils.AppendQueryString(url, "id=" + connectionId);
}
finally
{
if (disposeHttpClient)
{
httpClient.Dispose();
}
}
}
private static async Task<string> GetConnectionId(Uri url, HttpClient httpClient, ILogger logger)
{
var negotiateUrl = Utils.AppendPath(url, "negotiate");
try
{
// Get a connection ID from the server
logger.LogDebug("Establishing Connection at: {0}", negotiateUrl);
var connectionId = await httpClient.GetStringAsync(negotiateUrl);
logger.LogDebug("Connection Id: {0}", connectionId);
return connectionId;
}
catch (Exception ex)
{
logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", negotiateUrl, ex);
throw;
_transport.Dispose();
}
}
}

View File

@ -30,8 +30,8 @@ namespace Microsoft.AspNetCore.Sockets.Client
}
public Task Running { get; private set; }
public async Task StartAsync(Uri url, IChannelConnection<Message> application)
public async Task StartAsync(Uri url, IChannelConnection<Message> application)
{
if (url == null)
{

View File

@ -64,38 +64,22 @@ namespace Microsoft.AspNetCore.SignalR.Tests
using (var httpClient = new HttpClient())
{
var transport = new LongPollingTransport(httpClient, loggerFactory);
using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo"), transport, httpClient, loggerFactory))
using (var connection = new ClientConnection(new Uri(baseUrl + "/echo"), loggerFactory))
{
await connection.StartAsync(transport, httpClient);
await connection.SendAsync(Encoding.UTF8.GetBytes(message), MessageType.Text);
var receiveData = new ReceiveData();
Assert.True(await connection.ReceiveAsync(receiveData).OrTimeout());
Assert.Equal(message, Encoding.UTF8.GetString(receiveData.Data));
await connection.StopAsync();
}
}
}
[ConditionalFact]
[OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")]
public async Task ConnectionCanSendAndReceiveSmallMessagesWebSocketsTransport()
{
const string message = "Major Key";
var baseUrl = _serverFixture.BaseUrl;
var loggerFactory = new LoggerFactory();
var transport = new WebSocketsTransport();
using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo/ws"), transport, loggerFactory))
{
await connection.SendAsync(Encoding.UTF8.GetBytes(message), MessageType.Text);
var receiveData = new ReceiveData();
Assert.True(await connection.ReceiveAsync(receiveData).OrTimeout());
Assert.Equal(message, Encoding.UTF8.GetString(receiveData.Data));
}
}
public static IEnumerable<object[]> MessageSizesData
{
get
@ -114,14 +98,18 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var loggerFactory = new LoggerFactory();
var transport = new WebSocketsTransport();
using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo/ws"), transport, loggerFactory))
using (var connection = new ClientConnection(new Uri(baseUrl + "/echo/ws"), loggerFactory))
{
await connection.StartAsync(transport);
await connection.SendAsync(Encoding.UTF8.GetBytes(message), MessageType.Text);
var receiveData = new ReceiveData();
Assert.True(await connection.ReceiveAsync(receiveData).OrTimeout());
Assert.Equal(message, Encoding.UTF8.GetString(receiveData.Data));
await connection.StopAsync();
}
}
}

View File

@ -18,28 +18,23 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
public class ConnectionTests
{
[Fact]
public async Task ConnectionReturnsUrlUsedToStartTheConnection()
public void CannotCreateConnectionWithNullUrl()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
var exception = Assert.Throws<ArgumentNullException>(() => new Connection(null));
Assert.Equal("url", exception.ParamName);
}
[Fact]
public void ConnectionReturnsUrlUsedToStartTheConnection()
{
var connectionUrl = new Uri("http://fakeuri.org/");
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
{
using (var connection = await Connection.ConnectAsync(connectionUrl, longPollingTransport, httpClient))
{
Assert.Equal(connectionUrl, connection.Url);
}
Assert.Equal(connectionUrl, new Connection(connectionUrl).Url);
}
await longPollingTransport.Running.OrTimeout();
}
[Fact]
public void CanDisposeNotStartedConnection()
{
using (new Connection(new Uri("http://fakeuri.org"))) { }
}
[Fact]
@ -56,8 +51,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
using (var connection = new Connection(new Uri("http://fakeuri.org/")))
{
await connection.StartAsync(longPollingTransport, httpClient);
Assert.False(longPollingTransport.Running.IsCompleted);
await connection.StopAsync();
@ -81,8 +78,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
{
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
using (var connection = new Connection(new Uri("http://fakeuri.org/")))
{
await connection.StartAsync(longPollingTransport, httpClient);
Assert.False(longPollingTransport.Running.IsCompleted);
}
@ -109,12 +108,16 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
using (var connection = new Connection(new Uri("http://fakeuri.org/")))
{
await connection.StartAsync(longPollingTransport, httpClient);
var data = new byte[] { 1, 1, 2, 3, 5, 8 };
await connection.SendAsync(data, MessageType.Binary);
Assert.Equal(data, await sendTcs.Task.OrTimeout());
await connection.StopAsync();
}
}
@ -138,11 +141,15 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
using (var connection = new Connection(new Uri("http://fakeuri.org/")))
{
await connection.StartAsync(longPollingTransport, httpClient);
var receiveData = new ReceiveData();
Assert.True(await connection.ReceiveAsync(receiveData));
Assert.Equal("42", Encoding.UTF8.GetString(receiveData.Data));
await connection.StopAsync();
}
}
@ -160,8 +167,9 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
using (var connection = new Connection(new Uri("http://fakeuri.org/")))
{
await connection.StartAsync(longPollingTransport, httpClient);
await connection.StopAsync();
Assert.False(await connection.SendAsync(new byte[] { 1, 1, 3, 5, 8 }, MessageType.Binary));
}
@ -181,8 +189,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
using (var connection = new Connection(new Uri("http://fakeuri.org/")))
{
await connection.StartAsync(longPollingTransport, httpClient);
await connection.StopAsync();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(
async () => await connection.ReceiveAsync(new ReceiveData()));
@ -211,8 +221,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
using (var connection = new Connection(new Uri("http://fakeuri.org/")))
{
await connection.StartAsync(longPollingTransport, httpClient);
var receiveTask = connection.ReceiveAsync(new ReceiveData());
allowPollTcs.TrySetResult(null);
await Assert.ThrowsAsync<HttpRequestException>(async () => await receiveTask);
@ -241,8 +253,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
using (var httpClient = new HttpClient(mockHttpHandler.Object))
using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()))
using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient))
using (var connection = new Connection(new Uri("http://fakeuri.org/")))
{
await connection.StartAsync(longPollingTransport, httpClient);
var receiveTask = connection.ReceiveAsync(new ReceiveData());
allowPollTcs.TrySetResult(null);
await Assert.ThrowsAsync<HttpRequestException>(async () => await receiveTask);