diff --git a/.gitignore b/.gitignore index d5c2d3b074..89168981c0 100644 --- a/.gitignore +++ b/.gitignore @@ -32,4 +32,5 @@ runtimes/ launchSettings.json *.tmp *.nuget.props -*.nuget.targets \ No newline at end of file +*.nuget.targets +autobahnreports/ \ No newline at end of file diff --git a/Microsoft.AspNetCore.Sockets.sln b/Microsoft.AspNetCore.Sockets.sln index 7a41563b70..3202482228 100644 --- a/Microsoft.AspNetCore.Sockets.sln +++ b/Microsoft.AspNetCore.Sockets.sln @@ -30,6 +30,11 @@ EndProject Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.AspNetCore.SignalR", "src\Microsoft.AspNetCore.SignalR\Microsoft.AspNetCore.SignalR.xproj", "{42E76F87-92B6-45AB-BF07-6B811C0F2CAC}" EndProject Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.AspNetCore.SignalR.Redis", "src\Microsoft.AspNetCore.SignalR.Redis\Microsoft.AspNetCore.SignalR.Redis.xproj", "{59319B72-38BE-4041-8E5C-FF6938874CE8}" +Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.AspNetCore.WebSockets.Internal", "src\Microsoft.AspNetCore.WebSockets.Internal\Microsoft.AspNetCore.WebSockets.Internal.xproj", "{FFFE71F8-E476-4BCD-9689-F106EE1C1497}" +EndProject +Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest", "test\Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest\Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.xproj", "{8CBC1C71-AF0B-44E2-AEE9-D8024C07634D}" +EndProject +Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "WebSocketsTestApp", "test\WebSocketsTestApp\WebSocketsTestApp.xproj", "{58E771EC-8454-4558-B61A-C9D049065911}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -69,6 +74,18 @@ Global {59319B72-38BE-4041-8E5C-FF6938874CE8}.Debug|Any CPU.Build.0 = Debug|Any CPU {59319B72-38BE-4041-8E5C-FF6938874CE8}.Release|Any CPU.ActiveCfg = Release|Any CPU {59319B72-38BE-4041-8E5C-FF6938874CE8}.Release|Any CPU.Build.0 = Release|Any CPU + {FFFE71F8-E476-4BCD-9689-F106EE1C1497}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {FFFE71F8-E476-4BCD-9689-F106EE1C1497}.Debug|Any CPU.Build.0 = Debug|Any CPU + {FFFE71F8-E476-4BCD-9689-F106EE1C1497}.Release|Any CPU.ActiveCfg = Release|Any CPU + {FFFE71F8-E476-4BCD-9689-F106EE1C1497}.Release|Any CPU.Build.0 = Release|Any CPU + {8CBC1C71-AF0B-44E2-AEE9-D8024C07634D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8CBC1C71-AF0B-44E2-AEE9-D8024C07634D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8CBC1C71-AF0B-44E2-AEE9-D8024C07634D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8CBC1C71-AF0B-44E2-AEE9-D8024C07634D}.Release|Any CPU.Build.0 = Release|Any CPU + {58E771EC-8454-4558-B61A-C9D049065911}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {58E771EC-8454-4558-B61A-C9D049065911}.Debug|Any CPU.Build.0 = Debug|Any CPU + {58E771EC-8454-4558-B61A-C9D049065911}.Release|Any CPU.ActiveCfg = Release|Any CPU + {58E771EC-8454-4558-B61A-C9D049065911}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -82,5 +99,8 @@ Global {A7050BAE-3DB9-4FB3-A49D-303201415B13} = {6A35B453-52EC-48AF-89CA-D4A69800F131} {42E76F87-92B6-45AB-BF07-6B811C0F2CAC} = {DA69F624-5398-4884-87E4-B816698CDE65} {59319B72-38BE-4041-8E5C-FF6938874CE8} = {DA69F624-5398-4884-87E4-B816698CDE65} + {FFFE71F8-E476-4BCD-9689-F106EE1C1497} = {DA69F624-5398-4884-87E4-B816698CDE65} + {8CBC1C71-AF0B-44E2-AEE9-D8024C07634D} = {6A35B453-52EC-48AF-89CA-D4A69800F131} + {58E771EC-8454-4558-B61A-C9D049065911} = {6A35B453-52EC-48AF-89CA-D4A69800F131} EndGlobalSection EndGlobal diff --git a/src/Microsoft.AspNetCore.Sockets/project.json b/src/Microsoft.AspNetCore.Sockets/project.json index d3d5bc09a2..273816fd7f 100644 --- a/src/Microsoft.AspNetCore.Sockets/project.json +++ b/src/Microsoft.AspNetCore.Sockets/project.json @@ -1,13 +1,13 @@ { - "version": "0.1.0-*", - "dependencies": { - "Channels": "0.2.0-beta-*", - "Microsoft.AspNetCore.Routing": "1.1.0-*", - "Microsoft.AspNetCore.WebSockets": "0.2.0-*" - }, - "frameworks": { - "netstandard1.3": { + "version": "0.1.0-*", + "dependencies": { + "Channels": "0.2.0-beta-*", + "Microsoft.AspNetCore.Routing": "1.1.0-*", + "Microsoft.AspNetCore.WebSockets": "0.2.0-*", + "NETStandard.Library": "1.6.0" }, - "net46": { } - } + "frameworks": { + "netstandard1.3": {}, + "net46": {} + } } diff --git a/src/Microsoft.AspNetCore.WebSockets.Internal/Constants.cs b/src/Microsoft.AspNetCore.WebSockets.Internal/Constants.cs new file mode 100644 index 0000000000..d71e1563de --- /dev/null +++ b/src/Microsoft.AspNetCore.WebSockets.Internal/Constants.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.WebSockets.Internal +{ + public static class Constants + { + public static class Headers + { + public const string Upgrade = "Upgrade"; + public const string UpgradeWebSocket = "websocket"; + public const string Connection = "Connection"; + public const string ConnectionUpgrade = "Upgrade"; + public const string SecWebSocketKey = "Sec-WebSocket-Key"; + public const string SecWebSocketVersion = "Sec-WebSocket-Version"; + public const string SecWebSocketProtocol = "Sec-WebSocket-Protocol"; + public const string SecWebSocketAccept = "Sec-WebSocket-Accept"; + public const string SupportedVersion = "13"; + } + } +} diff --git a/src/Microsoft.AspNetCore.WebSockets.Internal/HandshakeHelpers.cs b/src/Microsoft.AspNetCore.WebSockets.Internal/HandshakeHelpers.cs new file mode 100644 index 0000000000..a3476c8b6b --- /dev/null +++ b/src/Microsoft.AspNetCore.WebSockets.Internal/HandshakeHelpers.cs @@ -0,0 +1,108 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Security.Cryptography; +using System.Text; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.WebSockets.Internal +{ + public static class HandshakeHelpers + { + // Verify Method, Upgrade, Connection, version, key, etc.. + public static bool CheckSupportedWebSocketRequest(HttpRequest request) + { + bool validUpgrade = false, validConnection = false, validKey = false, validVersion = false; + + if (!string.Equals("GET", request.Method, StringComparison.OrdinalIgnoreCase)) + { + return false; + } + + foreach (var pair in request.Headers) + { + if (string.Equals(Constants.Headers.Connection, pair.Key, StringComparison.OrdinalIgnoreCase)) + { + foreach (var value in pair.Value) + { + if (string.Equals(Constants.Headers.ConnectionUpgrade, value, StringComparison.OrdinalIgnoreCase)) + { + validConnection = true; + break; + } + } + } + else if (string.Equals(Constants.Headers.Upgrade, pair.Key, StringComparison.OrdinalIgnoreCase)) + { + if (string.Equals(Constants.Headers.UpgradeWebSocket, pair.Value, StringComparison.OrdinalIgnoreCase)) + { + validUpgrade = true; + } + } + else if (string.Equals(Constants.Headers.SecWebSocketVersion, pair.Key, StringComparison.OrdinalIgnoreCase)) + { + if (string.Equals(Constants.Headers.SupportedVersion, pair.Value, StringComparison.OrdinalIgnoreCase)) + { + validVersion = true; + } + } + else if (string.Equals(Constants.Headers.SecWebSocketKey, pair.Key, StringComparison.OrdinalIgnoreCase)) + { + validKey = IsRequestKeyValid(pair.Value); + } + } + + return validConnection && validUpgrade && validVersion && validKey; + } + + public static IEnumerable> GenerateResponseHeaders(string key, string subProtocol) + { + yield return new KeyValuePair(Constants.Headers.Connection, Constants.Headers.ConnectionUpgrade); + yield return new KeyValuePair(Constants.Headers.Upgrade, Constants.Headers.UpgradeWebSocket); + yield return new KeyValuePair(Constants.Headers.SecWebSocketAccept, CreateResponseKey(key)); + if (!string.IsNullOrWhiteSpace(subProtocol)) + { + yield return new KeyValuePair(Constants.Headers.SecWebSocketProtocol, subProtocol); + } + } + + /// + /// Validates the Sec-WebSocket-Key request header + /// "The value of this header field MUST be a nonce consisting of a randomly selected 16-byte value that has been base64-encoded." + /// + /// + /// + public static bool IsRequestKeyValid(string value) + { + if (string.IsNullOrWhiteSpace(value)) + { + return false; + } + return value.Length == 24; + } + + /// + /// "...the base64-encoded SHA-1 of the concatenation of the |Sec-WebSocket-Key| (as a string, not base64-decoded) with the string + /// '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'" + /// + /// + /// + public static string CreateResponseKey(string requestKey) + { + if (requestKey == null) + { + throw new ArgumentNullException(nameof(requestKey)); + } + + using (var algorithm = SHA1.Create()) + { + string merged = requestKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + byte[] mergedBytes = Encoding.UTF8.GetBytes(merged); + byte[] hashedBytes = algorithm.ComputeHash(mergedBytes); + return Convert.ToBase64String(hashedBytes); + } + } + } +} diff --git a/src/Microsoft.AspNetCore.WebSockets.Internal/IHttpWebSocketConnectionFeature.cs b/src/Microsoft.AspNetCore.WebSockets.Internal/IHttpWebSocketConnectionFeature.cs new file mode 100644 index 0000000000..80dbaddcec --- /dev/null +++ b/src/Microsoft.AspNetCore.WebSockets.Internal/IHttpWebSocketConnectionFeature.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.WebSockets.Internal; + +namespace Microsoft.AspNetCore.WebSockets.Internal +{ + public interface IHttpWebSocketConnectionFeature + { + bool IsWebSocketRequest { get; } + ValueTask AcceptWebSocketConnectionAsync(WebSocketAcceptContext context); + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.WebSockets.Internal/Microsoft.AspNetCore.WebSockets.Internal.xproj b/src/Microsoft.AspNetCore.WebSockets.Internal/Microsoft.AspNetCore.WebSockets.Internal.xproj new file mode 100644 index 0000000000..5678307e24 --- /dev/null +++ b/src/Microsoft.AspNetCore.WebSockets.Internal/Microsoft.AspNetCore.WebSockets.Internal.xproj @@ -0,0 +1,17 @@ + + + + 14.0 + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) + + + + + fffe71f8-e476-4bcd-9689-f106ee1c1497 + + + + 2.0 + + + diff --git a/src/Microsoft.AspNetCore.WebSockets.Internal/Properties/AssemblyInfo.cs b/src/Microsoft.AspNetCore.WebSockets.Internal/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..a5b133fdca --- /dev/null +++ b/src/Microsoft.AspNetCore.WebSockets.Internal/Properties/AssemblyInfo.cs @@ -0,0 +1,19 @@ +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.AspNetCore.WebSockets.Internal")] +[assembly: AssemblyTrademark("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("fffe71f8-e476-4bcd-9689-f106ee1c1497")] diff --git a/src/Microsoft.AspNetCore.WebSockets.Internal/WebSocketAppBuilderExtensions.cs b/src/Microsoft.AspNetCore.WebSockets.Internal/WebSocketAppBuilderExtensions.cs new file mode 100644 index 0000000000..96ec2b3560 --- /dev/null +++ b/src/Microsoft.AspNetCore.WebSockets.Internal/WebSocketAppBuilderExtensions.cs @@ -0,0 +1,41 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Channels; +using Microsoft.AspNetCore.WebSockets.Internal; + +namespace Microsoft.AspNetCore.Builder +{ + public static class WebSocketAppBuilderExtensions + { + public static void UseWebSocketConnections(this IApplicationBuilder self) + { + // Only the GC can clean up this channel factory :( + self.UseWebSocketConnections(new ChannelFactory(), new WebSocketConnectionOptions()); + } + + public static void UseWebSocketConnections(this IApplicationBuilder self, ChannelFactory channelFactory) + { + if (channelFactory == null) + { + throw new ArgumentNullException(nameof(channelFactory)); + } + self.UseWebSocketConnections(channelFactory, new WebSocketConnectionOptions()); + } + + public static void UseWebSocketConnections(this IApplicationBuilder self, ChannelFactory channelFactory, WebSocketConnectionOptions options) + { + if (channelFactory == null) + { + throw new ArgumentNullException(nameof(channelFactory)); + } + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + self.UseWebSocketConnections(channelFactory, options); + self.UseMiddleware(channelFactory, options); + } + } +} diff --git a/src/Microsoft.AspNetCore.WebSockets.Internal/WebSocketConnectionFeature.cs b/src/Microsoft.AspNetCore.WebSockets.Internal/WebSocketConnectionFeature.cs new file mode 100644 index 0000000000..83e2119da4 --- /dev/null +++ b/src/Microsoft.AspNetCore.WebSockets.Internal/WebSocketConnectionFeature.cs @@ -0,0 +1,79 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Channels; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.WebSockets.Internal; + +namespace Microsoft.AspNetCore.WebSockets.Internal +{ + internal class WebSocketConnectionFeature : IHttpWebSocketConnectionFeature + { + private HttpContext _context; + private IHttpUpgradeFeature _upgradeFeature; + private ILogger _logger; + private readonly ChannelFactory _channelFactory; + + public bool IsWebSocketRequest + { + get + { + if (!_upgradeFeature.IsUpgradableRequest) + { + return false; + } + return HandshakeHelpers.CheckSupportedWebSocketRequest(_context.Request); + } + } + + public WebSocketConnectionFeature(HttpContext context, ChannelFactory channelFactory, IHttpUpgradeFeature upgradeFeature, ILoggerFactory loggerFactory) + { + _channelFactory = channelFactory; + _context = context; + _upgradeFeature = upgradeFeature; + _logger = loggerFactory.CreateLogger(); + } + + public ValueTask AcceptWebSocketConnectionAsync(WebSocketAcceptContext acceptContext) + { + if (!IsWebSocketRequest) + { + throw new InvalidOperationException("Not a WebSocket request."); // TODO: LOC + } + + string subProtocol = null; + if (acceptContext != null) + { + subProtocol = acceptContext.SubProtocol; + } + + _logger.LogDebug("WebSocket Handshake completed. SubProtocol: {0}", subProtocol); + + var key = string.Join(", ", _context.Request.Headers[Constants.Headers.SecWebSocketKey]); + + var responseHeaders = HandshakeHelpers.GenerateResponseHeaders(key, subProtocol); + foreach (var headerPair in responseHeaders) + { + _context.Response.Headers[headerPair.Key] = headerPair.Value; + } + + // TODO: Avoid task allocation if there's a ValueTask-based UpgradeAsync? + return new ValueTask(AcceptWebSocketConnectionCoreAsync(subProtocol)); + } + + private async Task AcceptWebSocketConnectionCoreAsync(string subProtocol) + { + _logger.LogDebug("Upgrading connection to WebSockets"); + var opaqueTransport = await _upgradeFeature.UpgradeAsync(); + var connection = new WebSocketConnection( + opaqueTransport.AsReadableChannel(), + _channelFactory.MakeWriteableChannel(opaqueTransport), + subProtocol: subProtocol); + return connection; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.WebSockets.Internal/WebSocketConnectionMiddleware.cs b/src/Microsoft.AspNetCore.WebSockets.Internal/WebSocketConnectionMiddleware.cs new file mode 100644 index 0000000000..e416e25071 --- /dev/null +++ b/src/Microsoft.AspNetCore.WebSockets.Internal/WebSocketConnectionMiddleware.cs @@ -0,0 +1,59 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Channels; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.WebSockets.Internal +{ + public class WebSocketConnectionMiddleware + { + private readonly ChannelFactory _channelFactory; + private readonly ILoggerFactory _loggerFactory; + private readonly RequestDelegate _next; + private readonly WebSocketConnectionOptions _options; + + public WebSocketConnectionMiddleware(RequestDelegate next, ChannelFactory channelFactory, WebSocketConnectionOptions options, ILoggerFactory loggerFactory) + { + if (next == null) + { + throw new ArgumentNullException(nameof(next)); + } + if (channelFactory == null) + { + throw new ArgumentNullException(nameof(channelFactory)); + } + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + if (loggerFactory == null) + { + throw new ArgumentNullException(nameof(loggerFactory)); + } + + _next = next; + _loggerFactory = loggerFactory; + _channelFactory = channelFactory; + _options = options; + } + + public Task Invoke(HttpContext context) + { + var upgradeFeature = context.Features.Get(); + if (upgradeFeature != null) + { + if (_options.ReplaceFeature || context.Features.Get() == null) + { + context.Features.Set(new WebSocketConnectionFeature(context, _channelFactory, upgradeFeature, _loggerFactory)); + } + } + + return _next(context); + } + } +} diff --git a/src/Microsoft.AspNetCore.WebSockets.Internal/WebSocketConnectionOptions.cs b/src/Microsoft.AspNetCore.WebSockets.Internal/WebSocketConnectionOptions.cs new file mode 100644 index 0000000000..3ffa8337ae --- /dev/null +++ b/src/Microsoft.AspNetCore.WebSockets.Internal/WebSocketConnectionOptions.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.WebSockets.Internal +{ + public class WebSocketConnectionOptions + { + public bool ReplaceFeature { get; set; } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.WebSockets.Internal/project.json b/src/Microsoft.AspNetCore.WebSockets.Internal/project.json new file mode 100644 index 0000000000..648fc40ba5 --- /dev/null +++ b/src/Microsoft.AspNetCore.WebSockets.Internal/project.json @@ -0,0 +1,19 @@ +{ + "version": "0.1.0-*", + "buildOptions": { + "warningsAsErrors": true + }, + "description": "Low-allocation Push-oriented WebSockets Middleware based on Channels", + "dependencies": { + "Microsoft.AspNetCore.Http.Abstractions": "1.1.0-*", + "Microsoft.Extensions.Logging.Abstractions": "1.1.0-*", + "Microsoft.Extensions.WebSockets.Internal": "0.1.0-*", + "NETStandard.Library": "1.6.1-*" + }, + + "frameworks": { + "net46": {}, + "netstandard1.3": { + } + } +} diff --git a/src/Microsoft.Extensions.WebSockets.Internal/ChannelExtensions.cs b/src/Microsoft.Extensions.WebSockets.Internal/ChannelExtensions.cs index 8fa61fe247..9db07dbd9f 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/ChannelExtensions.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/ChannelExtensions.cs @@ -1,4 +1,7 @@ -using System.Threading; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading; using System.Threading.Tasks; using Channels; @@ -9,7 +12,7 @@ namespace Microsoft.Extensions.WebSockets.Internal public static ValueTask ReadAtLeastAsync(this IReadableChannel input, int minimumRequiredBytes) => ReadAtLeastAsync(input, minimumRequiredBytes, CancellationToken.None); // TODO: Pull this up to Channels. We should be able to do it there without allocating a Task in any case (rather than here where we can avoid allocation - // only if the buffer is already ready and has enough data. + // only if the buffer is already ready and has enough data) public static ValueTask ReadAtLeastAsync(this IReadableChannel input, int minimumRequiredBytes, CancellationToken cancellationToken) { var awaiter = input.ReadAsync(/* cancellationToken */); diff --git a/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs b/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs index 0befc266e9..dfb65595d9 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs @@ -1,4 +1,7 @@ -using System; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; using System.Threading; using System.Threading.Tasks; @@ -22,6 +25,14 @@ namespace Microsoft.Extensions.WebSockets.Internal /// public interface IWebSocketConnection : IDisposable { + /// + /// Gets the sub-protocol value configured during handshaking. + /// + string SubProtocol { get; } + + /// + /// Gets the current state of the connection + /// WebSocketConnectionState State { get; } /// @@ -45,7 +56,8 @@ namespace Microsoft.Extensions.WebSockets.Internal Task CloseAsync(WebSocketCloseResult result, CancellationToken cancellationToken); /// - /// Runs the WebSocket receive loop, using the provided message handler. + /// Runs the WebSocket receive loop, using the provided message handler. Note that and + /// frames will be passed to this handler for tracking/logging/monitoring, BUT will automatically be handled. /// /// The callback that will be invoked for each new frame /// A state parameter that will be passed to each invocation of @@ -65,7 +77,39 @@ namespace Microsoft.Extensions.WebSockets.Internal /// /// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame. /// - /// A with the payload for the close frame + /// A value to be sent to the client in the close frame. + /// A that completes when the close frame has been sent + public static Task CloseAsync(this IWebSocketConnection self, WebSocketCloseStatus status) => self.CloseAsync(new WebSocketCloseResult(status), CancellationToken.None); + + /// + /// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame. + /// + /// A value to be sent to the client in the close frame. + /// A textual description of the reason for closing the connection. + /// A that completes when the close frame has been sent + public static Task CloseAsync(this IWebSocketConnection self, WebSocketCloseStatus status, string description) => self.CloseAsync(new WebSocketCloseResult(status, description), CancellationToken.None); + + /// + /// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame. + /// + /// A value to be sent to the client in the close frame. + /// A that indicates when/if the send is cancelled. + /// A that completes when the close frame has been sent + public static Task CloseAsync(this IWebSocketConnection self, WebSocketCloseStatus status, CancellationToken cancellationToken) => self.CloseAsync(new WebSocketCloseResult(status), cancellationToken); + + /// + /// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame. + /// + /// A value to be sent to the client in the close frame. + /// A textual description of the reason for closing the connection. + /// A that indicates when/if the send is cancelled. + /// A that completes when the close frame has been sent + public static Task CloseAsync(this IWebSocketConnection self, WebSocketCloseStatus status, string description, CancellationToken cancellationToken) => self.CloseAsync(new WebSocketCloseResult(status, description), cancellationToken); + + /// + /// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame. + /// + /// A with the payload for the close frame. /// A that completes when the close frame has been sent public static Task CloseAsync(this IWebSocketConnection self, WebSocketCloseResult result) => self.CloseAsync(result, CancellationToken.None); diff --git a/src/Microsoft.Extensions.WebSockets.Internal/MaskingUtilities.cs b/src/Microsoft.Extensions.WebSockets.Internal/MaskingUtilities.cs index fea43ff40b..ad49d497a4 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/MaskingUtilities.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/MaskingUtilities.cs @@ -1,4 +1,7 @@ -using System; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; using System.Binary; using Channels; @@ -30,7 +33,6 @@ namespace Microsoft.Extensions.WebSockets.Internal { var span = mem.Span; ApplyMask(span, maskingKey, ref offset); - offset += span.Length; } } diff --git a/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.xproj b/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.xproj index f2985d72e2..74f2bc05f2 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.xproj +++ b/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.xproj @@ -4,18 +4,12 @@ 14.0 $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) - 5d9da986-2eab-4c6d-bf15-9a4bdd4de775 - Microsoft.Extensions.WebSockets - .\obj - .\bin\ - v4.6.1 - 2.0 - + \ No newline at end of file diff --git a/src/Microsoft.Extensions.WebSockets.Internal/Utf8Validator.cs b/src/Microsoft.Extensions.WebSockets.Internal/Utf8Validator.cs new file mode 100644 index 0000000000..9265db660d --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets.Internal/Utf8Validator.cs @@ -0,0 +1,140 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Channels; + +namespace Microsoft.Extensions.WebSockets.Internal +{ + /// + /// Stateful UTF-8 validator. + /// + public class Utf8Validator + { + // Table of UTF-8 code point widths. '0' indicates an invalid first byte. + private static readonly byte[] _utf8Width = new byte[256] + { + /* 0x00 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x0F */ + /* 0x10 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x1F */ + /* 0x20 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x2F */ + /* 0x30 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x3F */ + /* 0x40 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x4F */ + /* 0x50 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x5F */ + /* 0x60 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x6F */ + /* 0x70 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x7F */ + /* 0x80 */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 0x8F */ + /* 0x90 */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 0x9F */ + /* 0xA0 */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 0xAF */ + /* 0xB0 */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 0xBF */ + /* 0xC0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, /* 0xCF */ + /* 0xD0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, /* 0xDF */ + /* 0xE0 */ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, /* 0xEF */ + /* 0xF0 */ 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, /* 0xFF */ + }; + + // Table of masks used to extract the code point bits from the first byte. Indexed by (width - 1) + private static readonly byte[] _utf8Mask = new byte[4] { 0x7F, 0x1F, 0x0F, 0x07 }; + + // Table of minimum valid code-points based on the width. Indexed by (width - 1) + private static readonly int[] _utf8Min = new int[4] { 0x00000, 0x00080, 0x00800, 0x10000 }; + + private struct Utf8ValidatorState + { + public bool _withinSequence; + public int _remainingBytesInChar; + public int _currentDecodedValue; + public int _minCodePoint; + + public void Reset() + { + _withinSequence = false; + _remainingBytesInChar = 0; + _currentDecodedValue = 0; + _minCodePoint = 0; + } + } + + private Utf8ValidatorState _state; + + public void Reset() + { + _state.Reset(); + } + + public bool ValidateUtf8Frame(ReadableBuffer payload, bool fin) => ValidateUtf8(ref _state, payload, fin); + + public static bool ValidateUtf8(ReadableBuffer payload) + { + var state = new Utf8ValidatorState(); + return ValidateUtf8(ref state, payload, fin: true); + } + + private static bool ValidateUtf8(ref Utf8ValidatorState state, ReadableBuffer payload, bool fin) + { + // Walk through the payload verifying it + var offset = 0; + foreach (var mem in payload) + { + var span = mem.Span; + for (int i = 0; i < span.Length; i++) + { + var b = span[i]; + if (!state._withinSequence) + { + // This is the first byte of a char, so set things up + var width = _utf8Width[b]; + state._remainingBytesInChar = width - 1; + if (state._remainingBytesInChar < 0) + { + // Invalid first byte + return false; + } + + // Use the width (-1) to index into the mask and min tables. + state._currentDecodedValue = b & _utf8Mask[width - 1]; + state._minCodePoint = _utf8Min[width - 1]; + state._withinSequence = true; + } + else + { + // Add this byte to the value + state._currentDecodedValue = (state._currentDecodedValue << 6) | (b & 0x3F); + state._remainingBytesInChar--; + } + + // Fast invalid exits + if (state._remainingBytesInChar == 1 && state._currentDecodedValue >= 0x360 && state._currentDecodedValue <= 0x37F) + { + // This will be a UTF-16 surrogate: 0xD800-0xDFFF + return false; + } + if (state._remainingBytesInChar == 2 && state._currentDecodedValue >= 0x110) + { + // This will be above the maximum Unicode character (0x10FFFF). + return false; + } + + if (state._remainingBytesInChar == 0) + { + // Check the range of the final decoded value + if (state._currentDecodedValue < state._minCodePoint) + { + // This encoding is longer than it should be, which is not allowed. + return false; + } + + // Reset state + state._withinSequence = false; + } + offset++; + } + } + + // We're done. + // The value is valid if: + // 1. We haven't reached the end of the whole message yet (we'll be caching this state for the next message) + // 2. We aren't inside a character sequence (i.e. the last character isn't unterminated) + return !fin || !state._withinSequence; + } + } +} diff --git a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketCloseResult.cs b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketCloseResult.cs index a44dd573d6..248bd993d4 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketCloseResult.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketCloseResult.cs @@ -1,4 +1,7 @@ -using System.Binary; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Binary; using System.Text; using Channels; using Channels.Text.Primitives; @@ -32,22 +35,25 @@ namespace Microsoft.Extensions.WebSockets.Internal public int GetSize() => Encoding.UTF8.GetByteCount(Description) + sizeof(ushort); - public static bool TryParse(ReadableBuffer payload, out WebSocketCloseResult result) + public static bool TryParse(ReadableBuffer payload, out WebSocketCloseResult result, out ushort? actualCloseCode) { if (payload.Length == 0) { // Empty payload is OK + actualCloseCode = null; result = new WebSocketCloseResult(WebSocketCloseStatus.Empty, string.Empty); return true; } else if (payload.Length < 2) { + actualCloseCode = null; result = default(WebSocketCloseResult); return false; } else { var status = payload.ReadBigEndian(); + actualCloseCode = status; var description = string.Empty; payload = payload.Slice(2); if (payload.Length > 0) diff --git a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketCloseStatus.cs b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketCloseStatus.cs index 33fd343423..9b99a93d7c 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketCloseStatus.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketCloseStatus.cs @@ -1,4 +1,7 @@ -namespace Microsoft.Extensions.WebSockets.Internal +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.Extensions.WebSockets.Internal { /// /// Represents well-known WebSocket Close frame status codes. diff --git a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs index 237bcd2d6c..013288658e 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs @@ -1,7 +1,11 @@ -using System; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; using System.Binary; using System.Diagnostics; -using System.Security.Cryptography; +using System.Globalization; +using System.Text; using System.Threading; using System.Threading.Tasks; using Channels; @@ -24,11 +28,20 @@ namespace Microsoft.Extensions.WebSockets.Internal /// public class WebSocketConnection : IWebSocketConnection { - private readonly RandomNumberGenerator _random; - private readonly byte[] _maskingKey; + private WebSocketOptions _options; + private readonly byte[] _maskingKeyBuffer; private readonly IReadableChannel _inbound; private readonly IWritableChannel _outbound; private readonly CancellationTokenSource _terminateReceiveCts = new CancellationTokenSource(); + private readonly Timer _pinger; + private readonly CancellationTokenSource _timerCts = new CancellationTokenSource(); + private Utf8Validator _validator = new Utf8Validator(); + private WebSocketOpcode _currentMessageType = WebSocketOpcode.Continuation; + + // Sends must be serialized between SendAsync, Pinger, and the Close frames sent when invalid messages are received. + private SemaphoreSlim _sendLock = new SemaphoreSlim(1, 1); + + public string SubProtocol { get; } public WebSocketConnectionState State { get; private set; } = WebSocketConnectionState.Created; @@ -37,45 +50,91 @@ namespace Microsoft.Extensions.WebSockets.Internal /// /// A from which frames will be read when receiving. /// A to which frame will be written when sending. - public WebSocketConnection(IReadableChannel inbound, IWritableChannel outbound) : this(inbound, outbound, masked: false) { } + public WebSocketConnection(IReadableChannel inbound, IWritableChannel outbound) : this(inbound, outbound, options: WebSocketOptions.DefaultUnmasked) { } /// - /// Constructs a new, optionally masked, from an and an that represents an established WebSocket connection (i.e. after handshaking) + /// Constructs a new, unmasked, from an and an that represents an established WebSocket connection (i.e. after handshaking) /// /// A from which frames will be read when receiving. /// A to which frame will be written when sending. - /// A boolean indicating if frames sent from this socket should be masked (the masking key is automatically generated) - public WebSocketConnection(IReadableChannel inbound, IWritableChannel outbound, bool masked) + /// The sub-protocol provided during handshaking + public WebSocketConnection(IReadableChannel inbound, IWritableChannel outbound, string subProtocol) : this(inbound, outbound, subProtocol, options: WebSocketOptions.DefaultUnmasked) { } + + /// + /// Constructs a new, from an and an that represents an established WebSocket connection (i.e. after handshaking) + /// + /// A from which frames will be read when receiving. + /// A to which frame will be written when sending. + /// A which provides the configuration options for the socket. + public WebSocketConnection(IReadableChannel inbound, IWritableChannel outbound, WebSocketOptions options) : this(inbound, outbound, subProtocol: string.Empty, options: options) { } + + /// + /// Constructs a new from an and an that represents an established WebSocket connection (i.e. after handshaking) + /// + /// A from which frames will be read when receiving. + /// A to which frame will be written when sending. + /// The sub-protocol provided during handshaking + /// A which provides the configuration options for the socket. + public WebSocketConnection(IReadableChannel inbound, IWritableChannel outbound, string subProtocol, WebSocketOptions options) { _inbound = inbound; _outbound = outbound; + _options = options; + SubProtocol = subProtocol; - if (masked) + if (_options.FixedMaskingKey != null) { - _maskingKey = new byte[4]; - _random = RandomNumberGenerator.Create(); + // Use the fixed key directly as the buffer. + _maskingKeyBuffer = _options.FixedMaskingKey; + + // Clear the MaskingKeyGenerator just to ensure that nobody set it. + _options.MaskingKeyGenerator = null; + } + else if (_options.MaskingKeyGenerator != null) + { + // Establish a buffer for the random generator to use + _maskingKeyBuffer = new byte[4]; + } + + if (_options.PingInterval > TimeSpan.Zero) + { + var pingIntervalMillis = (int)_options.PingInterval.TotalMilliseconds; + // Set up the pinger + _pinger = new Timer(Pinger, this, pingIntervalMillis, pingIntervalMillis); } } - /// - /// Constructs a new, fixed masking-key, from an and an that represents an established WebSocket connection (i.e. after handshaking) - /// - /// A from which frames will be read when receiving. - /// A to which frame will be written when sending. - /// The masking key to use for the connection. Must be exactly 4-bytes long. This is ONLY recommended for testing and development purposes. - public WebSocketConnection(IReadableChannel inbound, IWritableChannel outbound, byte[] fixedMaskingKey) + private static void Pinger(object state) { - _inbound = inbound; - _outbound = outbound; - _maskingKey = fixedMaskingKey; + var connection = (WebSocketConnection)state; + + // If we are cancelled, don't send the ping + // Also, if we can't immediately acquire the send lock, we're already sending something, so we don't need the ping. + if (!connection._timerCts.Token.IsCancellationRequested && connection._sendLock.Wait(0)) + { + // We don't need to wait for this task to complete, we're "tail calling" and + // we are in a Timer thread-pool thread. +#pragma warning disable 4014 + connection.SendCoreLockAcquiredAsync( + fin: true, + opcode: WebSocketOpcode.Ping, + payloadAllocLength: 28, + payloadLength: 28, + payloadWriter: PingPayloadWriter, + payload: DateTime.UtcNow, + cancellationToken: connection._timerCts.Token); +#pragma warning restore 4014 + } } public void Dispose() { State = WebSocketConnectionState.Closed; + _pinger?.Dispose(); + _timerCts.Cancel(); + _terminateReceiveCts.Cancel(); _inbound.Complete(); _outbound.Complete(); - _terminateReceiveCts.Cancel(); } public Task ExecuteAsync(Func messageHandler, object state) @@ -109,7 +168,7 @@ namespace Microsoft.Extensions.WebSockets.Internal // This clause is a bit of an artificial restriction to ensure people run "Execute". Maybe we don't care? else if (State == WebSocketConnectionState.Created) { - throw new InvalidOperationException("Cannot send until the connection is started using Execute"); + throw new InvalidOperationException($"Cannot send until the connection is started using {nameof(ExecuteAsync)}"); } else if (State == WebSocketConnectionState.CloseSent) { @@ -118,9 +177,16 @@ namespace Microsoft.Extensions.WebSockets.Internal if (frame.Opcode == WebSocketOpcode.Close) { - throw new InvalidOperationException("Cannot use SendAsync to send a Close frame, use CloseAsync instead."); + throw new InvalidOperationException($"Cannot use {nameof(SendAsync)} to send a Close frame, use {nameof(CloseAsync)} instead."); } - return SendCoreAsync(frame, null, cancellationToken); + return SendCoreAsync( + fin: frame.EndOfMessage, + opcode: frame.Opcode, + payloadAllocLength: 0, // We don't copy the payload, we append it, so we don't need any alloc for the payload + payloadLength: frame.Payload.Length, + payloadWriter: AppendPayloadWriter, + payload: frame.Payload, + cancellationToken: cancellationToken); } /// @@ -148,10 +214,18 @@ namespace Microsoft.Extensions.WebSockets.Internal throw new InvalidOperationException("Cannot send multiple close frames"); } - // When we pass a close result to SendCoreAsync, the frame is only used for the header and the payload is ignored - var frame = new WebSocketFrame(endOfMessage: true, opcode: WebSocketOpcode.Close, payload: default(ReadableBuffer)); + var payloadSize = result.GetSize(); + await SendCoreAsync( + fin: true, + opcode: WebSocketOpcode.Close, + payloadAllocLength: payloadSize, + payloadLength: payloadSize, + payloadWriter: CloseResultPayloadWriter, + payload: result, + cancellationToken: cancellationToken); - await SendCoreAsync(frame, result, cancellationToken); + _timerCts.Cancel(); + _pinger?.Dispose(); if (State == WebSocketConnectionState.CloseReceived) { @@ -165,15 +239,15 @@ namespace Microsoft.Extensions.WebSockets.Internal private void WriteMaskingKey(Span buffer) { - if (_random != null) + if (_options.MaskingKeyGenerator != null) { // Get a new random mask // Until https://github.com/dotnet/corefx/issues/12323 is fixed we need to use this shared buffer and copy model // Once we have that fix we should be able to generate the mask directly into the output buffer. - _random.GetBytes(_maskingKey); + _options.MaskingKeyGenerator.GetBytes(_maskingKeyBuffer); } - buffer.Set(_maskingKey); + buffer.Set(_maskingKeyBuffer); } private async Task ReceiveLoop(Func messageHandler, object state, CancellationToken cancellationToken) @@ -213,15 +287,27 @@ namespace Microsoft.Extensions.WebSockets.Internal var opcodeByte = buffer.ReadBigEndian(); buffer = buffer.Slice(1); - var fin = (opcodeByte & 0x01) != 0; - var opcode = (WebSocketOpcode)((opcodeByte & 0xF0) >> 4); + var fin = (opcodeByte & 0x80) != 0; + var opcodeNum = opcodeByte & 0x0F; + var opcode = (WebSocketOpcode)opcodeNum; + + if ((opcodeByte & 0x70) != 0) + { + // Reserved bits set, this frame is invalid, close our side and terminate immediately + return await CloseFromProtocolError(cancellationToken, 0, default(ReadableBuffer), "Reserved bits, which are required to be zero, were set."); + } + else if ((opcodeNum >= 0x03 && opcodeNum <= 0x07) || (opcodeNum >= 0x0B && opcodeNum <= 0x0F)) + { + // Reserved opcode + return await CloseFromProtocolError(cancellationToken, 0, default(ReadableBuffer), $"Received frame using reserved opcode: 0x{opcodeNum:X}"); + } // Read the first byte of the payload length var lenByte = buffer.ReadBigEndian(); buffer = buffer.Slice(1); - var masked = (lenByte & 0x01) != 0; - var payloadLen = (lenByte & 0xFE) >> 1; + var masked = (lenByte & 0x80) != 0; + var payloadLen = (lenByte & 0x7F); // Mark what we've got so far as consumed _inbound.Advance(buffer.Start); @@ -234,7 +320,7 @@ namespace Microsoft.Extensions.WebSockets.Internal } else if (payloadLen == 127) { - headerLength += 4; + headerLength += 8; } uint maskingKey = 0; @@ -302,13 +388,104 @@ namespace Microsoft.Extensions.WebSockets.Internal cancellationToken.ThrowIfCancellationRequested(); var frame = new WebSocketFrame(fin, opcode, payload); + + if (frame.Opcode.IsControl() && !frame.EndOfMessage) + { + // Control frames cannot be fragmented. + return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Control frames may not be fragmented"); + } + else if (_currentMessageType != WebSocketOpcode.Continuation && opcode.IsMessage() && opcode != 0) + { + return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Received non-continuation frame during a fragmented message"); + } + else if (_currentMessageType == WebSocketOpcode.Continuation && frame.Opcode == WebSocketOpcode.Continuation) + { + return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Continuation Frame was received when expecting a new message"); + } + if (frame.Opcode == WebSocketOpcode.Close) { - return HandleCloseFrame(payloadLen, payload, frame); + // Allowed frame lengths: + // 0 - No body + // 2 - Code with no reason phrase + // >2 - Code and reason phrase (must be valid UTF-8) + if (frame.Payload.Length > 125) + { + return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Close frame payload too long. Maximum size is 125 bytes"); + } + else if ((frame.Payload.Length == 1) || (frame.Payload.Length > 2 && !Utf8Validator.ValidateUtf8(payload.Slice(2)))) + { + return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Close frame payload invalid"); + } + + ushort? actualStatusCode; + var closeResult = HandleCloseFrame(payload, frame, out actualStatusCode); + + // Verify the close result + if (actualStatusCode != null) + { + var statusCode = actualStatusCode.Value; + if (statusCode < 1000 || statusCode == 1004 || statusCode == 1005 || statusCode == 1006 || (statusCode > 1011 && statusCode < 3000)) + { + return await CloseFromProtocolError(cancellationToken, payloadLen, payload, $"Invalid close status: {statusCode}."); + } + } + + // Make the payload as consumed + if (payloadLen > 0) + { + _inbound.Advance(payload.End); + } + + return closeResult; } else { - await messageHandler(frame, state); + if (frame.Opcode == WebSocketOpcode.Ping) + { + // Check the ping payload length + if (frame.Payload.Length > 125) + { + // Payload too long + return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Ping frame exceeded maximum size of 125 bytes"); + } + + await SendCoreAsync( + frame.EndOfMessage, + WebSocketOpcode.Pong, + payloadAllocLength: 0, + payloadLength: payload.Length, + payloadWriter: AppendPayloadWriter, + payload: payload, + cancellationToken: cancellationToken); + } + var effectiveOpcode = opcode == WebSocketOpcode.Continuation ? _currentMessageType : opcode; + if (effectiveOpcode == WebSocketOpcode.Text && !_validator.ValidateUtf8Frame(frame.Payload, frame.EndOfMessage)) + { + // Drop the frame and immediately close with InvalidPayload + return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "An invalid Text frame payload was received", statusCode: WebSocketCloseStatus.InvalidPayloadData); + } + else if (_options.PassAllFramesThrough || (frame.Opcode != WebSocketOpcode.Ping && frame.Opcode != WebSocketOpcode.Pong)) + { + await messageHandler(frame, state); + } + } + + if (fin) + { + // Reset the UTF8 validator + _validator.Reset(); + + // If it's a non-control frame, reset the message type tracker + if (opcode.IsMessage()) + { + _currentMessageType = WebSocketOpcode.Continuation; + } + } + // If there isn't a current message type, and this was a fragmented message frame, set the current message type + else if (!fin && _currentMessageType == WebSocketOpcode.Continuation && opcode.IsMessage()) + { + _currentMessageType = opcode; } // Mark the payload as consumed @@ -320,7 +497,22 @@ namespace Microsoft.Extensions.WebSockets.Internal return WebSocketCloseResult.AbnormalClosure; } - private WebSocketCloseResult HandleCloseFrame(int payloadLen, ReadableBuffer payload, WebSocketFrame frame) + private async Task CloseFromProtocolError(CancellationToken cancellationToken, int payloadLen, ReadableBuffer payload, string reason, WebSocketCloseStatus statusCode = WebSocketCloseStatus.ProtocolError) + { + // Non-continuation non-control message during fragmented message + if (payloadLen > 0) + { + _inbound.Advance(payload.End); + } + var closeResult = new WebSocketCloseResult( + statusCode, + reason); + await CloseAsync(closeResult, cancellationToken); + Dispose(); + return closeResult; + } + + private WebSocketCloseResult HandleCloseFrame(ReadableBuffer payload, WebSocketFrame frame, out ushort? actualStatusCode) { // Update state if (State == WebSocketConnectionState.CloseSent) @@ -334,24 +526,134 @@ namespace Microsoft.Extensions.WebSockets.Internal // Process the close frame WebSocketCloseResult closeResult; - if (!WebSocketCloseResult.TryParse(frame.Payload, out closeResult)) + if (!WebSocketCloseResult.TryParse(frame.Payload, out closeResult, out actualStatusCode)) { closeResult = WebSocketCloseResult.Empty; } - - // Make the payload as consumed - if (payloadLen > 0) - { - _inbound.Advance(payload.End); - } return closeResult; } - private Task SendCoreAsync(WebSocketFrame message, WebSocketCloseResult? closeResult, CancellationToken cancellationToken) + private static void PingPayloadWriter(WritableBuffer output, Span maskingKey, int payloadLength, DateTime timestamp) + { + var payload = output.Memory.Slice(0, payloadLength); + + // TODO: Don't put this string on the heap? Is there a way to do that without re-implementing ToString? + // Ideally we'd like to render the string directly to the output buffer. + var str = timestamp.ToString("O", CultureInfo.InvariantCulture); + + ArraySegment buffer; + if (payload.TryGetArray(out buffer)) + { + // Fast path - Write the encoded bytes directly out. + Encoding.UTF8.GetBytes(str, 0, str.Length, buffer.Array, buffer.Offset); + } + else + { + // TODO: Could use TryGetPointer, GetBytes does take a byte*, but it seems like just waiting until we have a version that uses Span is best. + // Slow path - Allocate a heap buffer for the encoded bytes before writing them out. + payload.Span.Set(Encoding.UTF8.GetBytes(str)); + } + + if (maskingKey.Length > 0) + { + MaskingUtilities.ApplyMask(payload.Span, maskingKey); + } + + output.Advance(payloadLength); + } + + private static void CloseResultPayloadWriter(WritableBuffer output, Span maskingKey, int payloadLength, WebSocketCloseResult result) + { + // Write the close payload out + var payload = output.Memory.Slice(0, payloadLength).Span; + result.WriteTo(ref output); + + if (maskingKey.Length > 0) + { + MaskingUtilities.ApplyMask(payload, maskingKey); + } + } + + private static void AppendPayloadWriter(WritableBuffer output, Span maskingKey, int payloadLength, ReadableBuffer payload) + { + if (maskingKey.Length > 0) + { + // Mask the payload in it's own buffer + MaskingUtilities.ApplyMask(ref payload, maskingKey); + } + + output.Append(payload); + } + + private Task SendCoreAsync(bool fin, WebSocketOpcode opcode, int payloadAllocLength, int payloadLength, Action, int, T> payloadWriter, T payload, CancellationToken cancellationToken) + { + if (_sendLock.Wait(0)) + { + return SendCoreLockAcquiredAsync(fin, opcode, payloadAllocLength, payloadLength, payloadWriter, payload, cancellationToken); + } + else + { + return SendCoreWaitForLockAsync(fin, opcode, payloadAllocLength, payloadLength, payloadWriter, payload, cancellationToken); + } + } + + private async Task SendCoreWaitForLockAsync(bool fin, WebSocketOpcode opcode, int payloadAllocLength, int payloadLength, Action, int, T> payloadWriter, T payload, CancellationToken cancellationToken) + { + await _sendLock.WaitAsync(cancellationToken); + await SendCoreLockAcquiredAsync(fin, opcode, payloadAllocLength, payloadLength, payloadWriter, payload, cancellationToken); + } + + private async Task SendCoreLockAcquiredAsync(bool fin, WebSocketOpcode opcode, int payloadAllocLength, int payloadLength, Action, int, T> payloadWriter, T payload, CancellationToken cancellationToken) + { + try + { + // Ensure the lock is held + Debug.Assert(_sendLock.CurrentCount == 0); + + // Base header size is 2 bytes. + WritableBuffer buffer; + var allocSize = CalculateAllocSize(payloadAllocLength, payloadLength); + + // Allocate a buffer + buffer = _outbound.Alloc(minimumSize: allocSize); + Debug.Assert(buffer.Memory.Length >= allocSize); + + // Write the opcode and FIN flag + var opcodeByte = (byte)opcode; + if (fin) + { + opcodeByte |= 0x80; + } + buffer.WriteBigEndian(opcodeByte); + + // Write the length and mask flag + WritePayloadLength(payloadLength, buffer); + + var maskingKey = Span.Empty; + if (_maskingKeyBuffer != null) + { + // Get a span of the output buffer for the masking key, write it there, then advance the write head. + maskingKey = buffer.Memory.Slice(0, 4).Span; + WriteMaskingKey(maskingKey); + buffer.Advance(4); + } + + // Write the payload + payloadWriter(buffer, maskingKey, payloadLength, payload); + + // Flush. + await buffer.FlushAsync(); + } + finally + { + // Unlock. + _sendLock.Release(); + } + } + + private int CalculateAllocSize(int payloadAllocLength, int payloadLength) { - // Base header size is 2 bytes. var allocSize = 2; - var payloadLength = closeResult == null ? message.Payload.Length : closeResult.Value.GetSize(); if (payloadLength > ushort.MaxValue) { // We're going to need an 8-byte length @@ -362,46 +664,30 @@ namespace Microsoft.Extensions.WebSockets.Internal // We're going to need a 2-byte length allocSize += 2; } - if (_maskingKey != null) + if (_maskingKeyBuffer != null) { // We need space for the masking key allocSize += 4; } - if (closeResult != null) - { - // We need space for the close result payload too - allocSize += payloadLength; - } - // Allocate a buffer - var buffer = _outbound.Alloc(minimumSize: allocSize); - Debug.Assert(buffer.Memory.Length >= allocSize); - if (buffer.Memory.Length < allocSize) - { - throw new InvalidOperationException("Couldn't allocate enough data from the channel to write the header"); - } + // We may need space for the payload too + return allocSize + payloadAllocLength; + } - // Write the opcode and FIN flag - var opcodeByte = (byte)((int)message.Opcode << 4); - if (message.EndOfMessage) - { - opcodeByte |= 1; - } - buffer.WriteBigEndian(opcodeByte); - - // Write the length and mask flag - var maskingByte = _maskingKey != null ? 0x01 : 0x00; // TODO: Masking flag goes here + private void WritePayloadLength(int payloadLength, WritableBuffer buffer) + { + var maskingByte = _maskingKeyBuffer != null ? 0x80 : 0x00; if (payloadLength > ushort.MaxValue) { - buffer.WriteBigEndian((byte)(0xFE | maskingByte)); + buffer.WriteBigEndian((byte)(0x7F | maskingByte)); // 8-byte length buffer.WriteBigEndian((ulong)payloadLength); } else if (payloadLength > 125) { - buffer.WriteBigEndian((byte)(0xFC | maskingByte)); + buffer.WriteBigEndian((byte)(0x7E | maskingByte)); // 2-byte length buffer.WriteBigEndian((ushort)payloadLength); @@ -409,48 +695,8 @@ namespace Microsoft.Extensions.WebSockets.Internal else { // 1-byte length - buffer.WriteBigEndian((byte)((payloadLength << 1) | maskingByte)); + buffer.WriteBigEndian((byte)(payloadLength | maskingByte)); } - - var maskingKey = Span.Empty; - if (_maskingKey != null) - { - // Get a span of the output buffer for the masking key, write it there, then advance the write head. - maskingKey = buffer.Memory.Slice(0, 4).Span; - WriteMaskingKey(maskingKey); - buffer.Advance(4); - } - - if (closeResult != null) - { - // Write the close payload out - var payload = buffer.Memory.Slice(0, payloadLength).Span; - closeResult.Value.WriteTo(ref buffer); - - if (_maskingKey != null) - { - MaskingUtilities.ApplyMask(payload, maskingKey); - } - } - else - { - // This will copy the actual buffer struct, but NOT the underlying data - // We need a field so we can by-ref it. - var payload = message.Payload; - - if (_maskingKey != null) - { - // Mask the payload in it's own buffer - MaskingUtilities.ApplyMask(ref payload, maskingKey); - } - - // Append the (masked) buffer to the output channel - buffer.Append(payload); - } - - - // Commit and Flush - return buffer.FlushAsync(); } } } diff --git a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnectionState.cs b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnectionState.cs index e84cead104..d554c2d295 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnectionState.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnectionState.cs @@ -1,4 +1,7 @@ -namespace Microsoft.Extensions.WebSockets.Internal +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.Extensions.WebSockets.Internal { public enum WebSocketConnectionState { diff --git a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketException.cs b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketException.cs index 70cdb7d951..dea59c6fa6 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketException.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketException.cs @@ -1,4 +1,7 @@ -using System; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; namespace Microsoft.Extensions.WebSockets.Internal { diff --git a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketFrame.cs b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketFrame.cs index adc8721043..48add0aa3f 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketFrame.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketFrame.cs @@ -1,4 +1,7 @@ -using Channels; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Channels; namespace Microsoft.Extensions.WebSockets.Internal { @@ -28,5 +31,18 @@ namespace Microsoft.Extensions.WebSockets.Internal Opcode = opcode; Payload = payload; } + + /// + /// Creates a new containing the same information, but with all buffers + /// copied to new heap memory. + /// + /// + public WebSocketFrame Copy() + { + return new WebSocketFrame( + endOfMessage: EndOfMessage, + opcode: Opcode, + payload: ReadableBuffer.Create(Payload.ToArray())); + } } } \ No newline at end of file diff --git a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketOpcode.cs b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketOpcode.cs index 4840bbcb45..1dd7854b07 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketOpcode.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketOpcode.cs @@ -1,4 +1,9 @@ -namespace Microsoft.Extensions.WebSockets.Internal +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Runtime.CompilerServices; + +namespace Microsoft.Extensions.WebSockets.Internal { /// /// Represents the possible values for the "opcode" field of a WebSocket frame. @@ -39,4 +44,19 @@ /* all opcodes above 0xF are invalid */ } + + public static class WebSocketOpcodeExtensions + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsControl(this WebSocketOpcode self) + { + return self >= WebSocketOpcode.Close; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool IsMessage(this WebSocketOpcode self) + { + return self < WebSocketOpcode.Close; + } + } } \ No newline at end of file diff --git a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketOptions.cs b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketOptions.cs new file mode 100644 index 0000000000..82f9d1c513 --- /dev/null +++ b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketOptions.cs @@ -0,0 +1,140 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Security.Cryptography; + +namespace Microsoft.Extensions.WebSockets.Internal +{ + public class WebSocketOptions + { + /// + /// Gets the default ping interval of 30 seconds. + /// + public static TimeSpan DefaultPingInterval = TimeSpan.FromSeconds(30); + + /// + /// Gets the default for an unmasked sender. + /// + /// + /// Uses the default ping interval defined in , no masking key, + /// and automatically responds to pings. + /// + public static readonly WebSocketOptions DefaultUnmasked = new WebSocketOptions() + { + PingInterval = DefaultPingInterval, + MaskingKeyGenerator = null, + FixedMaskingKey = null + }; + + /// + /// Gets the default for an unmasked sender. + /// + /// + /// Uses the default ping interval defined in , the system random + /// key generator, and automatically responds to pings. + /// + public static readonly WebSocketOptions DefaultMasked = new WebSocketOptions() + { + PingInterval = DefaultPingInterval, + MaskingKeyGenerator = RandomNumberGenerator.Create(), + FixedMaskingKey = null + }; + + /// + /// Gets or sets a boolean indicating if all frames, even those automatically handled ( and frames), + /// should be passed to the callback. NOTE: The frames will STILL be automatically handled, they are + /// only passed along for diagnostic purposes. + /// + public bool PassAllFramesThrough { get; private set; } + + /// + /// Gets or sets the time between pings sent from the local endpoint + /// + public TimeSpan PingInterval { get; private set; } + + /// + /// Gets or sets the used to generate masking keys used to mask outgoing frames. + /// If is set, this value is ignored. If neither this value nor + /// is set, no masking will be performed. + /// + public RandomNumberGenerator MaskingKeyGenerator { get; internal set; } + + /// + /// Gets or sets a fixed masking key used to mask outgoing frames. If this value is set, + /// is ignored. If neither this value nor is set, no masking will be performed. + /// + public byte[] FixedMaskingKey { get; private set; } + + /// + /// Sets the ping interval for this . + /// + /// The interval at which ping frames will be sent + /// A new with the specified ping interval + public WebSocketOptions WithPingInterval(TimeSpan pingInterval) + { + return new WebSocketOptions() + { + PingInterval = pingInterval, + FixedMaskingKey = FixedMaskingKey, + MaskingKeyGenerator = MaskingKeyGenerator + }; + } + + /// + /// Enables frame pass-through in this . Generally for diagnostic or testing purposes only. + /// + /// A new with set to true + public WebSocketOptions WithAllFramesPassedThrough() + { + return new WebSocketOptions() + { + PassAllFramesThrough = true, + PingInterval = PingInterval, + FixedMaskingKey = FixedMaskingKey, + MaskingKeyGenerator = MaskingKeyGenerator + }; + } + + /// + /// Enables random masking in this , using the system random number generator. + /// + /// A new with random masking enabled + public WebSocketOptions WithRandomMasking() => WithRandomMasking(RandomNumberGenerator.Create()); + + /// + /// Enables random masking in this , using the provided random number generator. + /// + /// The to use to generate masking keys + /// A new with random masking enabled + public WebSocketOptions WithRandomMasking(RandomNumberGenerator rng) + { + return new WebSocketOptions() + { + PingInterval = PingInterval, + FixedMaskingKey = null, + MaskingKeyGenerator = rng + }; + } + + /// + /// Enables fixed masking in this . FOR DEVELOPMENT PURPOSES ONLY. + /// + /// The masking key to use for all outgoing frames. + /// A new with fixed masking enabled + public WebSocketOptions WithFixedMaskingKey(byte[] maskingKey) + { + if (maskingKey.Length != 4) + { + throw new ArgumentException("Masking Key must be exactly 4 bytes", nameof(maskingKey)); + } + + return new WebSocketOptions() + { + PingInterval = PingInterval, + FixedMaskingKey = maskingKey, + MaskingKeyGenerator = null + }; + } + } +} diff --git a/src/Microsoft.Extensions.WebSockets.Internal/project.json b/src/Microsoft.Extensions.WebSockets.Internal/project.json index 57180dfa02..b444d29b9a 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/project.json +++ b/src/Microsoft.Extensions.WebSockets.Internal/project.json @@ -5,30 +5,14 @@ "allowUnsafe": true }, "description": "Low-allocation Push-oriented WebSockets based on Channels", - "packOptions": { - "repository": { - "type": "git", - "url": "git://github.com/aspnet/websockets" - } - }, "dependencies": { "Channels": "0.2.0-beta-*", - "Channels.Text.Primitives": "0.2.0-beta-*" + "Channels.Text.Primitives": "0.2.0-beta-*", + "NETStandard.Library": "1.6.0" }, "frameworks": { "net46": {}, - "netstandard1.3": { - "dependencies": { - "System.Collections": "4.0.11", - "System.Diagnostics.Debug": "4.0.11", - "System.IO": "4.1.0", - "System.Linq": "4.1.0", - "System.Runtime": "4.1.0", - "System.Runtime.Extensions": "4.1.0", - "System.Threading": "4.0.11", - "System.Threading.Tasks": "4.0.11" - } - } + "netstandard1.3": {} } } diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnCaseResult.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnCaseResult.cs new file mode 100644 index 0000000000..5152cc871a --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnCaseResult.cs @@ -0,0 +1,33 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using Newtonsoft.Json.Linq; + +namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn +{ + public class AutobahnCaseResult + { + public string Name { get; } + public string ActualBehavior { get; } + + public AutobahnCaseResult(string name, string actualBehavior) + { + Name = name; + ActualBehavior = actualBehavior; + } + + public static AutobahnCaseResult FromJson(JProperty prop) + { + var caseObj = (JObject)prop.Value; + var actualBehavior = (string)caseObj["behavior"]; + return new AutobahnCaseResult(prop.Name, actualBehavior); + } + + public bool BehaviorIs(params string[] behaviors) + { + return behaviors.Any(b => string.Equals(b, ActualBehavior, StringComparison.Ordinal)); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnExpectations.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnExpectations.cs new file mode 100644 index 0000000000..f7438d7533 --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnExpectations.cs @@ -0,0 +1,89 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.Server.IntegrationTesting; + +namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn +{ + public class AutobahnExpectations + { + private Dictionary _expectations = new Dictionary(); + public bool Ssl { get; } + public ServerType Server { get; } + + public AutobahnExpectations(ServerType server, bool ssl) + { + Server = server; + Ssl = ssl; + } + + public AutobahnExpectations Fail(params string[] caseSpecs) => Expect(Expectation.Fail, caseSpecs); + public AutobahnExpectations NonStrict(params string[] caseSpecs) => Expect(Expectation.NonStrict, caseSpecs); + public AutobahnExpectations OkOrNonStrict(params string[] caseSpecs) => Expect(Expectation.OkOrNonStrict, caseSpecs); + public AutobahnExpectations OkOrFail(params string[] caseSpecs) => Expect(Expectation.OkOrFail, caseSpecs); + + public AutobahnExpectations Expect(Expectation expectation, params string[] caseSpecs) + { + foreach (var caseSpec in caseSpecs) + { + _expectations[caseSpec] = expectation; + } + return this; + } + + internal void Verify(AutobahnServerResult serverResult, StringBuilder failures) + { + foreach (var caseResult in serverResult.Cases) + { + // If this is an informational test result, we can't compare it to anything + if (!string.Equals(caseResult.ActualBehavior, "INFORMATIONAL", StringComparison.Ordinal)) + { + Expectation expectation; + if (!_expectations.TryGetValue(caseResult.Name, out expectation)) + { + expectation = Expectation.Ok; + } + + switch (expectation) + { + case Expectation.Fail: + if (!caseResult.BehaviorIs("FAILED")) + { + failures.AppendLine($"Case {serverResult.Name}:{caseResult.Name}. Expected 'FAILED', but got '{caseResult.ActualBehavior}'"); + } + break; + case Expectation.NonStrict: + if (!caseResult.BehaviorIs("NON-STRICT")) + { + failures.AppendLine($"Case {serverResult.Name}:{caseResult.Name}. Expected 'NON-STRICT', but got '{caseResult.ActualBehavior}'"); + } + break; + case Expectation.Ok: + if (!caseResult.BehaviorIs("OK")) + { + failures.AppendLine($"Case {serverResult.Name}:{caseResult.Name}. Expected 'OK', but got '{caseResult.ActualBehavior}'"); + } + break; + case Expectation.OkOrNonStrict: + if (!caseResult.BehaviorIs("NON-STRICT") && !caseResult.BehaviorIs("OK")) + { + failures.AppendLine($"Case {serverResult.Name}:{caseResult.Name}. Expected 'NON-STRICT' or 'OK', but got '{caseResult.ActualBehavior}'"); + } + break; + case Expectation.OkOrFail: + if (!caseResult.BehaviorIs("FAILED") && !caseResult.BehaviorIs("OK")) + { + failures.AppendLine($"Case {serverResult.Name}:{caseResult.Name}. Expected 'FAILED' or 'OK', but got '{caseResult.ActualBehavior}'"); + } + break; + default: + break; + } + } + } + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnResult.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnResult.cs new file mode 100644 index 0000000000..d2533e32c8 --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnResult.cs @@ -0,0 +1,25 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Linq; +using Newtonsoft.Json.Linq; + +namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn +{ + public class AutobahnResult + { + public IEnumerable Servers { get; } + + public AutobahnResult(IEnumerable servers) + { + Servers = servers; + } + + public static AutobahnResult FromReportJson(JObject indexJson) + { + // Load the report + return new AutobahnResult(indexJson.Properties().Select(AutobahnServerResult.FromJson)); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnServerResult.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnServerResult.cs new file mode 100644 index 0000000000..ee5f53501c --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnServerResult.cs @@ -0,0 +1,40 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNetCore.Server.IntegrationTesting; +using Newtonsoft.Json.Linq; + +namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn +{ + public class AutobahnServerResult + { + public ServerType Server { get; } + public bool Ssl { get; } + public string Name { get; } + public IEnumerable Cases { get; } + + public AutobahnServerResult(string name, IEnumerable cases) + { + Name = name; + + var splat = name.Split('|'); + if (splat.Length < 2) + { + throw new FormatException("Results incorrectly formatted"); + } + + Server = (ServerType)Enum.Parse(typeof(ServerType), splat[0]); + Ssl = string.Equals(splat[1], "SSL", StringComparison.Ordinal); + Cases = cases; + } + + public static AutobahnServerResult FromJson(JProperty prop) + { + var valueObj = ((JObject)prop.Value); + return new AutobahnServerResult(prop.Name, valueObj.Properties().Select(AutobahnCaseResult.FromJson)); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnSpec.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnSpec.cs new file mode 100644 index 0000000000..2dbee13723 --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnSpec.cs @@ -0,0 +1,62 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; + +namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn +{ + public class AutobahnSpec + { + public string OutputDirectory { get; } + public IList Servers { get; } = new List(); + public IList Cases { get; } = new List(); + public IList ExcludedCases { get; } = new List(); + + public AutobahnSpec(string outputDirectory) + { + OutputDirectory = outputDirectory; + } + + public AutobahnSpec WithServer(string name, string url) + { + Servers.Add(new ServerSpec(name, url)); + return this; + } + + public AutobahnSpec IncludeCase(params string[] caseSpecs) + { + foreach (var caseSpec in caseSpecs) + { + Cases.Add(caseSpec); + } + return this; + } + + public AutobahnSpec ExcludeCase(params string[] caseSpecs) + { + foreach (var caseSpec in caseSpecs) + { + ExcludedCases.Add(caseSpec); + } + return this; + } + + public void WriteJson(string file) + { + File.WriteAllText(file, GetJson().ToString(Formatting.Indented)); + } + + public JObject GetJson() => new JObject( + new JProperty("options", new JObject( + new JProperty("failByDrop", false))), + new JProperty("outdir", OutputDirectory), + new JProperty("servers", new JArray(Servers.Select(s => s.GetJson()).ToArray())), + new JProperty("cases", new JArray(Cases.ToArray())), + new JProperty("exclude-cases", new JArray(ExcludedCases.ToArray())), + new JProperty("exclude-agent-cases", new JObject())); + } +} diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnTester.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnTester.cs new file mode 100644 index 0000000000..56ce818d5b --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/AutobahnTester.cs @@ -0,0 +1,152 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.IntegrationTesting; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json.Linq; +using Xunit; + +namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn +{ + public class AutobahnTester : IDisposable + { + private int _nextPort; + private readonly List _deployers = new List(); + private readonly List _expectations = new List(); + private readonly ILoggerFactory _loggerFactory; + private readonly ILogger _logger; + + public AutobahnSpec Spec { get; } + + public AutobahnTester(ILoggerFactory loggerFactory, AutobahnSpec baseSpec) : this(7000, loggerFactory, baseSpec) { } + + public AutobahnTester(int startPort, ILoggerFactory loggerFactory, AutobahnSpec baseSpec) + { + _nextPort = startPort; + _loggerFactory = loggerFactory; + _logger = _loggerFactory.CreateLogger("AutobahnTester"); + + Spec = baseSpec; + } + + public async Task Run() + { + var specFile = Path.GetTempFileName(); + try + { + Spec.WriteJson(specFile); + + // Run the test (write something to the console so people know this will take a while...) + _logger.LogInformation("Now launching Autobahn Test Suite. This will take a while."); + var exitCode = await Wstest.Default.ExecAsync("-m fuzzingclient -s " + specFile); + if (exitCode != 0) + { + throw new Exception("wstest failed"); + } + } + finally + { + if (File.Exists(specFile)) + { + File.Delete(specFile); + } + } + + // Parse the output. + var outputFile = Path.Combine(Directory.GetCurrentDirectory(), Spec.OutputDirectory, "index.json"); + using (var reader = new StreamReader(File.OpenRead(outputFile))) + { + return AutobahnResult.FromReportJson(JObject.Parse(await reader.ReadToEndAsync())); + } + } + + public void Verify(AutobahnResult result) + { + var failures = new StringBuilder(); + foreach (var serverResult in result.Servers) + { + var serverExpectation = _expectations.FirstOrDefault(e => e.Server == serverResult.Server && e.Ssl == serverResult.Ssl); + if (serverExpectation == null) + { + failures.AppendLine($"Expected no results for server: {serverResult.Name} but found results!"); + } + else + { + serverExpectation.Verify(serverResult, failures); + } + } + + Assert.True(failures.Length == 0, "Autobahn results did not meet expectations:" + Environment.NewLine + failures.ToString()); + } + + public async Task DeployTestAndAddToSpec(ServerType server, bool ssl, Action expectationConfig = null) + { + var port = Interlocked.Increment(ref _nextPort); + var baseUrl = ssl ? $"https://localhost:{port}" : $"http://localhost:{port}"; + var sslNamePart = ssl ? "SSL" : "NoSSL"; + var name = $"{server}|{sslNamePart}"; + var logger = _loggerFactory.CreateLogger($"AutobahnTestApp:{server}:{sslNamePart}"); + + var appPath = Helpers.GetApplicationPath("WebSocketsTestApp"); + var parameters = new DeploymentParameters(appPath, server, RuntimeFlavor.CoreClr, RuntimeArchitecture.x64) + { + ApplicationBaseUriHint = baseUrl, + ApplicationType = ApplicationType.Portable, + TargetFramework = "netcoreapp1.1", + EnvironmentName = "Development" + }; + + var deployer = ApplicationDeployerFactory.Create(parameters, logger); + var result = deployer.Deploy(); + result.HostShutdownToken.ThrowIfCancellationRequested(); + +#if NET451 + System.Net.ServicePointManager.ServerCertificateValidationCallback = (_, __, ___, ____) => true; + var client = new HttpClient(); +#else + var handler = new HttpClientHandler(); + if (ssl) + { + // Don't take this out of the "if(ssl)". If we set it on some platforms, it crashes + // So we avoid running SSL tests on those platforms (for now). + // See https://github.com/dotnet/corefx/issues/9728 + handler.ServerCertificateCustomValidationCallback = (_, __, ___, ____) => true; + } + var client = new HttpClient(handler); +#endif + + // Make sure the server works + var resp = await RetryHelper.RetryRequest(() => + { + return client.GetAsync(result.ApplicationBaseUri); + }, logger, result.HostShutdownToken, retryCount: 5); + resp.EnsureSuccessStatusCode(); + + // Add to the current spec + var wsUrl = result.ApplicationBaseUri.Replace("https://", "wss://").Replace("http://", "ws://"); + Spec.WithServer(name, wsUrl); + + _deployers.Add(deployer); + + var expectations = new AutobahnExpectations(server, ssl); + expectationConfig?.Invoke(expectations); + _expectations.Add(expectations); + } + + public void Dispose() + { + foreach (var deployer in _deployers) + { + deployer.Dispose(); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/Executable.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/Executable.cs new file mode 100644 index 0000000000..f91240e575 --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/Executable.cs @@ -0,0 +1,57 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn +{ + public class Executable + { + private static readonly string _exeSuffix = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? ".exe" : string.Empty; + + private readonly string _path; + + protected Executable(string path) + { + _path = path; + } + + public static string Locate(string name) + { + foreach (var dir in Environment.GetEnvironmentVariable("PATH").Split(Path.PathSeparator)) + { + var candidate = Path.Combine(dir, name + _exeSuffix); + if (File.Exists(candidate)) + { + return candidate; + } + } + return null; + } + + public Task ExecAsync(string args) + { + var process = new Process() + { + StartInfo = new ProcessStartInfo() + { + FileName = _path, + Arguments = args, + UseShellExecute = false, + }, + EnableRaisingEvents = true + }; + var tcs = new TaskCompletionSource(); + + process.Exited += (_, __) => tcs.TrySetResult(process.ExitCode); + + process.Start(); + + return tcs.Task; + } + } +} diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/Expectation.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/Expectation.cs new file mode 100644 index 0000000000..02e7fe8d2f --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/Expectation.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn +{ + public enum Expectation + { + Fail, + NonStrict, + OkOrFail, + Ok, + OkOrNonStrict + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/ServerSpec.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/ServerSpec.cs new file mode 100644 index 0000000000..ce898bd19b --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/ServerSpec.cs @@ -0,0 +1,25 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Newtonsoft.Json.Linq; + +namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn +{ + public class ServerSpec + { + public string Name { get; } + public string Url { get; } + + public ServerSpec(string name, string url) + { + Name = name; + Url = url; + } + + public JObject GetJson() => new JObject( + new JProperty("agent", Name), + new JProperty("url", Url), + new JProperty("options", new JObject( + new JProperty("version", 18)))); + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/Wstest.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/Wstest.cs new file mode 100644 index 0000000000..34dc03e417 --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Autobahn/Wstest.cs @@ -0,0 +1,25 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn +{ + /// + /// Wrapper around the Autobahn Test Suite's "wstest" app. + /// + public class Wstest : Executable + { + private static Lazy _instance = new Lazy(Create); + + public static Wstest Default => _instance.Value; + + public Wstest(string path) : base(path) { } + + private static Wstest Create() + { + var location = Locate("wstest"); + return location == null ? null : new Wstest(location); + } + } +} diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/AutobahnTests.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/AutobahnTests.cs new file mode 100644 index 0000000000..e90384cff9 --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/AutobahnTests.cs @@ -0,0 +1,83 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.IntegrationTesting; +using Microsoft.AspNetCore.Testing.xunit; +using Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.PlatformAbstractions; + +namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest +{ + public class AutobahnTests + { + [ConditionalFact] + [SkipIfWsTestNotPresent] + public async Task AutobahnTestSuite() + { + var reportDir = Environment.GetEnvironmentVariable("AUTOBAHN_SUITES_REPORT_DIR"); + var outDir = !string.IsNullOrEmpty(reportDir) ? + reportDir : + Path.Combine(PlatformServices.Default.Application.ApplicationBasePath, "autobahnreports"); + + if (Directory.Exists(outDir)) + { + Directory.Delete(outDir, recursive: true); + } + + outDir = outDir.Replace("\\", "\\\\"); + + // 9.* is Limits/Performance which is VERY SLOW; 12.*/13.* are compression which we don't implement + var spec = new AutobahnSpec(outDir) + .IncludeCase("*") + .ExcludeCase("9.*", "12.*", "13.*"); + + var loggerFactory = new LoggerFactory(); // No logging by default! It's very loud... + + if (string.Equals(Environment.GetEnvironmentVariable("AUTOBAHN_SUITES_LOG"), "1", StringComparison.Ordinal)) + { + loggerFactory.AddConsole(); + } + + AutobahnResult result; + using (var tester = new AutobahnTester(loggerFactory, spec)) + { + await tester.DeployTestAndAddToSpec(ServerType.Kestrel, ssl: false, expectationConfig: expect => expect + .NonStrict("6.4.3", "6.4.4")); + + result = await tester.Run(); + tester.Verify(result); + } + } + + private bool IsWindows8OrHigher() + { + const string WindowsName = "Microsoft Windows "; + const int VersionOffset = 18; + + if (RuntimeInformation.OSDescription.StartsWith(WindowsName)) + { + var versionStr = RuntimeInformation.OSDescription.Substring(VersionOffset); + Version version; + if (Version.TryParse(versionStr, out version)) + { + return version.Major > 6 || (version.Major == 6 && version.Minor >= 2); + } + } + + return false; + } + + private bool IsIISExpress10Installed() + { + var pf = Environment.GetEnvironmentVariable("PROGRAMFILES"); + var iisExpressExe = Path.Combine(pf, "IIS Express", "iisexpress.exe"); + return File.Exists(iisExpressExe) && FileVersionInfo.GetVersionInfo(iisExpressExe).FileMajorPart >= 10; + } + } +} diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Helpers.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Helpers.cs new file mode 100644 index 0000000000..029e61616b --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Helpers.cs @@ -0,0 +1,32 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using Microsoft.Extensions.PlatformAbstractions; + +namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest +{ + public class Helpers + { + public static string GetApplicationPath(string projectName) + { + var applicationBasePath = PlatformServices.Default.Application.ApplicationBasePath; + + var directoryInfo = new DirectoryInfo(applicationBasePath); + do + { + var solutionFileInfo = new FileInfo(Path.Combine(directoryInfo.FullName, "Microsoft.AspNetCore.Sockets.sln")); + if (solutionFileInfo.Exists) + { + return Path.GetFullPath(Path.Combine(directoryInfo.FullName, "test", projectName)); + } + + directoryInfo = directoryInfo.Parent; + } + while (directoryInfo.Parent != null); + + throw new Exception($"Solution root could not be found using {applicationBasePath}"); + } + } +} diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.xproj b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.xproj new file mode 100644 index 0000000000..a8a97e6a06 --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.xproj @@ -0,0 +1,18 @@ + + + + 14.0.25420 + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) + + + + 8cbc1c71-af0b-44e2-aee9-d8024c07634d + + + 2.0 + + + + + + \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Properties/AssemblyInfo.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..72534b1587 --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/Properties/AssemblyInfo.cs @@ -0,0 +1,18 @@ +using System.Reflection; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.AspNetCore.WebSockets.Server.Test")] +[assembly: AssemblyTrademark("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("e82d9f64-8afa-4dcb-a842-2283fda73be8")] diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/SkipIfWsTestNotPresentAttribute.cs b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/SkipIfWsTestNotPresentAttribute.cs new file mode 100644 index 0000000000..09a427b921 --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/SkipIfWsTestNotPresentAttribute.cs @@ -0,0 +1,16 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Testing.xunit; +using Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn; + +namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest +{ + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] + public class SkipIfWsTestNotPresentAttribute : Attribute, ITestCondition + { + public bool IsMet => Wstest.Default != null; + public string SkipReason => "Autobahn Test Suite is not installed on the host machine."; + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/project.json b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/project.json new file mode 100644 index 0000000000..526a646659 --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest/project.json @@ -0,0 +1,23 @@ +{ + "dependencies": { + "dotnet-test-xunit": "2.2.0-*", + "Microsoft.AspNetCore.Server.IntegrationTesting": "0.2.0-*", + "Microsoft.AspNetCore.Testing": "1.1.0-*", + "Microsoft.Extensions.Logging": "1.1.0-*", + "Microsoft.Extensions.Logging.Console": "1.1.0-*", + "Microsoft.Extensions.PlatformAbstractions": "1.1.0-*", + "System.Diagnostics.FileVersionInfo": "4.3.0-*", + "xunit": "2.2.0-*" + }, + "testRunner": "xunit", + "frameworks": { + "netcoreapp1.1": { + "dependencies": { + "Microsoft.NETCore.App": { + "version": "1.1.0-*", + "type": "platform" + } + } + } + } +} \ No newline at end of file diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/Microsoft.Extensions.WebSockets.Internal.Tests.xproj b/test/Microsoft.Extensions.WebSockets.Internal.Tests/Microsoft.Extensions.WebSockets.Internal.Tests.xproj index 5739b5f12a..5cb69c8b33 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/Microsoft.Extensions.WebSockets.Internal.Tests.xproj +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/Microsoft.Extensions.WebSockets.Internal.Tests.xproj @@ -7,9 +7,6 @@ a7050bae-3db9-4fb3-a49d-303201415b13 - Microsoft.Extensions.WebSockets.Internal.Tests - .\obj - .\bin\ 2.0 diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/TestUtil.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/TestUtil.cs new file mode 100644 index 0000000000..0d934a0d98 --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/TestUtil.cs @@ -0,0 +1,57 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.WebSockets.Internal.Tests +{ + internal static class TestUtil + { + private static readonly TimeSpan DefaultTimeout = TimeSpan.FromSeconds(1); + + public static Task OrTimeout(this Task task) => OrTimeout(task, DefaultTimeout); + public static Task OrTimeout(this Task task) => OrTimeout(task, DefaultTimeout); + + public static async Task OrTimeout(this Task task, TimeSpan timeout) + { + var completed = await Task.WhenAny(task, CreateTimeoutTask()); + Assert.Same(completed, task); + } + + public static async Task OrTimeout(this Task task, TimeSpan timeout) + { + var completed = await Task.WhenAny(task, CreateTimeoutTask()); + Assert.Same(task, completed); + return task.Result; + } + + public static Task CreateTimeoutTask() => CreateTimeoutTask(DefaultTimeout); + + public static Task CreateTimeoutTask(TimeSpan timeout) + { + var tcs = new TaskCompletionSource(); + CreateTimeoutToken(timeout).Register(() => tcs.TrySetCanceled()); + return tcs.Task; + } + + public static CancellationToken CreateTimeoutToken() => CreateTimeoutToken(DefaultTimeout); + + public static CancellationToken CreateTimeoutToken(TimeSpan timeout) + { + if (Debugger.IsAttached) + { + return CancellationToken.None; + } + else + { + var cts = new CancellationTokenSource(); + cts.CancelAfter(timeout); + return cts.Token; + } + } + } +} diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/Utf8ValidatorTests.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/Utf8ValidatorTests.cs new file mode 100644 index 0000000000..1aa4a50dcc --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/Utf8ValidatorTests.cs @@ -0,0 +1,134 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Linq; +using System.Text; +using Channels; +using Xunit; + +namespace Microsoft.Extensions.WebSockets.Internal.Tests +{ + public class Utf8ValidatorTests + { + [Theory] + [InlineData(new byte[] { 0x48, 0x65, 0x6C, 0x6C, 0x6F }, "Hello")] + [InlineData(new byte[] { 0xC2, 0xA7, 0x31, 0x2C, 0x20, 0x39, 0x35, 0xC2, 0xA2 }, "§1, 95¢")] + [InlineData(new byte[] { 0xE0, 0xA0, 0x80, 0xE0, 0xA4, 0x80 }, "\u0800\u0900")] + [InlineData(new byte[] { 0xF0, 0x90, 0x80, 0x80 }, "\U00010000")] + public void ValidSingleFramePayloads(byte[] payload, string decoded) + { + var validator = new Utf8Validator(); + Assert.True(validator.ValidateUtf8Frame(ReadableBuffer.Create(payload), fin: true)); + + // Not really part of the test, but it ensures that the "decoded" string matches the "payload", + // so that the "decoded" string can be used as a human-readable explanation of the string in question + Assert.Equal(decoded, Encoding.UTF8.GetString(payload)); + } + + [Theory] + [InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0x6C, 0x6C, 0x6F }, "Hello")] + + [InlineData(new byte[0], new byte[] { 0xC2, 0xA7 }, "§")] + [InlineData(new byte[] { 0xC2 }, new byte[] { 0xA7 }, "§")] + [InlineData(new byte[] { 0xC2, 0xA7 }, new byte[0], "§")] + + [InlineData(new byte[0], new byte[] { 0xC2, 0xA2 }, "¢")] + [InlineData(new byte[] { 0xC2 }, new byte[] { 0xA2 }, "¢")] + [InlineData(new byte[] { 0xC2, 0xA2 }, new byte[0], "¢")] + + [InlineData(new byte[0], new byte[] { 0xE0, 0xA0, 0x80 }, "\u0800")] + [InlineData(new byte[] { 0xE0 }, new byte[] { 0xA0, 0x80 }, "\u0800")] + [InlineData(new byte[] { 0xE0, 0xA0 }, new byte[] { 0x80 }, "\u0800")] + [InlineData(new byte[] { 0xE0, 0xA0, 0x80 }, new byte[0], "\u0800")] + + [InlineData(new byte[0], new byte[] { 0xE0, 0xA4, 0x80 }, "\u0900")] + [InlineData(new byte[] { 0xE0 }, new byte[] { 0xA4, 0x80 }, "\u0900")] + [InlineData(new byte[] { 0xE0, 0xA4 }, new byte[] { 0x80 }, "\u0900")] + [InlineData(new byte[] { 0xE0, 0xA4, 0x80 }, new byte[0], "\u0900")] + + [InlineData(new byte[0], new byte[] { 0xF0, 0x90, 0x80, 0x80 }, "\U00010000")] + [InlineData(new byte[] { 0xF0 }, new byte[] { 0x90, 0x80, 0x80 }, "\U00010000")] + [InlineData(new byte[] { 0xF0, 0x90 }, new byte[] { 0x80, 0x80 }, "\U00010000")] + [InlineData(new byte[] { 0xF0, 0x90, 0x80 }, new byte[] { 0x80 }, "\U00010000")] + [InlineData(new byte[] { 0xF0, 0x90, 0x80, 0x80 }, new byte[0], "\U00010000")] + public void ValidMultiFramePayloads(byte[] payload1, byte[] payload2, string decoded) + { + var validator = new Utf8Validator(); + Assert.True(validator.ValidateUtf8Frame(ReadableBuffer.Create(payload1), fin: false)); + Assert.True(validator.ValidateUtf8Frame(ReadableBuffer.Create(payload2), fin: true)); + + // Not really part of the test, but it ensures that the "decoded" string matches the "payload", + // so that the "decoded" string can be used as a human-readable explanation of the string in question + Assert.Equal(decoded, Encoding.UTF8.GetString(Enumerable.Concat(payload1, payload2).ToArray())); + } + + [Theory] + + // Continuation byte as first byte of code point + [InlineData(new byte[] { 0x48, 0x65, 0x80, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x48, 0x65, 0x99, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x48, 0x65, 0xAB, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x48, 0x65, 0xB0, 0x6C, 0x6F })] + + // Incomplete Code Point + [InlineData(new byte[] { 0xC2 })] + [InlineData(new byte[] { 0xE0 })] + [InlineData(new byte[] { 0xE0, 0xA0 })] + [InlineData(new byte[] { 0xE0, 0xA4 })] + [InlineData(new byte[] { 0xF0, 0x90, 0x80 })] + + // Overlong Encoding + + // 'H' (1 byte char) encoded with 2, 3 and 4 bytes + [InlineData(new byte[] { 0xC1, 0x88 })] + [InlineData(new byte[] { 0xE0, 0x81, 0x88 })] + [InlineData(new byte[] { 0xF0, 0x80, 0x81, 0x88 })] + + // '§' (2 byte char) encoded with 3 and 4 bytes + [InlineData(new byte[] { 0xE0, 0x82, 0xA7 })] + [InlineData(new byte[] { 0xF0, 0x80, 0x82, 0xA7 })] + + // '\u0800' (3 byte char) encoded with 4 bytes + [InlineData(new byte[] { 0xF0, 0x80, 0xA0, 0x80 })] + public void InvalidSingleFramePayloads(byte[] payload) + { + var validator = new Utf8Validator(); + Assert.False(validator.ValidateUtf8Frame(ReadableBuffer.Create(payload), fin: true)); + } + + [Theory] + + // Continuation byte as first byte of code point + [InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0x80, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0x99, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0xAB, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0xB0, 0x6C, 0x6F })] + + // Incomplete Code Point + [InlineData(new byte[] { 0xC2 }, new byte[0])] + [InlineData(new byte[] { 0xE0 }, new byte[0])] + [InlineData(new byte[] { 0xE0, 0xA0 }, new byte[0])] + [InlineData(new byte[] { 0xE0, 0xA4 }, new byte[0])] + [InlineData(new byte[] { 0xF0, 0x90, 0x80 }, new byte[0])] + + // Overlong Encoding + + // 'H' (1 byte char) encoded with 2, 3 and 4 bytes + [InlineData(new byte[] { 0xC1 }, new byte[] { 0x88 })] + [InlineData(new byte[] { 0xE0 }, new byte[] { 0x81, 0x88 })] + [InlineData(new byte[] { 0xF0 }, new byte[] { 0x80, 0x81, 0x88 })] + + // '§' (2 byte char) encoded with 3 and 4 bytes + [InlineData(new byte[] { 0xE0, 0x82 }, new byte[] { 0xA7 })] + [InlineData(new byte[] { 0xF0, 0x80 }, new byte[] { 0x82, 0xA7 })] + + // '\u0800' (3 byte char) encoded with 4 bytes + [InlineData(new byte[] { 0xF0, 0x80 }, new byte[] { 0xA0, 0x80 })] + public void InvalidMultiFramePayloads(byte[] payload1, byte[] payload2) + { + var validator = new Utf8Validator(); + Assert.True(validator.ValidateUtf8Frame(ReadableBuffer.Create(payload1), fin: false)); + Assert.False(validator.ValidateUtf8Frame(ReadableBuffer.Create(payload2), fin: true)); + } + } +} diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionExtensions.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionExtensions.cs index 199a5b6dae..d67930de12 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionExtensions.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionExtensions.cs @@ -1,4 +1,7 @@ -using System.Collections.Generic; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; using System.Threading.Tasks; using Channels; @@ -9,15 +12,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests public static async Task ExecuteAndCaptureFramesAsync(this IWebSocketConnection self) { var frames = new List(); - var closeResult = await self.ExecuteAsync(frame => - { - var buffer = new byte[frame.Payload.Length]; - frame.Payload.CopyTo(buffer); - frames.Add(new WebSocketFrame( - frame.EndOfMessage, - frame.Opcode, - ReadableBuffer.Create(buffer, 0, buffer.Length))); - }); + var closeResult = await self.ExecuteAsync(frame => frames.Add(frame.Copy())); return new WebSocketConnectionSummary(frames, closeResult); } } diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionSummary.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionSummary.cs index a488a51a27..31e0241114 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionSummary.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionSummary.cs @@ -1,4 +1,7 @@ -using System.Collections.Generic; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; namespace Microsoft.Extensions.WebSockets.Internal.Tests { diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs index 2b76655fb4..4eae600cd8 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs @@ -1,25 +1,23 @@ -using System; -using System.Diagnostics; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; using System.Text; -using System.Threading; using System.Threading.Tasks; +using Channels; using Xunit; namespace Microsoft.Extensions.WebSockets.Internal.Tests { public partial class WebSocketConnectionTests { - [Fact] - public async Task SendReceiveFrames() + public class ConnectionLifecycle { - using (var pair = WebSocketPair.Create()) + [Fact] + public async Task SendReceiveFrames() { - var cts = new CancellationTokenSource(); - if (!Debugger.IsAttached) - { - cts.CancelAfter(TimeSpan.FromSeconds(5)); - } - using (cts.Token.Register(() => pair.Dispose())) + using (var pair = WebSocketPair.Create()) { var client = pair.ClientSocket.ExecuteAsync(_ => { @@ -28,103 +26,144 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests }); // Send Frames - await pair.ClientSocket.SendAsync(CreateTextFrame("Hello")); - await pair.ClientSocket.SendAsync(CreateTextFrame("World")); - await pair.ClientSocket.SendAsync(CreateBinaryFrame(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF })); - await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)); + await pair.ClientSocket.SendAsync(CreateTextFrame("Hello")).OrTimeout(); + await pair.ClientSocket.SendAsync(CreateTextFrame("World")).OrTimeout(); + await pair.ClientSocket.SendAsync(CreateBinaryFrame(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF })).OrTimeout(); + await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)).OrTimeout(); - var summary = await pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + var summary = await pair.ServerSocket.ExecuteAndCaptureFramesAsync().OrTimeout(); Assert.Equal(3, summary.Received.Count); Assert.Equal("Hello", Encoding.UTF8.GetString(summary.Received[0].Payload.ToArray())); Assert.Equal("World", Encoding.UTF8.GetString(summary.Received[1].Payload.ToArray())); Assert.Equal(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }, summary.Received[2].Payload.ToArray()); - await pair.ServerSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)); - await client; + await pair.ServerSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)).OrTimeout(); + await client.OrTimeout(); } } - } - [Fact] - public async Task ExecuteReturnsWhenCloseFrameReceived() - { - using (var pair = WebSocketPair.Create()) + [Fact] + public async Task ExecuteReturnsWhenCloseFrameReceived() { - var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); - await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.InvalidMessageType, "Abc")); - var serverSummary = await pair.ServerSocket.ExecuteAndCaptureFramesAsync(); - await pair.ServerSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure, "Ok")); - var clientSummary = await client; + using (var pair = WebSocketPair.Create()) + { + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.InvalidMessageType, "Abc")).OrTimeout(); + var serverSummary = await pair.ServerSocket.ExecuteAndCaptureFramesAsync().OrTimeout(); + await pair.ServerSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure, "Ok")).OrTimeout(); + var clientSummary = await client.OrTimeout(); - Assert.Equal(0, serverSummary.Received.Count); - Assert.Equal(WebSocketCloseStatus.InvalidMessageType, serverSummary.CloseResult.Status); - Assert.Equal("Abc", serverSummary.CloseResult.Description); + Assert.Equal(0, serverSummary.Received.Count); + Assert.Equal(WebSocketCloseStatus.InvalidMessageType, serverSummary.CloseResult.Status); + Assert.Equal("Abc", serverSummary.CloseResult.Description); - Assert.Equal(0, clientSummary.Received.Count); - Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.Status); - Assert.Equal("Ok", clientSummary.CloseResult.Description); + Assert.Equal(0, clientSummary.Received.Count); + Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.Status); + Assert.Equal("Ok", clientSummary.CloseResult.Description); + } } - } - [Fact] - public async Task AbnormalTerminationOfInboundChannelCausesExecuteToThrow() - { - using (var pair = WebSocketPair.Create()) + [Fact] + public async Task AbnormalTerminationOfInboundChannelCausesExecuteToThrow() { - var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); - var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); - pair.TerminateFromClient(new InvalidOperationException("It broke!")); + using (var pair = WebSocketPair.Create()) + { + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + pair.TerminateFromClient(new InvalidOperationException("It broke!")); - await Assert.ThrowsAsync(() => server); + await Assert.ThrowsAsync(() => server); + } } - } - [Fact] - public async Task StateTransitions() - { - using (var pair = WebSocketPair.Create()) + [Fact] + public async Task StateTransitions() { - // Initial State - Assert.Equal(WebSocketConnectionState.Created, pair.ServerSocket.State); - Assert.Equal(WebSocketConnectionState.Created, pair.ClientSocket.State); + using (var pair = WebSocketPair.Create()) + { + // Initial State + Assert.Equal(WebSocketConnectionState.Created, pair.ServerSocket.State); + Assert.Equal(WebSocketConnectionState.Created, pair.ClientSocket.State); - // Start the sockets - var serverReceiving = new TaskCompletionSource(); - var clientReceiving = new TaskCompletionSource(); - var server = pair.ServerSocket.ExecuteAsync(frame => serverReceiving.TrySetResult(null)); - var client = pair.ClientSocket.ExecuteAsync(frame => clientReceiving.TrySetResult(null)); + // Start the sockets + var serverReceiving = new TaskCompletionSource(); + var clientReceiving = new TaskCompletionSource(); + var server = pair.ServerSocket.ExecuteAsync(frame => serverReceiving.TrySetResult(null)); + var client = pair.ClientSocket.ExecuteAsync(frame => clientReceiving.TrySetResult(null)); - // Send a frame from each and verify that the state transitioned. - // We need to do this because it's the only way to correctly wait for the state transition (which happens asynchronously in ExecuteAsync) - await pair.ClientSocket.SendAsync(CreateTextFrame("Hello")); - await pair.ServerSocket.SendAsync(CreateTextFrame("Hello")); + // Send a frame from each and verify that the state transitioned. + // We need to do this because it's the only way to correctly wait for the state transition (which happens asynchronously in ExecuteAsync) + await pair.ClientSocket.SendAsync(CreateTextFrame("Hello")).OrTimeout(); + await pair.ServerSocket.SendAsync(CreateTextFrame("Hello")).OrTimeout(); - await Task.WhenAll(serverReceiving.Task, clientReceiving.Task); + await Task.WhenAll(serverReceiving.Task, clientReceiving.Task).OrTimeout(); - // Check state - Assert.Equal(WebSocketConnectionState.Connected, pair.ServerSocket.State); - Assert.Equal(WebSocketConnectionState.Connected, pair.ClientSocket.State); + // Check state + Assert.Equal(WebSocketConnectionState.Connected, pair.ServerSocket.State); + Assert.Equal(WebSocketConnectionState.Connected, pair.ClientSocket.State); - // Close the server socket - await pair.ServerSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)); - await client; + // Close the server socket + await pair.ServerSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)).OrTimeout(); + await client.OrTimeout(); - // Check state - Assert.Equal(WebSocketConnectionState.CloseSent, pair.ServerSocket.State); - Assert.Equal(WebSocketConnectionState.CloseReceived, pair.ClientSocket.State); + // Check state + Assert.Equal(WebSocketConnectionState.CloseSent, pair.ServerSocket.State); + Assert.Equal(WebSocketConnectionState.CloseReceived, pair.ClientSocket.State); - // Close the client socket - await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)); - await server; + // Close the client socket + await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)).OrTimeout(); + await server.OrTimeout(); - // Check state - Assert.Equal(WebSocketConnectionState.Closed, pair.ServerSocket.State); - Assert.Equal(WebSocketConnectionState.Closed, pair.ClientSocket.State); + // Check state + Assert.Equal(WebSocketConnectionState.Closed, pair.ServerSocket.State); + Assert.Equal(WebSocketConnectionState.Closed, pair.ClientSocket.State); - // Verify we can't restart the connection or send a message - await Assert.ThrowsAsync(async () => await pair.ServerSocket.ExecuteAsync(f => { })); - await Assert.ThrowsAsync(async () => await pair.ClientSocket.SendAsync(CreateTextFrame("Nope"))); - await Assert.ThrowsAsync(async () => await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure))); + // Verify we can't restart the connection or send a message + await Assert.ThrowsAsync(async () => await pair.ServerSocket.ExecuteAsync(f => { })); + await Assert.ThrowsAsync(async () => await pair.ClientSocket.SendAsync(CreateTextFrame("Nope"))); + await Assert.ThrowsAsync(async () => await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure))); + } + } + + [Fact] + public async Task CanReceiveControlFrameInTheMiddleOfFragmentedMessage() + { + using (var pair = WebSocketPair.Create()) + { + // Start the sockets + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + + // Send (Fin=false, "Hello"), (Ping), (Fin=true, "World") + await pair.ClientSocket.SendAsync(new WebSocketFrame( + endOfMessage: false, + opcode: WebSocketOpcode.Text, + payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello")))); + await pair.ClientSocket.SendAsync(new WebSocketFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Ping, + payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("ping")))); + await pair.ClientSocket.SendAsync(new WebSocketFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Continuation, + payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("World")))); + + // Close the socket + await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure); + var serverSummary = await server; + await pair.ServerSocket.CloseAsync(WebSocketCloseStatus.NormalClosure); + var clientSummary = await client; + + // Assert + var nonControlFrames = serverSummary.Received.Where(f => f.Opcode < WebSocketOpcode.Close).ToList(); + Assert.Equal(2, nonControlFrames.Count); + Assert.False(nonControlFrames[0].EndOfMessage); + Assert.True(nonControlFrames[1].EndOfMessage); + Assert.Equal(WebSocketOpcode.Text, nonControlFrames[0].Opcode); + Assert.Equal(WebSocketOpcode.Continuation, nonControlFrames[1].Opcode); + Assert.Equal("Hello", Encoding.UTF8.GetString(nonControlFrames[0].Payload.ToArray())); + Assert.Equal("World", Encoding.UTF8.GetString(nonControlFrames[1].Payload.ToArray())); + } } } } diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.PingPong.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.PingPong.cs new file mode 100644 index 0000000000..7934ea81e7 --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.PingPong.cs @@ -0,0 +1,105 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Globalization; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Channels; +using Xunit; + +namespace Microsoft.Extensions.WebSockets.Internal.Tests +{ + public partial class WebSocketConnectionTests + { + public class PingPongBehavior + { + [Fact] + public async Task AutomaticPingTransmission() + { + var startTime = DateTime.UtcNow; + // Arrange + using (var pair = WebSocketPair.Create( + serverOptions: new WebSocketOptions().WithAllFramesPassedThrough().WithPingInterval(TimeSpan.FromMilliseconds(100)), + clientOptions: new WebSocketOptions().WithAllFramesPassedThrough())) + { + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + + // Act + // Wait for pings to be sent + await Task.Delay(200); + + await pair.ServerSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout(); + var clientSummary = await client.OrTimeout(); + await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout(); + var serverSummary = await server.OrTimeout(); + + // Assert + Assert.NotEqual(0, clientSummary.Received.Count); + + Assert.True(clientSummary.Received.All(f => f.EndOfMessage)); + Assert.True(clientSummary.Received.All(f => f.Opcode == WebSocketOpcode.Ping)); + Assert.True(clientSummary.Received.All(f => + { + var str = Encoding.UTF8.GetString(f.Payload.ToArray()); + + // We can't verify the exact timestamp, but we can verify that it is a timestamp created after we started. + DateTime dt; + if (DateTime.TryParseExact(str, "O", CultureInfo.InvariantCulture, DateTimeStyles.AdjustToUniversal, out dt)) + { + return dt >= startTime; + } + return false; + })); + } + } + + [Fact] + public async Task AutomaticPingResponse() + { + // Arrange + using (var pair = WebSocketPair.Create( + serverOptions: new WebSocketOptions().WithAllFramesPassedThrough(), + clientOptions: new WebSocketOptions().WithAllFramesPassedThrough())) + { + var payload = Encoding.UTF8.GetBytes("ping payload"); + + var pongTcs = new TaskCompletionSource(); + + var client = pair.ClientSocket.ExecuteAsync(f => + { + if (f.Opcode == WebSocketOpcode.Pong) + { + pongTcs.TrySetResult(f.Copy()); + } + else + { + Assert.False(true, "Received non-pong frame from server!"); + } + }); + var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + + // Act + await pair.ClientSocket.SendAsync(new WebSocketFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Ping, + payload: ReadableBuffer.Create(payload))); + + var pongFrame = await pongTcs.Task.OrTimeout(); + + await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout(); + await server.OrTimeout(); + await pair.ServerSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout(); + await client.OrTimeout(); + + // Assert + Assert.True(pongFrame.EndOfMessage); + Assert.Equal(WebSocketOpcode.Pong, pongFrame.Opcode); + Assert.Equal(payload, pongFrame.Payload.ToArray()); + } + } + } + } +} diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ProtocolErrors.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ProtocolErrors.cs new file mode 100644 index 0000000000..555aebb0fd --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ProtocolErrors.cs @@ -0,0 +1,248 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Text; +using System.Threading.Tasks; +using Channels; +using Xunit; + +namespace Microsoft.Extensions.WebSockets.Internal.Tests +{ + public partial class WebSocketConnectionTests + { + public class ProtocolErrors + { + [Theory] + [InlineData(new byte[] { 0x11, 0x00 })] + [InlineData(new byte[] { 0x21, 0x00 })] + [InlineData(new byte[] { 0x31, 0x00 })] + [InlineData(new byte[] { 0x41, 0x00 })] + [InlineData(new byte[] { 0x51, 0x00 })] + [InlineData(new byte[] { 0x61, 0x00 })] + [InlineData(new byte[] { 0x71, 0x00 })] + public Task TerminatesConnectionOnReservedBitSet(byte[] rawFrame) + { + return WriteFrameAndExpectClose(rawFrame, WebSocketCloseStatus.ProtocolError, "Reserved bits, which are required to be zero, were set."); + } + + [Theory] + [InlineData(0x03)] + [InlineData(0x04)] + [InlineData(0x05)] + [InlineData(0x06)] + [InlineData(0x07)] + [InlineData(0x0B)] + [InlineData(0x0C)] + [InlineData(0x0D)] + [InlineData(0x0E)] + [InlineData(0x0F)] + public Task ReservedOpcodes(byte opcode) + { + var payload = Encoding.UTF8.GetBytes("hello"); + var frame = new WebSocketFrame( + endOfMessage: true, + opcode: (WebSocketOpcode)opcode, + payload: ReadableBuffer.Create(payload)); + return SendFrameAndExpectClose(frame, WebSocketCloseStatus.ProtocolError, $"Received frame using reserved opcode: 0x{opcode:X}"); + } + + [Theory] + [InlineData(new byte[] { 0x88, 0x01, 0xAB })] + + // Invalid UTF-8 reason + [InlineData(new byte[] { 0x88, 0x07, 0x03, 0xE8, 0x48, 0x65, 0x80, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x88, 0x07, 0x03, 0xE8, 0x48, 0x65, 0x99, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x88, 0x07, 0x03, 0xE8, 0x48, 0x65, 0xAB, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x88, 0x07, 0x03, 0xE8, 0x48, 0x65, 0xB0, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x88, 0x03, 0x03, 0xE8, 0xC2 })] + [InlineData(new byte[] { 0x88, 0x03, 0x03, 0xE8, 0xE0 })] + [InlineData(new byte[] { 0x88, 0x04, 0x03, 0xE8, 0xE0, 0xA0 })] + [InlineData(new byte[] { 0x88, 0x04, 0x03, 0xE8, 0xE0, 0xA4 })] + [InlineData(new byte[] { 0x88, 0x05, 0x03, 0xE8, 0xF0, 0x90, 0x80 })] + [InlineData(new byte[] { 0x88, 0x04, 0x03, 0xE8, 0xC1, 0x88 })] + [InlineData(new byte[] { 0x88, 0x05, 0x03, 0xE8, 0xE0, 0x81, 0x88 })] + [InlineData(new byte[] { 0x88, 0x06, 0x03, 0xE8, 0xF0, 0x80, 0x81, 0x88 })] + [InlineData(new byte[] { 0x88, 0x05, 0x03, 0xE8, 0xE0, 0x82, 0xA7 })] + [InlineData(new byte[] { 0x88, 0x06, 0x03, 0xE8, 0xF0, 0x80, 0x82, 0xA7 })] + [InlineData(new byte[] { 0x88, 0x06, 0x03, 0xE8, 0xF0, 0x80, 0xA0, 0x80 })] + public Task InvalidCloseFrames(byte[] rawFrame) + { + return WriteFrameAndExpectClose(rawFrame, WebSocketCloseStatus.ProtocolError, "Close frame payload invalid"); + } + + [Fact] + public Task CloseFrameTooLong() + { + var rawFrame = new byte[256]; + new Random().NextBytes(rawFrame); + + // Put a WebSocket frame header in front + rawFrame[0] = 0x88; // Close frame, FIN=true + rawFrame[1] = 0x7E; // Mask=false, LEN=126 + rawFrame[2] = 0x00; // Extended Len = 252 (256 - 4 bytes for header) + rawFrame[3] = 0xFC; + + return WriteFrameAndExpectClose(rawFrame, WebSocketCloseStatus.ProtocolError, "Close frame payload too long. Maximum size is 125 bytes"); + } + + [Theory] + // 0-999 reserved + [InlineData(0)] + [InlineData(999)] + // Specifically reserved status codes, or codes that should not be sent in frames. + [InlineData(1004)] + [InlineData(1005)] + [InlineData(1006)] + [InlineData(1012)] + [InlineData(1013)] + [InlineData(1014)] + [InlineData(1015)] + // Undefined status codes + [InlineData(1016)] + [InlineData(1100)] + [InlineData(2000)] + [InlineData(2999)] + public Task InvalidCloseStatuses(ushort status) + { + var rawFrame = new byte[] { 0x88, 0x02, (byte)(status >> 8), (byte)(status) }; + return WriteFrameAndExpectClose(rawFrame, WebSocketCloseStatus.ProtocolError, $"Invalid close status: {status}."); + } + + [Theory] + [InlineData(new byte[] { 0x08, 0x00 })] + [InlineData(new byte[] { 0x09, 0x00 })] + [InlineData(new byte[] { 0x0A, 0x00 })] + public Task TerminatesConnectionOnFragmentedControlFrame(byte[] rawFrame) + { + return WriteFrameAndExpectClose(rawFrame, WebSocketCloseStatus.ProtocolError, "Control frames may not be fragmented"); + } + + [Fact] + public async Task TerminatesConnectionOnNonContinuationFrameFollowingFragmentedMessageStart() + { + // Arrange + using (var pair = WebSocketPair.Create( + serverOptions: new WebSocketOptions().WithAllFramesPassedThrough(), + clientOptions: new WebSocketOptions().WithAllFramesPassedThrough())) + { + var payload = Encoding.UTF8.GetBytes("hello"); + + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + + // Act + await pair.ClientSocket.SendAsync(new WebSocketFrame( + endOfMessage: false, + opcode: WebSocketOpcode.Text, + payload: ReadableBuffer.Create(payload))); + await pair.ClientSocket.SendAsync(new WebSocketFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Text, + payload: ReadableBuffer.Create(payload))); + + // Server should terminate + var clientSummary = await client.OrTimeout(); + + Assert.Equal(WebSocketCloseStatus.ProtocolError, clientSummary.CloseResult.Status); + Assert.Equal("Received non-continuation frame during a fragmented message", clientSummary.CloseResult.Description); + + await server.OrTimeout(); + } + } + + [Fact] + public async Task TerminatesConnectionOnUnsolicitedContinuationFrame() + { + // Arrange + using (var pair = WebSocketPair.Create( + serverOptions: new WebSocketOptions().WithAllFramesPassedThrough(), + clientOptions: new WebSocketOptions().WithAllFramesPassedThrough())) + { + var payload = Encoding.UTF8.GetBytes("hello"); + + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + + // Act + await pair.ClientSocket.SendAsync(new WebSocketFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Text, + payload: ReadableBuffer.Create(payload))); + await pair.ClientSocket.SendAsync(new WebSocketFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Continuation, + payload: ReadableBuffer.Create(payload))); + + // Server should terminate + var clientSummary = await client.OrTimeout(); + + Assert.Equal(WebSocketCloseStatus.ProtocolError, clientSummary.CloseResult.Status); + Assert.Equal("Continuation Frame was received when expecting a new message", clientSummary.CloseResult.Description); + + await server.OrTimeout(); + } + } + + [Fact] + public Task TerminatesConnectionOnPingFrameLargerThan125Bytes() + { + var payload = new byte[126]; + new Random().NextBytes(payload); + return SendFrameAndExpectClose( + new WebSocketFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Ping, + payload: ReadableBuffer.Create(payload)), + WebSocketCloseStatus.ProtocolError, + "Ping frame exceeded maximum size of 125 bytes"); + } + + private static async Task SendFrameAndExpectClose(WebSocketFrame frame, WebSocketCloseStatus closeStatus, string closeReason) + { + // Arrange + using (var pair = WebSocketPair.Create( + serverOptions: new WebSocketOptions().WithAllFramesPassedThrough(), + clientOptions: new WebSocketOptions().WithAllFramesPassedThrough())) + { + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + + // Act + await pair.ClientSocket.SendAsync(frame); + + // Server should terminate + var clientSummary = await client.OrTimeout(); + + Assert.Equal(closeStatus, clientSummary.CloseResult.Status); + Assert.Equal(closeReason, clientSummary.CloseResult.Description); + + await server.OrTimeout(); + } + } + + private static async Task WriteFrameAndExpectClose(byte[] rawFrame, WebSocketCloseStatus closeStatus, string closeReason) + { + // Arrange + using (var pair = WebSocketPair.Create( + serverOptions: new WebSocketOptions().WithAllFramesPassedThrough(), + clientOptions: new WebSocketOptions().WithAllFramesPassedThrough())) + { + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + + // Act + await pair.ClientToServer.WriteAsync(rawFrame); + + // Server should terminate + var clientSummary = await client.OrTimeout(); + + Assert.Equal(closeStatus, clientSummary.CloseResult.Status); + Assert.Equal(closeReason, clientSummary.CloseResult.Description); + + await server.OrTimeout(); + } + } + } + } +} diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ReceiveAsync.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ReceiveAsync.cs index cb9fa5f599..888ac53260 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ReceiveAsync.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ReceiveAsync.cs @@ -1,5 +1,7 @@ -using System; -using System.Diagnostics; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -13,12 +15,12 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests public class TheReceiveAsyncMethod { [Theory] - [InlineData(new byte[] { 0x11, 0x00 }, "", true)] - [InlineData(new byte[] { 0x11, 0x0A, 0x48, 0x65, 0x6C, 0x6C, 0x6F }, "Hello", true)] - [InlineData(new byte[] { 0x11, 0x0B, 0x1, 0x2, 0x3, 0x4, 0x48 ^ 0x1, 0x65 ^ 0x2, 0x6C ^ 0x3, 0x6C ^ 0x4, 0x6F ^ 0x1 }, "Hello", true)] - [InlineData(new byte[] { 0x10, 0x00 }, "", false)] - [InlineData(new byte[] { 0x10, 0x0A, 0x48, 0x65, 0x6C, 0x6C, 0x6F }, "Hello", false)] - [InlineData(new byte[] { 0x10, 0x0B, 0x1, 0x2, 0x3, 0x4, 0x48 ^ 0x1, 0x65 ^ 0x2, 0x6C ^ 0x3, 0x6C ^ 0x4, 0x6F ^ 0x1 }, "Hello", false)] + [InlineData(new byte[] { 0x81, 0x00 }, "", true)] + [InlineData(new byte[] { 0x81, 0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F }, "Hello", true)] + [InlineData(new byte[] { 0x81, 0x85, 0x1, 0x2, 0x3, 0x4, 0x48 ^ 0x1, 0x65 ^ 0x2, 0x6C ^ 0x3, 0x6C ^ 0x4, 0x6F ^ 0x1 }, "Hello", true)] + [InlineData(new byte[] { 0x01, 0x00 }, "", false)] + [InlineData(new byte[] { 0x01, 0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F }, "Hello", false)] + [InlineData(new byte[] { 0x01, 0x85, 0x1, 0x2, 0x3, 0x4, 0x48 ^ 0x1, 0x65 ^ 0x2, 0x6C ^ 0x3, 0x6C ^ 0x4, 0x6F ^ 0x1 }, "Hello", false)] public Task ReadTextFrames(byte[] rawFrame, string message, bool endOfMessage) { return RunSingleFrameTest( @@ -30,36 +32,24 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests [Theory] // Opcode = Binary - [InlineData(new byte[] { 0x21, 0x00 }, new byte[0], WebSocketOpcode.Binary, true)] - [InlineData(new byte[] { 0x21, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, true)] - [InlineData(new byte[] { 0x21, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, true)] - [InlineData(new byte[] { 0x20, 0x00 }, new byte[0], WebSocketOpcode.Binary, false)] - [InlineData(new byte[] { 0x20, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, false)] - [InlineData(new byte[] { 0x20, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, false)] - - // Opcode = Continuation - [InlineData(new byte[] { 0x01, 0x00 }, new byte[0], WebSocketOpcode.Continuation, true)] - [InlineData(new byte[] { 0x01, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Continuation, true)] - [InlineData(new byte[] { 0x01, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Continuation, true)] - [InlineData(new byte[] { 0x00, 0x00 }, new byte[0], WebSocketOpcode.Continuation, false)] - [InlineData(new byte[] { 0x00, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Continuation, false)] - [InlineData(new byte[] { 0x00, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Continuation, false)] + [InlineData(new byte[] { 0x82, 0x00 }, new byte[0], WebSocketOpcode.Binary, true)] + [InlineData(new byte[] { 0x82, 0x05, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, true)] + [InlineData(new byte[] { 0x82, 0x85, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, true)] + [InlineData(new byte[] { 0x02, 0x00 }, new byte[0], WebSocketOpcode.Binary, false)] + [InlineData(new byte[] { 0x02, 0x05, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, false)] + [InlineData(new byte[] { 0x02, 0x85, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, false)] // Opcode = Ping - [InlineData(new byte[] { 0x91, 0x00 }, new byte[0], WebSocketOpcode.Ping, true)] - [InlineData(new byte[] { 0x91, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Ping, true)] - [InlineData(new byte[] { 0x91, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Ping, true)] - [InlineData(new byte[] { 0x90, 0x00 }, new byte[0], WebSocketOpcode.Ping, false)] - [InlineData(new byte[] { 0x90, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Ping, false)] - [InlineData(new byte[] { 0x90, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Ping, false)] + [InlineData(new byte[] { 0x89, 0x00 }, new byte[0], WebSocketOpcode.Ping, true)] + [InlineData(new byte[] { 0x89, 0x05, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Ping, true)] + [InlineData(new byte[] { 0x89, 0x85, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Ping, true)] + // Control frames can't have fin=false // Opcode = Pong - [InlineData(new byte[] { 0xA1, 0x00 }, new byte[0], WebSocketOpcode.Pong, true)] - [InlineData(new byte[] { 0xA1, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Pong, true)] - [InlineData(new byte[] { 0xA1, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Pong, true)] - [InlineData(new byte[] { 0xA0, 0x00 }, new byte[0], WebSocketOpcode.Pong, false)] - [InlineData(new byte[] { 0xA0, 0x0A, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Pong, false)] - [InlineData(new byte[] { 0xA0, 0x0B, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Pong, false)] + [InlineData(new byte[] { 0x8A, 0x00 }, new byte[0], WebSocketOpcode.Pong, true)] + [InlineData(new byte[] { 0x8A, 0x05, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Pong, true)] + [InlineData(new byte[] { 0x8A, 0x85, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Pong, true)] + // Control frames can't have fin=false public Task ReadBinaryFormattedFrames(byte[] rawFrame, byte[] payload, WebSocketOpcode opcode, bool endOfMessage) { return RunSingleFrameTest( @@ -75,10 +65,14 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests var result = await RunReceiveTest( producer: async (channel, cancellationToken) => { - await channel.WriteAsync(new byte[] { 0x20, 0x0A }.Slice()); - await channel.WriteAsync(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB, 0x01, 0x0A }.Slice()); - await channel.WriteAsync(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }.Slice()); - await channel.WriteAsync(new byte[] { 0xAB }.Slice()); + await channel.WriteAsync(new byte[] { 0x02, 0x05 }.Slice()).OrTimeout(); + await Task.Yield(); + await channel.WriteAsync(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB, 0x80, 0x05 }.Slice()).OrTimeout(); + await Task.Yield(); + await channel.WriteAsync(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }.Slice()).OrTimeout(); + await Task.Yield(); + await channel.WriteAsync(new byte[] { 0xAB }.Slice()).OrTimeout(); + await Task.Yield(); }); Assert.Equal(2, result.Received.Count); @@ -92,6 +86,47 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests Assert.Equal(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, result.Received[1].Payload.ToArray()); } + [Fact] + public async Task ReadLargeMaskedPayload() + { + // This test was added to ensure we don't break a behavior discovered while running the Autobahn test suite. + + // Larger than one page, which means it will span blocks in the memory pool. + var expectedPayload = new byte[4192]; + for (int i = 0; i < expectedPayload.Length; i++) + { + expectedPayload[i] = (byte)(i % byte.MaxValue); + } + var maskingKey = new byte[] { 0x01, 0x02, 0x03, 0x04 }; + var sendPayload = new byte[4192]; + for (int i = 0; i < expectedPayload.Length; i++) + { + sendPayload[i] = (byte)(expectedPayload[i] ^ maskingKey[i % 4]); + } + + var result = await RunReceiveTest( + producer: async (channel, cancellationToken) => + { + // We use a 64-bit length because we want to ensure that the first page of data ends at an + // offset within the frame that is NOT divisible by 4. This ensures that when the unmasking + // moves from one buffer to the other, we are at a non-zero position within the masking key. + // This ensures that we're tracking the masking key offset properly. + + // Header: (Opcode=Binary, Fin=true), (Mask=false, Len=126), (64-bit big endian length) + await channel.WriteAsync(new byte[] { 0x82, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x60 }).OrTimeout(); + await channel.WriteAsync(maskingKey).OrTimeout(); + await Task.Yield(); + await channel.WriteAsync(sendPayload).OrTimeout(); + }); + + Assert.Equal(1, result.Received.Count); + + var frame = result.Received[0]; + Assert.True(frame.EndOfMessage); + Assert.Equal(WebSocketOpcode.Binary, frame.Opcode); + Assert.Equal(expectedPayload, frame.Payload.ToArray()); + } + [Fact] public async Task Read16BitPayloadLength() { @@ -102,8 +137,9 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests producer: async (channel, cancellationToken) => { // Header: (Opcode=Binary, Fin=true), (Mask=false, Len=126), (16-bit big endian length) - await channel.WriteAsync(new byte[] { 0x21, 0xFC, 0x04, 0x00 }); - await channel.WriteAsync(expectedPayload); + await channel.WriteAsync(new byte[] { 0x82, 0x7E, 0x04, 0x00 }).OrTimeout(); + await Task.Yield(); + await channel.WriteAsync(expectedPayload).OrTimeout(); }); Assert.Equal(1, result.Received.Count); @@ -125,8 +161,9 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests producer: async (channel, cancellationToken) => { // Header: (Opcode=Binary, Fin=true), (Mask=false, Len=127), (64-bit big endian length) - await channel.WriteAsync(new byte[] { 0x21, 0xFE, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00 }); - await channel.WriteAsync(expectedPayload); + await channel.WriteAsync(new byte[] { 0x82, 0x7F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00 }).OrTimeout(); + await Task.Yield(); + await channel.WriteAsync(expectedPayload).OrTimeout(); }); Assert.Equal(1, result.Received.Count); @@ -142,7 +179,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests var result = await RunReceiveTest( producer: async (channel, cancellationToken) => { - await channel.WriteAsync(rawFrame.Slice()); + await channel.WriteAsync(rawFrame.Slice()).OrTimeout(); }); var frames = result.Received; Assert.Equal(1, frames.Count); @@ -153,43 +190,36 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests Assert.Equal(expectedOpcode, frame.Opcode); payloadAssert(frame.Payload.ToArray()); } + } - private static async Task RunReceiveTest(Func producer) + private static async Task RunReceiveTest(Func producer) + { + using (var factory = new ChannelFactory()) { - using (var factory = new ChannelFactory()) + var outbound = factory.CreateChannel(); + var inbound = factory.CreateChannel(); + + var timeoutToken = TestUtil.CreateTimeoutToken(); + + var producerTask = Task.Run(async () => { - var outbound = factory.CreateChannel(); - var inbound = factory.CreateChannel(); + await producer(inbound, timeoutToken).OrTimeout(); + inbound.CompleteWriter(); + }, timeoutToken); - var cts = new CancellationTokenSource(); - var cancellationToken = cts.Token; - - // Timeout for the test, but only if the debugger is not attached. - if (!Debugger.IsAttached) + var consumerTask = Task.Run(async () => + { + var connection = new WebSocketConnection(inbound, outbound, options: new WebSocketOptions().WithAllFramesPassedThrough()); + using (timeoutToken.Register(() => connection.Dispose())) + using (connection) { - cts.CancelAfter(TimeSpan.FromSeconds(5)); + // Receive frames until we're closed + return await connection.ExecuteAndCaptureFramesAsync().OrTimeout(); } + }, timeoutToken); - var producerTask = Task.Run(async () => - { - await producer(inbound, cancellationToken); - inbound.CompleteWriter(); - }, cancellationToken); - - var consumerTask = Task.Run(async () => - { - var connection = new WebSocketConnection(inbound, outbound); - using (cancellationToken.Register(() => connection.Dispose())) - using (connection) - { - // Receive frames until we're closed - return await connection.ExecuteAndCaptureFramesAsync(); - } - }, cancellationToken); - - await Task.WhenAll(producerTask, consumerTask); - return consumerTask.Result; - } + await Task.WhenAll(producerTask, consumerTask); + return consumerTask.Result; } } diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs index 0d5723de69..f376183124 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs @@ -1,7 +1,8 @@ -using System; -using System.Diagnostics; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; using System.Text; -using System.Threading; using System.Threading.Tasks; using Channels; using Xunit; @@ -12,113 +13,116 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests { public class TheSendAsyncMethod { + // No auto-pinging for us! + private readonly static WebSocketOptions DefaultTestOptions = new WebSocketOptions().WithAllFramesPassedThrough(); + [Theory] - [InlineData("", true, new byte[] { 0x11, 0x00 })] - [InlineData("Hello", true, new byte[] { 0x11, 0x0A, 0x48, 0x65, 0x6C, 0x6C, 0x6F })] - [InlineData("", false, new byte[] { 0x10, 0x00 })] - [InlineData("Hello", false, new byte[] { 0x10, 0x0A, 0x48, 0x65, 0x6C, 0x6C, 0x6F })] + [InlineData("", true, new byte[] { 0x81, 0x00 })] + [InlineData("Hello", true, new byte[] { 0x81, 0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F })] + [InlineData("", false, new byte[] { 0x01, 0x00 })] + [InlineData("Hello", false, new byte[] { 0x01, 0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F })] public async Task WriteTextFrames(string message, bool endOfMessage, byte[] expectedRawFrame) { var data = await RunSendTest( - producer: async (socket, cancellationToken) => + producer: async (socket) => { var payload = Encoding.UTF8.GetBytes(message); await socket.SendAsync(CreateFrame( endOfMessage, opcode: WebSocketOpcode.Text, - payload: payload)); - }, masked: false); + payload: payload)).OrTimeout(); + }, options: DefaultTestOptions); Assert.Equal(expectedRawFrame, data); } [Theory] // Opcode = Binary - [InlineData(new byte[0], WebSocketOpcode.Binary, true, new byte[] { 0x21, 0x00 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, true, new byte[] { 0x21, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] - [InlineData(new byte[0], WebSocketOpcode.Binary, false, new byte[] { 0x20, 0x00 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, false, new byte[] { 0x20, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] + [InlineData(new byte[0], WebSocketOpcode.Binary, true, new byte[] { 0x82, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, true, new byte[] { 0x82, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })] + [InlineData(new byte[0], WebSocketOpcode.Binary, false, new byte[] { 0x02, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, false, new byte[] { 0x02, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })] // Opcode = Continuation - [InlineData(new byte[0], WebSocketOpcode.Continuation, true, new byte[] { 0x01, 0x00 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, true, new byte[] { 0x01, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] + [InlineData(new byte[0], WebSocketOpcode.Continuation, true, new byte[] { 0x80, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, true, new byte[] { 0x80, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })] [InlineData(new byte[0], WebSocketOpcode.Continuation, false, new byte[] { 0x00, 0x00 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, false, new byte[] { 0x00, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, false, new byte[] { 0x00, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })] // Opcode = Ping - [InlineData(new byte[0], WebSocketOpcode.Ping, true, new byte[] { 0x91, 0x00 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, true, new byte[] { 0x91, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] - [InlineData(new byte[0], WebSocketOpcode.Ping, false, new byte[] { 0x90, 0x00 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, false, new byte[] { 0x90, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] + [InlineData(new byte[0], WebSocketOpcode.Ping, true, new byte[] { 0x89, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, true, new byte[] { 0x89, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })] + [InlineData(new byte[0], WebSocketOpcode.Ping, false, new byte[] { 0x09, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, false, new byte[] { 0x09, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })] // Opcode = Pong - [InlineData(new byte[0], WebSocketOpcode.Pong, true, new byte[] { 0xA1, 0x00 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, true, new byte[] { 0xA1, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] - [InlineData(new byte[0], WebSocketOpcode.Pong, false, new byte[] { 0xA0, 0x00 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, false, new byte[] { 0xA0, 0x0A, 0xA, 0xB, 0xC, 0xD, 0xE })] + [InlineData(new byte[0], WebSocketOpcode.Pong, true, new byte[] { 0x8A, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, true, new byte[] { 0x8A, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })] + [InlineData(new byte[0], WebSocketOpcode.Pong, false, new byte[] { 0x0A, 0x00 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, false, new byte[] { 0x0A, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })] public async Task WriteBinaryFormattedFrames(byte[] payload, WebSocketOpcode opcode, bool endOfMessage, byte[] expectedRawFrame) { var data = await RunSendTest( - producer: async (socket, cancellationToken) => + producer: async (socket) => { await socket.SendAsync(CreateFrame( endOfMessage, opcode, - payload: payload)); - }, masked: false); + payload: payload)).OrTimeout(); + }, options: DefaultTestOptions); Assert.Equal(expectedRawFrame, data); } [Theory] - [InlineData("", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x11, 0x01, 0x01, 0x02, 0x03, 0x04 })] - [InlineData("Hello", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x11, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x48 ^ 0x01, 0x65 ^ 0x02, 0x6C ^ 0x03, 0x6C ^ 0x04, 0x6F ^ 0x01 })] + [InlineData("", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x81, 0x80, 0x01, 0x02, 0x03, 0x04 })] + [InlineData("Hello", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x81, 0x85, 0x01, 0x02, 0x03, 0x04, 0x48 ^ 0x01, 0x65 ^ 0x02, 0x6C ^ 0x03, 0x6C ^ 0x04, 0x6F ^ 0x01 })] public async Task WriteMaskedTextFrames(string message, byte[] maskingKey, byte[] expectedRawFrame) { var data = await RunSendTest( - producer: async (socket, cancellationToken) => + producer: async (socket) => { var payload = Encoding.UTF8.GetBytes(message); await socket.SendAsync(CreateFrame( endOfMessage: true, opcode: WebSocketOpcode.Text, - payload: payload)); - }, maskingKey: maskingKey); + payload: payload)).OrTimeout(); + }, options: DefaultTestOptions.WithFixedMaskingKey(maskingKey)); Assert.Equal(expectedRawFrame, data); } [Theory] // Opcode = Binary - [InlineData(new byte[0], WebSocketOpcode.Binary, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x21, 0x01, 0x01, 0x02, 0x03, 0x04 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x21, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] - [InlineData(new byte[0], WebSocketOpcode.Binary, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x20, 0x01, 0x01, 0x02, 0x03, 0x04 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x20, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + [InlineData(new byte[0], WebSocketOpcode.Binary, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x82, 0x80, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x82, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + [InlineData(new byte[0], WebSocketOpcode.Binary, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x02, 0x80, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x02, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] // Opcode = Continuation - [InlineData(new byte[0], WebSocketOpcode.Continuation, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x01, 0x01, 0x01, 0x02, 0x03, 0x04 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x01, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] - [InlineData(new byte[0], WebSocketOpcode.Continuation, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x00, 0x01, 0x01, 0x02, 0x03, 0x04 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x00, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + [InlineData(new byte[0], WebSocketOpcode.Continuation, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x80, 0x80, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x80, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + [InlineData(new byte[0], WebSocketOpcode.Continuation, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x00, 0x80, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x00, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] // Opcode = Ping - [InlineData(new byte[0], WebSocketOpcode.Ping, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x91, 0x01, 0x01, 0x02, 0x03, 0x04 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x91, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] - [InlineData(new byte[0], WebSocketOpcode.Ping, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x90, 0x01, 0x01, 0x02, 0x03, 0x04 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x90, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + [InlineData(new byte[0], WebSocketOpcode.Ping, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x89, 0x80, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x89, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + [InlineData(new byte[0], WebSocketOpcode.Ping, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x09, 0x80, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x09, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] // Opcode = Pong - [InlineData(new byte[0], WebSocketOpcode.Pong, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0xA1, 0x01, 0x01, 0x02, 0x03, 0x04 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0xA1, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] - [InlineData(new byte[0], WebSocketOpcode.Pong, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0xA0, 0x01, 0x01, 0x02, 0x03, 0x04 })] - [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0xA0, 0x0B, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + [InlineData(new byte[0], WebSocketOpcode.Pong, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x8A, 0x80, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x8A, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] + [InlineData(new byte[0], WebSocketOpcode.Pong, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x0A, 0x80, 0x01, 0x02, 0x03, 0x04 })] + [InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x0A, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })] public async Task WriteMaskedBinaryFormattedFrames(byte[] payload, WebSocketOpcode opcode, bool endOfMessage, byte[] maskingKey, byte[] expectedRawFrame) { var data = await RunSendTest( - producer: async (socket, cancellationToken) => + producer: async (socket) => { await socket.SendAsync(CreateFrame( endOfMessage, opcode, - payload: payload)); - }, maskingKey: maskingKey); + payload: payload)).OrTimeout(); + }, options: DefaultTestOptions.WithFixedMaskingKey(maskingKey)); Assert.Equal(expectedRawFrame, data); } @@ -126,17 +130,17 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests public async Task WriteRandomMaskedFrame() { var data = await RunSendTest( - producer: async (socket, cancellationToken) => + producer: async (socket) => { await socket.SendAsync(CreateFrame( endOfMessage: true, opcode: WebSocketOpcode.Binary, - payload: new byte[] { 0x0A, 0x0B, 0x0C, 0x0D, 0x0E })); - }, masked: true); + payload: new byte[] { 0x0A, 0x0B, 0x0C, 0x0D, 0x0E })).OrTimeout(); + }, options: DefaultTestOptions.WithRandomMasking()); // Verify the header - Assert.Equal(0x21, data[0]); - Assert.Equal(0x0B, data[1]); + Assert.Equal(0x82, data[0]); + Assert.Equal(0x85, data[1]); // We don't know the mask, so we have to read it in order to verify this frame var mask = data.Slice(2, 4); @@ -151,57 +155,44 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests } [Theory] - [InlineData(WebSocketCloseStatus.MandatoryExtension, "Hi", null, new byte[] { 0x81, 0x08, 0x03, 0xF2, (byte)'H', (byte)'i' })] - [InlineData(WebSocketCloseStatus.PolicyViolation, "", null, new byte[] { 0x81, 0x04, 0x03, 0xF0 })] - [InlineData(WebSocketCloseStatus.MandatoryExtension, "Hi", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x81, 0x09, 0x01, 0x02, 0x03, 0x04, 0x03 ^ 0x01, 0xF2 ^ 0x02, (byte)'H' ^ 0x03, (byte)'i' ^ 0x04 })] - [InlineData(WebSocketCloseStatus.PolicyViolation, "", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x81, 0x05, 0x01, 0x02, 0x03, 0x04, 0x03 ^ 0x01, 0xF0 ^ 0x02 })] + [InlineData(WebSocketCloseStatus.MandatoryExtension, "Hi", null, new byte[] { 0x88, 0x04, 0x03, 0xF2, (byte)'H', (byte)'i' })] + [InlineData(WebSocketCloseStatus.PolicyViolation, "", null, new byte[] { 0x88, 0x02, 0x03, 0xF0 })] + [InlineData(WebSocketCloseStatus.MandatoryExtension, "Hi", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x88, 0x84, 0x01, 0x02, 0x03, 0x04, 0x03 ^ 0x01, 0xF2 ^ 0x02, (byte)'H' ^ 0x03, (byte)'i' ^ 0x04 })] + [InlineData(WebSocketCloseStatus.PolicyViolation, "", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x88, 0x82, 0x01, 0x02, 0x03, 0x04, 0x03 ^ 0x01, 0xF0 ^ 0x02 })] public async Task WriteCloseFrames(WebSocketCloseStatus status, string description, byte[] maskingKey, byte[] expectedRawFrame) { var data = await RunSendTest( - producer: async (socket, cancellationToken) => + producer: async (socket) => { - await socket.CloseAsync(new WebSocketCloseResult(status, description)); - }, maskingKey: maskingKey); + await socket.CloseAsync(new WebSocketCloseResult(status, description)).OrTimeout(); + }, options: maskingKey == null ? DefaultTestOptions : DefaultTestOptions.WithFixedMaskingKey(maskingKey)); Assert.Equal(expectedRawFrame, data); } - private static async Task RunSendTest(Func producer, bool masked = false, byte[] maskingKey = null) + private static async Task RunSendTest(Func producer, WebSocketOptions options) { using (var factory = new ChannelFactory()) { var outbound = factory.CreateChannel(); var inbound = factory.CreateChannel(); - var cts = new CancellationTokenSource(); - - // Timeout for the test, but only if the debugger is not attached. - if (!Debugger.IsAttached) + Task executeTask; + using (var connection = new WebSocketConnection(inbound, outbound, options)) { - cts.CancelAfter(TimeSpan.FromSeconds(5)); - } - - var cancellationToken = cts.Token; - using (cancellationToken.Register(() => CompleteChannels(inbound, outbound))) - { - - Task executeTask; - using (var connection = CreateConnection(inbound, outbound, masked, maskingKey)) + executeTask = connection.ExecuteAsync(f => { - executeTask = connection.ExecuteAsync(f => - { - Assert.False(true, "Did not expect to receive any messages"); - return Task.CompletedTask; - }); - await producer(connection, cancellationToken); - inbound.CompleteWriter(); - await executeTask; - } - - var data = (await outbound.ReadToEndAsync()).ToArray(); - inbound.CompleteReader(); - CompleteChannels(outbound); - return data; + Assert.False(true, "Did not expect to receive any messages"); + return Task.CompletedTask; + }); + await producer(connection).OrTimeout(); + inbound.CompleteWriter(); + await executeTask.OrTimeout(); } + + var data = (await outbound.ReadToEndAsync()).ToArray(); + inbound.CompleteReader(); + CompleteChannels(outbound); + return data; } } @@ -213,13 +204,6 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests channel.CompleteWriter(); } } - - private static WebSocketConnection CreateConnection(Channel inbound, Channel outbound, bool masked, byte[] maskingKey) - { - return (maskingKey != null) ? - new WebSocketConnection(inbound, outbound, fixedMaskingKey: maskingKey) : - new WebSocketConnection(inbound, outbound, masked); - } } } } diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.Utf8Validation.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.Utf8Validation.cs new file mode 100644 index 0000000000..3caa3984d5 --- /dev/null +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.Utf8Validation.cs @@ -0,0 +1,226 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Channels; +using Xunit; + +namespace Microsoft.Extensions.WebSockets.Internal.Tests +{ + public partial class WebSocketConnectionTests + { + public class Utf8Validation + { + [Theory] + [InlineData(new byte[] { 0x48, 0x65, 0x6C, 0x6C, 0x6F }, "Hello")] + [InlineData(new byte[] { 0xC2, 0xA7, 0x31, 0x2C, 0x20, 0x39, 0x35, 0xC2, 0xA2 }, "§1, 95¢")] + [InlineData(new byte[] { 0xE0, 0xA0, 0x80, 0xE0, 0xA4, 0x80 }, "\u0800\u0900")] + [InlineData(new byte[] { 0xF0, 0x90, 0x80, 0x80 }, "\U00010000")] + public async Task ValidSingleFramePayloads(byte[] payload, string decoded) + { + using (var pair = WebSocketPair.Create()) + { + var timeoutToken = TestUtil.CreateTimeoutToken(); + using (timeoutToken.Register(() => pair.Dispose())) + { + var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + + var frame = new WebSocketFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Text, + payload: ReadableBuffer.Create(payload)); + await pair.ClientSocket.SendAsync(frame).OrTimeout(); + await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout(); + var serverSummary = await server.OrTimeout(); + await pair.ServerSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout(); + var clientSummary = await client.OrTimeout(); + + Assert.Equal(0, clientSummary.Received.Count); + + Assert.Equal(1, serverSummary.Received.Count); + Assert.True(serverSummary.Received[0].EndOfMessage); + Assert.Equal(WebSocketOpcode.Text, serverSummary.Received[0].Opcode); + Assert.Equal(decoded, Encoding.UTF8.GetString(serverSummary.Received[0].Payload.ToArray())); + } + } + } + + [Theory] + [InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0x6C, 0x6C, 0x6F }, "Hello")] + + [InlineData(new byte[0], new byte[] { 0xC2, 0xA7 }, "§")] + [InlineData(new byte[] { 0xC2 }, new byte[] { 0xA7 }, "§")] + [InlineData(new byte[] { 0xC2, 0xA7 }, new byte[0], "§")] + + [InlineData(new byte[0], new byte[] { 0xC2, 0xA2 }, "¢")] + [InlineData(new byte[] { 0xC2 }, new byte[] { 0xA2 }, "¢")] + [InlineData(new byte[] { 0xC2, 0xA2 }, new byte[0], "¢")] + + [InlineData(new byte[0], new byte[] { 0xE0, 0xA0, 0x80 }, "\u0800")] + [InlineData(new byte[] { 0xE0 }, new byte[] { 0xA0, 0x80 }, "\u0800")] + [InlineData(new byte[] { 0xE0, 0xA0 }, new byte[] { 0x80 }, "\u0800")] + [InlineData(new byte[] { 0xE0, 0xA0, 0x80 }, new byte[0], "\u0800")] + + [InlineData(new byte[0], new byte[] { 0xE0, 0xA4, 0x80 }, "\u0900")] + [InlineData(new byte[] { 0xE0 }, new byte[] { 0xA4, 0x80 }, "\u0900")] + [InlineData(new byte[] { 0xE0, 0xA4 }, new byte[] { 0x80 }, "\u0900")] + [InlineData(new byte[] { 0xE0, 0xA4, 0x80 }, new byte[0], "\u0900")] + + [InlineData(new byte[0], new byte[] { 0xF0, 0x90, 0x80, 0x80 }, "\U00010000")] + [InlineData(new byte[] { 0xF0 }, new byte[] { 0x90, 0x80, 0x80 }, "\U00010000")] + [InlineData(new byte[] { 0xF0, 0x90 }, new byte[] { 0x80, 0x80 }, "\U00010000")] + [InlineData(new byte[] { 0xF0, 0x90, 0x80 }, new byte[] { 0x80 }, "\U00010000")] + [InlineData(new byte[] { 0xF0, 0x90, 0x80, 0x80 }, new byte[0], "\U00010000")] + public async Task ValidMultiFramePayloads(byte[] payload1, byte[] payload2, string decoded) + { + using (var pair = WebSocketPair.Create()) + { + var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + + var frame = new WebSocketFrame( + endOfMessage: false, + opcode: WebSocketOpcode.Text, + payload: ReadableBuffer.Create(payload1)); + await pair.ClientSocket.SendAsync(frame).OrTimeout(); + frame = new WebSocketFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Continuation, + payload: ReadableBuffer.Create(payload2)); + await pair.ClientSocket.SendAsync(frame).OrTimeout(); + await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout(); + var serverSummary = await server.OrTimeout(); + await pair.ServerSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout(); + var clientSummary = await client.OrTimeout(); + + Assert.Equal(0, clientSummary.Received.Count); + + Assert.Equal(2, serverSummary.Received.Count); + Assert.False(serverSummary.Received[0].EndOfMessage); + Assert.Equal(WebSocketOpcode.Text, serverSummary.Received[0].Opcode); + Assert.True(serverSummary.Received[1].EndOfMessage); + Assert.Equal(WebSocketOpcode.Continuation, serverSummary.Received[1].Opcode); + + var finalPayload = serverSummary.Received.SelectMany(f => f.Payload.ToArray()).ToArray(); + Assert.Equal(decoded, Encoding.UTF8.GetString(finalPayload)); + } + } + + [Theory] + + // Continuation byte as first byte of code point + [InlineData(new byte[] { 0x48, 0x65, 0x80, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x48, 0x65, 0x99, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x48, 0x65, 0xAB, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x48, 0x65, 0xB0, 0x6C, 0x6F })] + + // Incomplete Code Point + [InlineData(new byte[] { 0xC2 })] + [InlineData(new byte[] { 0xE0 })] + [InlineData(new byte[] { 0xE0, 0xA0 })] + [InlineData(new byte[] { 0xE0, 0xA4 })] + [InlineData(new byte[] { 0xF0, 0x90, 0x80 })] + + // Overlong Encoding + + // 'H' (1 byte char) encoded with 2, 3 and 4 bytes + [InlineData(new byte[] { 0xC1, 0x88 })] + [InlineData(new byte[] { 0xE0, 0x81, 0x88 })] + [InlineData(new byte[] { 0xF0, 0x80, 0x81, 0x88 })] + + // '§' (2 byte char) encoded with 3 and 4 bytes + [InlineData(new byte[] { 0xE0, 0x82, 0xA7 })] + [InlineData(new byte[] { 0xF0, 0x80, 0x82, 0xA7 })] + + // '\u0800' (3 byte char) encoded with 4 bytes + [InlineData(new byte[] { 0xF0, 0x80, 0xA0, 0x80 })] + public async Task InvalidSingleFramePayloads(byte[] payload) + { + using (var pair = WebSocketPair.Create()) + { + var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + + var frame = new WebSocketFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Text, + payload: ReadableBuffer.Create(payload)); + await pair.ClientSocket.SendAsync(frame).OrTimeout(); + var clientSummary = await client.OrTimeout(); + var serverSummary = await server.OrTimeout(); + + Assert.Equal(0, serverSummary.Received.Count); + Assert.Equal(0, clientSummary.Received.Count); + Assert.Equal(WebSocketCloseStatus.InvalidPayloadData, clientSummary.CloseResult.Status); + Assert.Equal("An invalid Text frame payload was received", clientSummary.CloseResult.Description); + } + } + + [Theory] + + // Continuation byte as first byte of code point + [InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0x80, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0x99, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0xAB, 0x6C, 0x6F })] + [InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0xB0, 0x6C, 0x6F })] + + // Incomplete Code Point + [InlineData(new byte[] { 0xC2 }, new byte[0])] + [InlineData(new byte[] { 0xE0 }, new byte[0])] + [InlineData(new byte[] { 0xE0, 0xA0 }, new byte[0])] + [InlineData(new byte[] { 0xE0, 0xA4 }, new byte[0])] + [InlineData(new byte[] { 0xF0, 0x90, 0x80 }, new byte[0])] + + // Overlong Encoding + + // 'H' (1 byte char) encoded with 2, 3 and 4 bytes + [InlineData(new byte[] { 0xC1 }, new byte[] { 0x88 })] + [InlineData(new byte[] { 0xE0 }, new byte[] { 0x81, 0x88 })] + [InlineData(new byte[] { 0xF0 }, new byte[] { 0x80, 0x81, 0x88 })] + + // '§' (2 byte char) encoded with 3 and 4 bytes + [InlineData(new byte[] { 0xE0, 0x82 }, new byte[] { 0xA7 })] + [InlineData(new byte[] { 0xF0, 0x80 }, new byte[] { 0x82, 0xA7 })] + + // '\u0800' (3 byte char) encoded with 4 bytes + [InlineData(new byte[] { 0xF0, 0x80 }, new byte[] { 0xA0, 0x80 })] + public async Task InvalidMultiFramePayloads(byte[] payload1, byte[] payload2) + { + using (var pair = WebSocketPair.Create()) + { + var timeoutToken = TestUtil.CreateTimeoutToken(); + using (timeoutToken.Register(() => pair.Dispose())) + { + var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync(); + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + + var frame = new WebSocketFrame( + endOfMessage: false, + opcode: WebSocketOpcode.Text, + payload: ReadableBuffer.Create(payload1)); + await pair.ClientSocket.SendAsync(frame).OrTimeout(); + frame = new WebSocketFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Continuation, + payload: ReadableBuffer.Create(payload2)); + await pair.ClientSocket.SendAsync(frame).OrTimeout(); + var clientSummary = await client.OrTimeout(); + var serverSummary = await server.OrTimeout(); + + Assert.Equal(1, serverSummary.Received.Count); + Assert.False(serverSummary.Received[0].EndOfMessage); + Assert.Equal(WebSocketOpcode.Text, serverSummary.Received[0].Opcode); + Assert.Equal(payload1, serverSummary.Received[0].Payload.ToArray()); + + Assert.Equal(0, clientSummary.Received.Count); + Assert.Equal(WebSocketCloseStatus.InvalidPayloadData, clientSummary.CloseResult.Status); + Assert.Equal("An invalid Text frame payload was received", clientSummary.CloseResult.Description); + } + } + } + } + } +} diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs index 1c22596f38..a37ae40271 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs @@ -1,4 +1,7 @@ -using System; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; using Channels; namespace Microsoft.Extensions.WebSockets.Internal.Tests @@ -7,8 +10,8 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests { private ChannelFactory _factory; - private Channel _serverToClient; - private Channel _clientToServer; + public Channel ServerToClient { get; } + public Channel ClientToServer { get; } public IWebSocketConnection ClientSocket { get; } public IWebSocketConnection ServerSocket { get; } @@ -16,35 +19,37 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests public WebSocketPair(ChannelFactory factory, Channel serverToClient, Channel clientToServer, IWebSocketConnection clientSocket, IWebSocketConnection serverSocket) { _factory = factory; - _serverToClient = serverToClient; - _clientToServer = clientToServer; + ServerToClient = serverToClient; + ClientToServer = clientToServer; ClientSocket = clientSocket; ServerSocket = serverSocket; } - public static WebSocketPair Create() + public static WebSocketPair Create() => Create(new WebSocketOptions().WithAllFramesPassedThrough().WithRandomMasking(), new WebSocketOptions().WithAllFramesPassedThrough()); + + public static WebSocketPair Create(WebSocketOptions serverOptions, WebSocketOptions clientOptions) { // Create channels var factory = new ChannelFactory(); var serverToClient = factory.CreateChannel(); var clientToServer = factory.CreateChannel(); - var serverSocket = new WebSocketConnection(clientToServer, serverToClient, masked: true); - var clientSocket = new WebSocketConnection(serverToClient, clientToServer, masked: false); + var serverSocket = new WebSocketConnection(clientToServer, serverToClient, options: serverOptions); + var clientSocket = new WebSocketConnection(serverToClient, clientToServer, options: clientOptions); return new WebSocketPair(factory, serverToClient, clientToServer, clientSocket, serverSocket); } public void Dispose() { - _factory.Dispose(); ServerSocket.Dispose(); ClientSocket.Dispose(); + _factory.Dispose(); } public void TerminateFromClient(Exception ex = null) { - _clientToServer.CompleteWriter(ex); + ClientToServer.CompleteWriter(ex); } } } \ No newline at end of file diff --git a/test/WebSocketsTestApp/Program.cs b/test/WebSocketsTestApp/Program.cs new file mode 100644 index 0000000000..c0e1548136 --- /dev/null +++ b/test/WebSocketsTestApp/Program.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.Configuration; + +namespace WebSocketsTestApp +{ + public class Program + { + public static void Main(string[] args) + { + var config = new ConfigurationBuilder() + .AddCommandLine(args) + .Build(); + + var host = new WebHostBuilder() + .UseConfiguration(config) + .UseKestrel() + .UseContentRoot(Directory.GetCurrentDirectory()) + .UseIISIntegration() + .UseStartup() + .Build(); + + host.Run(); + } + } +} diff --git a/test/WebSocketsTestApp/Startup.cs b/test/WebSocketsTestApp/Startup.cs new file mode 100644 index 0000000000..66374eccd0 --- /dev/null +++ b/test/WebSocketsTestApp/Startup.cs @@ -0,0 +1,118 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Text; +using System.Threading.Tasks; +using Channels; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.WebSockets.Internal; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.WebSockets.Internal; + +namespace WebSocketsTestApp +{ + public class Startup + { + // This method gets called by the runtime. Use this method to add services to the container. + // For more information on how to configure your application, visit http://go.microsoft.com/fwlink/?LinkID=398940 + public void ConfigureServices(IServiceCollection services) + { + services.AddSingleton(); + } + + // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. + public void Configure(IApplicationBuilder app, IHostingEnvironment env, ILoggerFactory loggerFactory, ChannelFactory channelFactory) + { + loggerFactory.AddConsole(LogLevel.Debug); + + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + + app.UseWebSocketConnections(new ChannelFactory()); + + app.Use(async (context, next) => + { + var webSocketConnectionFeature = context.Features.Get(); + if (webSocketConnectionFeature != null && webSocketConnectionFeature.IsWebSocketRequest) + { + using (var webSocket = await webSocketConnectionFeature.AcceptWebSocketConnectionAsync(new WebSocketAcceptContext())) + { + await Echo(context, webSocket, loggerFactory.CreateLogger("Echo")); + } + } + else + { + await next(); + } + }); + + app.UseFileServer(); + } + + private async Task Echo(HttpContext context, IWebSocketConnection webSocket, ILogger logger) + { + var lastFrameOpcode = WebSocketOpcode.Continuation; + var closeResult = await webSocket.ExecuteAsync(frame => + { + if (frame.Opcode == WebSocketOpcode.Ping || frame.Opcode == WebSocketOpcode.Pong) + { + // Already handled + return Task.CompletedTask; + } + + LogFrame(logger, lastFrameOpcode, ref frame); + + // If the client send "ServerClose", then they want a server-originated close to occur + string content = "<>"; + if (frame.Opcode == WebSocketOpcode.Text) + { + // Slooooow + content = Encoding.UTF8.GetString(frame.Payload.ToArray()); + if (content.Equals("ServerClose")) + { + logger.LogDebug($"Sending Frame Close: {WebSocketCloseStatus.NormalClosure} Closing from Server"); + return webSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure, "Closing from Server")); + } + else if (content.Equals("ServerAbort")) + { + context.Abort(); + } + } + + if (frame.Opcode != WebSocketOpcode.Continuation) + { + lastFrameOpcode = frame.Opcode; + } + logger.LogDebug($"Sending {frame.Opcode}: Len={frame.Payload.Length}, Fin={frame.EndOfMessage}: {content}"); + return webSocket.SendAsync(frame); + }); + + if (webSocket.State == WebSocketConnectionState.CloseReceived) + { + // Close the connection from our end + await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure); + logger.LogDebug("Socket closed"); + } + else if (webSocket.State != WebSocketConnectionState.Closed) + { + logger.LogError("WebSocket closed but not closed?"); + } + } + + private void LogFrame(ILogger logger, WebSocketOpcode lastFrameOpcode, ref WebSocketFrame frame) + { + var opcode = frame.Opcode; + if (opcode == WebSocketOpcode.Continuation) + { + opcode = lastFrameOpcode; + } + + logger.LogDebug($"Received {frame.Opcode} frame (FIN={frame.EndOfMessage}, LEN={frame.Payload.Length})"); + } + } +} diff --git a/test/WebSocketsTestApp/WebSocketsTestApp.xproj b/test/WebSocketsTestApp/WebSocketsTestApp.xproj new file mode 100644 index 0000000000..18006ae239 --- /dev/null +++ b/test/WebSocketsTestApp/WebSocketsTestApp.xproj @@ -0,0 +1,16 @@ + + + + 14.0.25420 + $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) + + + + 58e771ec-8454-4558-b61a-c9d049065911 + + + + 2.0 + + + \ No newline at end of file diff --git a/test/WebSocketsTestApp/project.json b/test/WebSocketsTestApp/project.json new file mode 100644 index 0000000000..dd7831c3de --- /dev/null +++ b/test/WebSocketsTestApp/project.json @@ -0,0 +1,47 @@ +{ + "dependencies": { + "Microsoft.AspNetCore.Diagnostics": "1.1.0-*", + "Microsoft.AspNetCore.Server.IISIntegration": "1.1.0-*", + "Microsoft.AspNetCore.Server.Kestrel": "1.1.0-*", + "Microsoft.AspNetCore.StaticFiles": "1.1.0-*", + "Microsoft.AspNetCore.WebSockets.Internal": "0.1.0-*", + "Microsoft.Extensions.Configuration": "1.1.0-*", + "Microsoft.Extensions.Configuration.CommandLine": "1.1.0-*", + "Microsoft.Extensions.Logging.Console": "1.1.0-*", + "Microsoft.NETCore.App": { + "version": "1.1.0-*", + "type": "platform" + } + }, + "tools": { + "Microsoft.AspNetCore.Server.IISIntegration.Tools": "1.0.0-*" + }, + "frameworks": { + "netcoreapp1.1": { + "imports": [ + "dotnet5.6", + "portable-net45+win8" + ] + } + }, + "buildOptions": { + "emitEntryPoint": true, + "preserveCompilationContext": true + }, + "runtimeOptions": { + "configProperties": { + "System.GC.Server": true + } + }, + "publishOptions": { + "include": [ + "wwwroot", + "web.config" + ] + }, + "scripts": { + "postpublish": [ + "dotnet publish-iis --publish-folder %publish:OutputPath% --framework %publish:FullTargetFramework%" + ] + } +} \ No newline at end of file diff --git a/test/WebSocketsTestApp/scripts/RunAutobahnTests.ps1 b/test/WebSocketsTestApp/scripts/RunAutobahnTests.ps1 new file mode 100644 index 0000000000..d109182eac --- /dev/null +++ b/test/WebSocketsTestApp/scripts/RunAutobahnTests.ps1 @@ -0,0 +1,43 @@ +# +# RunAutobahnTests.ps1 +# +param([Parameter(Mandatory=$true)][string]$ServerUrl, [string[]]$Cases = @("*"), [string]$OutputDir, [int]$Iterations = 1) + +if(!(Get-Command wstest -ErrorAction SilentlyContinue)) { + throw "Missing required command 'wstest'. See README.md in Microsoft.AspNetCore.WebSockets.Server.Test project for information on installing Autobahn Test Suite." +} + +if(!$OutputDir) { + $OutputDir = Convert-Path "." + $OutputDir = Join-Path $OutputDir "autobahnreports" +} + +Write-Host "Launching Autobahn Test Suite ($Iterations iteration(s))..." + +0..($Iterations-1) | % { + $iteration = $_ + + $Spec = Convert-Path (Join-Path $PSScriptRoot "autobahn.spec.json") + + $CasesArray = [string]::Join(",", @($Cases | ForEach-Object { "`"$_`"" })) + + $SpecJson = [IO.File]::ReadAllText($Spec).Replace("OUTPUTDIR", $OutputDir.Replace("\", "\\")).Replace("WEBSOCKETURL", $ServerUrl).Replace("`"CASES`"", $CasesArray) + + $TempFile = [IO.Path]::GetTempFileName() + + try { + [IO.File]::WriteAllText($TempFile, $SpecJson) + $wstestOutput = & wstest -m fuzzingclient -s $TempFile + } finally { + if(Test-Path $TempFile) { + rm $TempFile + } + } + + $report = ConvertFrom-Json ([IO.File]::ReadAllText((Convert-Path (Join-Path $OutputDir "index.json")))) + + $report.Server | gm | ? { $_.MemberType -eq "NoteProperty" } | % { + $case = $report.Server."$($_.Name)" + Write-Host "[#$($iteration.ToString().PadRight(2))] [$($case.behavior.PadRight(6))] Case $($_.Name)" + } +} \ No newline at end of file diff --git a/test/WebSocketsTestApp/scripts/autobahn.spec.json b/test/WebSocketsTestApp/scripts/autobahn.spec.json new file mode 100644 index 0000000000..aa6d841167 --- /dev/null +++ b/test/WebSocketsTestApp/scripts/autobahn.spec.json @@ -0,0 +1,14 @@ +{ + "options": { "failByDrop": false }, + "outdir": "OUTPUTDIR", + "servers": [ + { + "agent": "Server", + "url": "WEBSOCKETURL", + "options": { "version": 18 } + } + ], + "cases": ["CASES"], + "exclude-cases": ["12.*", "13.*"], + "exclude-agent-cases": {} +} diff --git a/test/WebSocketsTestApp/web.config b/test/WebSocketsTestApp/web.config new file mode 100644 index 0000000000..dc0514fca5 --- /dev/null +++ b/test/WebSocketsTestApp/web.config @@ -0,0 +1,14 @@ + + + + + + + + + + + + diff --git a/test/WebSocketsTestApp/wwwroot/index.html b/test/WebSocketsTestApp/wwwroot/index.html new file mode 100644 index 0000000000..1663600a5e --- /dev/null +++ b/test/WebSocketsTestApp/wwwroot/index.html @@ -0,0 +1,151 @@ + + + + + + + + +

WebSocket Test Page

+

Ready to connect...

+
+ + + +
+
+ + + + +
+ +

Note: When connected to the default server (i.e. the server in the address bar ;)), the message "ServerClose" will cause the server to close the connection. Similarly, the message "ServerAbort" will cause the server to forcibly terminate the connection without a closing handshake

+ +

Communication Log

+ + + + + + + + + + +
FromToData
+ + + + \ No newline at end of file