From 0133153bc92d2d16514bf3b66b0fcea3be386d3c Mon Sep 17 00:00:00 2001 From: Andrew Stanton-Nurse Date: Wed, 15 Mar 2017 18:03:23 -0700 Subject: [PATCH] use new protocol for '/send' (#297) --- .../LongPollingTransport.cs | 102 ++++++++++++++---- ...Microsoft.AspNetCore.Sockets.Client.csproj | 2 + .../Message.cs | 5 + .../HttpConnectionDispatcher.cs | 77 ++++++++++--- .../Transports/LongPollingTransport.cs | 9 +- .../Formatters => Common}/ArrayOutput.cs | 2 +- .../HubConnectionTests.cs | 2 + ...Core.SignalR.Client.FunctionalTests.csproj | 1 + .../ConnectionTests.cs | 34 ++++-- .../LongPollingTransportTests.cs | 63 ++++++++++- .../Microsoft.AspNetCore.Client.Tests.csproj | 1 + ...oft.AspNetCore.Sockets.Common.Tests.csproj | 4 + .../HttpConnectionDispatcherTests.cs | 78 +++++++++++++- 13 files changed, 324 insertions(+), 56 deletions(-) rename test/{Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters => Common}/ArrayOutput.cs (96%) diff --git a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs index c8c36a25e3..42d1995215 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs @@ -1,18 +1,23 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; +using System.IO.Pipelines.Text.Primitives; +using System.Net; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Text; +using System.Text.Formatting; +using System.Threading; +using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets.Internal.Formatters; using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using System; -using System.Buffers; -using System.Collections.Generic; -using System.Net; -using System.Net.Http; -using System.Net.Http.Headers; -using System.Threading; -using System.Threading.Tasks; namespace Microsoft.AspNetCore.Sockets.Client { @@ -171,43 +176,80 @@ namespace Microsoft.AspNetCore.Sockets.Client private async Task SendMessages(Uri sendUrl, CancellationToken cancellationToken) { _logger.LogInformation("Starting the send loop"); - - TaskCompletionSource sendTcs = null; + IList messages = null; try { while (await _application.Input.WaitToReadAsync(cancellationToken)) { + // Grab as many messages as we can from the channel + messages = new List(); while (!cancellationToken.IsCancellationRequested && _application.Input.TryRead(out SendMessage message)) { - sendTcs = message.SendResult; + messages.Add(message); + } + + if (messages.Count > 0) + { + _logger.LogDebug("Sending {0} message(s) to the server using url: {1}", messages.Count, sendUrl); + + // Send them in a single post var request = new HttpRequestMessage(HttpMethod.Post, sendUrl); request.Headers.UserAgent.Add(DefaultUserAgentHeader); - if (message.Payload != null && message.Payload.Length > 0) - { - request.Content = new ByteArrayContent(message.Payload); - } + // TODO: We can probably use a pipeline here or some kind of pooled memory. + // But where do we get the pool from? ArrayBufferPool.Instance? + var memoryStream = new MemoryStream(); - _logger.LogDebug("Sending a message to the server using url: '{0}'. Message type {1}", sendUrl, message.Type); + // Write the messages to the stream + var pipe = memoryStream.AsPipelineWriter(); + var output = new PipelineTextOutput(pipe, TextEncoder.Utf8); // We don't need the Encoder, but it's harmless to set. + await WriteMessagesAsync(messages, output, MessageFormat.Binary); + + // Seek back to the start + memoryStream.Seek(0, SeekOrigin.Begin); + + // Set the, now filled, stream as the content + request.Content = new StreamContent(memoryStream); + request.Content.Headers.ContentType = MediaTypeHeaderValue.Parse(MessageFormatter.GetContentType(MessageFormat.Binary)); var response = await _httpClient.SendAsync(request); response.EnsureSuccessStatusCode(); - _logger.LogDebug("Message sent successfully"); - - sendTcs.SetResult(null); + _logger.LogDebug("Message(s) sent successfully"); + foreach (var message in messages) + { + message.SendResult?.TrySetResult(null); + } + } + else + { + _logger.LogDebug("No messages in batch to send"); } } } catch (OperationCanceledException) { // transport is being closed - sendTcs?.TrySetCanceled(); + if (messages != null) + { + foreach (var message in messages) + { + // This will no-op for any messages that were already marked as completed. + message.SendResult?.TrySetCanceled(); + } + } } catch (Exception ex) { _logger.LogError("Error while sending to '{0}': {1}", sendUrl, ex); - sendTcs?.TrySetException(ex); + if (messages != null) + { + foreach (var message in messages) + { + // This will no-op for any messages that were already marked as completed. + message.SendResult?.TrySetException(ex); + } + } throw; } finally @@ -218,5 +260,23 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger.LogInformation("Send loop stopped"); } + + private async Task WriteMessagesAsync(IList messages, PipelineTextOutput output, MessageFormat format) + { + output.Append(MessageFormatter.GetFormatIndicator(format), TextEncoder.Utf8); + + foreach (var message in messages) + { + _logger.LogDebug("Writing '{0}' message to the server", message.Type); + + var payload = message.Payload ?? Array.Empty(); + if (!MessageFormatter.TryWriteMessage(new Message(payload, message.Type, endOfMessage: true), output, format)) + { + // We didn't get any more memory! + throw new InvalidOperationException("Unable to write message to pipeline"); + } + await output.FlushAsync(); + } + } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj b/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj index 320a3717b7..ce299285e6 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj @@ -17,7 +17,9 @@ + + diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Message.cs b/src/Microsoft.AspNetCore.Sockets.Common/Message.cs index c8c8f51fc4..ad0db7a791 100644 --- a/src/Microsoft.AspNetCore.Sockets.Common/Message.cs +++ b/src/Microsoft.AspNetCore.Sockets.Common/Message.cs @@ -23,6 +23,11 @@ namespace Microsoft.AspNetCore.Sockets public Message(byte[] payload, MessageType type, bool endOfMessage) { + if (payload == null) + { + throw new ArgumentNullException(nameof(payload)); + } + Type = type; EndOfMessage = endOfMessage; Payload = payload; diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index 4f7ca13971..8e4806f1a0 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -2,6 +2,8 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; +using System.Collections.Generic; using System.IO; using System.IO.Pipelines; using System.Text; @@ -9,6 +11,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Sockets.Internal; +using Microsoft.AspNetCore.Sockets.Internal.Formatters; using Microsoft.AspNetCore.Sockets.Transports; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; @@ -286,7 +289,7 @@ namespace Microsoft.AspNetCore.Sockets private async Task ExecuteApplication(EndPoint endpoint, Connection connection) { - // Jump onto the thread pool thread so blocking user code doesn't block the setup of the + // Jump onto the thread pool thread so blocking user code doesn't block the setup of the // connection and transport await AwaitableThreadPool.Yield(); @@ -316,31 +319,52 @@ namespace Microsoft.AspNetCore.Sockets return; } - // Collect the message and write it to the channel - // TODO: Need to use some kind of pooled memory here. + // Read the entire payload to a byte array for now because Pipelines and ReadOnlyBytes + // don't play well with each other yet. byte[] buffer; using (var stream = new MemoryStream()) { await context.Request.Body.CopyToAsync(stream); + await stream.FlushAsync(); buffer = stream.ToArray(); } - var format = - string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase) - ? MessageType.Binary - : MessageType.Text; + IList messages; + if (string.Equals(context.Request.ContentType, MessageFormatter.TextContentType, StringComparison.OrdinalIgnoreCase)) + { + var reader = new BytesReader(buffer); + messages = ParseSendBatch(ref reader, MessageFormat.Text); + } + else if (string.Equals(context.Request.ContentType, MessageFormatter.BinaryContentType, StringComparison.OrdinalIgnoreCase)) + { + var reader = new BytesReader(buffer); + messages = ParseSendBatch(ref reader, MessageFormat.Binary); + } + else + { + // Legacy, single message raw format + + var format = + string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase) + ? MessageType.Binary + : MessageType.Text; + messages = new List() + { + new Message(buffer, format, endOfMessage: true) + }; + } - var message = new Message( - buffer, - format, - endOfMessage: true); // REVIEW: Do we want to return a specific status code here if the connection has ended? - while (await state.Application.Output.WaitToWriteAsync()) + _logger.LogDebug("Received batch of {0} message(s) in '/send'", messages.Count); + foreach (var message in messages) { - if (state.Application.Output.TryWrite(message)) + while (!state.Application.Output.TryWrite(message)) { - break; + if (!await state.Application.Output.WaitToWriteAsync()) + { + return; + } } } } @@ -407,5 +431,30 @@ namespace Microsoft.AspNetCore.Sockets return connectionState; } + + private IList ParseSendBatch(ref BytesReader payload, MessageFormat messageFormat) + { + var messages = new List(); + + if (payload.Unread.Length == 0) + { + return messages; + } + + if (payload.Unread[0] != MessageFormatter.GetFormatIndicator(messageFormat)) + { + throw new FormatException($"Format indicator '{(char)payload.Unread[0]}' does not match format determined by Content-Type '{MessageFormatter.GetContentType(messageFormat)}'"); + } + + payload.Advance(1); + + // REVIEW: This needs a little work. We could probably new up exactly the right parser, if we tinkered with the inheritance hierarchy a bit. + var parser = new MessageParser(); + while (parser.TryParseMessage(ref payload, messageFormat, out var message)) + { + messages.Add(message); + } + return messages; + } } } diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs index dc8f7997d8..50dbcc3228 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs @@ -1,9 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Sockets.Internal.Formatters; -using Microsoft.Extensions.Logging; using System; using System.IO.Pipelines; using System.IO.Pipelines.Text.Primitives; @@ -12,14 +9,14 @@ using System.Text.Formatting; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Sockets.Internal.Formatters; +using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Sockets.Transports { public class LongPollingTransport : IHttpTransport { - // REVIEW: This size? - internal const int MaxBufferSize = 4096; - public static readonly string Name = "longPolling"; private readonly ReadableChannel _application; private readonly ILogger _logger; diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/ArrayOutput.cs b/test/Common/ArrayOutput.cs similarity index 96% rename from test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/ArrayOutput.cs rename to test/Common/ArrayOutput.cs index 8b851a4486..463bf20580 100644 --- a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/ArrayOutput.cs +++ b/test/Common/ArrayOutput.cs @@ -7,7 +7,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; -namespace Microsoft.AspNetCore.Sockets.Tests.Internal.Formatters +namespace Microsoft.AspNetCore.Sockets.Tests.Internal { internal class ArrayOutput : IOutput { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 047c228f70..6d703a7653 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -11,6 +11,7 @@ using Microsoft.Extensions.Logging; using System; using System.Threading.Tasks; using Xunit; +using System.Diagnostics; namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { @@ -182,6 +183,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { var loggerFactory = new LoggerFactory(); loggerFactory.AddConsole(_verbose ? LogLevel.Trace : LogLevel.Error); + loggerFactory.AddDebug(LogLevel.Trace); return loggerFactory; } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj index 30a39f1465..9d0bca5f3d 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj @@ -24,6 +24,7 @@ + diff --git a/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs b/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs index a20512b6c4..a5c73f6b30 100644 --- a/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs @@ -1,17 +1,21 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Net; -using System.Net.Http; -using System.Text; -using System.Threading; -using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.Sockets.Internal.Formatters; +using Microsoft.AspNetCore.Sockets.Tests.Internal; using Microsoft.Extensions.Logging; using Moq; using Moq.Protected; +using System; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Text.Formatting; +using System.Threading; +using System.Threading.Tasks; using Xunit; -using Microsoft.AspNetCore.SignalR.Tests.Common; namespace Microsoft.AspNetCore.Sockets.Client.Tests { @@ -460,6 +464,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests [Fact] public async Task CanSendData() { + var data = new byte[] { 1, 1, 2, 3, 5, 8 }; + var message = new Message(data, MessageType.Binary); + var expectedPayload = FormatMessageToArray(message, MessageFormat.Binary); + var sendTcs = new TaskCompletionSource(); var mockHttpHandler = new Mock(); mockHttpHandler.Protected() @@ -476,17 +484,15 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()); var connection = new Connection(new Uri("http://fakeuri.org/")); try { await connection.StartAsync(longPollingTransport, httpClient); - var data = new byte[] { 1, 1, 2, 3, 5, 8 }; await connection.SendAsync(data, MessageType.Binary); - Assert.Equal(data, await sendTcs.Task.OrTimeout()); + Assert.Equal(expectedPayload, await sendTcs.Task.OrTimeout()); } finally { @@ -659,5 +665,13 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests } } } + + private byte[] FormatMessageToArray(Message message, MessageFormat binary, int bufferSize = 1024) + { + var output = new ArrayOutput(bufferSize); + output.Append('B', TextEncoder.Utf8); + Assert.True(MessageFormatter.TryWriteMessage(message, output, binary)); + return output.ToArray(); + } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.Sockets.Client.Tests/LongPollingTransportTests.cs index 97dd13b463..2bc976126e 100644 --- a/test/Microsoft.AspNetCore.Sockets.Client.Tests/LongPollingTransportTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Client.Tests/LongPollingTransportTests.cs @@ -2,17 +2,19 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Net; using System.Net.Http; +using System.Text; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.Logging; using Moq; using Moq.Protected; using Xunit; -using Microsoft.AspNetCore.Sockets.Internal; -using Microsoft.AspNetCore.SignalR.Tests.Common; namespace Microsoft.AspNetCore.Sockets.Client.Tests { @@ -201,5 +203,62 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests } } } + + [Fact] + public async Task LongPollingTransportSendsAvailableMessagesWhenTheyArrive() + { + var sentRequests = new List(); + + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + if (request.RequestUri.LocalPath.EndsWith("send")) + { + // Build a new request object, but convert the entire payload to string + sentRequests.Add(await request.Content.ReadAsByteArrayAsync()); + } + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()); + try + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + + var tcs1 = new TaskCompletionSource(); + var tcs2 = new TaskCompletionSource(); + + // Pre-queue some messages + await connectionToTransport.Out.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("Hello"), MessageType.Text, tcs1)).OrTimeout(); + await connectionToTransport.Out.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("World"), MessageType.Binary, tcs2)).OrTimeout(); + + // Start the transport + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection); + + connectionToTransport.Out.Complete(); + + await longPollingTransport.Running.OrTimeout(); + await connectionToTransport.In.Completion.OrTimeout(); + + Assert.Equal(1, sentRequests.Count); + Assert.Equal(new byte[] { + (byte)'B', + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, (byte)'H', (byte)'e', (byte)'l', (byte)'l', (byte)'o', + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x01, (byte)'W', (byte)'o', (byte)'r', (byte)'l', (byte)'d' + }, sentRequests[0]); + } + finally + { + await longPollingTransport.StopAsync(); + } + } + } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Client.Tests/Microsoft.AspNetCore.Client.Tests.csproj b/test/Microsoft.AspNetCore.Sockets.Client.Tests/Microsoft.AspNetCore.Client.Tests.csproj index 4add80dd80..b5085e23fc 100644 --- a/test/Microsoft.AspNetCore.Sockets.Client.Tests/Microsoft.AspNetCore.Client.Tests.csproj +++ b/test/Microsoft.AspNetCore.Sockets.Client.Tests/Microsoft.AspNetCore.Client.Tests.csproj @@ -12,6 +12,7 @@ + diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Microsoft.AspNetCore.Sockets.Common.Tests.csproj b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Microsoft.AspNetCore.Sockets.Common.Tests.csproj index b7bfa26804..2fef7c6ca7 100644 --- a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Microsoft.AspNetCore.Sockets.Common.Tests.csproj +++ b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Microsoft.AspNetCore.Sockets.Common.Tests.csproj @@ -10,6 +10,10 @@ true + + + + diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index a8887aaf4c..8e12651d9a 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -4,11 +4,11 @@ using System; using System.Collections.Generic; using System.IO; -using System.IO.Pipelines; using System.Text; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Internal; +using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; @@ -19,6 +19,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests { public class HttpConnectionDispatcherTests { + // Redefined from MessageFormatter because we want constants to go in the Attributes + private const string TextContentType = "application/vnd.microsoft.aspnetcore.endpoint-messages.v1+text"; + private const string BinaryContentType = "application/vnd.microsoft.aspnetcore.endpoint-messages.v1+binary"; + [Fact] public async Task NegotiateReservesConnectionIdAndReturnsIt() { @@ -314,7 +318,73 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.False(exists); } - private static DefaultHttpContext MakeRequest(string path, ConnectionState state) where TEndPoint : EndPoint + [Theory] + [InlineData("", "text", "Hello, World", "Hello, World", MessageType.Text)] // Legacy format + [InlineData("", "binary", "Hello, World", "Hello, World", MessageType.Binary)] // Legacy format + [InlineData(TextContentType, null, "T12:T:Hello, World;", "Hello, World", MessageType.Text)] + [InlineData(TextContentType, null, "T16:B:SGVsbG8sIFdvcmxk;", "Hello, World", MessageType.Binary)] + [InlineData(TextContentType, null, "T12:E:Hello, World;", "Hello, World", MessageType.Error)] + [InlineData(TextContentType, null, "T12:C:Hello, World;", "Hello, World", MessageType.Close)] + [InlineData(BinaryContentType, null, "QgAAAAAAAAAMAEhlbGxvLCBXb3JsZA==", "Hello, World", MessageType.Text)] + [InlineData(BinaryContentType, null, "QgAAAAAAAAAMAUhlbGxvLCBXb3JsZA==", "Hello, World", MessageType.Binary)] + [InlineData(BinaryContentType, null, "QgAAAAAAAAAMAkhlbGxvLCBXb3JsZA==", "Hello, World", MessageType.Error)] + [InlineData(BinaryContentType, null, "QgAAAAAAAAAMA0hlbGxvLCBXb3JsZA==", "Hello, World", MessageType.Close)] + public async Task SendPutsPayloadsInTheChannel(string contentType, string format, string encoded, string payload, MessageType type) + { + var messages = await RunSendTest(contentType, encoded, format); + + Assert.Equal(1, messages.Count); + Assert.Equal(payload, Encoding.UTF8.GetString(messages[0].Payload)); + Assert.Equal(type, messages[0].Type); + } + + [Theory] + [InlineData(TextContentType, "T12:T:Hello, World;16:B:SGVsbG8sIFdvcmxk;5:E:Error;6:C:Closed;")] + [InlineData(BinaryContentType, "QgAAAAAAAAAMAEhlbGxvLCBXb3JsZAAAAAAAAAAMAUhlbGxvLCBXb3JsZAAAAAAAAAAFAkVycm9yAAAAAAAAAAYDQ2xvc2Vk")] + public async Task SendAllowsMultipleMessages(string contentType, string encoded) + { + var messages = await RunSendTest(contentType, encoded, format: null); + + Assert.Equal(4, messages.Count); + Assert.Equal("Hello, World", Encoding.UTF8.GetString(messages[0].Payload)); + Assert.Equal(MessageType.Text, messages[0].Type); + Assert.Equal("Hello, World", Encoding.UTF8.GetString(messages[1].Payload)); + Assert.Equal(MessageType.Binary, messages[1].Type); + Assert.Equal("Error", Encoding.UTF8.GetString(messages[2].Payload)); + Assert.Equal(MessageType.Error, messages[2].Type); + Assert.Equal("Closed", Encoding.UTF8.GetString(messages[3].Payload)); + Assert.Equal(MessageType.Close, messages[3].Type); + } + + private static async Task> RunSendTest(string contentType, string encoded, string format) + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + + var context = MakeRequest("/send", state, format); + context.Request.ContentType = contentType; + var endPoint = context.RequestServices.GetRequiredService(); + + var buffer = contentType == BinaryContentType ? + Convert.FromBase64String(encoded) : + Encoding.UTF8.GetBytes(encoded); + var messages = new List(); + using (context.Request.Body = new MemoryStream(buffer, writable: false)) + { + await dispatcher.ExecuteAsync("", context).OrTimeout(); + } + + while (state.Connection.Transport.Input.TryRead(out var message)) + { + messages.Add(message); + } + + return messages; + } + + private static DefaultHttpContext MakeRequest(string path, ConnectionState state, string format = null) where TEndPoint : EndPoint { var context = new DefaultHttpContext(); var services = new ServiceCollection(); @@ -323,6 +393,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests context.Request.Path = path; var values = new Dictionary(); values["id"] = state.Connection.ConnectionId; + if (format != null) + { + values["format"] = format; + } var qs = new QueryCollection(values); context.Request.Query = qs; return context;