diff --git a/src/Common/PipeWriterStream.cs b/src/Common/PipeWriterStream.cs index ecf2b239ca..8c294b95b6 100644 --- a/src/Common/PipeWriterStream.cs +++ b/src/Common/PipeWriterStream.cs @@ -57,17 +57,28 @@ namespace System.IO.Pipelines public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - Write(buffer, offset, count); - return Task.CompletedTask; + return WriteCoreAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); } #if NETCOREAPP2_1 public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) { - _pipeWriter.Write(source.Span); - _length += source.Length; - return default; + return WriteCoreAsync(source, cancellationToken); } #endif + + private ValueTask WriteCoreAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + _length += source.Length; + var task = _pipeWriter.WriteAsync(source); + if (!task.IsCompletedSuccessfully) + { + return WriteSlowAsync(task); + } + + return default; + + async ValueTask WriteSlowAsync(ValueTask flushTask) => await flushTask; + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index 3351195c26..25be1fe0d7 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -516,6 +516,8 @@ namespace Microsoft.AspNetCore.SignalR.Client { var result = await _connectionState.Connection.Transport.Input.ReadAsync(); var buffer = result.Buffer; + var consumed = buffer.Start; + var examined = buffer.End; try { @@ -524,6 +526,12 @@ namespace Microsoft.AspNetCore.SignalR.Client { if (HandshakeProtocol.TryParseResponseMessage(ref buffer, out var message)) { + // Adjust consumed and examined to point to the end of the handshake + // response, this handles the case where invocations are sent in the same payload + // as the the negotiate response. + consumed = buffer.Start; + examined = consumed; + if (message.Error != null) { Log.HandshakeServerError(_logger, message.Error); @@ -543,10 +551,7 @@ namespace Microsoft.AspNetCore.SignalR.Client } finally { - // The buffer was sliced up to where it was consumed, so we can just advance to the start. - // We mark examined as buffer.End so that if we didn't receive a full frame, we'll wait for more data - // before yielding the read again. - _connectionState.Connection.Transport.Input.AdvanceTo(buffer.Start, buffer.End); + _connectionState.Connection.Transport.Input.AdvanceTo(consumed, examined); } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index d9395b76f5..e2a757d618 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -231,7 +231,10 @@ namespace Microsoft.AspNetCore.SignalR { using (var cts = new CancellationTokenSource()) { - cts.CancelAfter(timeout); + if (!Debugger.IsAttached) + { + cts.CancelAfter(timeout); + } while (true) { diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/LongPollingTransport.cs index 21925cac62..72b5bebf04 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/LongPollingTransport.cs @@ -124,7 +124,6 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal var stream = new PipeWriterStream(_application.Output); await response.Content.CopyToAsync(stream); - await _application.Output.FlushAsync(); } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs index e22a4a0367..15aab4ea1f 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs @@ -452,14 +452,10 @@ namespace Microsoft.AspNetCore.Sockets return; } - // Until the parsers are incremental, we buffer the entire request body before - // flushing the buffer. Using CopyToAsync allows us to avoid allocating a single giant - // buffer before writing. var pipeWriterStream = new PipeWriterStream(connection.Application.Output); await context.Request.Body.CopyToAsync(pipeWriterStream); Log.ReceivedBytes(_logger, pipeWriterStream.Length); - await connection.Application.Output.FlushAsync(); } private async Task EnsureConnectionStateAsync(HttpConnectionContext connection, HttpContext context, TransportType transportType, TransportType supportedTransports, ConnectionLogScope logScope, HttpConnectionOptions options) diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs index ff25d20291..3493d16475 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs @@ -160,16 +160,13 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports _application.Output.Advance(receiveResult.Count); - if (receiveResult.EndOfMessage) - { - var flushResult = await _application.Output.FlushAsync(); + var flushResult = await _application.Output.FlushAsync(); - // We canceled in the middle of applying back pressure - // or if the consumer is done - if (flushResult.IsCanceled || flushResult.IsCompleted) - { - break; - } + // We canceled in the middle of applying back pressure + // or if the consumer is done + if (flushResult.IsCanceled || flushResult.IsCompleted) + { + break; } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Protocol.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Protocol.cs index 2cabce8d80..2f2db62e1f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Protocol.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Protocol.cs @@ -102,7 +102,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { await hubConnection.StartAsync().OrTimeout(); - await connection.ReceiveJsonMessage(new {type = 7}).OrTimeout(); + await connection.ReceiveJsonMessage(new { type = 7 }).OrTimeout(); Exception closeException = await closedTcs.Task.OrTimeout(); Assert.Null(closeException); @@ -127,7 +127,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { await hubConnection.StartAsync().OrTimeout(); - await connection.ReceiveJsonMessage(new {type = 7, error = "Error!"}).OrTimeout(); + await connection.ReceiveJsonMessage(new { type = 7, error = "Error!" }).OrTimeout(); Exception closeException = await closedTcs.Task.OrTimeout(); Assert.NotNull(closeException); @@ -156,7 +156,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Assert.Equal("{\"type\":4,\"invocationId\":\"1\",\"target\":\"Foo\",\"arguments\":[]}", invokeMessage); // Complete the channel - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3}).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); await channel.Completion; } finally @@ -177,7 +177,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3}).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); await invokeTask.OrTimeout(); } @@ -199,7 +199,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var channel = await hubConnection.StreamAsChannelAsync("Foo").OrTimeout(); - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3}).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); Assert.Empty(await channel.ReadAllAsync()); } @@ -221,7 +221,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3, result = 42}).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, result = 42 }).OrTimeout(); Assert.Equal(42, await invokeTask.OrTimeout()); } @@ -243,7 +243,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3, error = "An error occurred"}).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "An error occurred" }).OrTimeout(); var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout(); Assert.Equal("An error occurred", ex.Message); @@ -266,7 +266,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var channel = await hubConnection.StreamAsChannelAsync("Foo").OrTimeout(); - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3, result = "Oops"}).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, result = "Oops" }).OrTimeout(); var ex = await Assert.ThrowsAsync(async () => await channel.ReadAllAsync().OrTimeout()); Assert.Equal("Server provided a result in a completion response to a streamed invocation.", ex.Message); @@ -289,7 +289,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var channel = await hubConnection.StreamAsChannelAsync("Foo").OrTimeout(); - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3, error = "An error occurred"}).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "An error occurred" }).OrTimeout(); var ex = await Assert.ThrowsAsync(async () => await channel.ReadAllAsync().OrTimeout()); Assert.Equal("An error occurred", ex.Message); @@ -312,7 +312,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 2, item = 42}).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = 42 }).OrTimeout(); var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout(); Assert.Equal("Streaming hub methods must be invoked with the 'HubConnection.StreamAsChannelAsync' method.", ex.Message); @@ -335,14 +335,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var channel = await hubConnection.StreamAsChannelAsync("Foo").OrTimeout(); - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 2, item = "1"}).OrTimeout(); - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 2, item = "2"}).OrTimeout(); - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 2, item = "3"}).OrTimeout(); - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3}).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "1" }).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "2" }).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "3" }).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); var notifications = await channel.ReadAllAsync().OrTimeout(); - Assert.Equal(new[] {"1", "2", "3",}, notifications.ToArray()); + Assert.Equal(new[] { "1", "2", "3", }, notifications.ToArray()); } finally { @@ -361,10 +361,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { await hubConnection.StartAsync().OrTimeout(); - hubConnection.On("Foo", (r1, r2, r3) => handlerCalled.TrySetResult(new object[] {r1, r2, r3})); + hubConnection.On("Foo", (r1, r2, r3) => handlerCalled.TrySetResult(new object[] { r1, r2, r3 })); - var args = new object[] {1, "Foo", 2.0f}; - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 1, target = "Foo", arguments = args}).OrTimeout(); + var args = new object[] { 1, "Foo", 2.0f }; + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 1, target = "Foo", arguments = args }).OrTimeout(); Assert.Equal(args, await handlerCalled.Task.OrTimeout()); } @@ -389,10 +389,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); // Receive the ping mid-invocation so we can see that the rest of the flow works fine - await connection.ReceiveJsonMessage(new {type = 6}).OrTimeout(); + await connection.ReceiveJsonMessage(new { type = 6 }).OrTimeout(); // Receive a completion - await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3}).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); // Ensure the invokeTask completes properly await invokeTask.OrTimeout(); @@ -403,6 +403,100 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await connection.DisposeAsync().OrTimeout(); } } + + [Fact] + public async Task PartialHandshakeResponseWorks() + { + var connection = new TestConnection(synchronousCallbacks: true, autoNegotiate: false); + var hubConnection = CreateHubConnection(connection); + try + { + var task = hubConnection.StartAsync(); + + await connection.ReceiveTextAsync("{"); + + Assert.False(task.IsCompleted); + + await connection.ReceiveTextAsync("}"); + + Assert.False(task.IsCompleted); + + await connection.ReceiveTextAsync("\u001e"); + + await task.OrTimeout(); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task HandshakeAndInvocationInSameBufferWorks() + { + var payload = "{}\u001e{\"type\":1, \"target\": \"Echo\", \"arguments\":[\"hello\"]}\u001e"; + var connection = new TestConnection(synchronousCallbacks: true, autoNegotiate: false); + var hubConnection = CreateHubConnection(connection); + try + { + var tcs = new TaskCompletionSource(); + hubConnection.On("Echo", data => + { + tcs.TrySetResult(data); + }); + + await connection.ReceiveTextAsync(payload); + + await hubConnection.StartAsync(); + + var response = await tcs.Task.OrTimeout(); + Assert.Equal("hello", response); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task PartialInvocationWorks() + { + var connection = new TestConnection(synchronousCallbacks: true); + var hubConnection = CreateHubConnection(connection); + try + { + var tcs = new TaskCompletionSource(); + hubConnection.On("Echo", data => + { + tcs.TrySetResult(data); + }); + + await hubConnection.StartAsync().OrTimeout(); + + await connection.ReceiveTextAsync("{\"type\":1, "); + + Assert.False(tcs.Task.IsCompleted); + + await connection.ReceiveTextAsync("\"target\": \"Echo\", \"arguments\""); + + Assert.False(tcs.Task.IsCompleted); + + await connection.ReceiveTextAsync(":[\"hello\"]}\u001e"); + + Assert.True(tcs.Task.IsCompleted); + + var response = await tcs.Task; + + Assert.Equal("hello", response); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index 56883a367e..4a1f8fc7be 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -38,13 +38,16 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public IFeatureCollection Features { get; } = new FeatureCollection(); public int DisposeCount => _disposeCount; - public TestConnection(Func onStart = null, Func onDispose = null, bool autoNegotiate = true) + public TestConnection(Func onStart = null, Func onDispose = null, bool autoNegotiate = true, bool synchronousCallbacks = false) { _autoNegotiate = autoNegotiate; _onStart = onStart ?? (() => Task.CompletedTask); _onDispose = onDispose ?? (() => Task.CompletedTask); - var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var scheduler = synchronousCallbacks ? PipeScheduler.Inline : null; + var options = new PipeOptions(readerScheduler: scheduler, writerScheduler: scheduler, useSynchronizationContext: false); + + var pair = DuplexPipe.CreateConnectionPair(options, options); Application = pair.Application; Transport = pair.Transport; @@ -88,6 +91,16 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return Application.Output.WriteAsync(bytes).AsTask(); } + public Task ReceiveTextAsync(string rawText) + { + return ReceiveBytesAsync(Encoding.UTF8.GetBytes(rawText)); + } + + public Task ReceiveBytesAsync(byte[] bytes) + { + return Application.Output.WriteAsync(bytes).AsTask(); + } + public async Task ReadSentTextMessageAsync() { // Read a single text message from the Application Input pipe diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs index bb68a0f7c1..df2af40261 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs @@ -82,5 +82,26 @@ namespace System.IO.Pipelines pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.End); } } + + public static async Task ReadAsync(this PipeReader pipeReader, int numBytes) + { + while (true) + { + var result = await pipeReader.ReadAsync(); + if (result.Buffer.Length < numBytes) + { + pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.End); + continue; + } + + var buffer = result.Buffer.Slice(0, numBytes); + + var bytes = buffer.ToArray(); + + pipeReader.AdvanceTo(buffer.End); + + return bytes; + } + } } } \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs index 8dacbe9b65..31eafa92fc 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs @@ -38,7 +38,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, bool addClaimId = false) { - var options = new PipeOptions(readerScheduler: synchronousCallbacks ? PipeScheduler.Inline : null); + var scheduler = synchronousCallbacks ? PipeScheduler.Inline : null; + var options = new PipeOptions(readerScheduler: scheduler, writerScheduler: scheduler, useSynchronizationContext: false); var pair = DuplexPipe.CreateConnectionPair(options, options); Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Transport, pair.Application); @@ -236,11 +237,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests else { // read first message out of the incoming data - if (!HandshakeProtocol.TryParseResponseMessage(ref buffer, out var responseMessage)) + if (HandshakeProtocol.TryParseResponseMessage(ref buffer, out var responseMessage)) { - throw new InvalidDataException("Unable to parse payload as a handshake response message."); + return responseMessage; } - return responseMessage; } } finally diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EchoConnectionHandler.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EchoConnectionHandler.cs index f9367ca2c6..398ccc60f1 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EchoConnectionHandler.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EchoConnectionHandler.cs @@ -13,18 +13,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests { public override async Task OnConnectedAsync(ConnectionContext connection) { - var result = await connection.Transport.Input.ReadAsync(); - var buffer = result.Buffer; - - try + while (true) { + var result = await connection.Transport.Input.ReadAsync(); + var buffer = result.Buffer; + if (!buffer.IsEmpty) { await connection.Transport.Output.WriteAsync(buffer.ToArray()); } - } - finally - { + else if (result.IsCompleted) + { + break; + } + connection.Transport.Input.AdvanceTo(buffer.End); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index 98e88fe958..fde6de4969 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -98,7 +98,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests var bytes = Encoding.UTF8.GetBytes(message); logger.LogInformation("Sending {length} byte frame", bytes.Length); - await ws.SendAsync(new ArraySegment(bytes), WebSocketMessageType.Binary, true, CancellationToken.None).OrTimeout(); + await ws.SendAsync(new ArraySegment(bytes), WebSocketMessageType.Binary, endOfMessage: true, CancellationToken.None).OrTimeout(); logger.LogInformation("Receiving frame"); var buffer = new ArraySegment(new byte[1024]); @@ -107,11 +107,49 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.Equal(bytes, buffer.Array.AsSpan().Slice(0, result.Count).ToArray()); + logger.LogInformation("Closing socket"); + await ws.CloseOutputAsync(WebSocketCloseStatus.Empty, "", CancellationToken.None).OrTimeout(); logger.LogInformation("Waiting for close"); result = await ws.ReceiveAsync(buffer, CancellationToken.None).OrTimeout(); Assert.Equal(WebSocketMessageType.Close, result.MessageType); + logger.LogInformation("Closed socket"); + } + } + } + + [ConditionalFact] + [WebSocketsSupportedCondition] + public async Task WebSocketsReceivesAndSendsPartialFramesTest() + { + using (StartLog(out var loggerFactory)) + { + var logger = loggerFactory.CreateLogger(); + + const string message = "Hello, World!"; + using (var ws = new ClientWebSocket()) + { + var socketUrl = _serverFixture.WebSocketsUrl + "/echo"; + + logger.LogInformation("Connecting WebSocket to {socketUrl}", socketUrl); + await ws.ConnectAsync(new Uri(socketUrl), CancellationToken.None).OrTimeout(); + + var bytes = Encoding.UTF8.GetBytes(message); + logger.LogInformation("Sending {length} byte frame", bytes.Length); + // We're sending a partial frame, we should still get the data + await ws.SendAsync(new ArraySegment(bytes), WebSocketMessageType.Binary, endOfMessage: false, CancellationToken.None).OrTimeout(); + + logger.LogInformation("Receiving frame"); + var buffer = new ArraySegment(new byte[1024]); + var result = await ws.ReceiveAsync(buffer, CancellationToken.None).OrTimeout(); + logger.LogInformation("Received {length} byte frame", result.Count); + + Assert.Equal(bytes, buffer.Array.AsSpan().Slice(0, result.Count).ToArray()); + logger.LogInformation("Closing socket"); - await ws.CloseAsync(WebSocketCloseStatus.Empty, "", CancellationToken.None).OrTimeout(); + await ws.CloseOutputAsync(WebSocketCloseStatus.Empty, "", CancellationToken.None).OrTimeout(); + logger.LogInformation("Waiting for close"); + result = await ws.ReceiveAsync(buffer, CancellationToken.None).OrTimeout(); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); logger.LogInformation("Closed socket"); } } @@ -141,7 +179,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests await connection.Transport.Output.WriteAsync(message).OrTimeout(); - var receivedData = await connection.Transport.Input.ReadAllAsync(); + var receivedData = await connection.Transport.Input.ReadAsync(1); Assert.Equal(message, receivedData); } catch (Exception ex) @@ -194,7 +232,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests logger.LogInformation("Sent message"); logger.LogInformation("Receiving message"); - Assert.Equal(message, Encoding.UTF8.GetString(await connection.Transport.Input.ReadAllAsync())); + Assert.Equal(message, Encoding.UTF8.GetString(await connection.Transport.Input.ReadAsync(bytes.Length))); logger.LogInformation("Completed receive"); } catch (Exception ex) @@ -245,7 +283,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests logger.LogInformation("Receiving message"); // Big timeout here because it can take a while to receive all the bytes - var receivedData = await connection.Transport.Input.ReadAllAsync(); + var receivedData = await connection.Transport.Input.ReadAsync(bytes.Length); Assert.Equal(message, Encoding.UTF8.GetString(receivedData)); logger.LogInformation("Completed receive"); } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs index d88e36e8bc..f198761582 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs @@ -352,6 +352,121 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task SendingHandshakeRequestInChunksWorks() + { + var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT)); + var part1 = Encoding.UTF8.GetBytes("{\"protocol\": \"json\""); + var part2 = Encoding.UTF8.GetBytes(",\"version\": 1}"); + var part3 = Encoding.UTF8.GetBytes("\u001e"); + + using (var client = new TestClient(synchronousCallbacks: true)) + { + client.SupportedFormats = TransferFormat.Text; + + var connectionHandlerTask = await client.ConnectAsync(connectionHandler, + sendHandshakeRequestMessage: false, + expectedHandshakeResponseMessage: false); + + // Wait for the handshake response + var task = client.ReadAsync(isHandshake: true); + + await client.Connection.Application.Output.WriteAsync(part1); + + Assert.False(task.IsCompleted); + + await client.Connection.Application.Output.WriteAsync(part2); + + Assert.False(task.IsCompleted); + + await client.Connection.Application.Output.WriteAsync(part3); + + Assert.True(task.IsCompleted); + + var response = (await task) as HandshakeResponseMessage; + Assert.NotNull(response); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + + [Fact] + public async Task SendingInvocatonInChunksWorks() + { + var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT)); + var part1 = Encoding.UTF8.GetBytes("{\"type\":1, \"invocationId\":\"1\", "); + var part2 = Encoding.UTF8.GetBytes("\"target\": \"Echo\", \"arguments\""); + var part3 = Encoding.UTF8.GetBytes(":[\"hello\"]}\u001e"); + + using (var client = new TestClient(synchronousCallbacks: true)) + { + client.SupportedFormats = TransferFormat.Text; + + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + // Wait for the hub completion + var task = client.ReadAsync(); + + await client.Connection.Application.Output.WriteAsync(part1); + + Assert.False(task.IsCompleted); + + await client.Connection.Application.Output.WriteAsync(part2); + + Assert.False(task.IsCompleted); + + await client.Connection.Application.Output.WriteAsync(part3); + + Assert.True(task.IsCompleted); + + var completionMessage = await task as CompletionMessage; + Assert.NotNull(completionMessage); + Assert.Equal("hello", completionMessage.Result); + Assert.Equal("1", completionMessage.InvocationId); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + + [Fact] + public async Task SendingHandshakeRequestAndInvocationInSamePayloadParsesHandshakeAndInvocation() + { + var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT)); + var payload = Encoding.UTF8.GetBytes("{\"protocol\": \"json\",\"version\": 1}\u001e{\"type\":1, \"invocationId\":\"1\", \"target\": \"Echo\", \"arguments\":[\"hello\"]}\u001e"); + + using (var client = new TestClient(synchronousCallbacks: true)) + { + client.SupportedFormats = TransferFormat.Text; + + var connectionHandlerTask = await client.ConnectAsync(connectionHandler, + sendHandshakeRequestMessage: false, + expectedHandshakeResponseMessage: false); + + // Wait for the handshake response + var task = client.ReadAsync(isHandshake: true); + + await client.Connection.Application.Output.WriteAsync(payload); + + Assert.True(task.IsCompleted); + + var response = await task as HandshakeResponseMessage; + Assert.NotNull(response); + + var completionMessage = await client.ReadAsync() as CompletionMessage; + Assert.NotNull(completionMessage); + Assert.Equal("hello", completionMessage.Result); + Assert.Equal("1", completionMessage.InvocationId); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + [Fact] public async Task HandshakeSuccessSendsResponseWithoutError() { diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/Startup.cs b/test/Microsoft.AspNetCore.SignalR.Tests/Startup.cs index 0cd9a6eece..b6ec249ed7 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/Startup.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/Startup.cs @@ -13,15 +13,21 @@ namespace Microsoft.AspNetCore.SignalR.Tests { services.AddConnections(); services.AddSignalR(); - services.AddSingleton(); - services.AddSingleton(); } public void Configure(IApplicationBuilder app, IHostingEnvironment env) { - app.UseConnections(options => options.MapConnectionHandler("/echo")); - app.UseConnections(options => options.MapConnectionHandler("/httpheader")); - app.UseSignalR(options => options.MapHub("/uncreatable")); + app.UseConnections(routes => + { + routes.MapConnectionHandler("/echo"); + routes.MapConnectionHandler("/echoAndClose"); + routes.MapConnectionHandler("/httpheader"); + }); + + app.UseSignalR(routes => + { + routes.MapHub("/uncreatable"); + }); } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs index 039d37ba85..8254ec2a6f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs @@ -138,11 +138,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); - await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), pair.Application, transferFormat, connection: Mock.Of()); + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echoAndClose"), pair.Application, transferFormat, connection: Mock.Of()); await pair.Transport.Output.WriteAsync(new byte[] { 0x42 }); - // The echo endpoint closes the connection immediately after sending response which should stop the transport + // The echoAndClose endpoint closes the connection immediately after sending response which should stop the transport await webSocketsTransport.Running.OrTimeout(); Assert.True(pair.Transport.Input.TryRead(out var result)); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/WriteThenCloseConnectionHandler.cs b/test/Microsoft.AspNetCore.SignalR.Tests/WriteThenCloseConnectionHandler.cs new file mode 100644 index 0000000000..bf96db4266 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests/WriteThenCloseConnectionHandler.cs @@ -0,0 +1,27 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.IO.Pipelines; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Sockets; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public class WriteThenCloseConnectionHandler : ConnectionHandler + { + public override async Task OnConnectedAsync(ConnectionContext connection) + { + var result = await connection.Transport.Input.ReadAsync(); + var buffer = result.Buffer; + + if (!buffer.IsEmpty) + { + await connection.Transport.Output.WriteAsync(buffer.ToArray()); + } + + connection.Transport.Input.AdvanceTo(buffer.End); + } + } +}