From 3ba95b98af149b88658287eb143941437da3f4ad Mon Sep 17 00:00:00 2001 From: moozzyk Date: Wed, 8 Feb 2017 12:46:20 -0800 Subject: [PATCH] Converting static ConnectAsync to instance StartAsync --- samples/ClientSample/RawSample.cs | 5 +- .../HubConnection.cs | 4 +- .../Connection.cs | 196 +++++++++--------- .../WebSocketsTransport.cs | 4 +- .../EndToEndTests.cs | 32 +-- .../ConnectionTests.cs | 66 +++--- 6 files changed, 152 insertions(+), 155 deletions(-) diff --git a/samples/ClientSample/RawSample.cs b/samples/ClientSample/RawSample.cs index 2d2fcc7ba5..84791b4f21 100644 --- a/samples/ClientSample/RawSample.cs +++ b/samples/ClientSample/RawSample.cs @@ -30,8 +30,10 @@ namespace ClientSample { logger.LogInformation("Connecting to {0}", baseUrl); var transport = new LongPollingTransport(httpClient, loggerFactory); - using (var connection = await Connection.ConnectAsync(new Uri(baseUrl), transport, httpClient, loggerFactory)) + using (var connection = new Connection(new Uri(baseUrl), loggerFactory)) { + await connection.StartAsync(transport, httpClient); + logger.LogInformation("Connected to {0}", baseUrl); var cts = new CancellationTokenSource(); @@ -49,6 +51,7 @@ namespace ClientSample StartSending(loggerFactory.CreateLogger("SendLoop"), connection, cts.Token).ContinueWith(_ => cts.Cancel()); await Task.WhenAll(receive, send); + await connection.StopAsync(); } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs index 0a1870c84f..347ac8bae6 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs @@ -118,9 +118,9 @@ namespace Microsoft.AspNetCore.SignalR.Client public static async Task ConnectAsync(Uri url, IInvocationAdapter adapter, ITransport transport, HttpClient httpClient, ILoggerFactory loggerFactory) { // Connect the underlying connection - var connection = await Connection.ConnectAsync(url, transport, httpClient, loggerFactory); + var connection = new Connection(url, loggerFactory); + await connection.StartAsync(transport, httpClient); - // Create the RPC connection wrapper return new HubConnection(connection, adapter, loggerFactory.CreateLogger()); } diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs index 40a0865596..b299b7e2d1 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs @@ -14,24 +14,93 @@ namespace Microsoft.AspNetCore.Sockets.Client { public class Connection : IDisposable { + private readonly ILoggerFactory _loggerFactory; + private readonly ILogger _logger; private IChannelConnection _transportChannel; private ITransport _transport; - private readonly ILogger _logger; - - public Uri Url { get; } - - private Connection(Uri url, ITransport transport, IChannelConnection transportChannel, ILogger logger) - { - Url = url; - - _logger = logger; - _transport = transport; - _transportChannel = transportChannel; - } private ReadableChannel Input => _transportChannel.Input; private WritableChannel Output => _transportChannel.Output; + public Uri Url { get; } + + public Connection(Uri url) + : this(url, null) + { } + + public Connection(Uri url, ILoggerFactory loggerFactory) + { + Url = url ?? throw new ArgumentNullException(nameof(url)); + + _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; + _logger = _loggerFactory.CreateLogger(); + } + + public Task StartAsync(Uri url, ITransport transport) => StartAsync((ITransport)null, null); + public Task StartAsync(HttpClient httpClient) => StartAsync(null, httpClient); + public Task StartAsync(ITransport transport) => StartAsync(transport, null); + + public async Task StartAsync(ITransport transport, HttpClient httpClient) + { + // TODO: make transport optional + _transport = transport ?? throw new ArgumentNullException(nameof(transport)); + + var connectUrl = await GetConnectUrl(Url, httpClient, _logger); + + var applicationToTransport = Channel.CreateUnbounded(); + var transportToApplication = Channel.CreateUnbounded(); + var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); + _transportChannel = new ChannelConnection(applicationToTransport, transportToApplication); + + + // Start the transport, giving it one end of the pipeline + try + { + await transport.StartAsync(connectUrl, applicationSide); + } + catch (Exception ex) + { + _logger.LogError("Failed to start connection. Error starting transport '{0}': {1}", transport.GetType().Name, ex); + throw; + } + } + + private static async Task GetConnectUrl(Uri url, HttpClient httpClient, ILogger logger) + { + var disposeHttpClient = httpClient == null; + httpClient = httpClient ?? new HttpClient(); + try + { + var connectionId = await GetConnectionId(url, httpClient, logger); + return Utils.AppendQueryString(url, "id=" + connectionId); + } + finally + { + if (disposeHttpClient) + { + httpClient.Dispose(); + } + } + } + + private static async Task GetConnectionId(Uri url, HttpClient httpClient, ILogger logger) + { + var negotiateUrl = Utils.AppendPath(url, "negotiate"); + try + { + // Get a connection ID from the server + logger.LogDebug("Establishing Connection at: {0}", negotiateUrl); + var connectionId = await httpClient.GetStringAsync(negotiateUrl); + logger.LogDebug("Connection Id: {0}", connectionId); + return connectionId; + } + catch (Exception ex) + { + logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", negotiateUrl, ex); + throw; + } + } + public Task ReceiveAsync(ReceiveData receiveData) { return ReceiveAsync(receiveData, CancellationToken.None); @@ -102,104 +171,27 @@ namespace Microsoft.AspNetCore.Sockets.Client public async Task StopAsync() { - Output.TryComplete(); - await _transport.StopAsync(); - await DrainMessages(); + if (_transportChannel != null) + { + Output.TryComplete(); + } + + if (_transport != null) + { + await _transport.StopAsync(); + } } public void Dispose() { - Output.TryComplete(); - _transport.Dispose(); - } - - private async Task DrainMessages() - { - while (await Input.WaitToReadAsync()) + if (_transportChannel != null) { - if (Input.TryRead(out Message message)) - { - message.Dispose(); - } - } - } - - public static Task ConnectAsync(Uri url, ITransport transport) => ConnectAsync(url, transport, null, null); - public static Task ConnectAsync(Uri url, ITransport transport, ILoggerFactory loggerFactory) => ConnectAsync(url, transport, null, loggerFactory); - public static Task ConnectAsync(Uri url, ITransport transport, HttpClient httpClient) => ConnectAsync(url, transport, httpClient, null); - - public static async Task ConnectAsync(Uri url, ITransport transport, HttpClient httpClient, ILoggerFactory loggerFactory) - { - if (url == null) - { - throw new ArgumentNullException(nameof(url)); + Output.TryComplete(); } - // TODO: Once we have websocket transport we would be able to use it as the default transport - if (transport == null) + if (_transport != null) { - throw new ArgumentNullException(nameof(url)); - } - - loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; - var logger = loggerFactory.CreateLogger(); - - var connectUrl = await GetConnectUrl(url, httpClient, logger); - - var applicationToTransport = Channel.CreateUnbounded(); - var transportToApplication = Channel.CreateUnbounded(); - var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); - var transportSide = new ChannelConnection(applicationToTransport, transportToApplication); - - - // Start the transport, giving it one end of the pipeline - try - { - await transport.StartAsync(connectUrl, applicationSide); - } - catch (Exception ex) - { - logger.LogError("Failed to start connection. Error starting transport '{0}': {1}", transport.GetType().Name, ex); - throw; - } - - // Create the connection, giving it the other end of the pipeline - return new Connection(url, transport, transportSide, logger); - } - - private static async Task GetConnectUrl(Uri url, HttpClient httpClient, ILogger logger) - { - var disposeHttpClient = httpClient == null; - httpClient = httpClient ?? new HttpClient(); - try - { - var connectionId = await GetConnectionId(url, httpClient, logger); - return Utils.AppendQueryString(url, "id=" + connectionId); - } - finally - { - if (disposeHttpClient) - { - httpClient.Dispose(); - } - } - } - - private static async Task GetConnectionId(Uri url, HttpClient httpClient, ILogger logger) - { - var negotiateUrl = Utils.AppendPath(url, "negotiate"); - try - { - // Get a connection ID from the server - logger.LogDebug("Establishing Connection at: {0}", negotiateUrl); - var connectionId = await httpClient.GetStringAsync(negotiateUrl); - logger.LogDebug("Connection Id: {0}", connectionId); - return connectionId; - } - catch (Exception ex) - { - logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", negotiateUrl, ex); - throw; + _transport.Dispose(); } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs index 751cde5155..b27b35e477 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs @@ -30,8 +30,8 @@ namespace Microsoft.AspNetCore.Sockets.Client } public Task Running { get; private set; } - - public async Task StartAsync(Uri url, IChannelConnection application) + + public async Task StartAsync(Uri url, IChannelConnection application) { if (url == null) { diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index c862d5cf68..ca8ce16069 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -64,38 +64,22 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var httpClient = new HttpClient()) { var transport = new LongPollingTransport(httpClient, loggerFactory); - using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo"), transport, httpClient, loggerFactory)) + using (var connection = new ClientConnection(new Uri(baseUrl + "/echo"), loggerFactory)) { + await connection.StartAsync(transport, httpClient); + await connection.SendAsync(Encoding.UTF8.GetBytes(message), MessageType.Text); var receiveData = new ReceiveData(); Assert.True(await connection.ReceiveAsync(receiveData).OrTimeout()); Assert.Equal(message, Encoding.UTF8.GetString(receiveData.Data)); + + await connection.StopAsync(); } } } - [ConditionalFact] - [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] - public async Task ConnectionCanSendAndReceiveSmallMessagesWebSocketsTransport() - { - const string message = "Major Key"; - var baseUrl = _serverFixture.BaseUrl; - var loggerFactory = new LoggerFactory(); - - var transport = new WebSocketsTransport(); - using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo/ws"), transport, loggerFactory)) - { - await connection.SendAsync(Encoding.UTF8.GetBytes(message), MessageType.Text); - - var receiveData = new ReceiveData(); - - Assert.True(await connection.ReceiveAsync(receiveData).OrTimeout()); - Assert.Equal(message, Encoding.UTF8.GetString(receiveData.Data)); - } - } - public static IEnumerable MessageSizesData { get @@ -114,14 +98,18 @@ namespace Microsoft.AspNetCore.SignalR.Tests var loggerFactory = new LoggerFactory(); var transport = new WebSocketsTransport(); - using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo/ws"), transport, loggerFactory)) + using (var connection = new ClientConnection(new Uri(baseUrl + "/echo/ws"), loggerFactory)) { + await connection.StartAsync(transport); + await connection.SendAsync(Encoding.UTF8.GetBytes(message), MessageType.Text); var receiveData = new ReceiveData(); Assert.True(await connection.ReceiveAsync(receiveData).OrTimeout()); Assert.Equal(message, Encoding.UTF8.GetString(receiveData.Data)); + + await connection.StopAsync(); } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs b/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs index 00e4ec5f90..f9a9fc4a25 100644 --- a/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs @@ -18,28 +18,23 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests public class ConnectionTests { [Fact] - public async Task ConnectionReturnsUrlUsedToStartTheConnection() + public void CannotCreateConnectionWithNullUrl() { - var mockHttpHandler = new Mock(); - mockHttpHandler.Protected() - .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .Returns(async (request, cancellationToken) => - { - await Task.Yield(); - return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; - }); + var exception = Assert.Throws(() => new Connection(null)); + Assert.Equal("url", exception.ParamName); + } + [Fact] + public void ConnectionReturnsUrlUsedToStartTheConnection() + { var connectionUrl = new Uri("http://fakeuri.org/"); - using (var httpClient = new HttpClient(mockHttpHandler.Object)) - using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) - { - using (var connection = await Connection.ConnectAsync(connectionUrl, longPollingTransport, httpClient)) - { - Assert.Equal(connectionUrl, connection.Url); - } + Assert.Equal(connectionUrl, new Connection(connectionUrl).Url); + } - await longPollingTransport.Running.OrTimeout(); - } + [Fact] + public void CanDisposeNotStartedConnection() + { + using (new Connection(new Uri("http://fakeuri.org"))) { } } [Fact] @@ -56,8 +51,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) - using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + using (var connection = new Connection(new Uri("http://fakeuri.org/"))) { + await connection.StartAsync(longPollingTransport, httpClient); + Assert.False(longPollingTransport.Running.IsCompleted); await connection.StopAsync(); @@ -81,8 +78,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) { - using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + using (var connection = new Connection(new Uri("http://fakeuri.org/"))) { + await connection.StartAsync(longPollingTransport, httpClient); + Assert.False(longPollingTransport.Running.IsCompleted); } @@ -109,12 +108,16 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) - using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + using (var connection = new Connection(new Uri("http://fakeuri.org/"))) { + await connection.StartAsync(longPollingTransport, httpClient); + var data = new byte[] { 1, 1, 2, 3, 5, 8 }; await connection.SendAsync(data, MessageType.Binary); Assert.Equal(data, await sendTcs.Task.OrTimeout()); + + await connection.StopAsync(); } } @@ -138,11 +141,15 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) - using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + using (var connection = new Connection(new Uri("http://fakeuri.org/"))) { + await connection.StartAsync(longPollingTransport, httpClient); + var receiveData = new ReceiveData(); Assert.True(await connection.ReceiveAsync(receiveData)); Assert.Equal("42", Encoding.UTF8.GetString(receiveData.Data)); + + await connection.StopAsync(); } } @@ -160,8 +167,9 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) - using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + using (var connection = new Connection(new Uri("http://fakeuri.org/"))) { + await connection.StartAsync(longPollingTransport, httpClient); await connection.StopAsync(); Assert.False(await connection.SendAsync(new byte[] { 1, 1, 3, 5, 8 }, MessageType.Binary)); } @@ -181,8 +189,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) - using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + using (var connection = new Connection(new Uri("http://fakeuri.org/"))) { + await connection.StartAsync(longPollingTransport, httpClient); + await connection.StopAsync(); var exception = await Assert.ThrowsAsync( async () => await connection.ReceiveAsync(new ReceiveData())); @@ -211,8 +221,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) - using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + using (var connection = new Connection(new Uri("http://fakeuri.org/"))) { + await connection.StartAsync(longPollingTransport, httpClient); + var receiveTask = connection.ReceiveAsync(new ReceiveData()); allowPollTcs.TrySetResult(null); await Assert.ThrowsAsync(async () => await receiveTask); @@ -241,8 +253,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) - using (var connection = await Connection.ConnectAsync(new Uri("http://fakeuri.org/"), longPollingTransport, httpClient)) + using (var connection = new Connection(new Uri("http://fakeuri.org/"))) { + await connection.StartAsync(longPollingTransport, httpClient); + var receiveTask = connection.ReceiveAsync(new ReceiveData()); allowPollTcs.TrySetResult(null); await Assert.ThrowsAsync(async () => await receiveTask);