From a1c0970222b3863283a9442f41345148541a0cf2 Mon Sep 17 00:00:00 2001 From: Andrew Stanton-Nurse Date: Mon, 3 Oct 2016 15:47:20 -0700 Subject: [PATCH] move prototype WebSockets-over-Channels code in nothing is wired up to anything in Sockets yet, it's just a copy of the code --- .gitignore | 2 + Microsoft.AspNetCore.Sockets.sln | 14 + .../IWebSocketConnection.cs | 82 ++++ .../Internal/ChannelExtensions.cs | 52 ++ .../MaskingUtilities.cs | 52 ++ .../Microsoft.Extensions.WebSockets.xproj | 21 + .../Properties/AssemblyInfo.cs | 12 + .../WebSocketCloseResult.cs | 71 +++ .../WebSocketCloseStatus.cs | 74 +++ .../WebSocketConnection.cs | 456 ++++++++++++++++++ .../WebSocketConnectionState.cs | 16 + .../WebSocketException.cs | 19 + .../WebSocketFrame.cs | 34 ++ .../WebSocketOpcode.cs | 42 ++ .../project.json | 34 ++ .../Internal/WebSocketPair.cs | 50 ++ ...icrosoft.Extensions.WebSockets.Tests.xproj | 21 + .../Properties/AssemblyInfo.cs | 19 + .../WebSocketConnectionExtensions.cs | 24 + .../WebSocketConnectionSummary.cs | 16 + ...cketConnectionTests.ConnectionLifecycle.cs | 133 +++++ .../WebSocketConnectionTests.ReceiveAsync.cs | 213 ++++++++ .../WebSocketConnectionTests.SendAsync.cs | 225 +++++++++ .../project.json | 30 ++ 24 files changed, 1712 insertions(+) create mode 100644 src/Microsoft.Extensions.WebSockets/IWebSocketConnection.cs create mode 100644 src/Microsoft.Extensions.WebSockets/Internal/ChannelExtensions.cs create mode 100644 src/Microsoft.Extensions.WebSockets/MaskingUtilities.cs create mode 100644 src/Microsoft.Extensions.WebSockets/Microsoft.Extensions.WebSockets.xproj create mode 100644 src/Microsoft.Extensions.WebSockets/Properties/AssemblyInfo.cs create mode 100644 src/Microsoft.Extensions.WebSockets/WebSocketCloseResult.cs create mode 100644 src/Microsoft.Extensions.WebSockets/WebSocketCloseStatus.cs create mode 100644 src/Microsoft.Extensions.WebSockets/WebSocketConnection.cs create mode 100644 src/Microsoft.Extensions.WebSockets/WebSocketConnectionState.cs create mode 100644 src/Microsoft.Extensions.WebSockets/WebSocketException.cs create mode 100644 src/Microsoft.Extensions.WebSockets/WebSocketFrame.cs create mode 100644 src/Microsoft.Extensions.WebSockets/WebSocketOpcode.cs create mode 100644 src/Microsoft.Extensions.WebSockets/project.json create mode 100644 test/Microsoft.Extensions.WebSockets.Tests/Internal/WebSocketPair.cs create mode 100644 test/Microsoft.Extensions.WebSockets.Tests/Microsoft.Extensions.WebSockets.Tests.xproj create mode 100644 test/Microsoft.Extensions.WebSockets.Tests/Properties/AssemblyInfo.cs create mode 100644 test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionExtensions.cs create mode 100644 test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionSummary.cs create mode 100644 test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs create mode 100644 test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionTests.ReceiveAsync.cs create mode 100644 test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionTests.SendAsync.cs create mode 100644 test/Microsoft.Extensions.WebSockets.Tests/project.json diff --git a/.gitignore b/.gitignore index 9c43d9d9a3..d5c2d3b074 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,5 @@ runtimes/ .testPublish/ launchSettings.json *.tmp +*.nuget.props +*.nuget.targets \ No newline at end of file diff --git a/Microsoft.AspNetCore.Sockets.sln b/Microsoft.AspNetCore.Sockets.sln index 61f6f31626..e8d27466ec 100644 --- a/Microsoft.AspNetCore.Sockets.sln +++ b/Microsoft.AspNetCore.Sockets.sln @@ -23,6 +23,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{6A35B453-5 EndProject Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.AspNetCore.Sockets.Tests", "test\Microsoft.AspNetCore.Sockets.Tests\Microsoft.AspNetCore.Sockets.Tests.xproj", "{AAD719D5-5E31-4ED1-A60F-6EB92EFA66D9}" EndProject +Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.Extensions.WebSockets", "src\Microsoft.Extensions.WebSockets\Microsoft.Extensions.WebSockets.xproj", "{5D9DA986-2EAB-4C6D-BF15-9A4BDD4DE775}" +EndProject +Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.Extensions.WebSockets.Tests", "test\Microsoft.Extensions.WebSockets.Tests\Microsoft.Extensions.WebSockets.Tests.xproj", "{8FA6BE8F-B5EB-42F9-9B16-101917CC45E2}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -45,6 +49,14 @@ Global {AAD719D5-5E31-4ED1-A60F-6EB92EFA66D9}.Debug|Any CPU.Build.0 = Debug|Any CPU {AAD719D5-5E31-4ED1-A60F-6EB92EFA66D9}.Release|Any CPU.ActiveCfg = Release|Any CPU {AAD719D5-5E31-4ED1-A60F-6EB92EFA66D9}.Release|Any CPU.Build.0 = Release|Any CPU + {5D9DA986-2EAB-4C6D-BF15-9A4BDD4DE775}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {5D9DA986-2EAB-4C6D-BF15-9A4BDD4DE775}.Debug|Any CPU.Build.0 = Debug|Any CPU + {5D9DA986-2EAB-4C6D-BF15-9A4BDD4DE775}.Release|Any CPU.ActiveCfg = Release|Any CPU + {5D9DA986-2EAB-4C6D-BF15-9A4BDD4DE775}.Release|Any CPU.Build.0 = Release|Any CPU + {8FA6BE8F-B5EB-42F9-9B16-101917CC45E2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8FA6BE8F-B5EB-42F9-9B16-101917CC45E2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8FA6BE8F-B5EB-42F9-9B16-101917CC45E2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8FA6BE8F-B5EB-42F9-9B16-101917CC45E2}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -54,5 +66,7 @@ Global {1715EA8D-8E13-4ACF-8BCA-57D048E55ED8} = {DA69F624-5398-4884-87E4-B816698CDE65} {BA99C2A1-48F9-4FA5-B95A-9687A73B7CC9} = {C4BC9889-B49F-41B6-806B-F84941B2549B} {AAD719D5-5E31-4ED1-A60F-6EB92EFA66D9} = {6A35B453-52EC-48AF-89CA-D4A69800F131} + {5D9DA986-2EAB-4C6D-BF15-9A4BDD4DE775} = {DA69F624-5398-4884-87E4-B816698CDE65} + {8FA6BE8F-B5EB-42F9-9B16-101917CC45E2} = {6A35B453-52EC-48AF-89CA-D4A69800F131} EndGlobalSection EndGlobal diff --git a/src/Microsoft.Extensions.WebSockets/IWebSocketConnection.cs b/src/Microsoft.Extensions.WebSockets/IWebSocketConnection.cs new file mode 100644 index 0000000000..fa53a74d49 --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets/IWebSocketConnection.cs @@ -0,0 +1,82 @@ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.WebSockets +{ + /// + /// Represents a connection to a WebSocket endpoint. + /// + /// + /// + /// Implementors of this type are generally considered thread-safe under the following condition: No two threads attempt to call either + /// or simultaneously. Different threads may call each method, but the same method + /// cannot be re-entered while it is being run in a different thread. However, ensure you verify that the specific implementor is + /// thread-safe in this way. For example, (including the implementations returned by the + /// static factory methods on that type) is thread-safe in this way. + /// + /// + /// The general pattern of having a single thread running and a separate thread running will + /// be thread-safe, as each method interacts with completely separate state. + /// + /// + public interface IWebSocketConnection : IDisposable + { + WebSocketConnectionState State { get; } + + /// + /// Sends the specified frame. + /// + /// The message to send. + /// A that indicates when/if the send is cancelled. + /// A that completes when the message has been written to the outbound stream. + Task SendAsync(WebSocketFrame message, CancellationToken cancellationToken); + + /// + /// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame. + /// + /// + /// If the other party does not respond with a close frame, the connection will remain open and the + /// will remain active. Call the method on this instance to forcibly terminate the connection. + /// + /// A with the payload for the close frame + /// A that indicates when/if the send is cancelled. + /// A that completes when the close frame has been sent + Task CloseAsync(WebSocketCloseResult result, CancellationToken cancellationToken); + + /// + /// Runs the WebSocket receive loop, using the provided message handler. + /// + /// The callback that will be invoked for each new frame + /// A that will complete when the client has sent a close frame, or the connection has been terminated + Task ExecuteAsync(Func messageHandler); + } + + public static class WebSocketConnectionExtensions + { + /// + /// Sends the specified frame. + /// + /// The message to send. + /// A that completes when the message has been written to the outbound stream. + public static Task SendAsync(this IWebSocketConnection self, WebSocketFrame message) => self.SendAsync(message, CancellationToken.None); + + /// + /// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame. + /// + /// A with the payload for the close frame + /// A that completes when the close frame has been sent + public static Task CloseAsync(this IWebSocketConnection self, WebSocketCloseResult result) => self.CloseAsync(result, CancellationToken.None); + + /// + /// Runs the WebSocket receive loop, using the provided message handler. + /// + /// The callback that will be invoked for each new frame + /// A that will complete when the client has sent a close frame, or the connection has been terminated + public static Task ExecuteAsync(this IWebSocketConnection self, Action messageHandler) => + self.ExecuteAsync(frame => { + messageHandler(frame); + return Task.CompletedTask; + }); + } +} diff --git a/src/Microsoft.Extensions.WebSockets/Internal/ChannelExtensions.cs b/src/Microsoft.Extensions.WebSockets/Internal/ChannelExtensions.cs new file mode 100644 index 0000000000..41c5790955 --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets/Internal/ChannelExtensions.cs @@ -0,0 +1,52 @@ +using System; +using System.Buffers; +using System.Threading; +using System.Threading.Tasks; +using Channels; + +namespace Microsoft.Extensions.WebSockets.Internal +{ + public static class ChannelExtensions + { + public static ValueTask ReadAtLeastAsync(this IReadableChannel input, int minimumRequiredBytes) => ReadAtLeastAsync(input, minimumRequiredBytes, CancellationToken.None); + + // TODO: Pull this up to Channels. We should be able to do it there without allocating a Task in any case (rather than here where we can avoid allocation + // only if the buffer is already ready and has enough data. + public static ValueTask ReadAtLeastAsync(this IReadableChannel input, int minimumRequiredBytes, CancellationToken cancellationToken) + { + var awaiter = input.ReadAsync(/* cancellationToken */); + + // Short-cut path! + if (awaiter.IsCompleted) + { + // We have a buffer, is it big enough? + var result = awaiter.GetResult(); + + if (result.IsCompleted || result.Buffer.Length >= minimumRequiredBytes) + { + return new ValueTask(result); + } + + // Buffer wasn't big enough, mark it as examined and continue to the "slow" path below + input.Advance( + consumed: result.Buffer.Start, + examined: result.Buffer.End); + } + return new ValueTask(ReadAtLeastSlowAsync(awaiter, input, minimumRequiredBytes, cancellationToken)); + } + + private static async Task ReadAtLeastSlowAsync(ReadableChannelAwaitable awaitable, IReadableChannel input, int minimumRequiredBytes, CancellationToken cancellationToken) + { + var result = await awaitable; + while (!result.IsCompleted && result.Buffer.Length < minimumRequiredBytes) + { + cancellationToken.ThrowIfCancellationRequested(); + input.Advance( + consumed: result.Buffer.Start, + examined: result.Buffer.End); + result = await input.ReadAsync(/* cancelToken */); + } + return result; + } + } +} diff --git a/src/Microsoft.Extensions.WebSockets/MaskingUtilities.cs b/src/Microsoft.Extensions.WebSockets/MaskingUtilities.cs new file mode 100644 index 0000000000..83db09c3dc --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets/MaskingUtilities.cs @@ -0,0 +1,52 @@ +using System; +using System.Binary; +using Channels; + +namespace Microsoft.Extensions.WebSockets +{ + internal static class MaskingUtilities + { + // Plenty of optimization to be done here but not our immediate priority right now. + // Including: Vectorization, striding by uints (even when not vectorized; we'd probably flip the + // overload that does the implementation in that case and do it in the uint version). + + public static void ApplyMask(ref ReadableBuffer payload, uint maskingKey) + { + unsafe + { + // Write the masking key as bytes to simplify access. Use a stackalloc buffer because it's fixed-size + var maskingKeyBytes = stackalloc byte[4]; + var maskingKeySpan = new Span(maskingKeyBytes, 4); + maskingKeySpan.WriteBigEndian(maskingKey); + + ApplyMask(ref payload, maskingKeySpan); + } + } + + public static void ApplyMask(ref ReadableBuffer payload, Span maskingKey) + { + var offset = 0; + foreach (var mem in payload) + { + var span = mem.Span; + ApplyMask(span, maskingKey, ref offset); + offset += span.Length; + } + } + + public static void ApplyMask(Span payload, Span maskingKey) + { + var i = 0; + ApplyMask(payload, maskingKey, ref i); + } + + private static void ApplyMask(Span payload, Span maskingKey, ref int maskingKeyOffset) + { + for (int i = 0; i < payload.Length; i++) + { + payload[i] = (byte)(payload[i] ^ maskingKey[maskingKeyOffset % 4]); + maskingKeyOffset++; + } + } + } +} diff --git a/src/Microsoft.Extensions.WebSockets/Microsoft.Extensions.WebSockets.xproj b/src/Microsoft.Extensions.WebSockets/Microsoft.Extensions.WebSockets.xproj new file mode 100644 index 0000000000..f2985d72e2 --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets/Microsoft.Extensions.WebSockets.xproj @@ -0,0 +1,21 @@ + + + + 14.0 + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) + + + + + 5d9da986-2eab-4c6d-bf15-9a4bdd4de775 + Microsoft.Extensions.WebSockets + .\obj + .\bin\ + v4.6.1 + + + + 2.0 + + + diff --git a/src/Microsoft.Extensions.WebSockets/Properties/AssemblyInfo.cs b/src/Microsoft.Extensions.WebSockets/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..c76663bf35 --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets/Properties/AssemblyInfo.cs @@ -0,0 +1,12 @@ +using System.Reflection; +using System.Runtime.CompilerServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.Extensions.WebSockets")] +[assembly: AssemblyTrademark("")] + +[assembly: InternalsVisibleTo("Microsoft.Extensions.WebSockets.Tests")] \ No newline at end of file diff --git a/src/Microsoft.Extensions.WebSockets/WebSocketCloseResult.cs b/src/Microsoft.Extensions.WebSockets/WebSocketCloseResult.cs new file mode 100644 index 0000000000..6aa21bca3b --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets/WebSocketCloseResult.cs @@ -0,0 +1,71 @@ +using System.Binary; +using System.Text; +using Channels; +using Channels.Text.Primitives; + +namespace Microsoft.Extensions.WebSockets +{ + /// + /// Represents the payload of a Close frame (i.e. a with an of ). + /// + public struct WebSocketCloseResult + { + internal static WebSocketCloseResult AbnormalClosure = new WebSocketCloseResult(WebSocketCloseStatus.AbnormalClosure, "Underlying transport connection was terminated"); + internal static WebSocketCloseResult Empty = new WebSocketCloseResult(WebSocketCloseStatus.Empty); + + /// + /// Gets the close status code specified in the frame. + /// + public WebSocketCloseStatus Status { get; } + + /// + /// Gets the close status description specified in the frame. + /// + public string Description { get; } + + public WebSocketCloseResult(WebSocketCloseStatus status) : this(status, string.Empty) { } + public WebSocketCloseResult(WebSocketCloseStatus status, string description) + { + Status = status; + Description = description; + } + + public int GetSize() => Encoding.UTF8.GetByteCount(Description) + sizeof(ushort); + + public static bool TryParse(ReadableBuffer payload, out WebSocketCloseResult result) + { + if(payload.Length == 0) + { + // Empty payload is OK + result = new WebSocketCloseResult(WebSocketCloseStatus.Empty, string.Empty); + return true; + } + else if(payload.Length < 2) + { + result = default(WebSocketCloseResult); + return false; + } + else + { + var status = payload.ReadBigEndian(); + var description = string.Empty; + payload = payload.Slice(2); + if(payload.Length > 0) + { + description = payload.GetUtf8String(); + } + result = new WebSocketCloseResult((WebSocketCloseStatus)status, description); + return true; + } + } + + public void WriteTo(ref WritableBuffer buffer) + { + buffer.WriteBigEndian((ushort)Status); + if (!string.IsNullOrEmpty(Description)) + { + buffer.WriteUtf8String(Description); + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.Extensions.WebSockets/WebSocketCloseStatus.cs b/src/Microsoft.Extensions.WebSockets/WebSocketCloseStatus.cs new file mode 100644 index 0000000000..2ee3e322e9 --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets/WebSocketCloseStatus.cs @@ -0,0 +1,74 @@ +namespace Microsoft.Extensions.WebSockets +{ + /// + /// Represents well-known WebSocket Close frame status codes. + /// + /// + /// See https://tools.ietf.org/html/rfc6455#section-7.4 for details + /// + public enum WebSocketCloseStatus : ushort + { + /// + /// Indicates that the purpose for the connection was fulfilled and thus the connection was closed normally. + /// + NormalClosure = 1000, + + /// + /// Indicates that the other endpoint is going away, such as a server shutting down or a browser navigating to a new page. + /// + EndpointUnavailable = 1001, + + /// + /// Indicates that a protocol error has occurred, causing the connection to be terminated. + /// + ProtocolError = 1002, + + /// + /// Indicates an invalid message type was received. For example, if the end point only supports messages + /// but received a message. + /// + InvalidMessageType = 1003, + + /// + /// Indicates that the Close frame did not have a status code. Not used in actual transmission. + /// + Empty = 1005, + + /// + /// Indicates that the underlying transport connection was terminated without a proper close handshake. Not used in actual transmission. + /// + AbnormalClosure = 1006, + + /// + /// Indicates that an invalid payload was encountered. For example, a frame of type contained non-UTF-8 data. + /// + InvalidPayloadData = 1007, + + /// + /// Indicates that the connection is being terminated due to a violation of policy. This is a generic error code used whenever a party needs to terminate + /// a connection without disclosing the specific reason. + /// + PolicyViolation = 1008, + + /// + /// Indicates that the connection is being terminated due to an endpoint receiving a message that is too large. + /// + MessageTooBig = 1009, + + /// + /// Indicates that the connection is being terminated due to being unable to negotiate a mandatory extension with the other party. Usually sent + /// from the client to the server after the client finishes handshaking without negotiating the extension. + /// + MandatoryExtension = 1010, + + /// + /// Indicates that a server is terminating the connection due to an internal error. + /// + InternalServerError = 1011, + + /// + /// Indicates that the connection failed to establish because the TLS handshake failed. Not used in actual transmission. + /// + TLSHandshakeFailed = 1015 + } +} \ No newline at end of file diff --git a/src/Microsoft.Extensions.WebSockets/WebSocketConnection.cs b/src/Microsoft.Extensions.WebSockets/WebSocketConnection.cs new file mode 100644 index 0000000000..1501da1dfc --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets/WebSocketConnection.cs @@ -0,0 +1,456 @@ +using System; +using System.Binary; +using System.Diagnostics; +using System.Security.Cryptography; +using System.Threading; +using System.Threading.Tasks; +using Channels; +using Microsoft.Extensions.WebSockets.Internal; + +namespace Microsoft.Extensions.WebSockets +{ + /// + /// Provides the default implementation of . + /// + /// + /// + /// This type is thread-safe under the following condition: No two threads attempt to call either + /// or simultaneously. Different threads may call each method, but the same method + /// cannot be re-entered while it is being run in a different thread. + /// + /// + /// The general pattern of having a single thread running and a separate thread running will + /// be thread-safe, as each method interacts with completely separate state. + /// + /// + public class WebSocketConnection : IWebSocketConnection + { + private readonly RandomNumberGenerator _random; + private readonly byte[] _maskingKey; + private readonly IReadableChannel _inbound; + private readonly IWritableChannel _outbound; + private readonly CancellationTokenSource _terminateReceiveCts = new CancellationTokenSource(); + + public WebSocketConnectionState State { get; private set; } = WebSocketConnectionState.Created; + + /// + /// Constructs a new, unmasked, from an and an that represents an established WebSocket connection (i.e. after handshaking) + /// + /// A from which frames will be read when receiving. + /// A to which frame will be written when sending. + public WebSocketConnection(IReadableChannel inbound, IWritableChannel outbound) : this(inbound, outbound, masked: false) { } + + /// + /// Constructs a new, optionally masked, from an and an that represents an established WebSocket connection (i.e. after handshaking) + /// + /// A from which frames will be read when receiving. + /// A to which frame will be written when sending. + /// A boolean indicating if frames sent from this socket should be masked (the masking key is automatically generated) + public WebSocketConnection(IReadableChannel inbound, IWritableChannel outbound, bool masked) + { + _inbound = inbound; + _outbound = outbound; + + if (masked) + { + _maskingKey = new byte[4]; + _random = RandomNumberGenerator.Create(); + } + } + + /// + /// Constructs a new, fixed masking-key, from an and an that represents an established WebSocket connection (i.e. after handshaking) + /// + /// A from which frames will be read when receiving. + /// A to which frame will be written when sending. + /// The masking key to use for the connection. Must be exactly 4-bytes long. This is ONLY recommended for testing and development purposes. + public WebSocketConnection(IReadableChannel inbound, IWritableChannel outbound, byte[] fixedMaskingKey) + { + _inbound = inbound; + _outbound = outbound; + _maskingKey = fixedMaskingKey; + } + + public void Dispose() + { + State = WebSocketConnectionState.Closed; + _inbound.Complete(); + _outbound.Complete(); + _terminateReceiveCts.Cancel(); + } + + public Task ExecuteAsync(Func messageHandler) + { + if (State == WebSocketConnectionState.Closed) + { + throw new ObjectDisposedException(nameof(WebSocketConnection)); + } + + if (State != WebSocketConnectionState.Created) + { + throw new InvalidOperationException("Connection is already running."); + } + State = WebSocketConnectionState.Connected; + return Task.Run(() => ReceiveLoop(messageHandler, _terminateReceiveCts.Token)); + } + + /// + /// Sends the specified frame. + /// + /// The frame to send. + /// A that indicates when/if the send is cancelled. + /// A that completes when the message has been written to the outbound stream. + // TODO: De-taskify this to allow consumers to create their own awaiter. + public Task SendAsync(WebSocketFrame frame, CancellationToken cancellationToken) + { + if (State == WebSocketConnectionState.Closed) + { + throw new ObjectDisposedException(nameof(WebSocketConnection)); + } + // This clause is a bit of an artificial restriction to ensure people run "Execute". Maybe we don't care? + else if (State == WebSocketConnectionState.Created) + { + throw new InvalidOperationException("Cannot send until the connection is started using Execute"); + } + else if (State == WebSocketConnectionState.CloseSent) + { + throw new InvalidOperationException("Cannot send after sending a Close frame"); + } + + if (frame.Opcode == WebSocketOpcode.Close) + { + throw new InvalidOperationException("Cannot use SendAsync to send a Close frame, use CloseAsync instead."); + } + return SendCoreAsync(frame, null, cancellationToken); + } + + /// + /// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame. + /// + /// + /// If the other party does not respond with a close frame, the connection will remain open and the + /// will remain active. Call the method on this instance to forcibly terminate the connection. + /// + /// A with the payload for the close frame + /// A that indicates when/if the send is cancelled. + /// A that completes when the close frame has been sent + public async Task CloseAsync(WebSocketCloseResult result, CancellationToken cancellationToken) + { + if (State == WebSocketConnectionState.Closed) + { + throw new ObjectDisposedException(nameof(WebSocketConnection)); + } + else if (State == WebSocketConnectionState.Created) + { + throw new InvalidOperationException("Cannot send close frame when the connection hasn't been started"); + } + else if (State == WebSocketConnectionState.CloseSent) + { + throw new InvalidOperationException("Cannot send multiple close frames"); + } + + // When we pass a close result to SendCoreAsync, the frame is only used for the header and the payload is ignored + var frame = new WebSocketFrame(endOfMessage: true, opcode: WebSocketOpcode.Close, payload: default(ReadableBuffer)); + + await SendCoreAsync(frame, result, cancellationToken); + + if (State == WebSocketConnectionState.CloseReceived) + { + State = WebSocketConnectionState.Closed; + } + else + { + State = WebSocketConnectionState.CloseSent; + } + } + + private void WriteMaskingKey(Span buffer) + { + if (_random != null) + { + // Get a new random mask + // Until https://github.com/dotnet/corefx/issues/12323 is fixed we need to use this shared buffer and copy model + // Once we have that fix we should be able to generate the mask directly into the output buffer. + _random.GetBytes(_maskingKey); + } + + buffer.Set(_maskingKey); + } + + private async Task ReceiveLoop(Func messageHandler, CancellationToken cancellationToken) + { + while (!cancellationToken.IsCancellationRequested) + { + // WebSocket Frame layout (https://tools.ietf.org/html/rfc6455#section-5.2): + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-------+-+-------------+-------------------------------+ + // |F|R|R|R| opcode|M| Payload len | Extended payload length | + // |I|S|S|S| (4) |A| (7) | (16/64) | + // |N|V|V|V| |S| | (if payload len==126/127) | + // | |1|2|3| |K| | | + // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + // | Extended payload length continued, if payload len == 127 | + // + - - - - - - - - - - - - - - - +-------------------------------+ + // | |Masking-key, if MASK set to 1 | + // +-------------------------------+-------------------------------+ + // | Masking-key (continued) | Payload Data | + // +-------------------------------- - - - - - - - - - - - - - - - + + // : Payload Data continued ... : + // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + // | Payload Data continued ... | + // +---------------------------------------------------------------+ + + // Read at least 2 bytes + var result = await _inbound.ReadAtLeastAsync(2, cancellationToken); + cancellationToken.ThrowIfCancellationRequested(); + if (result.IsCompleted && result.Buffer.Length < 2) + { + return WebSocketCloseResult.AbnormalClosure; + } + var buffer = result.Buffer; + + // Read the opcode + var opcodeByte = buffer.ReadBigEndian(); + buffer = buffer.Slice(1); + + var fin = (opcodeByte & 0x01) != 0; + var opcode = (WebSocketOpcode)((opcodeByte & 0xF0) >> 4); + + // Read the first byte of the payload length + var lenByte = buffer.ReadBigEndian(); + buffer = buffer.Slice(1); + + var masked = (lenByte & 0x01) != 0; + var payloadLen = (lenByte & 0xFE) >> 1; + + // Mark what we've got so far as consumed + _inbound.Advance(buffer.Start); + + // Calculate the rest of the header length + var headerLength = masked ? 4 : 0; + if (payloadLen == 126) + { + headerLength += 2; + } + else if (payloadLen == 127) + { + headerLength += 4; + } + + uint maskingKey = 0; + + if (headerLength > 0) + { + result = await _inbound.ReadAtLeastAsync(headerLength, cancellationToken); + cancellationToken.ThrowIfCancellationRequested(); + if (result.IsCompleted && result.Buffer.Length < headerLength) + { + return WebSocketCloseResult.AbnormalClosure; + } + buffer = result.Buffer; + + // Read extended payload length (if any) + if (payloadLen == 126) + { + payloadLen = buffer.ReadBigEndian(); + buffer = buffer.Slice(sizeof(ushort)); + } + else if (payloadLen == 127) + { + var longLen = buffer.ReadBigEndian(); + buffer = buffer.Slice(sizeof(ulong)); + if (longLen > int.MaxValue) + { + throw new WebSocketException($"Frame is too large. Maximum frame size is {int.MaxValue} bytes"); + } + payloadLen = (int)longLen; + } + + // Read masking key + if (masked) + { + var maskingKeyStart = buffer.Start; + maskingKey = buffer.Slice(0, 4).ReadBigEndian(); + buffer = buffer.Slice(4); + } + + // Mark the length and masking key consumed + _inbound.Advance(buffer.Start); + } + + var payload = default(ReadableBuffer); + if (payloadLen > 0) + { + result = await _inbound.ReadAtLeastAsync(payloadLen, cancellationToken); + cancellationToken.ThrowIfCancellationRequested(); + if (result.IsCompleted && result.Buffer.Length < payloadLen) + { + return WebSocketCloseResult.AbnormalClosure; + } + buffer = result.Buffer; + + payload = buffer.Slice(0, payloadLen); + + if (masked) + { + // Unmask + MaskingUtilities.ApplyMask(ref payload, maskingKey); + } + } + + // Run the callback, if we're not cancelled. + cancellationToken.ThrowIfCancellationRequested(); + + var frame = new WebSocketFrame(fin, opcode, payload); + if (frame.Opcode == WebSocketOpcode.Close) + { + return HandleCloseFrame(payloadLen, payload, frame); + } + else + { + await messageHandler(frame); + } + + // Mark the payload as consumed + if (payloadLen > 0) + { + _inbound.Advance(payload.End); + } + } + return WebSocketCloseResult.AbnormalClosure; + } + + private WebSocketCloseResult HandleCloseFrame(int payloadLen, ReadableBuffer payload, WebSocketFrame frame) + { + // Update state + if (State == WebSocketConnectionState.CloseSent) + { + State = WebSocketConnectionState.Closed; + } + else + { + State = WebSocketConnectionState.CloseReceived; + } + + // Process the close frame + WebSocketCloseResult closeResult; + if (!WebSocketCloseResult.TryParse(frame.Payload, out closeResult)) + { + closeResult = WebSocketCloseResult.Empty; + } + + // Make the payload as consumed + if (payloadLen > 0) + { + _inbound.Advance(payload.End); + } + return closeResult; + } + + private Task SendCoreAsync(WebSocketFrame message, WebSocketCloseResult? closeResult, CancellationToken cancellationToken) + { + // Base header size is 2 bytes. + var allocSize = 2; + var payloadLength = closeResult == null ? message.Payload.Length : closeResult.Value.GetSize(); + if (payloadLength > ushort.MaxValue) + { + // We're going to need an 8-byte length + allocSize += 8; + } + else if (payloadLength > 125) + { + // We're going to need a 2-byte length + allocSize += 2; + } + if (_maskingKey != null) + { + // We need space for the masking key + allocSize += 4; + } + if (closeResult != null) + { + // We need space for the close result payload too + allocSize += payloadLength; + } + + // Allocate a buffer + var buffer = _outbound.Alloc(minimumSize: allocSize); + if (buffer.Memory.Length < allocSize) + { + throw new InvalidOperationException("Couldn't allocate enough data from the channel to write the header"); + } + + // Write the opcode and FIN flag + var opcodeByte = (byte)((int)message.Opcode << 4); + if (message.EndOfMessage) + { + opcodeByte |= 1; + } + buffer.WriteBigEndian(opcodeByte); + + // Write the length and mask flag + var maskingByte = _maskingKey != null ? 0x01 : 0x00; // TODO: Masking flag goes here + + if (payloadLength > ushort.MaxValue) + { + buffer.WriteBigEndian((byte)(0xFE | maskingByte)); + + // 8-byte length + buffer.WriteBigEndian((ulong)payloadLength); + } + else if (payloadLength > 125) + { + buffer.WriteBigEndian((byte)(0xFC | maskingByte)); + + // 2-byte length + buffer.WriteBigEndian((ushort)payloadLength); + } + else + { + // 1-byte length + buffer.WriteBigEndian((byte)((payloadLength << 1) | maskingByte)); + } + + var maskingKey = Span.Empty; + if (_maskingKey != null) + { + // Get a span of the output buffer for the masking key, write it there, then advance the write head. + maskingKey = buffer.Memory.Slice(0, 4).Span; + WriteMaskingKey(maskingKey); + buffer.Advance(4); + } + + if (closeResult != null) + { + // Write the close payload out + var payload = buffer.Memory.Slice(0, payloadLength).Span; + closeResult.Value.WriteTo(ref buffer); + + if (_maskingKey != null) + { + MaskingUtilities.ApplyMask(payload, maskingKey); + } + } + else + { + // This will copy the actual buffer struct, but NOT the underlying data + // We need a field so we can by-ref it. + var payload = message.Payload; + + if (_maskingKey != null) + { + // Mask the payload in it's own buffer + MaskingUtilities.ApplyMask(ref payload, maskingKey); + } + + // Append the (masked) buffer to the output channel + buffer.Append(payload); + } + + + // Commit and Flush + return buffer.FlushAsync(); + } + } +} diff --git a/src/Microsoft.Extensions.WebSockets/WebSocketConnectionState.cs b/src/Microsoft.Extensions.WebSockets/WebSocketConnectionState.cs new file mode 100644 index 0000000000..90b77296d3 --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets/WebSocketConnectionState.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.WebSockets +{ + public enum WebSocketConnectionState + { + Created, + Connected, + CloseSent, + CloseReceived, + Closed + } +} diff --git a/src/Microsoft.Extensions.WebSockets/WebSocketException.cs b/src/Microsoft.Extensions.WebSockets/WebSocketException.cs new file mode 100644 index 0000000000..70cdb7d951 --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets/WebSocketException.cs @@ -0,0 +1,19 @@ +using System; + +namespace Microsoft.Extensions.WebSockets.Internal +{ + public class WebSocketException : Exception + { + public WebSocketException() + { + } + + public WebSocketException(string message) : base(message) + { + } + + public WebSocketException(string message, Exception innerException) : base(message, innerException) + { + } + } +} \ No newline at end of file diff --git a/src/Microsoft.Extensions.WebSockets/WebSocketFrame.cs b/src/Microsoft.Extensions.WebSockets/WebSocketFrame.cs new file mode 100644 index 0000000000..4a55af6a73 --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets/WebSocketFrame.cs @@ -0,0 +1,34 @@ +using System; +using System.Text; +using Channels; + +namespace Microsoft.Extensions.WebSockets +{ + /// + /// Represents a single Frame received or sent on a . + /// + public struct WebSocketFrame + { + /// + /// Indicates if the "FIN" flag is set on this frame, which indicates it is the final frame of a message. + /// + public bool EndOfMessage { get; } + + /// + /// Gets the value describing the opcode of the WebSocket frame. + /// + public WebSocketOpcode Opcode { get; } + + /// + /// Gets the payload of the WebSocket frame. + /// + public ReadableBuffer Payload { get; } + + public WebSocketFrame(bool endOfMessage, WebSocketOpcode opcode, ReadableBuffer payload) + { + EndOfMessage = endOfMessage; + Opcode = opcode; + Payload = payload; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.Extensions.WebSockets/WebSocketOpcode.cs b/src/Microsoft.Extensions.WebSockets/WebSocketOpcode.cs new file mode 100644 index 0000000000..6e59e2264c --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets/WebSocketOpcode.cs @@ -0,0 +1,42 @@ +namespace Microsoft.Extensions.WebSockets +{ + /// + /// Represents the possible values for the "opcode" field of a WebSocket frame. + /// + public enum WebSocketOpcode + { + /// + /// Indicates that the frame is a continuation of the previous or frame. + /// + Continuation = 0x0, + + /// + /// Indicates that the frame is the first frame of a new Text message, formatted in UTF-8. + /// + Text = 0x1, + + /// + /// Indicates that the frame is the first frame of a new Binary message. + /// + Binary = 0x2, + /* 0x3 - 0x7 are reserved */ + + /// + /// Indicates that the frame is a notification that the sender is closing their end of the connection + /// + Close = 0x8, + + /// + /// Indicates a request from the sender to receive a , in order to maintain the connection. + /// + Ping = 0x9, + + /// + /// Indicates a response to a , in order to maintain the connection. + /// + Pong = 0xA, + /* 0xB-0xF are reserved */ + + /* all opcodes above 0xF are invalid */ + } +} \ No newline at end of file diff --git a/src/Microsoft.Extensions.WebSockets/project.json b/src/Microsoft.Extensions.WebSockets/project.json new file mode 100644 index 0000000000..57180dfa02 --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets/project.json @@ -0,0 +1,34 @@ +{ + "version": "0.1.0-*", + "buildOptions": { + "warningsAsErrors": true, + "allowUnsafe": true + }, + "description": "Low-allocation Push-oriented WebSockets based on Channels", + "packOptions": { + "repository": { + "type": "git", + "url": "git://github.com/aspnet/websockets" + } + }, + "dependencies": { + "Channels": "0.2.0-beta-*", + "Channels.Text.Primitives": "0.2.0-beta-*" + }, + + "frameworks": { + "net46": {}, + "netstandard1.3": { + "dependencies": { + "System.Collections": "4.0.11", + "System.Diagnostics.Debug": "4.0.11", + "System.IO": "4.1.0", + "System.Linq": "4.1.0", + "System.Runtime": "4.1.0", + "System.Runtime.Extensions": "4.1.0", + "System.Threading": "4.0.11", + "System.Threading.Tasks": "4.0.11" + } + } + } +} diff --git a/test/Microsoft.Extensions.WebSockets.Tests/Internal/WebSocketPair.cs b/test/Microsoft.Extensions.WebSockets.Tests/Internal/WebSocketPair.cs new file mode 100644 index 0000000000..8c5ef2d8b5 --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Tests/Internal/WebSocketPair.cs @@ -0,0 +1,50 @@ +using System; +using Channels; + +namespace Microsoft.Extensions.WebSockets.Test +{ + internal class WebSocketPair : IDisposable + { + private ChannelFactory _factory; + + private Channel _serverToClient; + private Channel _clientToServer; + + public IWebSocketConnection ClientSocket { get; } + public IWebSocketConnection ServerSocket { get; } + + public WebSocketPair(ChannelFactory factory, Channel serverToClient, Channel clientToServer, IWebSocketConnection clientSocket, IWebSocketConnection serverSocket) + { + _factory = factory; + _serverToClient = serverToClient; + _clientToServer = clientToServer; + ClientSocket = clientSocket; + ServerSocket = serverSocket; + } + + public static WebSocketPair Create() + { + // Create channels + var factory = new ChannelFactory(); + var serverToClient = factory.CreateChannel(); + var clientToServer = factory.CreateChannel(); + + var serverSocket = new WebSocketConnection(clientToServer, serverToClient, masked: true); + var clientSocket = new WebSocketConnection(serverToClient, clientToServer, masked: false); + + return new WebSocketPair(factory, serverToClient, clientToServer, clientSocket, serverSocket); + } + + public void Dispose() + { + _factory.Dispose(); + ServerSocket.Dispose(); + ClientSocket.Dispose(); + } + + public void TerminateFromClient(Exception ex = null) + { + _clientToServer.CompleteWriter(ex); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.Extensions.WebSockets.Tests/Microsoft.Extensions.WebSockets.Tests.xproj b/test/Microsoft.Extensions.WebSockets.Tests/Microsoft.Extensions.WebSockets.Tests.xproj new file mode 100644 index 0000000000..33a56191b0 --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Tests/Microsoft.Extensions.WebSockets.Tests.xproj @@ -0,0 +1,21 @@ + + + + 14.0.25420 + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) + + + + 8fa6be8f-b5eb-42f9-9b16-101917cc45e2 + Microsoft.Extensions.WebSockets.Tests + .\obj + .\bin\ + + + 2.0 + + + + + + \ No newline at end of file diff --git a/test/Microsoft.Extensions.WebSockets.Tests/Properties/AssemblyInfo.cs b/test/Microsoft.Extensions.WebSockets.Tests/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..f562f703c0 --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Tests/Properties/AssemblyInfo.cs @@ -0,0 +1,19 @@ +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.Extensions.WebSockets.Test")] +[assembly: AssemblyTrademark("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("216f6739-da4d-4371-8393-739a90826c29")] diff --git a/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionExtensions.cs b/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionExtensions.cs new file mode 100644 index 0000000000..7e69797eac --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionExtensions.cs @@ -0,0 +1,24 @@ +using System.Collections.Generic; +using System.Threading.Tasks; +using Channels; + +namespace Microsoft.Extensions.WebSockets.Tests +{ + public static class WebSocketConnectionExtensions + { + public static async Task ExecuteAndCaptureFramesAsync(this IWebSocketConnection self) + { + var frames = new List(); + var closeResult = await self.ExecuteAsync(frame => + { + var buffer = new byte[frame.Payload.Length]; + frame.Payload.CopyTo(buffer); + frames.Add(new WebSocketFrame( + frame.EndOfMessage, + frame.Opcode, + ReadableBuffer.Create(buffer, 0, buffer.Length))); + }); + return new WebSocketConnectionSummary(frames, closeResult); + } + } +} diff --git a/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionSummary.cs b/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionSummary.cs new file mode 100644 index 0000000000..675aa72499 --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionSummary.cs @@ -0,0 +1,16 @@ +using System.Collections.Generic; + +namespace Microsoft.Extensions.WebSockets.Tests +{ + public class WebSocketConnectionSummary + { + public IList Received { get; } + public WebSocketCloseResult CloseResult { get; } + + public WebSocketConnectionSummary(IList received, WebSocketCloseResult closeResult) + { + Received = received; + CloseResult = closeResult; + } + } +} \ No newline at end of file diff --git a/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs b/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs new file mode 100644 index 0000000000..7a8b24ce2e --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs @@ -0,0 +1,133 @@ +using System; +using System.Diagnostics; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Channels; +using Microsoft.Extensions.WebSockets.Test; +using Xunit; + +namespace Microsoft.Extensions.WebSockets.Tests +{ + public partial class WebSocketConnectionTests + { + [Fact] + public async Task SendReceiveFrames() + { + using (var pair = WebSocketPair.Create()) + { + var cts = new CancellationTokenSource(); + if (!Debugger.IsAttached) + { + cts.CancelAfter(TimeSpan.FromSeconds(5)); + } + using (cts.Token.Register(() => pair.Dispose())) + { + var client = pair.ClientSocket.ExecuteAsync(_ => + { + Assert.False(true, "did not expect the client to receive any frames!"); + return Task.CompletedTask; + }); + + // Send Frames + await pair.ClientSocket.SendAsync(CreateTextFrame("Hello")); + await pair.ClientSocket.SendAsync(CreateTextFrame("World")); + await pair.ClientSocket.SendAsync(CreateBinaryFrame(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF })); + await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)); + + var summary = await pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + Assert.Equal(3, summary.Received.Count); + Assert.Equal("Hello", Encoding.UTF8.GetString(summary.Received[0].Payload.ToArray())); + Assert.Equal("World", Encoding.UTF8.GetString(summary.Received[1].Payload.ToArray())); + Assert.Equal(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }, summary.Received[2].Payload.ToArray()); + + await pair.ServerSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)); + await client; + } + } + } + + [Fact] + public async Task ExecuteReturnsWhenCloseFrameReceived() + { + using(var pair = WebSocketPair.Create()) + { + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.InvalidMessageType, "Abc")); + var serverSummary = await pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + await pair.ServerSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure, "Ok")); + var clientSummary = await client; + + Assert.Equal(0, serverSummary.Received.Count); + Assert.Equal(WebSocketCloseStatus.InvalidMessageType, serverSummary.CloseResult.Status); + Assert.Equal("Abc", serverSummary.CloseResult.Description); + + Assert.Equal(0, clientSummary.Received.Count); + Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.Status); + Assert.Equal("Ok", clientSummary.CloseResult.Description); + } + } + + [Fact] + public async Task AbnormalTerminationOfInboundChannelCausesExecuteToThrow() + { + using(var pair = WebSocketPair.Create()) + { + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + pair.TerminateFromClient(new InvalidOperationException("It broke!")); + + await Assert.ThrowsAsync(() => server); + } + } + + [Fact] + public async Task StateTransitions() + { + using (var pair = WebSocketPair.Create()) + { + // Initial State + Assert.Equal(WebSocketConnectionState.Created, pair.ServerSocket.State); + Assert.Equal(WebSocketConnectionState.Created, pair.ClientSocket.State); + + // Start the sockets + var serverReceiving = new TaskCompletionSource(); + var clientReceiving = new TaskCompletionSource(); + var server = pair.ServerSocket.ExecuteAsync(frame => serverReceiving.TrySetResult(null)); + var client = pair.ClientSocket.ExecuteAsync(frame => clientReceiving.TrySetResult(null)); + + // Send a frame from each and verify that the state transitioned. + // We need to do this because it's the only way to correctly wait for the state transition (which happens asynchronously in ExecuteAsync) + await pair.ClientSocket.SendAsync(CreateTextFrame("Hello")); + await pair.ServerSocket.SendAsync(CreateTextFrame("Hello")); + + await Task.WhenAll(serverReceiving.Task, clientReceiving.Task); + + // Check state + Assert.Equal(WebSocketConnectionState.Connected, pair.ServerSocket.State); + Assert.Equal(WebSocketConnectionState.Connected, pair.ClientSocket.State); + + // Close the server socket + await pair.ServerSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)); + await client; + + // Check state + Assert.Equal(WebSocketConnectionState.CloseSent, pair.ServerSocket.State); + Assert.Equal(WebSocketConnectionState.CloseReceived, pair.ClientSocket.State); + + // Close the client socket + await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)); + await server; + + // Check state + Assert.Equal(WebSocketConnectionState.Closed, pair.ServerSocket.State); + Assert.Equal(WebSocketConnectionState.Closed, pair.ClientSocket.State); + + // Verify we can't restart the connection or send a message + await Assert.ThrowsAsync(async () => await pair.ServerSocket.ExecuteAsync(f => { })); + await Assert.ThrowsAsync(async () => await pair.ClientSocket.SendAsync(CreateTextFrame("Nope"))); + await Assert.ThrowsAsync(async () => await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure))); + } + } + } +} diff --git a/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionTests.ReceiveAsync.cs b/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionTests.ReceiveAsync.cs new file mode 100644 index 0000000000..6eac3e1481 --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionTests.ReceiveAsync.cs @@ -0,0 +1,213 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Channels; +using Xunit; + +namespace Microsoft.Extensions.WebSockets.Tests +{ + public partial class WebSocketConnectionTests + { + public class TheReceiveAsyncMethod + { + [Theory] + [InlineData(new byte[] { 0x11, 0x00 }, "", true)] + [InlineData(new byte[] { 0x11, 0x0A, 0x48, 0x65, 0x6C, 0x6C, 0x6F }, "Hello", true)] + [InlineData(new byte[] { 0x11, 0x0B, 0x1, 0x2, 0x3, 0x4, 0x48 ^ 0x1, 0x65 ^ 0x2, 0x6C ^ 0x3, 0x6C ^ 0x4, 0x6F ^ 0x1 }, "Hello", true)] + [InlineData(new byte[] { 0x10, 0x00 }, "", false)] + [InlineData(new byte[] { 0x10, 0x0A, 0x48, 0x65, 0x6C, 0x6C, 0x6F }, "Hello", false)] + [InlineData(new byte[] { 0x10, 0x0B, 0x1, 0x2, 0x3, 0x4, 0x48 ^ 0x1, 0x65 ^ 0x2, 0x6C ^ 0x3, 0x6C ^ 0x4, 0x6F ^ 0x1 }, "Hello", false)] + public Task ReadTextFrames(byte[] rawFrame, string message, bool endOfMessage) + { + return RunSingleFrameTest( + rawFrame, + endOfMessage, + WebSocketOpcode.Text, + b => Assert.Equal(message, Encoding.UTF8.GetString(b))); + } + + [Theory] + // Opcode = Binary + [InlineData(new byte[] { 0x21, 0x00 }, new byte[0], WebSocketOpcode.Binary, true)] + [InlineData(new byte[] { 0x21, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, true)] + [InlineData(new byte[] { 0x21, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, true)] + [InlineData(new byte[] { 0x20, 0x00 }, new byte[0], WebSocketOpcode.Binary, false)] + [InlineData(new byte[] { 0x20, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, false)] + [InlineData(new byte[] { 0x20, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, false)] + + // Opcode = Continuation + [InlineData(new byte[] { 0x01, 0x00 }, new byte[0], WebSocketOpcode.Continuation, true)] + [InlineData(new byte[] { 0x01, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Continuation, true)] + [InlineData(new byte[] { 0x01, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Continuation, true)] + [InlineData(new byte[] { 0x00, 0x00 }, new byte[0], WebSocketOpcode.Continuation, false)] + [InlineData(new byte[] { 0x00, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Continuation, false)] + [InlineData(new byte[] { 0x00, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Continuation, false)] + + // Opcode = Ping + [InlineData(new byte[] { 0x91, 0x00 }, new byte[0], WebSocketOpcode.Ping, true)] + [InlineData(new byte[] { 0x91, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Ping, true)] + [InlineData(new byte[] { 0x91, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Ping, true)] + [InlineData(new byte[] { 0x90, 0x00 }, new byte[0], WebSocketOpcode.Ping, false)] + [InlineData(new byte[] { 0x90, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Ping, false)] + [InlineData(new byte[] { 0x90, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Ping, false)] + + // Opcode = Pong + [InlineData(new byte[] { 0xA1, 0x00 }, new byte[0], WebSocketOpcode.Pong, true)] + [InlineData(new byte[] { 0xA1, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Pong, true)] + [InlineData(new byte[] { 0xA1, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Pong, true)] + [InlineData(new byte[] { 0xA0, 0x00 }, new byte[0], WebSocketOpcode.Pong, false)] + [InlineData(new byte[] { 0xA0, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Pong, false)] + [InlineData(new byte[] { 0xA0, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Pong, false)] + public Task ReadBinaryFormattedFrames(byte[] rawFrame, byte[] payload, WebSocketOpcode opcode, bool endOfMessage) + { + return RunSingleFrameTest( + rawFrame, + endOfMessage, + opcode, + b => Assert.Equal(payload, b)); + } + + [Fact] + public async Task ReadMultipleFramesAcrossMultipleBuffers() + { + var result = await RunReceiveTest( + producer: async (channel, cancellationToken) => + { + await channel.WriteAsync(new byte[] { 0x20, 0x0A }.Slice()); + await channel.WriteAsync(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB, 0x01, 0x0A }.Slice()); + await channel.WriteAsync(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }.Slice()); + await channel.WriteAsync(new byte[] { 0xAB }.Slice()); + }); + + Assert.Equal(2, result.Received.Count); + + Assert.False(result.Received[0].EndOfMessage); + Assert.Equal(WebSocketOpcode.Binary, result.Received[0].Opcode); + Assert.Equal(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, result.Received[0].Payload.ToArray()); + + Assert.True(result.Received[1].EndOfMessage); + Assert.Equal(WebSocketOpcode.Continuation, result.Received[1].Opcode); + Assert.Equal(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, result.Received[1].Payload.ToArray()); + } + + [Fact] + public async Task Read16BitPayloadLength() + { + var expectedPayload = new byte[1024]; + new Random().NextBytes(expectedPayload); + + var result = await RunReceiveTest( + producer: async (channel, cancellationToken) => + { + // Header: (Opcode=Binary, Fin=true), (Mask=false, Len=126), (16-bit big endian length) + await channel.WriteAsync(new byte[] { 0x21, 0xFC, 0x04, 0x00 }); + await channel.WriteAsync(expectedPayload); + }); + + Assert.Equal(1, result.Received.Count); + + var frame = result.Received[0]; + Assert.True(frame.EndOfMessage); + Assert.Equal(WebSocketOpcode.Binary, frame.Opcode); + Assert.Equal(expectedPayload, frame.Payload.ToArray()); + } + + [Fact] + public async Task Read64bitPayloadLength() + { + // Allocating an actual (2^32 + 1) byte payload is crazy for this test. We just need to test that we can USE a 64-bit length + var expectedPayload = new byte[1024]; + new Random().NextBytes(expectedPayload); + + var result = await RunReceiveTest( + producer: async (channel, cancellationToken) => + { + // Header: (Opcode=Binary, Fin=true), (Mask=false, Len=127), (64-bit big endian length) + await channel.WriteAsync(new byte[] { 0x21, 0xFE, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00 }); + await channel.WriteAsync(expectedPayload); + }); + + Assert.Equal(1, result.Received.Count); + + var frame = result.Received[0]; + Assert.True(frame.EndOfMessage); + Assert.Equal(WebSocketOpcode.Binary, frame.Opcode); + Assert.Equal(expectedPayload, frame.Payload.ToArray()); + } + + private static async Task RunSingleFrameTest(byte[] rawFrame, bool endOfMessage, WebSocketOpcode expectedOpcode, Action payloadAssert) + { + var result = await RunReceiveTest( + producer: async (channel, cancellationToken) => + { + await channel.WriteAsync(rawFrame.Slice()); + }); + var frames = result.Received; + Assert.Equal(1, frames.Count); + + var frame = frames[0]; + + Assert.Equal(endOfMessage, frame.EndOfMessage); + Assert.Equal(expectedOpcode, frame.Opcode); + payloadAssert(frame.Payload.ToArray()); + } + + private static async Task RunReceiveTest(Func producer) + { + using (var factory = new ChannelFactory()) + { + var outbound = factory.CreateChannel(); + var inbound = factory.CreateChannel(); + + var cts = new CancellationTokenSource(); + var cancellationToken = cts.Token; + + // Timeout for the test, but only if the debugger is not attached. + if (!Debugger.IsAttached) + { + cts.CancelAfter(TimeSpan.FromSeconds(5)); + } + + var producerTask = Task.Run(async () => + { + await producer(inbound, cancellationToken); + inbound.CompleteWriter(); + }, cancellationToken); + + var consumerTask = Task.Run(async () => + { + var connection = new WebSocketConnection(inbound, outbound); + using (cancellationToken.Register(() => connection.Dispose())) + using (connection) + { + // Receive frames until we're closed + return await connection.ExecuteAndCaptureFramesAsync(); + } + }, cancellationToken); + + await Task.WhenAll(producerTask, consumerTask); + return consumerTask.Result; + } + } + } + + private static WebSocketFrame CreateTextFrame(string message) + { + var payload = Encoding.UTF8.GetBytes(message); + return CreateFrame(endOfMessage: true, opcode: WebSocketOpcode.Text, payload: payload); + } + + private static WebSocketFrame CreateBinaryFrame(byte[] payload) + { + return CreateFrame(endOfMessage: true, opcode: WebSocketOpcode.Binary, payload: payload); + } + + private static WebSocketFrame CreateFrame(bool endOfMessage, WebSocketOpcode opcode, byte[] payload) + { + return new WebSocketFrame(endOfMessage, opcode, payload: ReadableBuffer.Create(payload, 0, payload.Length)); + } + } +} diff --git a/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionTests.SendAsync.cs b/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionTests.SendAsync.cs new file mode 100644 index 0000000000..cfdfe4d560 --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Tests/WebSocketConnectionTests.SendAsync.cs @@ -0,0 +1,225 @@ +using System; +using System.Diagnostics; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Channels; +using Xunit; + +namespace Microsoft.Extensions.WebSockets.Tests +{ + public partial class WebSocketConnectionTests + { + public class TheSendAsyncMethod + { + [Theory] + [InlineData("", true, new byte[] { 0x11, 0x00 })] + [InlineData("Hello", true, new byte[] { 0x11, 0x0A, 0x48, 0x65, 0x6C, 0x6C, 0x6F })] + [InlineData("", false, new byte[] { 0x10, 0x00 })] + [InlineData("Hello", false, new byte[] { 0x10, 0x0A, 0x48, 0x65, 0x6C, 0x6C, 0x6F })] + public async Task WriteTextFrames(string message, bool endOfMessage, byte[] expectedRawFrame) + { + var data = await RunSendTest( + producer: async (socket, cancellationToken) => + { + var payload = Encoding.UTF8.GetBytes(message); + await socket.SendAsync(CreateFrame( + endOfMessage, + opcode: WebSocketOpcode.Text, + payload: payload)); + }, masked: false); + Assert.Equal(expectedRawFrame, data); + } + + [Theory] + // Opcode = Binary + [InlineData(new byte[0], WebSocketOpcode.Binary, true, new byte[] { 0x21, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, true, new byte[] { 0x21, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] + [InlineData(new byte[0], WebSocketOpcode.Binary, false, new byte[] { 0x20, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, false, new byte[] { 0x20, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] + + // Opcode = Continuation + [InlineData(new byte[0], WebSocketOpcode.Continuation, true, new byte[] { 0x01, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, true, new byte[] { 0x01, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] + [InlineData(new byte[0], WebSocketOpcode.Continuation, false, new byte[] { 0x00, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, false, new byte[] { 0x00, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] + + // Opcode = Ping + [InlineData(new byte[0], WebSocketOpcode.Ping, true, new byte[] { 0x91, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, true, new byte[] { 0x91, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] + [InlineData(new byte[0], WebSocketOpcode.Ping, false, new byte[] { 0x90, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, false, new byte[] { 0x90, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] + + // Opcode = Pong + [InlineData(new byte[0], WebSocketOpcode.Pong, true, new byte[] { 0xA1, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, true, new byte[] { 0xA1, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] + [InlineData(new byte[0], WebSocketOpcode.Pong, false, new byte[] { 0xA0, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, false, new byte[] { 0xA0, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] + public async Task WriteBinaryFormattedFrames(byte[] payload, WebSocketOpcode opcode, bool endOfMessage, byte[] expectedRawFrame) + { + var data = await RunSendTest( + producer: async (socket, cancellationToken) => + { + await socket.SendAsync(CreateFrame( + endOfMessage, + opcode, + payload: payload)); + }, masked: false); + Assert.Equal(expectedRawFrame, data); + } + + [Theory] + [InlineData("", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x11, 0x01, 0x01, 0x02, 0x03, 0x04 })] + [InlineData("Hello", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x11, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x48 ^ 0x01, 0x65 ^ 0x02, 0x6C ^ 0x03, 0x6C ^ 0x04, 0x6F ^ 0x01 })] + public async Task WriteMaskedTextFrames(string message, byte[] maskingKey, byte[] expectedRawFrame) + { + var data = await RunSendTest( + producer: async (socket, cancellationToken) => + { + var payload = Encoding.UTF8.GetBytes(message); + await socket.SendAsync(CreateFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Text, + payload: payload)); + }, maskingKey: maskingKey); + Assert.Equal(expectedRawFrame, data); + } + + [Theory] + // Opcode = Binary + [InlineData(new byte[0], WebSocketOpcode.Binary, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x21, 0x01, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x21, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + [InlineData(new byte[0], WebSocketOpcode.Binary, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x20, 0x01, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x20, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + + // Opcode = Continuation + [InlineData(new byte[0], WebSocketOpcode.Continuation, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x01, 0x01, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x01, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + [InlineData(new byte[0], WebSocketOpcode.Continuation, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x00, 0x01, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x00, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + + // Opcode = Ping + [InlineData(new byte[0], WebSocketOpcode.Ping, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x91, 0x01, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x91, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + [InlineData(new byte[0], WebSocketOpcode.Ping, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x90, 0x01, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x90, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + + // Opcode = Pong + [InlineData(new byte[0], WebSocketOpcode.Pong, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0xA1, 0x01, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0xA1, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + [InlineData(new byte[0], WebSocketOpcode.Pong, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0xA0, 0x01, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0xA0, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + public async Task WriteMaskedBinaryFormattedFrames(byte[] payload, WebSocketOpcode opcode, bool endOfMessage, byte[] maskingKey, byte[] expectedRawFrame) + { + var data = await RunSendTest( + producer: async (socket, cancellationToken) => + { + await socket.SendAsync(CreateFrame( + endOfMessage, + opcode, + payload: payload)); + }, maskingKey: maskingKey); + Assert.Equal(expectedRawFrame, data); + } + + [Fact] + public async Task WriteRandomMaskedFrame() + { + var data = await RunSendTest( + producer: async (socket, cancellationToken) => + { + await socket.SendAsync(CreateFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Binary, + payload: new byte[] { 0x0A, 0x0B, 0x0C, 0x0D, 0x0E })); + }, masked: true); + + // Verify the header + Assert.Equal(0x21, data[0]); + Assert.Equal(0x0B, data[1]); + + // We don't know the mask, so we have to read it in order to verify this frame + var mask = data.Slice(2, 4); + var actualPayload = data.Slice(6); + + // Unmask the payload + for (int i = 0; i < actualPayload.Length; i++) + { + actualPayload[i] = (byte)(mask[i % 4] ^ actualPayload[i]); + } + Assert.Equal(new byte[] { 0x0A, 0x0B, 0x0C, 0x0D, 0x0E }, actualPayload.ToArray()); + } + + [Theory] + [InlineData(WebSocketCloseStatus.MandatoryExtension, "Hi", null, new byte[] { 0x81, 0x08, 0x03, 0xF2, (byte)'H', (byte)'i' })] + [InlineData(WebSocketCloseStatus.PolicyViolation, "", null, new byte[] { 0x81, 0x04, 0x03, 0xF0 })] + [InlineData(WebSocketCloseStatus.MandatoryExtension, "Hi", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x81, 0x09, 0x01, 0x02, 0x03, 0x04, 0x03 ^ 0x01, 0xF2 ^ 0x02, (byte)'H' ^ 0x03, (byte)'i' ^ 0x04 })] + [InlineData(WebSocketCloseStatus.PolicyViolation, "", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x81, 0x05, 0x01, 0x02, 0x03, 0x04, 0x03 ^ 0x01, 0xF0 ^ 0x02 })] + public async Task WriteCloseFrames(WebSocketCloseStatus status, string description, byte[] maskingKey, byte[] expectedRawFrame) + { + var data = await RunSendTest( + producer: async (socket, cancellationToken) => + { + await socket.CloseAsync(new WebSocketCloseResult(status, description)); + }, maskingKey: maskingKey); + Assert.Equal(expectedRawFrame, data); + } + + private static async Task RunSendTest(Func producer, bool masked = false, byte[] maskingKey = null) + { + using (var factory = new ChannelFactory()) + { + var outbound = factory.CreateChannel(); + var inbound = factory.CreateChannel(); + + var cts = new CancellationTokenSource(); + + // Timeout for the test, but only if the debugger is not attached. + if (!Debugger.IsAttached) + { + cts.CancelAfter(TimeSpan.FromSeconds(5)); + } + + var cancellationToken = cts.Token; + using (cancellationToken.Register(() => CompleteChannels(inbound, outbound))) + { + + Task executeTask; + using (var connection = CreateConnection(inbound, outbound, masked, maskingKey)) + { + executeTask = connection.ExecuteAsync(f => + { + Assert.False(true, "Did not expect to receive any messages"); + return Task.CompletedTask; + }); + await producer(connection, cancellationToken); + inbound.CompleteWriter(); + await executeTask; + } + + var data = (await outbound.ReadToEndAsync()).ToArray(); + inbound.CompleteReader(); + CompleteChannels(outbound); + return data; + } + } + } + + private static void CompleteChannels(params Channel[] channels) + { + foreach (var channel in channels) + { + channel.CompleteReader(); + channel.CompleteWriter(); + } + } + + private static WebSocketConnection CreateConnection(Channel inbound, Channel outbound, bool masked, byte[] maskingKey) + { + return (maskingKey != null) ? + new WebSocketConnection(inbound, outbound, fixedMaskingKey: maskingKey) : + new WebSocketConnection(inbound, outbound, masked); + } + } + } +} diff --git a/test/Microsoft.Extensions.WebSockets.Tests/project.json b/test/Microsoft.Extensions.WebSockets.Tests/project.json new file mode 100644 index 0000000000..008c6d1ab9 --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Tests/project.json @@ -0,0 +1,30 @@ +{ + "buildOptions": { + "warningsAsErrors": true + }, + "dependencies": { + "dotnet-test-xunit": "1.0.0-rc3-000000-01", + "Microsoft.Extensions.WebSockets": "0.1.0-*", + "xunit": "2.1.0" + }, + "testRunner": "xunit", + "frameworks": { + "netcoreapp1.0": { + "dependencies": { + "Microsoft.NETCore.App": { + "version": "1.0.0", + "type": "platform" + } + }, + "imports": [ + "dnxcore50", + "portable-net451+win8" + ] + }, + "net46": { + "dependencies": { + "xunit.runner.console": "2.1.0" + } + } + } +}