use new protocol for '/send' (#297)
This commit is contained in:
parent
cd246adb6f
commit
0133153bc9
|
|
@ -1,18 +1,23 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.IO.Pipelines;
|
||||
using System.IO.Pipelines.Text.Primitives;
|
||||
using System.Net;
|
||||
using System.Net.Http;
|
||||
using System.Net.Http.Headers;
|
||||
using System.Text;
|
||||
using System.Text.Formatting;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
|
||||
using Microsoft.Extensions.Internal;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.Net;
|
||||
using System.Net.Http;
|
||||
using System.Net.Http.Headers;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets.Client
|
||||
{
|
||||
|
|
@ -171,43 +176,80 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
private async Task SendMessages(Uri sendUrl, CancellationToken cancellationToken)
|
||||
{
|
||||
_logger.LogInformation("Starting the send loop");
|
||||
|
||||
TaskCompletionSource<object> sendTcs = null;
|
||||
IList<SendMessage> messages = null;
|
||||
try
|
||||
{
|
||||
while (await _application.Input.WaitToReadAsync(cancellationToken))
|
||||
{
|
||||
// Grab as many messages as we can from the channel
|
||||
messages = new List<SendMessage>();
|
||||
while (!cancellationToken.IsCancellationRequested && _application.Input.TryRead(out SendMessage message))
|
||||
{
|
||||
sendTcs = message.SendResult;
|
||||
messages.Add(message);
|
||||
}
|
||||
|
||||
if (messages.Count > 0)
|
||||
{
|
||||
_logger.LogDebug("Sending {0} message(s) to the server using url: {1}", messages.Count, sendUrl);
|
||||
|
||||
// Send them in a single post
|
||||
var request = new HttpRequestMessage(HttpMethod.Post, sendUrl);
|
||||
request.Headers.UserAgent.Add(DefaultUserAgentHeader);
|
||||
|
||||
if (message.Payload != null && message.Payload.Length > 0)
|
||||
{
|
||||
request.Content = new ByteArrayContent(message.Payload);
|
||||
}
|
||||
// TODO: We can probably use a pipeline here or some kind of pooled memory.
|
||||
// But where do we get the pool from? ArrayBufferPool.Instance?
|
||||
var memoryStream = new MemoryStream();
|
||||
|
||||
_logger.LogDebug("Sending a message to the server using url: '{0}'. Message type {1}", sendUrl, message.Type);
|
||||
// Write the messages to the stream
|
||||
var pipe = memoryStream.AsPipelineWriter();
|
||||
var output = new PipelineTextOutput(pipe, TextEncoder.Utf8); // We don't need the Encoder, but it's harmless to set.
|
||||
await WriteMessagesAsync(messages, output, MessageFormat.Binary);
|
||||
|
||||
// Seek back to the start
|
||||
memoryStream.Seek(0, SeekOrigin.Begin);
|
||||
|
||||
// Set the, now filled, stream as the content
|
||||
request.Content = new StreamContent(memoryStream);
|
||||
request.Content.Headers.ContentType = MediaTypeHeaderValue.Parse(MessageFormatter.GetContentType(MessageFormat.Binary));
|
||||
|
||||
var response = await _httpClient.SendAsync(request);
|
||||
response.EnsureSuccessStatusCode();
|
||||
|
||||
_logger.LogDebug("Message sent successfully");
|
||||
|
||||
sendTcs.SetResult(null);
|
||||
_logger.LogDebug("Message(s) sent successfully");
|
||||
foreach (var message in messages)
|
||||
{
|
||||
message.SendResult?.TrySetResult(null);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
_logger.LogDebug("No messages in batch to send");
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
// transport is being closed
|
||||
sendTcs?.TrySetCanceled();
|
||||
if (messages != null)
|
||||
{
|
||||
foreach (var message in messages)
|
||||
{
|
||||
// This will no-op for any messages that were already marked as completed.
|
||||
message.SendResult?.TrySetCanceled();
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError("Error while sending to '{0}': {1}", sendUrl, ex);
|
||||
sendTcs?.TrySetException(ex);
|
||||
if (messages != null)
|
||||
{
|
||||
foreach (var message in messages)
|
||||
{
|
||||
// This will no-op for any messages that were already marked as completed.
|
||||
message.SendResult?.TrySetException(ex);
|
||||
}
|
||||
}
|
||||
throw;
|
||||
}
|
||||
finally
|
||||
|
|
@ -218,5 +260,23 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
|
||||
_logger.LogInformation("Send loop stopped");
|
||||
}
|
||||
|
||||
private async Task WriteMessagesAsync(IList<SendMessage> messages, PipelineTextOutput output, MessageFormat format)
|
||||
{
|
||||
output.Append(MessageFormatter.GetFormatIndicator(format), TextEncoder.Utf8);
|
||||
|
||||
foreach (var message in messages)
|
||||
{
|
||||
_logger.LogDebug("Writing '{0}' message to the server", message.Type);
|
||||
|
||||
var payload = message.Payload ?? Array.Empty<byte>();
|
||||
if (!MessageFormatter.TryWriteMessage(new Message(payload, message.Type, endOfMessage: true), output, format))
|
||||
{
|
||||
// We didn't get any more memory!
|
||||
throw new InvalidOperationException("Unable to write message to pipeline");
|
||||
}
|
||||
await output.FlushAsync();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,9 @@
|
|||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
|
||||
<PackageReference Include="System.Text.Formatting" Version="$(CoreFxLabsVersion)" />
|
||||
<PackageReference Include="System.IO.Pipelines" Version="$(CoreFxLabsVersion)" />
|
||||
<PackageReference Include="System.IO.Pipelines.Text.Primitives" Version="$(CoreFxLabsVersion)" />
|
||||
<PackageReference Include="System.Net.WebSockets.Client" Version="$(CoreFxVersion)" />
|
||||
<PackageReference Include="System.Threading.Tasks.Channels" Version="$(CoreFxLabsVersion)" />
|
||||
</ItemGroup>
|
||||
|
|
|
|||
|
|
@ -23,6 +23,11 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
public Message(byte[] payload, MessageType type, bool endOfMessage)
|
||||
{
|
||||
if (payload == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(payload));
|
||||
}
|
||||
|
||||
Type = type;
|
||||
EndOfMessage = endOfMessage;
|
||||
Payload = payload;
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.IO.Pipelines;
|
||||
using System.Text;
|
||||
|
|
@ -9,6 +11,7 @@ using System.Threading;
|
|||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Sockets.Internal;
|
||||
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
|
||||
using Microsoft.AspNetCore.Sockets.Transports;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
|
@ -286,7 +289,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
private async Task ExecuteApplication(EndPoint endpoint, Connection connection)
|
||||
{
|
||||
// Jump onto the thread pool thread so blocking user code doesn't block the setup of the
|
||||
// Jump onto the thread pool thread so blocking user code doesn't block the setup of the
|
||||
// connection and transport
|
||||
await AwaitableThreadPool.Yield();
|
||||
|
||||
|
|
@ -316,31 +319,52 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
return;
|
||||
}
|
||||
|
||||
// Collect the message and write it to the channel
|
||||
// TODO: Need to use some kind of pooled memory here.
|
||||
// Read the entire payload to a byte array for now because Pipelines and ReadOnlyBytes
|
||||
// don't play well with each other yet.
|
||||
byte[] buffer;
|
||||
using (var stream = new MemoryStream())
|
||||
{
|
||||
await context.Request.Body.CopyToAsync(stream);
|
||||
await stream.FlushAsync();
|
||||
buffer = stream.ToArray();
|
||||
}
|
||||
|
||||
var format =
|
||||
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
|
||||
? MessageType.Binary
|
||||
: MessageType.Text;
|
||||
IList<Message> messages;
|
||||
if (string.Equals(context.Request.ContentType, MessageFormatter.TextContentType, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
var reader = new BytesReader(buffer);
|
||||
messages = ParseSendBatch(ref reader, MessageFormat.Text);
|
||||
}
|
||||
else if (string.Equals(context.Request.ContentType, MessageFormatter.BinaryContentType, StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
var reader = new BytesReader(buffer);
|
||||
messages = ParseSendBatch(ref reader, MessageFormat.Binary);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Legacy, single message raw format
|
||||
|
||||
var format =
|
||||
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
|
||||
? MessageType.Binary
|
||||
: MessageType.Text;
|
||||
messages = new List<Message>()
|
||||
{
|
||||
new Message(buffer, format, endOfMessage: true)
|
||||
};
|
||||
}
|
||||
|
||||
var message = new Message(
|
||||
buffer,
|
||||
format,
|
||||
endOfMessage: true);
|
||||
|
||||
// REVIEW: Do we want to return a specific status code here if the connection has ended?
|
||||
while (await state.Application.Output.WaitToWriteAsync())
|
||||
_logger.LogDebug("Received batch of {0} message(s) in '/send'", messages.Count);
|
||||
foreach (var message in messages)
|
||||
{
|
||||
if (state.Application.Output.TryWrite(message))
|
||||
while (!state.Application.Output.TryWrite(message))
|
||||
{
|
||||
break;
|
||||
if (!await state.Application.Output.WaitToWriteAsync())
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -407,5 +431,30 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
return connectionState;
|
||||
}
|
||||
|
||||
private IList<Message> ParseSendBatch(ref BytesReader payload, MessageFormat messageFormat)
|
||||
{
|
||||
var messages = new List<Message>();
|
||||
|
||||
if (payload.Unread.Length == 0)
|
||||
{
|
||||
return messages;
|
||||
}
|
||||
|
||||
if (payload.Unread[0] != MessageFormatter.GetFormatIndicator(messageFormat))
|
||||
{
|
||||
throw new FormatException($"Format indicator '{(char)payload.Unread[0]}' does not match format determined by Content-Type '{MessageFormatter.GetContentType(messageFormat)}'");
|
||||
}
|
||||
|
||||
payload.Advance(1);
|
||||
|
||||
// REVIEW: This needs a little work. We could probably new up exactly the right parser, if we tinkered with the inheritance hierarchy a bit.
|
||||
var parser = new MessageParser();
|
||||
while (parser.TryParseMessage(ref payload, messageFormat, out var message))
|
||||
{
|
||||
messages.Add(message);
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using System;
|
||||
using System.IO.Pipelines;
|
||||
using System.IO.Pipelines.Text.Primitives;
|
||||
|
|
@ -12,14 +9,14 @@ using System.Text.Formatting;
|
|||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets.Transports
|
||||
{
|
||||
public class LongPollingTransport : IHttpTransport
|
||||
{
|
||||
// REVIEW: This size?
|
||||
internal const int MaxBufferSize = 4096;
|
||||
|
||||
public static readonly string Name = "longPolling";
|
||||
private readonly ReadableChannel<Message> _application;
|
||||
private readonly ILogger _logger;
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ using System.Collections.Generic;
|
|||
using System.Diagnostics;
|
||||
using System.Linq;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets.Tests.Internal.Formatters
|
||||
namespace Microsoft.AspNetCore.Sockets.Tests.Internal
|
||||
{
|
||||
internal class ArrayOutput : IOutput
|
||||
{
|
||||
|
|
@ -11,6 +11,7 @@ using Microsoft.Extensions.Logging;
|
|||
using System;
|
||||
using System.Threading.Tasks;
|
||||
using Xunit;
|
||||
using System.Diagnostics;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
|
||||
{
|
||||
|
|
@ -182,6 +183,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
|
|||
{
|
||||
var loggerFactory = new LoggerFactory();
|
||||
loggerFactory.AddConsole(_verbose ? LogLevel.Trace : LogLevel.Error);
|
||||
loggerFactory.AddDebug(LogLevel.Trace);
|
||||
|
||||
return loggerFactory;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
<PackageReference Include="Microsoft.AspNetCore.Server.Kestrel" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.TestHost" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.Extensions.Logging.Debug" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(TestSdkVersion)" />
|
||||
<PackageReference Include="xunit.runner.visualstudio" Version="$(XunitVersion)" />
|
||||
<PackageReference Include="xunit" Version="$(XunitVersion)" />
|
||||
|
|
|
|||
|
|
@ -1,17 +1,21 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Net;
|
||||
using System.Net.Http;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
|
||||
using Microsoft.AspNetCore.Sockets.Tests.Internal;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Moq;
|
||||
using Moq.Protected;
|
||||
using System;
|
||||
using System.Linq;
|
||||
using System.Net;
|
||||
using System.Net.Http;
|
||||
using System.Text;
|
||||
using System.Text.Formatting;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Xunit;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets.Client.Tests
|
||||
{
|
||||
|
|
@ -460,6 +464,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
|
|||
[Fact]
|
||||
public async Task CanSendData()
|
||||
{
|
||||
var data = new byte[] { 1, 1, 2, 3, 5, 8 };
|
||||
var message = new Message(data, MessageType.Binary);
|
||||
var expectedPayload = FormatMessageToArray(message, MessageFormat.Binary);
|
||||
|
||||
var sendTcs = new TaskCompletionSource<byte[]>();
|
||||
var mockHttpHandler = new Mock<HttpMessageHandler>();
|
||||
mockHttpHandler.Protected()
|
||||
|
|
@ -476,17 +484,15 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
|
|||
|
||||
using (var httpClient = new HttpClient(mockHttpHandler.Object))
|
||||
{
|
||||
|
||||
var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory());
|
||||
var connection = new Connection(new Uri("http://fakeuri.org/"));
|
||||
try
|
||||
{
|
||||
await connection.StartAsync(longPollingTransport, httpClient);
|
||||
|
||||
var data = new byte[] { 1, 1, 2, 3, 5, 8 };
|
||||
await connection.SendAsync(data, MessageType.Binary);
|
||||
|
||||
Assert.Equal(data, await sendTcs.Task.OrTimeout());
|
||||
Assert.Equal(expectedPayload, await sendTcs.Task.OrTimeout());
|
||||
}
|
||||
finally
|
||||
{
|
||||
|
|
@ -659,5 +665,13 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] FormatMessageToArray(Message message, MessageFormat binary, int bufferSize = 1024)
|
||||
{
|
||||
var output = new ArrayOutput(bufferSize);
|
||||
output.Append('B', TextEncoder.Utf8);
|
||||
Assert.True(MessageFormatter.TryWriteMessage(message, output, binary));
|
||||
return output.ToArray();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,17 +2,19 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Net;
|
||||
using System.Net.Http;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
using Microsoft.AspNetCore.Sockets.Internal;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Moq;
|
||||
using Moq.Protected;
|
||||
using Xunit;
|
||||
using Microsoft.AspNetCore.Sockets.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets.Client.Tests
|
||||
{
|
||||
|
|
@ -201,5 +203,62 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task LongPollingTransportSendsAvailableMessagesWhenTheyArrive()
|
||||
{
|
||||
var sentRequests = new List<byte[]>();
|
||||
|
||||
var mockHttpHandler = new Mock<HttpMessageHandler>();
|
||||
mockHttpHandler.Protected()
|
||||
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
|
||||
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
|
||||
{
|
||||
await Task.Yield();
|
||||
if (request.RequestUri.LocalPath.EndsWith("send"))
|
||||
{
|
||||
// Build a new request object, but convert the entire payload to string
|
||||
sentRequests.Add(await request.Content.ReadAsByteArrayAsync());
|
||||
}
|
||||
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
|
||||
});
|
||||
|
||||
using (var httpClient = new HttpClient(mockHttpHandler.Object))
|
||||
{
|
||||
var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory());
|
||||
try
|
||||
{
|
||||
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
|
||||
var transportToConnection = Channel.CreateUnbounded<Message>();
|
||||
var channelConnection = new ChannelConnection<SendMessage, Message>(connectionToTransport, transportToConnection);
|
||||
|
||||
var tcs1 = new TaskCompletionSource<object>();
|
||||
var tcs2 = new TaskCompletionSource<object>();
|
||||
|
||||
// Pre-queue some messages
|
||||
await connectionToTransport.Out.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("Hello"), MessageType.Text, tcs1)).OrTimeout();
|
||||
await connectionToTransport.Out.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("World"), MessageType.Binary, tcs2)).OrTimeout();
|
||||
|
||||
// Start the transport
|
||||
await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection);
|
||||
|
||||
connectionToTransport.Out.Complete();
|
||||
|
||||
await longPollingTransport.Running.OrTimeout();
|
||||
await connectionToTransport.In.Completion.OrTimeout();
|
||||
|
||||
Assert.Equal(1, sentRequests.Count);
|
||||
Assert.Equal(new byte[] {
|
||||
(byte)'B',
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, (byte)'H', (byte)'e', (byte)'l', (byte)'l', (byte)'o',
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x01, (byte)'W', (byte)'o', (byte)'r', (byte)'l', (byte)'d'
|
||||
}, sentRequests[0]);
|
||||
}
|
||||
finally
|
||||
{
|
||||
await longPollingTransport.StopAsync();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
<ItemGroup>
|
||||
<Compile Include="..\Common\TaskExtensions.cs" Link="TaskExtensions.cs" />
|
||||
<Compile Include="..\Common\ArrayOutput.cs" Link="ArrayOutput.cs" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@
|
|||
<GenerateBindingRedirectsOutputType>true</GenerateBindingRedirectsOutputType>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<Compile Include="..\Common\ArrayOutput.cs" Link="ArrayOutput.cs" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.Sockets.Common\Microsoft.AspNetCore.Sockets.Common.csproj" />
|
||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(TestSdkVersion)" />
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.IO.Pipelines;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Http.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
using Microsoft.AspNetCore.Sockets.Internal;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
|
@ -19,6 +19,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
{
|
||||
public class HttpConnectionDispatcherTests
|
||||
{
|
||||
// Redefined from MessageFormatter because we want constants to go in the Attributes
|
||||
private const string TextContentType = "application/vnd.microsoft.aspnetcore.endpoint-messages.v1+text";
|
||||
private const string BinaryContentType = "application/vnd.microsoft.aspnetcore.endpoint-messages.v1+binary";
|
||||
|
||||
[Fact]
|
||||
public async Task NegotiateReservesConnectionIdAndReturnsIt()
|
||||
{
|
||||
|
|
@ -314,7 +318,73 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
Assert.False(exists);
|
||||
}
|
||||
|
||||
private static DefaultHttpContext MakeRequest<TEndPoint>(string path, ConnectionState state) where TEndPoint : EndPoint
|
||||
[Theory]
|
||||
[InlineData("", "text", "Hello, World", "Hello, World", MessageType.Text)] // Legacy format
|
||||
[InlineData("", "binary", "Hello, World", "Hello, World", MessageType.Binary)] // Legacy format
|
||||
[InlineData(TextContentType, null, "T12:T:Hello, World;", "Hello, World", MessageType.Text)]
|
||||
[InlineData(TextContentType, null, "T16:B:SGVsbG8sIFdvcmxk;", "Hello, World", MessageType.Binary)]
|
||||
[InlineData(TextContentType, null, "T12:E:Hello, World;", "Hello, World", MessageType.Error)]
|
||||
[InlineData(TextContentType, null, "T12:C:Hello, World;", "Hello, World", MessageType.Close)]
|
||||
[InlineData(BinaryContentType, null, "QgAAAAAAAAAMAEhlbGxvLCBXb3JsZA==", "Hello, World", MessageType.Text)]
|
||||
[InlineData(BinaryContentType, null, "QgAAAAAAAAAMAUhlbGxvLCBXb3JsZA==", "Hello, World", MessageType.Binary)]
|
||||
[InlineData(BinaryContentType, null, "QgAAAAAAAAAMAkhlbGxvLCBXb3JsZA==", "Hello, World", MessageType.Error)]
|
||||
[InlineData(BinaryContentType, null, "QgAAAAAAAAAMA0hlbGxvLCBXb3JsZA==", "Hello, World", MessageType.Close)]
|
||||
public async Task SendPutsPayloadsInTheChannel(string contentType, string format, string encoded, string payload, MessageType type)
|
||||
{
|
||||
var messages = await RunSendTest(contentType, encoded, format);
|
||||
|
||||
Assert.Equal(1, messages.Count);
|
||||
Assert.Equal(payload, Encoding.UTF8.GetString(messages[0].Payload));
|
||||
Assert.Equal(type, messages[0].Type);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(TextContentType, "T12:T:Hello, World;16:B:SGVsbG8sIFdvcmxk;5:E:Error;6:C:Closed;")]
|
||||
[InlineData(BinaryContentType, "QgAAAAAAAAAMAEhlbGxvLCBXb3JsZAAAAAAAAAAMAUhlbGxvLCBXb3JsZAAAAAAAAAAFAkVycm9yAAAAAAAAAAYDQ2xvc2Vk")]
|
||||
public async Task SendAllowsMultipleMessages(string contentType, string encoded)
|
||||
{
|
||||
var messages = await RunSendTest(contentType, encoded, format: null);
|
||||
|
||||
Assert.Equal(4, messages.Count);
|
||||
Assert.Equal("Hello, World", Encoding.UTF8.GetString(messages[0].Payload));
|
||||
Assert.Equal(MessageType.Text, messages[0].Type);
|
||||
Assert.Equal("Hello, World", Encoding.UTF8.GetString(messages[1].Payload));
|
||||
Assert.Equal(MessageType.Binary, messages[1].Type);
|
||||
Assert.Equal("Error", Encoding.UTF8.GetString(messages[2].Payload));
|
||||
Assert.Equal(MessageType.Error, messages[2].Type);
|
||||
Assert.Equal("Closed", Encoding.UTF8.GetString(messages[3].Payload));
|
||||
Assert.Equal(MessageType.Close, messages[3].Type);
|
||||
}
|
||||
|
||||
private static async Task<List<Message>> RunSendTest(string contentType, string encoded, string format)
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
||||
var context = MakeRequest<TestEndPoint>("/send", state, format);
|
||||
context.Request.ContentType = contentType;
|
||||
var endPoint = context.RequestServices.GetRequiredService<TestEndPoint>();
|
||||
|
||||
var buffer = contentType == BinaryContentType ?
|
||||
Convert.FromBase64String(encoded) :
|
||||
Encoding.UTF8.GetBytes(encoded);
|
||||
var messages = new List<Message>();
|
||||
using (context.Request.Body = new MemoryStream(buffer, writable: false))
|
||||
{
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>("", context).OrTimeout();
|
||||
}
|
||||
|
||||
while (state.Connection.Transport.Input.TryRead(out var message))
|
||||
{
|
||||
messages.Add(message);
|
||||
}
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
private static DefaultHttpContext MakeRequest<TEndPoint>(string path, ConnectionState state, string format = null) where TEndPoint : EndPoint
|
||||
{
|
||||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
|
|
@ -323,6 +393,10 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
context.Request.Path = path;
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
if (format != null)
|
||||
{
|
||||
values["format"] = format;
|
||||
}
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
return context;
|
||||
|
|
|
|||
Loading…
Reference in New Issue