diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs index b299b7e2d1..0c6a6ad859 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs @@ -16,6 +16,7 @@ namespace Microsoft.AspNetCore.Sockets.Client { private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; + private int _connectionState = ConnectionState.Initial; private IChannelConnection _transportChannel; private ITransport _transport; @@ -36,7 +37,7 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger = _loggerFactory.CreateLogger(); } - public Task StartAsync(Uri url, ITransport transport) => StartAsync((ITransport)null, null); + public Task StartAsync() => StartAsync(null, null); public Task StartAsync(HttpClient httpClient) => StartAsync(null, httpClient); public Task StartAsync(ITransport transport) => StartAsync(transport, null); @@ -45,24 +46,24 @@ namespace Microsoft.AspNetCore.Sockets.Client // TODO: make transport optional _transport = transport ?? throw new ArgumentNullException(nameof(transport)); - var connectUrl = await GetConnectUrl(Url, httpClient, _logger); + if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connecting, ConnectionState.Initial) + != ConnectionState.Initial) + { + throw new InvalidOperationException("Cannot start a connection that is not in the Initial state."); + } - var applicationToTransport = Channel.CreateUnbounded(); - var transportToApplication = Channel.CreateUnbounded(); - var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); - _transportChannel = new ChannelConnection(applicationToTransport, transportToApplication); - - - // Start the transport, giving it one end of the pipeline try { - await transport.StartAsync(connectUrl, applicationSide); + var connectUrl = await GetConnectUrl(Url, httpClient, _logger); + await StartTransport(connectUrl); } - catch (Exception ex) + catch { - _logger.LogError("Failed to start connection. Error starting transport '{0}': {1}", transport.GetType().Name, ex); + Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected); throw; } + + Interlocked.Exchange(ref _connectionState, ConnectionState.Connected); } private static async Task GetConnectUrl(Uri url, HttpClient httpClient, ILogger logger) @@ -101,6 +102,29 @@ namespace Microsoft.AspNetCore.Sockets.Client } } + private async Task StartTransport(Uri connectUrl) + { + var applicationToTransport = Channel.CreateUnbounded(); + var transportToApplication = Channel.CreateUnbounded(); + var applicationSide = new ChannelConnection(transportToApplication, applicationToTransport); + + _transportChannel = new ChannelConnection(applicationToTransport, transportToApplication); +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + Input.Completion.ContinueWith(t => Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected)); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + + // Start the transport, giving it one end of the pipeline + try + { + await _transport.StartAsync(connectUrl, applicationSide); + } + catch (Exception ex) + { + _logger.LogError("Failed to start connection. Error starting transport '{0}': {1}", _transport.GetType().Name, ex); + throw; + } + } + public Task ReceiveAsync(ReceiveData receiveData) { return ReceiveAsync(receiveData, CancellationToken.None); @@ -156,6 +180,12 @@ namespace Microsoft.AspNetCore.Sockets.Client public async Task SendAsync(byte[] data, MessageType type, CancellationToken cancellationToken) { + // TODO: data == null? + if (_connectionState != ConnectionState.Connected) + { + return false; + } + var message = new Message(ReadableBuffer.Create(data).Preserve(), type); while (await Output.WaitToWriteAsync(cancellationToken)) @@ -171,6 +201,8 @@ namespace Microsoft.AspNetCore.Sockets.Client public async Task StopAsync() { + Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected); + if (_transportChannel != null) { Output.TryComplete(); @@ -184,6 +216,8 @@ namespace Microsoft.AspNetCore.Sockets.Client public void Dispose() { + Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected); + if (_transportChannel != null) { Output.TryComplete(); @@ -194,5 +228,13 @@ namespace Microsoft.AspNetCore.Sockets.Client _transport.Dispose(); } } + + private class ConnectionState + { + public const int Initial = 0; + public const int Connecting = 1; + public const int Connected = 2; + public const int Disconnected = 3; + } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs b/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs index f9a9fc4a25..38651dc50c 100644 --- a/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs @@ -37,6 +37,114 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests using (new Connection(new Uri("http://fakeuri.org"))) { } } + [Fact] + public async Task CannotStartRunningConnection() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + using (var connection = new Connection(new Uri("http://fakeuri.org/"))) + { + await connection.StartAsync(longPollingTransport, httpClient); + var exception = + await Assert.ThrowsAsync( + async () => await connection.StartAsync(longPollingTransport)); + Assert.Equal("Cannot start a connection that is not in the Initial state.", exception.Message); + + await connection.StopAsync(); + } + } + + [Fact] + public async Task CannotStartStoppedConnection() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + using (var connection = new Connection(new Uri("http://fakeuri.org/"))) + { + await connection.StartAsync(longPollingTransport, httpClient); + await connection.StopAsync(); + var exception = + await Assert.ThrowsAsync( + async () => await connection.StartAsync(longPollingTransport)); + + Assert.Equal("Cannot start a connection that is not in the Initial state.", exception.Message); + } + } + + [Fact] + public async Task CannotStartDisposedConnection() + { + using (var httpClient = new HttpClient()) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + { + var connection = new Connection(new Uri("http://fakeuri.org/")); + connection.Dispose(); + var exception = + await Assert.ThrowsAsync( + async () => await connection.StartAsync(longPollingTransport)); + + Assert.Equal("Cannot start a connection that is not in the Initial state.", exception.Message); + } + } + + [Fact] + public async Task SendReturnsFalseIfConnectionIsNotStarted() + { + using (var connection = new Connection(new Uri("http://fakeuri.org/"))) + { + Assert.False(await connection.SendAsync(new byte[0], MessageType.Binary)); + } + } + + [Fact] + public async Task SendReturnsFalseIfConnectionIsStopped() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + using (var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory())) + using (var connection = new Connection(new Uri("http://fakeuri.org/"))) + { + await connection.StartAsync(longPollingTransport, httpClient); + await connection.StopAsync(); + + Assert.False(await connection.SendAsync(new byte[0], MessageType.Binary)); + } + } + + [Fact] + public async Task SendReturnsFalseIfConnectionIsDisposed() + { + var connection = new Connection(new Uri("http://fakeuri.org/")); + connection.Dispose(); + Assert.False(await connection.SendAsync(new byte[0], MessageType.Binary)); + } + [Fact] public async Task TransportIsStoppedWhenConnectionIsStopped() {