Removed custom websocket implementation (#507)

- Use the default websocket middleware
- Rewrote TestWebSocketConnectionFeature to use Channels instead of pipes
This commit is contained in:
David Fowler 2017-06-03 06:53:39 -10:00 committed by GitHub
parent db9712c3f2
commit 72423ee203
66 changed files with 313 additions and 4867 deletions

View File

@ -1,6 +1,6 @@
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15
VisualStudioVersion = 15.0.26526.1
VisualStudioVersion = 15.0.26510.0
MinimumVisualStudioVersion = 10.0.40219.1
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{DA69F624-5398-4884-87E4-B816698CDE65}"
EndProject
@ -24,20 +24,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{6A35B453-5
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.Sockets.Tests", "test\Microsoft.AspNetCore.Sockets.Tests\Microsoft.AspNetCore.Sockets.Tests.csproj", "{AAD719D5-5E31-4ED1-A60F-6EB92EFA66D9}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.Extensions.WebSockets.Internal", "src\Microsoft.Extensions.WebSockets.Internal\Microsoft.Extensions.WebSockets.Internal.csproj", "{5D9DA986-2EAB-4C6D-BF15-9A4BDD4DE775}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.Extensions.WebSockets.Internal.Tests", "test\Microsoft.Extensions.WebSockets.Internal.Tests\Microsoft.Extensions.WebSockets.Internal.Tests.csproj", "{A7050BAE-3DB9-4FB3-A49D-303201415B13}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.SignalR", "src\Microsoft.AspNetCore.SignalR\Microsoft.AspNetCore.SignalR.csproj", "{42E76F87-92B6-45AB-BF07-6B811C0F2CAC}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.SignalR.Redis", "src\Microsoft.AspNetCore.SignalR.Redis\Microsoft.AspNetCore.SignalR.Redis.csproj", "{59319B72-38BE-4041-8E5C-FF6938874CE8}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.WebSockets.Internal", "src\Microsoft.AspNetCore.WebSockets.Internal\Microsoft.AspNetCore.WebSockets.Internal.csproj", "{FFFE71F8-E476-4BCD-9689-F106EE1C1497}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest", "test\Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest\Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.csproj", "{8CBC1C71-AF0B-44E2-AEE9-D8024C07634D}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "WebSocketsTestApp", "test\WebSocketsTestApp\WebSocketsTestApp.csproj", "{58E771EC-8454-4558-B61A-C9D049065911}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ChatSample", "samples\ChatSample\ChatSample.csproj", "{300979F6-A02E-407A-B8DF-F6200806C18D}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SocialWeather", "samples\SocialWeather\SocialWeather.csproj", "{8D789F94-CB74-45FD-ACE7-92AF6E55042E}"
@ -107,14 +97,6 @@ Global
{AAD719D5-5E31-4ED1-A60F-6EB92EFA66D9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{AAD719D5-5E31-4ED1-A60F-6EB92EFA66D9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{AAD719D5-5E31-4ED1-A60F-6EB92EFA66D9}.Release|Any CPU.Build.0 = Release|Any CPU
{5D9DA986-2EAB-4C6D-BF15-9A4BDD4DE775}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{5D9DA986-2EAB-4C6D-BF15-9A4BDD4DE775}.Debug|Any CPU.Build.0 = Debug|Any CPU
{5D9DA986-2EAB-4C6D-BF15-9A4BDD4DE775}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5D9DA986-2EAB-4C6D-BF15-9A4BDD4DE775}.Release|Any CPU.Build.0 = Release|Any CPU
{A7050BAE-3DB9-4FB3-A49D-303201415B13}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{A7050BAE-3DB9-4FB3-A49D-303201415B13}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A7050BAE-3DB9-4FB3-A49D-303201415B13}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A7050BAE-3DB9-4FB3-A49D-303201415B13}.Release|Any CPU.Build.0 = Release|Any CPU
{42E76F87-92B6-45AB-BF07-6B811C0F2CAC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{42E76F87-92B6-45AB-BF07-6B811C0F2CAC}.Debug|Any CPU.Build.0 = Debug|Any CPU
{42E76F87-92B6-45AB-BF07-6B811C0F2CAC}.Release|Any CPU.ActiveCfg = Release|Any CPU
@ -123,18 +105,6 @@ Global
{59319B72-38BE-4041-8E5C-FF6938874CE8}.Debug|Any CPU.Build.0 = Debug|Any CPU
{59319B72-38BE-4041-8E5C-FF6938874CE8}.Release|Any CPU.ActiveCfg = Release|Any CPU
{59319B72-38BE-4041-8E5C-FF6938874CE8}.Release|Any CPU.Build.0 = Release|Any CPU
{FFFE71F8-E476-4BCD-9689-F106EE1C1497}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{FFFE71F8-E476-4BCD-9689-F106EE1C1497}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FFFE71F8-E476-4BCD-9689-F106EE1C1497}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FFFE71F8-E476-4BCD-9689-F106EE1C1497}.Release|Any CPU.Build.0 = Release|Any CPU
{8CBC1C71-AF0B-44E2-AEE9-D8024C07634D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{8CBC1C71-AF0B-44E2-AEE9-D8024C07634D}.Debug|Any CPU.Build.0 = Debug|Any CPU
{8CBC1C71-AF0B-44E2-AEE9-D8024C07634D}.Release|Any CPU.ActiveCfg = Release|Any CPU
{8CBC1C71-AF0B-44E2-AEE9-D8024C07634D}.Release|Any CPU.Build.0 = Release|Any CPU
{58E771EC-8454-4558-B61A-C9D049065911}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{58E771EC-8454-4558-B61A-C9D049065911}.Debug|Any CPU.Build.0 = Debug|Any CPU
{58E771EC-8454-4558-B61A-C9D049065911}.Release|Any CPU.ActiveCfg = Release|Any CPU
{58E771EC-8454-4558-B61A-C9D049065911}.Release|Any CPU.Build.0 = Release|Any CPU
{300979F6-A02E-407A-B8DF-F6200806C18D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{300979F6-A02E-407A-B8DF-F6200806C18D}.Debug|Any CPU.Build.0 = Debug|Any CPU
{300979F6-A02E-407A-B8DF-F6200806C18D}.Release|Any CPU.ActiveCfg = Release|Any CPU
@ -219,13 +189,8 @@ Global
{C4AEAB04-F341-4539-B6C0-52368FB4BF9E} = {C4BC9889-B49F-41B6-806B-F84941B2549B}
{1715EA8D-8E13-4ACF-8BCA-57D048E55ED8} = {DA69F624-5398-4884-87E4-B816698CDE65}
{AAD719D5-5E31-4ED1-A60F-6EB92EFA66D9} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
{5D9DA986-2EAB-4C6D-BF15-9A4BDD4DE775} = {DA69F624-5398-4884-87E4-B816698CDE65}
{A7050BAE-3DB9-4FB3-A49D-303201415B13} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
{42E76F87-92B6-45AB-BF07-6B811C0F2CAC} = {DA69F624-5398-4884-87E4-B816698CDE65}
{59319B72-38BE-4041-8E5C-FF6938874CE8} = {DA69F624-5398-4884-87E4-B816698CDE65}
{FFFE71F8-E476-4BCD-9689-F106EE1C1497} = {DA69F624-5398-4884-87E4-B816698CDE65}
{8CBC1C71-AF0B-44E2-AEE9-D8024C07634D} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
{58E771EC-8454-4558-B61A-C9D049065911} = {6A35B453-52EC-48AF-89CA-D4A69800F131}
{300979F6-A02E-407A-B8DF-F6200806C18D} = {C4BC9889-B49F-41B6-806B-F84941B2549B}
{8D789F94-CB74-45FD-ACE7-92AF6E55042E} = {C4BC9889-B49F-41B6-806B-F84941B2549B}
{A0BF246B-FE7D-4E12-99BF-FFDC131B85D8} = {3A76C5A2-79ED-49BC-8BDC-6A3A766FFA1B}

View File

@ -1,4 +1,4 @@
<!DOCTYPE html>
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
@ -65,11 +65,11 @@ function invoke(connection, method, ...args) {
console.log("invocation completed successfully: " + (result === null ? '(null)' : result));
if (result) {
addLine('message', result);
addLine('messages', result);
}
})
.catch(err => {
addLine('message', err, 'red');
addLine('messages', err, 'red');
});
}
@ -83,15 +83,15 @@ document.getElementById('head1').innerHTML = signalR.TransportType[transportType
let connection = new signalR.HubConnection(`http://${document.location.host}/hubs`, 'formatType=json&format=text');
connection.on('Send', msg => {
addLine('message', msg);
addLine('messages', msg);
});
connection.onClosed = e => {
if (e) {
addLine('message', 'Connection closed with error: ' + e, 'red');
addLine('messages', 'Connection closed with error: ' + e, 'red');
}
else {
addLine('message', 'Disconnected', 'green');
addLine('messages', 'Disconnected', 'green');
}
}
@ -99,10 +99,10 @@ click('connect', event => {
connection.start(transportType)
.then(() => {
isConnected = true;
addLine('message', 'Connected successfully', 'green');
addLine('messages', 'Connected successfully', 'green');
})
.catch(err => {
addLine('message', err, 'red');
addLine('messages', err, 'red');
});
});
@ -130,7 +130,7 @@ click('leave-group', event => {
click('groupmsg', event => {
let groupName = getText('target');
let message = getText('message');
let message = getText('messages');
invoke(connection, 'SendToGroup', groupName, message);
});

View File

@ -13,7 +13,6 @@ using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
using Microsoft.AspNetCore.Sockets.Transports;
using Microsoft.AspNetCore.WebSockets.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
@ -89,7 +88,7 @@ namespace Microsoft.AspNetCore.Sockets
await DoPersistentConnection(socketDelegate, sse, context, state);
}
else if (context.Features.Get<IHttpWebSocketConnectionFeature>()?.IsWebSocketRequest == true)
else if (context.WebSockets.IsWebSocketRequest)
{
// Connection can be established lazily
var state = await GetOrCreateConnectionAsync(context);

View File

@ -14,12 +14,13 @@
<ItemGroup>
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets\Microsoft.AspNetCore.Sockets.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common\Microsoft.AspNetCore.Sockets.Common.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.WebSockets.Internal\Microsoft.AspNetCore.WebSockets.Internal.csproj" />
<PackageReference Include="Microsoft.AspNetCore.Authorization.Policy" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Hosting.Abstractions" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Routing" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.WebSockets" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.SecurityHelper.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
<PackageReference Include="System.Threading.Tasks.Channels" Version="$(CoreFxLabsVersion)" />
<PackageReference Include="System.IO.Pipelines.Text.Primitives" Version="$(CoreFxLabsVersion)" />
</ItemGroup>
</Project>

View File

@ -18,7 +18,7 @@ namespace Microsoft.AspNetCore.Builder
callback(new SocketRouteBuilder(routes, dispatcher));
app.UseWebSocketConnections();
app.UseWebSockets();
app.UseRouter(routes.Build());
return app;
}

View File

@ -2,26 +2,19 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.WebSockets.Internal;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.WebSockets.Internal;
namespace Microsoft.AspNetCore.Sockets.Transports
{
public class WebSocketsTransport : IHttpTransport
{
private readonly WebSocketOptions _options;
private static readonly WebSocketAcceptContext _emptyContext = new WebSocketAcceptContext();
private WebSocketOpcode _lastOpcode = WebSocketOpcode.Continuation;
private bool _lastFrameIncomplete = false;
private readonly ILogger _logger;
private readonly IChannelConnection<Message> _application;
@ -49,11 +42,9 @@ namespace Microsoft.AspNetCore.Sockets.Transports
public async Task ProcessRequestAsync(HttpContext context, CancellationToken token)
{
var feature = context.Features.Get<IHttpWebSocketConnectionFeature>();
Debug.Assert(context.WebSockets.IsWebSocketRequest, "Not a websocket request");
Debug.Assert(feature != null, $"The {nameof(IHttpWebSocketConnectionFeature)} feature is missing!");
using (var ws = await feature.AcceptWebSocketConnectionAsync(_emptyContext))
using (var ws = await context.WebSockets.AcceptWebSocketAsync())
{
_logger.LogInformation("Socket opened.");
@ -62,10 +53,10 @@ namespace Microsoft.AspNetCore.Sockets.Transports
_logger.LogInformation("Socket closed.");
}
public async Task ProcessSocketAsync(IWebSocketConnection socket)
public async Task ProcessSocketAsync(WebSocket socket)
{
// Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync.
var receiving = socket.ExecuteAsync((frame, state) => ((WebSocketsTransport)state).HandleFrame(frame), this);
var receiving = StartReceiving(socket);
var sending = StartSending(socket);
// Wait for something to shut down.
@ -88,7 +79,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
// Shutting down because we received a close frame from the client.
// Complete the input writer so that the application knows there won't be any more input.
_logger.LogDebug("Client closed connection with status code '{0}' ({1}). Signaling end-of-input to application", receiving.Result.Status, receiving.Result.Description);
_logger.LogDebug("Client closed connection with status code '{status}' ({description}). Signaling end-of-input to application", receiving.Result.CloseStatus, receiving.Result.CloseStatusDescription);
_application.Output.TryComplete();
// Wait for the application to finish sending.
@ -96,7 +87,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
await sending;
// Send the server's close frame
await socket.CloseAsync(WebSocketCloseStatus.NormalClosure);
await socket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
}
else
{
@ -104,7 +95,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
// The application finished sending. Close our end of the connection
_logger.LogDebug(!failed ? "Application finished sending. Sending close frame." : "Application failed during sending. Sending InternalServerError close frame");
await socket.CloseAsync(!failed ? WebSocketCloseStatus.NormalClosure : WebSocketCloseStatus.InternalServerError);
await socket.CloseOutputAsync(!failed ? WebSocketCloseStatus.NormalClosure : WebSocketCloseStatus.InternalServerError, "", CancellationToken.None);
// Now trigger the exception from the application, if there was one.
sending.GetAwaiter().GetResult();
@ -126,69 +117,93 @@ namespace Microsoft.AspNetCore.Sockets.Transports
}
}
private Task HandleFrame(WebSocketFrame frame)
private async Task<WebSocketReceiveResult> StartReceiving(WebSocket socket)
{
// Is this a frame we care about?
if (!frame.Opcode.IsMessage())
// REVIEW: This code was copied from the client, it's highly unoptimized at the moment (especially
// for server logic)
var incomingMessage = new List<ArraySegment<byte>>();
while (true)
{
return Task.CompletedTask;
}
const int bufferSize = 4096;
var totalBytes = 0;
WebSocketReceiveResult receiveResult;
do
{
var buffer = new ArraySegment<byte>(new byte[bufferSize]);
LogFrame("Receiving", frame);
// Exceptions are handled above where the send and receive tasks are being run.
receiveResult = await socket.ReceiveAsync(buffer, CancellationToken.None);
if (receiveResult.MessageType == WebSocketMessageType.Close)
{
return receiveResult;
}
// Determine the effective opcode based on the continuation.
var effectiveOpcode = frame.Opcode;
if (frame.Opcode == WebSocketOpcode.Continuation)
{
effectiveOpcode = _lastOpcode;
}
else
{
_lastOpcode = frame.Opcode;
}
_logger.LogDebug("Message received. Type: {messageType}, size: {size}, EndOfMessage: {endOfMessage}",
receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage);
// Create a Message for the frame
// This has to copy the buffer :(.
var message = new Message(frame.Payload.ToArray(), effectiveOpcode == WebSocketOpcode.Binary ? MessageType.Binary : MessageType.Text, frame.EndOfMessage);
var truncBuffer = new ArraySegment<byte>(buffer.Array, 0, receiveResult.Count);
incomingMessage.Add(truncBuffer);
totalBytes += receiveResult.Count;
} while (!receiveResult.EndOfMessage);
// Write the message to the channel
return _application.Output.WriteAsync(message);
}
// Making sure the message type is either text or binary
Debug.Assert((receiveResult.MessageType == WebSocketMessageType.Binary || receiveResult.MessageType == WebSocketMessageType.Text), "Unexpected message type");
private void LogFrame(string action, WebSocketFrame frame)
{
if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug(
$"{action} frame: Opcode={frame.Opcode}, Fin={frame.EndOfMessage}, Payload={frame.Payload.Length} bytes");
Message message;
var messageType = receiveResult.MessageType == WebSocketMessageType.Binary ? MessageType.Binary : MessageType.Text;
if (incomingMessage.Count > 1)
{
var messageBuffer = new byte[totalBytes];
var offset = 0;
for (var i = 0; i < incomingMessage.Count; i++)
{
Buffer.BlockCopy(incomingMessage[i].Array, 0, messageBuffer, offset, incomingMessage[i].Count);
offset += incomingMessage[i].Count;
}
message = new Message(messageBuffer, messageType, receiveResult.EndOfMessage);
}
else
{
var buffer = new byte[incomingMessage[0].Count];
Buffer.BlockCopy(incomingMessage[0].Array, incomingMessage[0].Offset, buffer, 0, incomingMessage[0].Count);
message = new Message(buffer, messageType, receiveResult.EndOfMessage);
}
_logger.LogInformation("Passing message to application. Payload size: {length}", message.Payload.Length);
while (await _application.Output.WaitToWriteAsync())
{
if (_application.Output.TryWrite(message))
{
incomingMessage.Clear();
break;
}
}
}
}
private async Task StartSending(IWebSocketConnection ws)
private async Task StartSending(WebSocket ws)
{
while (await _application.Input.WaitToReadAsync())
{
// Get a frame from the application
Message message;
while (_application.Input.TryRead(out message))
while (_application.Input.TryRead(out var message))
{
if (message.Payload.Length > 0)
{
try
{
var opcode = message.Type == MessageType.Binary ?
WebSocketOpcode.Binary :
WebSocketOpcode.Text;
var messageType = message.Type == MessageType.Binary ?
WebSocketMessageType.Binary :
WebSocketMessageType.Text;
var frame = new WebSocketFrame(
endOfMessage: message.EndOfMessage,
opcode: _lastFrameIncomplete ? WebSocketOpcode.Continuation : opcode,
payload: ReadableBuffer.Create(message.Payload));
if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("Sending Type: {messageType}, size: {size}, EndOfMessage: {endOfMessage}",
message.Type, message.EndOfMessage, message.Payload.Length);
}
_lastFrameIncomplete = !message.EndOfMessage;
LogFrame("Sending", frame);
await ws.SendAsync(frame);
await ws.SendAsync(new ArraySegment<byte>(message.Payload), messageType, message.EndOfMessage, CancellationToken.None);
}
catch (Exception ex)
{

View File

@ -1,21 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.WebSockets.Internal
{
public static class Constants
{
public static class Headers
{
public const string Upgrade = "Upgrade";
public const string UpgradeWebSocket = "websocket";
public const string Connection = "Connection";
public const string ConnectionUpgrade = "Upgrade";
public const string SecWebSocketKey = "Sec-WebSocket-Key";
public const string SecWebSocketVersion = "Sec-WebSocket-Version";
public const string SecWebSocketProtocol = "Sec-WebSocket-Protocol";
public const string SecWebSocketAccept = "Sec-WebSocket-Accept";
public const string SupportedVersion = "13";
}
}
}

View File

@ -1,106 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Security.Cryptography;
using System.Text;
using Microsoft.AspNetCore.Http;
namespace Microsoft.AspNetCore.WebSockets.Internal
{
public static class HandshakeHelpers
{
// Verify Method, Upgrade, Connection, version, key, etc..
public static bool CheckSupportedWebSocketRequest(HttpRequest request)
{
bool validUpgrade = false, validConnection = false, validKey = false, validVersion = false;
if (!string.Equals("GET", request.Method, StringComparison.OrdinalIgnoreCase))
{
return false;
}
foreach (var value in request.Headers.GetCommaSeparatedValues(Constants.Headers.Connection))
{
if (string.Equals(Constants.Headers.ConnectionUpgrade, value, StringComparison.OrdinalIgnoreCase))
{
validConnection = true;
break;
}
}
foreach (var pair in request.Headers)
{
if (string.Equals(Constants.Headers.Upgrade, pair.Key, StringComparison.OrdinalIgnoreCase))
{
if (string.Equals(Constants.Headers.UpgradeWebSocket, pair.Value, StringComparison.OrdinalIgnoreCase))
{
validUpgrade = true;
}
}
else if (string.Equals(Constants.Headers.SecWebSocketVersion, pair.Key, StringComparison.OrdinalIgnoreCase))
{
if (string.Equals(Constants.Headers.SupportedVersion, pair.Value, StringComparison.OrdinalIgnoreCase))
{
validVersion = true;
}
}
else if (string.Equals(Constants.Headers.SecWebSocketKey, pair.Key, StringComparison.OrdinalIgnoreCase))
{
validKey = IsRequestKeyValid(pair.Value);
}
}
return validConnection && validUpgrade && validVersion && validKey;
}
public static IEnumerable<KeyValuePair<string, string>> GenerateResponseHeaders(string key, string subProtocol)
{
yield return new KeyValuePair<string, string>(Constants.Headers.Connection, Constants.Headers.ConnectionUpgrade);
yield return new KeyValuePair<string, string>(Constants.Headers.Upgrade, Constants.Headers.UpgradeWebSocket);
yield return new KeyValuePair<string, string>(Constants.Headers.SecWebSocketAccept, CreateResponseKey(key));
if (!string.IsNullOrWhiteSpace(subProtocol))
{
yield return new KeyValuePair<string, string>(Constants.Headers.SecWebSocketProtocol, subProtocol);
}
}
/// <summary>
/// Validates the Sec-WebSocket-Key request header
/// "The value of this header field MUST be a nonce consisting of a randomly selected 16-byte value that has been base64-encoded."
/// </summary>
/// <param name="value"></param>
/// <returns></returns>
public static bool IsRequestKeyValid(string value)
{
if (string.IsNullOrWhiteSpace(value))
{
return false;
}
return value.Length == 24;
}
/// <summary>
/// "...the base64-encoded SHA-1 of the concatenation of the |Sec-WebSocket-Key| (as a string, not base64-decoded) with the string
/// '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'"
/// </summary>
/// <param name="requestKey"></param>
/// <returns></returns>
public static string CreateResponseKey(string requestKey)
{
if (requestKey == null)
{
throw new ArgumentNullException(nameof(requestKey));
}
using (var algorithm = SHA1.Create())
{
string merged = requestKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
byte[] mergedBytes = Encoding.UTF8.GetBytes(merged);
byte[] hashedBytes = algorithm.ComputeHash(mergedBytes);
return Convert.ToBase64String(hashedBytes);
}
}
}
}

View File

@ -1,15 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.WebSockets.Internal;
namespace Microsoft.AspNetCore.WebSockets.Internal
{
public interface IHttpWebSocketConnectionFeature
{
bool IsWebSocketRequest { get; }
ValueTask<IWebSocketConnection> AcceptWebSocketConnectionAsync(WebSocketAcceptContext context);
}
}

View File

@ -1,21 +0,0 @@
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="..\..\build\common.props" />
<PropertyGroup>
<Description>WebSockets support for ASP.NET Core.</Description>
<VersionPrefix>0.1.0</VersionPrefix>
<TargetFramework>netstandard2.0</TargetFramework>
<NoWarn>$(NoWarn);CS1591</NoWarn>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<PackageTags>aspnetcore;signalr</PackageTags>
<EnableApiCheck>false</EnableApiCheck>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\Microsoft.Extensions.WebSockets.Internal\Microsoft.Extensions.WebSockets.Internal.csproj" />
<PackageReference Include="Microsoft.AspNetCore.Http.Abstractions" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="$(AspNetCoreVersion)" />
</ItemGroup>
</Project>

View File

@ -1,40 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using Microsoft.AspNetCore.WebSockets.Internal;
namespace Microsoft.AspNetCore.Builder
{
public static class WebSocketAppBuilderExtensions
{
public static void UseWebSocketConnections(this IApplicationBuilder app)
{
// Only the GC can clean up this channel factory :(
app.UseWebSocketConnections(new PipeFactory(), new WebSocketConnectionOptions());
}
public static void UseWebSocketConnections(this IApplicationBuilder app, PipeFactory factory)
{
if (factory == null)
{
throw new ArgumentNullException(nameof(factory));
}
app.UseWebSocketConnections(factory, new WebSocketConnectionOptions());
}
public static void UseWebSocketConnections(this IApplicationBuilder app, PipeFactory factory, WebSocketConnectionOptions options)
{
if (factory == null)
{
throw new ArgumentNullException(nameof(factory));
}
if (options == null)
{
throw new ArgumentNullException(nameof(options));
}
app.UseMiddleware<WebSocketConnectionMiddleware>(factory, options);
}
}
}

View File

@ -1,79 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.WebSockets.Internal;
namespace Microsoft.AspNetCore.WebSockets.Internal
{
internal class WebSocketConnectionFeature : IHttpWebSocketConnectionFeature
{
private HttpContext _context;
private IHttpUpgradeFeature _upgradeFeature;
private ILogger _logger;
private readonly PipeFactory _factory;
public bool IsWebSocketRequest
{
get
{
if (!_upgradeFeature.IsUpgradableRequest)
{
return false;
}
return HandshakeHelpers.CheckSupportedWebSocketRequest(_context.Request);
}
}
public WebSocketConnectionFeature(HttpContext context, PipeFactory factory, IHttpUpgradeFeature upgradeFeature, ILoggerFactory loggerFactory)
{
_factory = factory;
_context = context;
_upgradeFeature = upgradeFeature;
_logger = loggerFactory.CreateLogger<WebSocketConnectionFeature>();
}
public ValueTask<IWebSocketConnection> AcceptWebSocketConnectionAsync(WebSocketAcceptContext acceptContext)
{
if (!IsWebSocketRequest)
{
throw new InvalidOperationException("Not a WebSocket request."); // TODO: LOC
}
string subProtocol = null;
if (acceptContext != null)
{
subProtocol = acceptContext.SubProtocol;
}
_logger.LogDebug("WebSocket Handshake completed. SubProtocol: {0}", subProtocol);
var key = string.Join(", ", _context.Request.Headers[Constants.Headers.SecWebSocketKey]);
var responseHeaders = HandshakeHelpers.GenerateResponseHeaders(key, subProtocol);
foreach (var headerPair in responseHeaders)
{
_context.Response.Headers[headerPair.Key] = headerPair.Value;
}
// TODO: Avoid task allocation if there's a ValueTask-based UpgradeAsync?
return new ValueTask<IWebSocketConnection>(AcceptWebSocketConnectionCoreAsync(subProtocol));
}
private async Task<IWebSocketConnection> AcceptWebSocketConnectionCoreAsync(string subProtocol)
{
_logger.LogDebug("Upgrading connection to WebSockets");
var opaqueTransport = await _upgradeFeature.UpgradeAsync();
var connection = new WebSocketConnection(
opaqueTransport.AsPipelineReader(),
_factory.CreateWriter(opaqueTransport),
subProtocol: subProtocol);
return connection;
}
}
}

View File

@ -1,59 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.WebSockets.Internal
{
public class WebSocketConnectionMiddleware
{
private readonly PipeFactory _factory;
private readonly ILoggerFactory _loggerFactory;
private readonly RequestDelegate _next;
private readonly WebSocketConnectionOptions _options;
public WebSocketConnectionMiddleware(RequestDelegate next, PipeFactory factory, WebSocketConnectionOptions options, ILoggerFactory loggerFactory)
{
if (next == null)
{
throw new ArgumentNullException(nameof(next));
}
if (factory == null)
{
throw new ArgumentNullException(nameof(factory));
}
if (options == null)
{
throw new ArgumentNullException(nameof(options));
}
if (loggerFactory == null)
{
throw new ArgumentNullException(nameof(loggerFactory));
}
_next = next;
_loggerFactory = loggerFactory;
_factory = factory;
_options = options;
}
public Task Invoke(HttpContext context)
{
var upgradeFeature = context.Features.Get<IHttpUpgradeFeature>();
if (upgradeFeature != null)
{
if (_options.ReplaceFeature || context.Features.Get<IHttpWebSocketConnectionFeature>() == null)
{
context.Features.Set<IHttpWebSocketConnectionFeature>(new WebSocketConnectionFeature(context, _factory, upgradeFeature, _loggerFactory));
}
}
return _next(context);
}
}
}

View File

@ -1,10 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.WebSockets.Internal
{
public class WebSocketConnectionOptions
{
public bool ReplaceFeature { get; set; }
}
}

View File

@ -1,163 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Extensions.WebSockets.Internal
{
/// <summary>
/// Represents a connection to a WebSocket endpoint.
/// </summary>
/// <remarks>
/// <para>
/// Implementors of this type are generally considered thread-safe under the following condition: No two threads attempt to call either
/// <see cref="ExecuteAsync"/> or <see cref="SendAsync"/> simultaneously. Different threads may call each method, but the same method
/// cannot be re-entered while it is being run in a different thread. However, ensure you verify that the specific implementor is
/// thread-safe in this way. For example, <see cref="WebSocketConnection"/> (including the implementations returned by the
/// static factory methods on that type) is thread-safe in this way.
/// </para>
/// <para>
/// The general pattern of having a single thread running <see cref="ExecuteAsync"/> and a separate thread running <see cref="SendAsync"/> will
/// be thread-safe, as each method interacts with completely separate state.
/// </para>
/// </remarks>
public interface IWebSocketConnection : IDisposable
{
/// <summary>
/// Gets the sub-protocol value configured during handshaking.
/// </summary>
string SubProtocol { get; }
/// <summary>
/// Gets the current state of the connection
/// </summary>
WebSocketConnectionState State { get; }
/// <summary>
/// Sends the specified frame.
/// </summary>
/// <param name="message">The message to send.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> that indicates when/if the send is cancelled.</param>
/// <returns>A <see cref="Task"/> that completes when the message has been written to the outbound stream.</returns>
Task SendAsync(WebSocketFrame message, CancellationToken cancellationToken);
/// <summary>
/// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame.
/// </summary>
/// <remarks>
/// If the other party does not respond with a close frame, the connection will remain open and the <see cref="Task{WebSocketCloseResult}"/>
/// will remain active. Call the <see cref="IDisposable.Dispose"/> method on this instance to forcibly terminate the connection.
/// </remarks>
/// <param name="result">A <see cref="WebSocketCloseResult"/> with the payload for the close frame</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> that indicates when/if the send is cancelled.</param>
/// <returns>A <see cref="Task"/> that completes when the close frame has been sent</returns>
Task CloseAsync(WebSocketCloseResult result, CancellationToken cancellationToken);
/// <summary>
/// Runs the WebSocket receive loop, using the provided message handler. Note that <see cref="WebSocketOpcode.Ping"/> and
/// <see cref="WebSocketOpcode.Pong"/> frames will be passed to this handler for tracking/logging/monitoring, BUT will automatically be handled.
/// </summary>
/// <param name="messageHandler">The callback that will be invoked for each new frame</param>
/// <param name="state">A state parameter that will be passed to each invocation of <paramref name="messageHandler"/></param>
/// <returns>A <see cref="Task{WebSocketCloseResult}"/> that will complete when the client has sent a close frame, or the connection has been terminated</returns>
Task<WebSocketCloseResult> ExecuteAsync(Func<WebSocketFrame, object, Task> messageHandler, object state);
/// <summary>
/// Forcibly terminates the socket, cleaning up the necessary resources.
/// </summary>
void Abort();
}
public static class WebSocketConnectionExtensions
{
/// <summary>
/// Sends the specified frame.
/// </summary>
/// <param name="connection">The <see cref="IWebSocketConnection"/></param>
/// <param name="message">The message to send.</param>
/// <returns>A <see cref="Task"/> that completes when the message has been written to the outbound stream.</returns>
public static Task SendAsync(this IWebSocketConnection connection, WebSocketFrame message) => connection.SendAsync(message, CancellationToken.None);
/// <summary>
/// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame.
/// </summary>
/// <param name="connection">The <see cref="IWebSocketConnection"/></param>
/// <param name="status">A <see cref="WebSocketCloseStatus"/> value to be sent to the client in the close frame</param>.
/// <returns>A <see cref="Task"/> that completes when the close frame has been sent</returns>
public static Task CloseAsync(this IWebSocketConnection connection, WebSocketCloseStatus status) => connection.CloseAsync(new WebSocketCloseResult(status), CancellationToken.None);
/// <summary>
/// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame.
/// </summary>
/// <param name="connection">The <see cref="IWebSocketConnection"/></param>
/// <param name="status">A <see cref="WebSocketCloseStatus"/> value to be sent to the client in the close frame</param>.
/// <param name="description">A textual description of the reason for closing the connection.</param>
/// <returns>A <see cref="Task"/> that completes when the close frame has been sent</returns>
public static Task CloseAsync(this IWebSocketConnection connection, WebSocketCloseStatus status, string description) => connection.CloseAsync(new WebSocketCloseResult(status, description), CancellationToken.None);
/// <summary>
/// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame.
/// </summary>
/// <param name="connection">The <see cref="IWebSocketConnection"/></param>
/// <param name="status">A <see cref="WebSocketCloseStatus"/> value to be sent to the client in the close frame</param>.
/// <param name="cancellationToken">A <see cref="CancellationToken"/> that indicates when/if the send is cancelled.</param>
/// <returns>A <see cref="Task"/> that completes when the close frame has been sent</returns>
public static Task CloseAsync(this IWebSocketConnection connection, WebSocketCloseStatus status, CancellationToken cancellationToken) => connection.CloseAsync(new WebSocketCloseResult(status), cancellationToken);
/// <summary>
/// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame.
/// </summary>
/// <param name="connection">The <see cref="IWebSocketConnection"/></param>
/// <param name="status">A <see cref="WebSocketCloseStatus"/> value to be sent to the client in the close frame</param>.
/// <param name="description">A textual description of the reason for closing the connection.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> that indicates when/if the send is cancelled.</param>
/// <returns>A <see cref="Task"/> that completes when the close frame has been sent</returns>
public static Task CloseAsync(this IWebSocketConnection connection, WebSocketCloseStatus status, string description, CancellationToken cancellationToken) => connection.CloseAsync(new WebSocketCloseResult(status, description), cancellationToken);
/// <summary>
/// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame.
/// </summary>
/// <param name="connection">The <see cref="IWebSocketConnection"/></param>
/// <param name="result">A <see cref="WebSocketCloseResult"/> with the payload for the close frame.</param>
/// <returns>A <see cref="Task"/> that completes when the close frame has been sent</returns>
public static Task CloseAsync(this IWebSocketConnection connection, WebSocketCloseResult result) => connection.CloseAsync(result, CancellationToken.None);
/// <summary>
/// Runs the WebSocket receive loop, using the provided message handler.
/// </summary>
/// <param name="connection">The <see cref="IWebSocketConnection"/></param>
/// <param name="messageHandler">The callback that will be invoked for each new frame</param>
/// <returns>A <see cref="Task{WebSocketCloseResult}"/> that will complete when the client has sent a close frame, or the connection has been terminated</returns>
public static Task<WebSocketCloseResult> ExecuteAsync(this IWebSocketConnection connection, Action<WebSocketFrame> messageHandler) =>
connection.ExecuteAsync((frame, _) =>
{
messageHandler(frame);
return Task.CompletedTask;
}, null);
/// <summary>
/// Runs the WebSocket receive loop, using the provided message handler.
/// </summary>
/// <param name="connection">The <see cref="IWebSocketConnection"/></param>
/// <param name="messageHandler">The callback that will be invoked for each new frame</param>
/// <param name="state">The state to pass to the callback when the delegate is invoked. This may be null.</param>
/// <returns>A <see cref="Task{WebSocketCloseResult}"/> that will complete when the client has sent a close frame, or the connection has been terminated</returns>
public static Task<WebSocketCloseResult> ExecuteAsync(this IWebSocketConnection connection, Action<WebSocketFrame, object> messageHandler, object state) =>
connection.ExecuteAsync((frame, s) =>
{
messageHandler(frame, s);
return Task.CompletedTask;
}, state);
/// <summary>
/// Runs the WebSocket receive loop, using the provided message handler.
/// </summary>
/// <param name="connection">The <see cref="IWebSocketConnection"/></param>
/// <param name="messageHandler">The callback that will be invoked for each new frame</param>
/// <returns>A <see cref="Task{WebSocketCloseResult}"/> that will complete when the client has sent a close frame, or the connection has been terminated</returns>
public static Task<WebSocketCloseResult> ExecuteAsync(this IWebSocketConnection connection, Func<WebSocketFrame, Task> messageHandler) =>
connection.ExecuteAsync((frame, _) => messageHandler(frame), null);
}
}

View File

@ -1,54 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Binary;
using System.IO.Pipelines;
namespace Microsoft.Extensions.WebSockets.Internal
{
internal static class MaskingUtilities
{
// Plenty of optimization to be done here but not our immediate priority right now.
// Including: Vectorization, striding by uints (even when not vectorized; we'd probably flip the
// overload that does the implementation in that case and do it in the uint version).
public static void ApplyMask(ref ReadableBuffer payload, uint maskingKey)
{
unsafe
{
// Write the masking key as bytes to simplify access. Use a stackalloc buffer because it's fixed-size
var maskingKeyBytes = stackalloc byte[4];
var maskingKeySpan = new Span<byte>(maskingKeyBytes, 4);
maskingKeySpan.WriteBigEndian(maskingKey);
ApplyMask(ref payload, maskingKeySpan);
}
}
public static void ApplyMask(ref ReadableBuffer payload, Span<byte> maskingKey)
{
var offset = 0;
foreach (var mem in payload)
{
var span = mem.Span;
ApplyMask(span, maskingKey, ref offset);
}
}
public static void ApplyMask(Span<byte> payload, Span<byte> maskingKey)
{
var i = 0;
ApplyMask(payload, maskingKey, ref i);
}
private static void ApplyMask(Span<byte> payload, Span<byte> maskingKey, ref int maskingKeyOffset)
{
for (int i = 0; i < payload.Length; i++)
{
payload[i] = (byte)(payload[i] ^ maskingKey[maskingKeyOffset % 4]);
maskingKeyOffset++;
}
}
}
}

View File

@ -1,21 +0,0 @@
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="..\..\build\common.props" />
<PropertyGroup>
<Description>Low-allocation Push-oriented WebSockets based on Channels</Description>
<VersionPrefix>0.1.0</VersionPrefix>
<TargetFramework>netstandard2.0</TargetFramework>
<NoWarn>$(NoWarn);CS1591</NoWarn>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<PackageTags>aspnetcore;signalr</PackageTags>
<EnableApiCheck>false</EnableApiCheck>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="System.IO.Pipelines" Version="$(CoreFxLabsVersion)" />
<PackageReference Include="System.IO.Pipelines.Text.Primitives" Version="$(CoreFxLabsVersion)" />
</ItemGroup>
</Project>

View File

@ -1,45 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
using System.Threading.Tasks;
namespace Microsoft.Extensions.WebSockets.Internal
{
public static class PipeReaderExtensions
{
// TODO: Pull this up to Channels. We should be able to do it there without allocating a Task<T> in any case (rather than here where we can avoid allocation
// only if the buffer is already ready and has enough data)
public static async ValueTask<ReadResult> ReadAtLeastAsync(this IPipeReader input, int minimumRequiredBytes)
{
var awaiter = input.ReadAsync(/* cancellationToken */);
// Short-cut path!
ReadResult result;
if (awaiter.IsCompleted)
{
// We have a buffer, is it big enough?
result = awaiter.GetResult();
if (result.IsCompleted || result.Buffer.Length >= minimumRequiredBytes)
{
return result;
}
// Buffer wasn't big enough, mark it as examined and continue to the "slow" path below
input.Advance(
consumed: result.Buffer.Start,
examined: result.Buffer.End);
}
result = await awaiter;
while (!result.IsCancelled && !result.IsCompleted && result.Buffer.Length < minimumRequiredBytes)
{
input.Advance(
consumed: result.Buffer.Start,
examined: result.Buffer.End);
result = await input.ReadAsync(/* cancelToken */);
}
return result;
}
}
}

View File

@ -1,7 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Runtime.CompilerServices;
[assembly: InternalsVisibleTo("Microsoft.Extensions.WebSockets.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")]

View File

@ -1,143 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
namespace Microsoft.Extensions.WebSockets.Internal
{
/// <summary>
/// Stateful UTF-8 validator.
/// </summary>
public class Utf8Validator
{
// Table of UTF-8 code point widths. '0' indicates an invalid first byte.
// 0x80 - 0xBF are the continuation bytes and invalid as first byte.
// 0xC0 - 0xC1 are overlong encodings of ASCII characters
// 0xF5 - 0xFF encode numbers that are larger than the Unicode limit (0x10FFFF)
private static readonly byte[] _utf8Width = new byte[256]
{
/* 0x00 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x0F */
/* 0x10 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x1F */
/* 0x20 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x2F */
/* 0x30 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x3F */
/* 0x40 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x4F */
/* 0x50 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x5F */
/* 0x60 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x6F */
/* 0x70 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* 0x7F */
/* 0x80 */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 0x8F */
/* 0x90 */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 0x9F */
/* 0xA0 */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 0xAF */
/* 0xB0 */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 0xBF */
/* 0xC0 */ 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, /* 0xCF */
/* 0xD0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, /* 0xDF */
/* 0xE0 */ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, /* 0xEF */
/* 0xF0 */ 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* 0xFF */
};
// Table of masks used to extract the code point bits from the first byte. Indexed by (width - 1)
private static readonly byte[] _utf8Mask = new byte[4] { 0x7F, 0x1F, 0x0F, 0x07 };
// Table of minimum valid code-points based on the width. Indexed by (width - 1)
private static readonly int[] _utf8Min = new int[4] { 0x00000, 0x00080, 0x00800, 0x10000 };
private struct Utf8ValidatorState
{
public bool _withinSequence;
public int _remainingBytesInChar;
public int _currentDecodedValue;
public int _minCodePoint;
public void Reset()
{
_withinSequence = false;
_remainingBytesInChar = 0;
_currentDecodedValue = 0;
_minCodePoint = 0;
}
}
private Utf8ValidatorState _state;
public void Reset()
{
_state.Reset();
}
public bool ValidateUtf8Frame(ReadableBuffer payload, bool fin) => ValidateUtf8(ref _state, payload, fin);
public static bool ValidateUtf8(ReadableBuffer payload)
{
var state = new Utf8ValidatorState();
return ValidateUtf8(ref state, payload, fin: true);
}
private static bool ValidateUtf8(ref Utf8ValidatorState state, ReadableBuffer payload, bool fin)
{
// Walk through the payload verifying it
var offset = 0;
foreach (var mem in payload)
{
var span = mem.Span;
for (int i = 0; i < span.Length; i++)
{
var b = span[i];
if (!state._withinSequence)
{
// This is the first byte of a char, so set things up
var width = _utf8Width[b];
state._remainingBytesInChar = width - 1;
if (state._remainingBytesInChar < 0)
{
// Invalid first byte
return false;
}
// Use the width (-1) to index into the mask and min tables.
state._currentDecodedValue = b & _utf8Mask[width - 1];
state._minCodePoint = _utf8Min[width - 1];
state._withinSequence = true;
}
else
{
// Add this byte to the value
state._currentDecodedValue = (state._currentDecodedValue << 6) | (b & 0x3F);
state._remainingBytesInChar--;
}
// Fast invalid exits
if (state._remainingBytesInChar == 1 && state._currentDecodedValue >= 0x360 && state._currentDecodedValue <= 0x37F)
{
// This will be a UTF-16 surrogate: 0xD800-0xDFFF
return false;
}
if (state._remainingBytesInChar == 2 && state._currentDecodedValue >= 0x110)
{
// This will be above the maximum Unicode character (0x10FFFF).
return false;
}
if (state._remainingBytesInChar == 0)
{
// Check the range of the final decoded value
if (state._currentDecodedValue < state._minCodePoint)
{
// This encoding is longer than it should be, which is not allowed.
return false;
}
// Reset state
state._withinSequence = false;
}
offset++;
}
}
// We're done.
// The value is valid if:
// 1. We haven't reached the end of the whole message yet (we'll be caching this state for the next message)
// 2. We aren't inside a character sequence (i.e. the last character isn't unterminated)
return !fin || !state._withinSequence;
}
}
}

View File

@ -1,78 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Binary;
using System.IO.Pipelines;
using System.IO.Pipelines.Text.Primitives;
using System.Text;
using System.Text.Formatting;
namespace Microsoft.Extensions.WebSockets.Internal
{
/// <summary>
/// Represents the payload of a Close frame (i.e. a <see cref="WebSocketFrame"/> with an <see cref="WebSocketFrame.Opcode"/> of <see cref="WebSocketOpcode.Close"/>).
/// </summary>
public struct WebSocketCloseResult
{
internal static WebSocketCloseResult AbnormalClosure = new WebSocketCloseResult(WebSocketCloseStatus.AbnormalClosure, "Underlying transport connection was terminated");
internal static WebSocketCloseResult Empty = new WebSocketCloseResult(WebSocketCloseStatus.Empty);
/// <summary>
/// Gets the close status code specified in the frame.
/// </summary>
public WebSocketCloseStatus Status { get; }
/// <summary>
/// Gets the close status description specified in the frame.
/// </summary>
public string Description { get; }
public WebSocketCloseResult(WebSocketCloseStatus status) : this(status, string.Empty) { }
public WebSocketCloseResult(WebSocketCloseStatus status, string description)
{
Status = status;
Description = description;
}
public int GetSize() => Encoding.UTF8.GetByteCount(Description) + sizeof(ushort);
public static bool TryParse(ReadableBuffer payload, out WebSocketCloseResult result, out ushort? actualCloseCode)
{
if (payload.Length == 0)
{
// Empty payload is OK
actualCloseCode = null;
result = new WebSocketCloseResult(WebSocketCloseStatus.Empty, string.Empty);
return true;
}
else if (payload.Length < 2)
{
actualCloseCode = null;
result = default(WebSocketCloseResult);
return false;
}
else
{
var status = payload.ReadBigEndian<ushort>();
actualCloseCode = status;
var description = string.Empty;
payload = payload.Slice(2);
if (payload.Length > 0)
{
description = payload.GetUtf8String();
}
result = new WebSocketCloseResult((WebSocketCloseStatus)status, description);
return true;
}
}
public void WriteTo(ref WritableBuffer buffer)
{
buffer.WriteBigEndian((ushort)Status);
if (!string.IsNullOrEmpty(Description))
{
buffer.Append(Description, TextEncoder.Utf8);
}
}
}
}

View File

@ -1,77 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.Extensions.WebSockets.Internal
{
/// <summary>
/// Represents well-known WebSocket Close frame status codes.
/// </summary>
/// <remarks>
/// See https://tools.ietf.org/html/rfc6455#section-7.4 for details
/// </remarks>
public enum WebSocketCloseStatus : ushort
{
/// <summary>
/// Indicates that the purpose for the connection was fulfilled and thus the connection was closed normally.
/// </summary>
NormalClosure = 1000,
/// <summary>
/// Indicates that the other endpoint is going away, such as a server shutting down or a browser navigating to a new page.
/// </summary>
EndpointUnavailable = 1001,
/// <summary>
/// Indicates that a protocol error has occurred, causing the connection to be terminated.
/// </summary>
ProtocolError = 1002,
/// <summary>
/// Indicates an invalid message type was received. For example, if the end point only supports <see cref="WebSocketOpcode.Text"/> messages
/// but received a <see cref="WebSocketOpcode.Binary"/> message.
/// </summary>
InvalidMessageType = 1003,
/// <summary>
/// Indicates that the Close frame did not have a status code. Not used in actual transmission.
/// </summary>
Empty = 1005,
/// <summary>
/// Indicates that the underlying transport connection was terminated without a proper close handshake. Not used in actual transmission.
/// </summary>
AbnormalClosure = 1006,
/// <summary>
/// Indicates that an invalid payload was encountered. For example, a frame of type <see cref="WebSocketOpcode.Text"/> contained non-UTF-8 data.
/// </summary>
InvalidPayloadData = 1007,
/// <summary>
/// Indicates that the connection is being terminated due to a violation of policy. This is a generic error code used whenever a party needs to terminate
/// a connection without disclosing the specific reason.
/// </summary>
PolicyViolation = 1008,
/// <summary>
/// Indicates that the connection is being terminated due to an endpoint receiving a message that is too large.
/// </summary>
MessageTooBig = 1009,
/// <summary>
/// Indicates that the connection is being terminated due to being unable to negotiate a mandatory extension with the other party. Usually sent
/// from the client to the server after the client finishes handshaking without negotiating the extension.
/// </summary>
MandatoryExtension = 1010,
/// <summary>
/// Indicates that a server is terminating the connection due to an internal error.
/// </summary>
InternalServerError = 1011,
/// <summary>
/// Indicates that the connection failed to establish because the TLS handshake failed. Not used in actual transmission.
/// </summary>
TLSHandshakeFailed = 1015
}
}

View File

@ -1,770 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Binary;
using System.Diagnostics;
using System.Globalization;
using System.IO.Pipelines;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Extensions.WebSockets.Internal
{
/// <summary>
/// Provides the default implementation of <see cref="IWebSocketConnection"/>.
/// </summary>
/// <remarks>
/// <para>
/// This type is thread-safe, as long as only one thread ever calls <see cref="ExecuteAsync"/>. Multiple threads may call <see cref="SendAsync"/> simultaneously
/// and the sends will block until ongoing send operations complete.
/// </para>
/// <para>
/// The general pattern of having a single thread running <see cref="ExecuteAsync"/> and a separate thread running <see cref="SendAsync"/> will
/// be thread-safe, as each method interacts with completely separate state.
/// </para>
/// </remarks>
public class WebSocketConnection : IWebSocketConnection
{
private WebSocketOptions _options;
private readonly byte[] _maskingKeyBuffer;
private readonly IPipeReader _inbound;
private readonly IPipeWriter _outbound;
private readonly Timer _pinger;
private readonly CancellationTokenSource _timerCts = new CancellationTokenSource();
private Utf8Validator _validator = new Utf8Validator();
private WebSocketOpcode _currentMessageType = WebSocketOpcode.Continuation;
// Sends must be serialized between SendAsync, Pinger, and the Close frames sent when invalid messages are received.
private SemaphoreSlim _sendLock = new SemaphoreSlim(1, 1);
public string SubProtocol { get; }
public WebSocketConnectionState State { get; private set; } = WebSocketConnectionState.Created;
/// <summary>
/// Constructs a new, unmasked, <see cref="WebSocketConnection"/> from an <see cref="IPipeReader"/> and an <see cref="IPipeWriter"/> that represents an established WebSocket connection (i.e. after handshaking)
/// </summary>
/// <param name="inbound">A <see cref="IPipeReader"/> from which frames will be read when receiving.</param>
/// <param name="outbound">A <see cref="IPipeWriter"/> to which frame will be written when sending.</param>
public WebSocketConnection(IPipeReader inbound, IPipeWriter outbound) : this(inbound, outbound, options: WebSocketOptions.DefaultUnmasked) { }
/// <summary>
/// Constructs a new, unmasked, <see cref="WebSocketConnection"/> from an <see cref="IPipeReader"/> and an <see cref="IPipeWriter"/> that represents an established WebSocket connection (i.e. after handshaking)
/// </summary>
/// <param name="inbound">A <see cref="IPipeReader"/> from which frames will be read when receiving.</param>
/// <param name="outbound">A <see cref="IPipeWriter"/> to which frame will be written when sending.</param>
/// <param name="subProtocol">The sub-protocol provided during handshaking</param>
public WebSocketConnection(IPipeReader inbound, IPipeWriter outbound, string subProtocol) : this(inbound, outbound, subProtocol, options: WebSocketOptions.DefaultUnmasked) { }
/// <summary>
/// Constructs a new, <see cref="WebSocketConnection"/> from an <see cref="IPipeReader"/> and an <see cref="IPipeWriter"/> that represents an established WebSocket connection (i.e. after handshaking)
/// </summary>
/// <param name="inbound">A <see cref="IPipeReader"/> from which frames will be read when receiving.</param>
/// <param name="outbound">A <see cref="IPipeWriter"/> to which frame will be written when sending.</param>
/// <param name="options">A <see cref="WebSocketOptions"/> which provides the configuration options for the socket.</param>
public WebSocketConnection(IPipeReader inbound, IPipeWriter outbound, WebSocketOptions options) : this(inbound, outbound, subProtocol: string.Empty, options: options) { }
/// <summary>
/// Constructs a new <see cref="WebSocketConnection"/> from an <see cref="IPipeReader"/> and an <see cref="IPipeWriter"/> that represents an established WebSocket connection (i.e. after handshaking)
/// </summary>
/// <param name="inbound">A <see cref="IPipeReader"/> from which frames will be read when receiving.</param>
/// <param name="outbound">A <see cref="IPipeWriter"/> to which frame will be written when sending.</param>
/// <param name="subProtocol">The sub-protocol provided during handshaking</param>
/// <param name="options">A <see cref="WebSocketOptions"/> which provides the configuration options for the socket.</param>
public WebSocketConnection(IPipeReader inbound, IPipeWriter outbound, string subProtocol, WebSocketOptions options)
{
_inbound = inbound;
_outbound = outbound;
_options = options;
SubProtocol = subProtocol;
if (_options.FixedMaskingKey != null)
{
// Use the fixed key directly as the buffer.
_maskingKeyBuffer = _options.FixedMaskingKey;
// Clear the MaskingKeyGenerator just to ensure that nobody set it.
_options.MaskingKeyGenerator = null;
}
else if (_options.MaskingKeyGenerator != null)
{
// Establish a buffer for the random generator to use
_maskingKeyBuffer = new byte[4];
}
if (_options.PingInterval > TimeSpan.Zero)
{
var pingIntervalMillis = (int)_options.PingInterval.TotalMilliseconds;
// Set up the pinger
_pinger = new Timer(Pinger, this, pingIntervalMillis, pingIntervalMillis);
}
}
private static void Pinger(object state)
{
var connection = (WebSocketConnection)state;
// If we are cancelled, don't send the ping
// Also, if we can't immediately acquire the send lock, we're already sending something, so we don't need the ping.
if (!connection._timerCts.Token.IsCancellationRequested && connection._sendLock.Wait(0))
{
// We don't need to wait for this task to complete, we're "tail calling" and
// we are in a Timer thread-pool thread.
var ignore = connection.SendCoreLockAcquiredAsync(
fin: true,
opcode: WebSocketOpcode.Ping,
payloadAllocLength: 28,
payloadLength: 28,
payloadWriter: PingPayloadWriter,
payload: DateTime.UtcNow,
cancellationToken: connection._timerCts.Token);
}
}
public void Dispose()
{
State = WebSocketConnectionState.Closed;
_pinger?.Dispose();
_timerCts.Cancel();
_inbound.Complete();
_outbound.Complete();
}
public Task<WebSocketCloseResult> ExecuteAsync(Func<WebSocketFrame, object, Task> messageHandler, object state)
{
if (State == WebSocketConnectionState.Closed)
{
throw new ObjectDisposedException(nameof(WebSocketConnection));
}
if (State != WebSocketConnectionState.Created)
{
throw new InvalidOperationException("Connection is already running.");
}
State = WebSocketConnectionState.Connected;
return ReceiveLoop(messageHandler, state);
}
/// <summary>
/// Sends the specified frame.
/// </summary>
/// <param name="frame">The frame to send.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> that indicates when/if the send is cancelled.</param>
/// <returns>A <see cref="Task"/> that completes when the message has been written to the outbound stream.</returns>
// TODO: De-taskify this to allow consumers to create their own awaiter.
public Task SendAsync(WebSocketFrame frame, CancellationToken cancellationToken)
{
if (State == WebSocketConnectionState.Closed)
{
throw new ObjectDisposedException(nameof(WebSocketConnection));
}
// This clause is a bit of an artificial restriction to ensure people run "Execute". Maybe we don't care?
else if (State == WebSocketConnectionState.Created)
{
throw new InvalidOperationException($"Cannot send until the connection is started using {nameof(ExecuteAsync)}");
}
else if (State == WebSocketConnectionState.CloseSent)
{
throw new InvalidOperationException("Cannot send after sending a Close frame");
}
if (frame.Opcode == WebSocketOpcode.Close)
{
throw new InvalidOperationException($"Cannot use {nameof(SendAsync)} to send a Close frame, use {nameof(CloseAsync)} instead.");
}
return SendCoreAsync(
fin: frame.EndOfMessage,
opcode: frame.Opcode,
payloadAllocLength: 0, // We don't copy the payload, we append it, so we don't need any alloc for the payload
payloadLength: frame.Payload.Length,
payloadWriter: AppendPayloadWriter,
payload: frame.Payload,
cancellationToken: cancellationToken);
}
/// <summary>
/// Sends a Close frame to the other party. This does not guarantee that the client will send a responding close frame.
/// </summary>
/// <remarks>
/// If the other party does not respond with a close frame, the connection will remain open and the <see cref="Task{WebSocketCloseResult}"/>
/// will remain active. Call the <see cref="IDisposable.Dispose"/> method on this instance to forcibly terminate the connection.
/// </remarks>
/// <param name="result">A <see cref="WebSocketCloseResult"/> with the payload for the close frame</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> that indicates when/if the send is cancelled.</param>
/// <returns>A <see cref="Task"/> that completes when the close frame has been sent</returns>
public async Task CloseAsync(WebSocketCloseResult result, CancellationToken cancellationToken)
{
if (State == WebSocketConnectionState.Closed)
{
// Already closed
return;
}
else if (State == WebSocketConnectionState.Created)
{
throw new InvalidOperationException("Cannot send close frame when the connection hasn't been started");
}
else if (State == WebSocketConnectionState.CloseSent)
{
throw new InvalidOperationException("Cannot send multiple close frames");
}
var payloadSize = result.GetSize();
await SendCoreAsync(
fin: true,
opcode: WebSocketOpcode.Close,
payloadAllocLength: payloadSize,
payloadLength: payloadSize,
payloadWriter: CloseResultPayloadWriter,
payload: result,
cancellationToken: cancellationToken);
_timerCts.Cancel();
_pinger?.Dispose();
if (State == WebSocketConnectionState.CloseReceived)
{
State = WebSocketConnectionState.Closed;
}
else
{
State = WebSocketConnectionState.CloseSent;
}
}
private void WriteMaskingKey(Span<byte> buffer)
{
if (_options.MaskingKeyGenerator != null)
{
// Get a new random mask
// Until https://github.com/dotnet/corefx/issues/12323 is fixed we need to use this shared buffer and copy model
// Once we have that fix we should be able to generate the mask directly into the output buffer.
_options.MaskingKeyGenerator.GetBytes(_maskingKeyBuffer);
}
_maskingKeyBuffer.CopyTo(buffer);
}
/// <summary>
/// Terminates the socket abruptly.
/// </summary>
public void Abort()
{
// We duplicate some work from Dispose here, but that's OK.
_timerCts.Cancel();
_inbound.CancelPendingRead();
_outbound.Complete();
}
private async ValueTask<(bool Success, byte OpcodeByte, bool Masked, bool Fin, int Length, uint MaskingKey)> ReadHeaderAsync()
{
// Read at least 2 bytes
var readResult = await _inbound.ReadAtLeastAsync(2);
if (readResult.IsCancelled || (readResult.IsCompleted && readResult.Buffer.Length < 2))
{
_inbound.Advance(readResult.Buffer.End);
return (Success: false, OpcodeByte: 0, Masked: false, Fin: false, Length: 0, MaskingKey: 0);
}
var buffer = readResult.Buffer;
// Read the opcode and length
var opcodeByte = buffer.ReadBigEndian<byte>();
buffer = buffer.Slice(1);
// Read the first byte of the payload length
var lengthByte = buffer.ReadBigEndian<byte>();
buffer = buffer.Slice(1);
_inbound.Advance(buffer.Start);
// Determine how much header there still is to read
var fin = (opcodeByte & 0x80) != 0;
var masked = (lengthByte & 0x80) != 0;
var length = lengthByte & 0x7F;
// Calculate the rest of the header length
var headerLength = masked ? 4 : 0;
if (length == 126)
{
headerLength += 2;
}
else if (length == 127)
{
headerLength += 8;
}
// Read the next set of header data
uint maskingKey = 0;
if (headerLength > 0)
{
readResult = await _inbound.ReadAtLeastAsync(headerLength);
if (readResult.IsCancelled || (readResult.IsCompleted && readResult.Buffer.Length < headerLength))
{
_inbound.Advance(readResult.Buffer.End);
return (Success: false, OpcodeByte: 0, Masked: false, Fin: false, Length: 0, MaskingKey: 0);
}
buffer = readResult.Buffer;
// Read extended payload length (if any)
if (length == 126)
{
length = buffer.ReadBigEndian<ushort>();
buffer = buffer.Slice(sizeof(ushort));
}
else if (length == 127)
{
var longLen = buffer.ReadBigEndian<ulong>();
buffer = buffer.Slice(sizeof(ulong));
if (longLen > int.MaxValue)
{
throw new WebSocketException($"Frame is too large. Maximum frame size is {int.MaxValue} bytes");
}
length = (int)longLen;
}
// Read masking key
if (masked)
{
var maskingKeyStart = buffer.Start;
maskingKey = buffer.Slice(0, sizeof(uint)).ReadBigEndian<uint>();
buffer = buffer.Slice(sizeof(uint));
}
// Mark the length and masking key consumed
_inbound.Advance(buffer.Start);
}
return (Success: true, opcodeByte, masked, fin, length, maskingKey);
}
private async ValueTask<(bool Success, ReadableBuffer Buffer)> ReadPayloadAsync(int length, bool masked, uint maskingKey)
{
var payload = default(ReadableBuffer);
if (length > 0)
{
var readResult = await _inbound.ReadAtLeastAsync(length);
if (readResult.IsCancelled || (readResult.IsCompleted && readResult.Buffer.Length < length))
{
return (Success: false, Buffer: readResult.Buffer);
}
var buffer = readResult.Buffer;
payload = buffer.Slice(0, length);
if (masked)
{
// Unmask
MaskingUtilities.ApplyMask(ref payload, maskingKey);
}
}
return (Success: true, Buffer: payload);
}
private async Task<WebSocketCloseResult> ReceiveLoop(Func<WebSocketFrame, object, Task> messageHandler, object state)
{
try
{
while (true)
{
// WebSocket Frame layout (https://tools.ietf.org/html/rfc6455#section-5.2):
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-------+-+-------------+-------------------------------+
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
// |I|S|S|S| (4) |A| (7) | (16/64) |
// |N|V|V|V| |S| | (if payload len==126/127) |
// | |1|2|3| |K| | |
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
// | Extended payload length continued, if payload len == 127 |
// + - - - - - - - - - - - - - - - +-------------------------------+
// | |Masking-key, if MASK set to 1 |
// +-------------------------------+-------------------------------+
// | Masking-key (continued) | Payload Data |
// +-------------------------------- - - - - - - - - - - - - - - - +
// : Payload Data continued ... :
// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
// | Payload Data continued ... |
// +---------------------------------------------------------------+
var header = await ReadHeaderAsync();
if (!header.Success)
{
break;
}
// Validate Opcode
var opcodeNum = header.OpcodeByte & 0x0F;
if ((header.OpcodeByte & 0x70) != 0)
{
// Reserved bits set, this frame is invalid, close our side and terminate immediately
await CloseFromProtocolError("Reserved bits, which are required to be zero, were set.");
break;
}
else if ((opcodeNum >= 0x03 && opcodeNum <= 0x07) || (opcodeNum >= 0x0B && opcodeNum <= 0x0F))
{
// Reserved opcode
await CloseFromProtocolError($"Received frame using reserved opcode: 0x{opcodeNum:X}");
break;
}
var opcode = (WebSocketOpcode)opcodeNum;
var payload = await ReadPayloadAsync(header.Length, header.Masked, header.MaskingKey);
if (!payload.Success)
{
_inbound.Advance(payload.Buffer.End);
break;
}
var frame = new WebSocketFrame(header.Fin, opcode, payload.Buffer);
// Start a try-finally because we may get an exception while closing, if there's an error
// And we need to advance the buffer even if that happens. It wasn't needed above because
// we had already parsed the buffer before we verified it, so we had already advanced the
// buffer, if we encountered an error while closing we didn't have to advance the buffer.
// Side Note: Look at this gloriously aligned comment. You have anurse and brecon to thank for it. Oh wait, I ruined it.
try
{
if (frame.Opcode.IsControl() && !frame.EndOfMessage)
{
// Control frames cannot be fragmented.
await CloseFromProtocolError("Control frames may not be fragmented");
break;
}
else if (_currentMessageType != WebSocketOpcode.Continuation && opcode.IsMessage() && opcode != 0)
{
await CloseFromProtocolError("Received non-continuation frame during a fragmented message");
break;
}
else if (_currentMessageType == WebSocketOpcode.Continuation && frame.Opcode == WebSocketOpcode.Continuation)
{
await CloseFromProtocolError("Continuation Frame was received when expecting a new message");
break;
}
if (frame.Opcode == WebSocketOpcode.Close)
{
return await ProcessCloseFrameAsync(frame);
}
else
{
if (frame.Opcode == WebSocketOpcode.Ping)
{
// Check the ping payload length
if (frame.Payload.Length > 125)
{
// Payload too long
await CloseFromProtocolError("Ping frame exceeded maximum size of 125 bytes");
break;
}
await SendCoreAsync(
frame.EndOfMessage,
WebSocketOpcode.Pong,
payloadAllocLength: 0,
payloadLength: frame.Payload.Length,
payloadWriter: AppendPayloadWriter,
payload: frame.Payload,
cancellationToken: CancellationToken.None);
}
var effectiveOpcode = opcode == WebSocketOpcode.Continuation ? _currentMessageType : opcode;
if (effectiveOpcode == WebSocketOpcode.Text && !_validator.ValidateUtf8Frame(frame.Payload, frame.EndOfMessage))
{
// Drop the frame and immediately close with InvalidPayload
await CloseFromProtocolError("An invalid Text frame payload was received", statusCode: WebSocketCloseStatus.InvalidPayloadData);
break;
}
else if (_options.PassAllFramesThrough || (frame.Opcode != WebSocketOpcode.Ping && frame.Opcode != WebSocketOpcode.Pong))
{
await messageHandler(frame, state);
}
}
}
finally
{
if (frame.Payload.Length > 0)
{
_inbound.Advance(frame.Payload.End);
}
}
if (header.Fin)
{
// Reset the UTF8 validator
_validator.Reset();
// If it's a non-control frame, reset the message type tracker
if (opcode.IsMessage())
{
_currentMessageType = WebSocketOpcode.Continuation;
}
}
// If there isn't a current message type, and this was a fragmented message frame, set the current message type
else if (!header.Fin && _currentMessageType == WebSocketOpcode.Continuation && opcode.IsMessage())
{
_currentMessageType = opcode;
}
}
}
catch
{
// Abort the socket and rethrow
Abort();
throw;
}
return WebSocketCloseResult.AbnormalClosure;
}
private async ValueTask<WebSocketCloseResult> ProcessCloseFrameAsync(WebSocketFrame frame)
{
// Allowed frame lengths:
// 0 - No body
// 2 - Code with no reason phrase
// >2 - Code and reason phrase (must be valid UTF-8)
if (frame.Payload.Length > 125)
{
await CloseFromProtocolError("Close frame payload too long. Maximum size is 125 bytes");
return WebSocketCloseResult.AbnormalClosure;
}
else if ((frame.Payload.Length == 1) || (frame.Payload.Length > 2 && !Utf8Validator.ValidateUtf8(frame.Payload.Slice(2))))
{
await CloseFromProtocolError("Close frame payload invalid");
return WebSocketCloseResult.AbnormalClosure;
}
ushort? actualStatusCode;
var closeResult = ParseCloseFrame(frame.Payload, frame, out actualStatusCode);
// Verify the close result
if (actualStatusCode != null)
{
var statusCode = actualStatusCode.Value;
if (statusCode < 1000 || statusCode == 1004 || statusCode == 1005 || statusCode == 1006 || (statusCode > 1011 && statusCode < 3000))
{
await CloseFromProtocolError($"Invalid close status: {statusCode}.");
return WebSocketCloseResult.AbnormalClosure;
}
}
return closeResult;
}
private async Task CloseFromProtocolError(string reason, WebSocketCloseStatus statusCode = WebSocketCloseStatus.ProtocolError)
{
var closeResult = new WebSocketCloseResult(
statusCode,
reason);
await CloseAsync(closeResult, CancellationToken.None);
// We can now terminate our connection, according to the spec.
Abort();
}
private WebSocketCloseResult ParseCloseFrame(ReadableBuffer payload, WebSocketFrame frame, out ushort? actualStatusCode)
{
// Update state
if (State == WebSocketConnectionState.CloseSent)
{
State = WebSocketConnectionState.Closed;
}
else
{
State = WebSocketConnectionState.CloseReceived;
}
// Process the close frame
WebSocketCloseResult closeResult;
if (!WebSocketCloseResult.TryParse(frame.Payload, out closeResult, out actualStatusCode))
{
closeResult = WebSocketCloseResult.Empty;
}
return closeResult;
}
private static unsafe void PingPayloadWriter(WritableBuffer output, uint maskingKeyValue, int payloadLength, DateTime timestamp)
{
var maskingKey = new Span<byte>(&maskingKeyValue, sizeof(uint));
var payload = output.Buffer.Slice(0, payloadLength);
// TODO: Don't put this string on the heap? Is there a way to do that without re-implementing ToString?
// Ideally we'd like to render the string directly to the output buffer.
var str = timestamp.ToString("O", CultureInfo.InvariantCulture);
ArraySegment<byte> buffer;
if (payload.TryGetArray(out buffer))
{
// Fast path - Write the encoded bytes directly out.
Encoding.UTF8.GetBytes(str, 0, str.Length, buffer.Array, buffer.Offset);
}
else
{
// TODO: Could use TryGetPointer, GetBytes does take a byte*, but it seems like just waiting until we have a version that uses Span is best.
// Slow path - Allocate a heap buffer for the encoded bytes before writing them out.
Encoding.UTF8.GetBytes(str).CopyTo(payload.Span);
}
if (maskingKey.Length > 0)
{
MaskingUtilities.ApplyMask(payload.Span, maskingKey);
}
output.Advance(payloadLength);
}
private static unsafe void CloseResultPayloadWriter(WritableBuffer output, uint maskingKeyValue, int payloadLength, WebSocketCloseResult result)
{
var maskingKey = new Span<byte>(&maskingKeyValue, sizeof(uint));
// Write the close payload out
var payload = output.Buffer.Slice(0, payloadLength).Span;
result.WriteTo(ref output);
if (maskingKey.Length > 0)
{
MaskingUtilities.ApplyMask(payload, maskingKey);
}
}
private static unsafe void AppendPayloadWriter(WritableBuffer output, uint maskingKeyValue, int payloadLength, ReadableBuffer payload)
{
var maskingKey = new Span<byte>(&maskingKeyValue, sizeof(uint));
if (maskingKey.Length > 0)
{
// Mask the payload in it's own buffer
MaskingUtilities.ApplyMask(ref payload, maskingKey);
}
output.Append(payload);
}
private Task SendCoreAsync<T>(bool fin, WebSocketOpcode opcode, int payloadAllocLength, int payloadLength, Action<WritableBuffer, uint, int, T> payloadWriter, T payload, CancellationToken cancellationToken)
{
if (_sendLock.Wait(0))
{
return SendCoreLockAcquiredAsync(fin, opcode, payloadAllocLength, payloadLength, payloadWriter, payload, cancellationToken);
}
else
{
return SendCoreWaitForLockAsync(fin, opcode, payloadAllocLength, payloadLength, payloadWriter, payload, cancellationToken);
}
}
private async Task SendCoreWaitForLockAsync<T>(bool fin, WebSocketOpcode opcode, int payloadAllocLength, int payloadLength, Action<WritableBuffer, uint, int, T> payloadWriter, T payload, CancellationToken cancellationToken)
{
await _sendLock.WaitAsync(cancellationToken);
await SendCoreLockAcquiredAsync(fin, opcode, payloadAllocLength, payloadLength, payloadWriter, payload, cancellationToken);
}
private async Task SendCoreLockAcquiredAsync<T>(bool fin, WebSocketOpcode opcode, int payloadAllocLength, int payloadLength, Action<WritableBuffer, uint, int, T> payloadWriter, T payload, CancellationToken cancellationToken)
{
try
{
// Ensure the lock is held
Debug.Assert(_sendLock.CurrentCount == 0);
// Base header size is 2 bytes.
WritableBuffer buffer;
var allocSize = CalculateAllocSize(payloadAllocLength, payloadLength);
// Allocate a buffer
buffer = _outbound.Alloc(minimumSize: allocSize);
Debug.Assert(buffer.Buffer.Length >= allocSize);
// Write the opcode and FIN flag
var opcodeByte = (byte)opcode;
if (fin)
{
opcodeByte |= 0x80;
}
buffer.WriteBigEndian(opcodeByte);
// Write the length and mask flag
WritePayloadLength(payloadLength, buffer);
WritePayload(ref buffer, payloadLength, payloadWriter, payload);
// Flush.
await buffer.FlushAsync();
}
finally
{
// Unlock.
_sendLock.Release();
}
}
private void WritePayload<T>(ref WritableBuffer buffer, int payloadLength, Action<WritableBuffer, uint, int, T> payloadWriter, T payload)
{
var maskingKey = Span<byte>.Empty;
var keySize = sizeof(uint);
if (_maskingKeyBuffer != null)
{
// Get a span of the output buffer for the masking key, write it there, then advance the write head.
maskingKey = buffer.Buffer.Slice(0, keySize).Span;
WriteMaskingKey(maskingKey);
buffer.Advance(keySize);
// Write the payload
payloadWriter(buffer, maskingKey.Read<uint>(), payloadLength, payload);
}
else
{
// Write the payload un-masked
payloadWriter(buffer, 0, payloadLength, payload);
}
}
private int CalculateAllocSize(int payloadAllocLength, int payloadLength)
{
var allocSize = 2;
if (payloadLength > ushort.MaxValue)
{
// We're going to need an 8-byte length
allocSize += 8;
}
else if (payloadLength > 125)
{
// We're going to need a 2-byte length
allocSize += 2;
}
if (_maskingKeyBuffer != null)
{
// We need space for the masking key
allocSize += 4;
}
// We may need space for the payload too
return allocSize + payloadAllocLength;
}
private void WritePayloadLength(int payloadLength, WritableBuffer buffer)
{
var maskingByte = _maskingKeyBuffer != null ? 0x80 : 0x00;
if (payloadLength > ushort.MaxValue)
{
buffer.WriteBigEndian((byte)(0x7F | maskingByte));
// 8-byte length
buffer.WriteBigEndian((ulong)payloadLength);
}
else if (payloadLength > 125)
{
buffer.WriteBigEndian((byte)(0x7E | maskingByte));
// 2-byte length
buffer.WriteBigEndian((ushort)payloadLength);
}
else
{
// 1-byte length
buffer.WriteBigEndian((byte)(payloadLength | maskingByte));
}
}
}
}

View File

@ -1,14 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.Extensions.WebSockets.Internal
{
public enum WebSocketConnectionState
{
Created,
Connected,
CloseSent,
CloseReceived,
Closed
}
}

View File

@ -1,22 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.Extensions.WebSockets.Internal
{
public class WebSocketException : Exception
{
public WebSocketException()
{
}
public WebSocketException(string message) : base(message)
{
}
public WebSocketException(string message, Exception innerException) : base(message, innerException)
{
}
}
}

View File

@ -1,48 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
namespace Microsoft.Extensions.WebSockets.Internal
{
/// <summary>
/// Represents a single Frame received or sent on a <see cref="IWebSocketConnection"/>.
/// </summary>
public struct WebSocketFrame
{
/// <summary>
/// Indicates if the "FIN" flag is set on this frame, which indicates it is the final frame of a message.
/// </summary>
public bool EndOfMessage { get; }
/// <summary>
/// Gets the <see cref="WebSocketOpcode"/> value describing the opcode of the WebSocket frame.
/// </summary>
public WebSocketOpcode Opcode { get; }
/// <summary>
/// Gets the payload of the WebSocket frame.
/// </summary>
public ReadableBuffer Payload { get; }
public WebSocketFrame(bool endOfMessage, WebSocketOpcode opcode, ReadableBuffer payload)
{
EndOfMessage = endOfMessage;
Opcode = opcode;
Payload = payload;
}
/// <summary>
/// Creates a new <see cref="WebSocketFrame"/> containing the same information, but with all buffers
/// copied to new heap memory.
/// </summary>
/// <returns></returns>
public WebSocketFrame Copy()
{
return new WebSocketFrame(
endOfMessage: EndOfMessage,
opcode: Opcode,
payload: ReadableBuffer.Create(Payload.ToArray()));
}
}
}

View File

@ -1,62 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Runtime.CompilerServices;
namespace Microsoft.Extensions.WebSockets.Internal
{
/// <summary>
/// Represents the possible values for the "opcode" field of a WebSocket frame.
/// </summary>
public enum WebSocketOpcode
{
/// <summary>
/// Indicates that the frame is a continuation of the previous <see cref="Text"/> or <see cref="Binary"/> frame.
/// </summary>
Continuation = 0x0,
/// <summary>
/// Indicates that the frame is the first frame of a new Text message, formatted in UTF-8.
/// </summary>
Text = 0x1,
/// <summary>
/// Indicates that the frame is the first frame of a new Binary message.
/// </summary>
Binary = 0x2,
/* 0x3 - 0x7 are reserved */
/// <summary>
/// Indicates that the frame is a notification that the sender is closing their end of the connection
/// </summary>
Close = 0x8,
/// <summary>
/// Indicates a request from the sender to receive a <see cref="Pong"/>, in order to maintain the connection.
/// </summary>
Ping = 0x9,
/// <summary>
/// Indicates a response to a <see cref="Ping"/>, in order to maintain the connection.
/// </summary>
Pong = 0xA,
/* 0xB-0xF are reserved */
/* all opcodes above 0xF are invalid */
}
public static class WebSocketOpcodeExtensions
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool IsControl(this WebSocketOpcode opcode)
{
return opcode >= WebSocketOpcode.Close;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool IsMessage(this WebSocketOpcode opcode)
{
return opcode < WebSocketOpcode.Close;
}
}
}

View File

@ -1,140 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Security.Cryptography;
namespace Microsoft.Extensions.WebSockets.Internal
{
public class WebSocketOptions
{
/// <summary>
/// Gets the default ping interval of 30 seconds.
/// </summary>
public static TimeSpan DefaultPingInterval = TimeSpan.FromSeconds(30);
/// <summary>
/// Gets the default <see cref="WebSocketOptions"/> for an unmasked sender.
/// </summary>
/// <remarks>
/// Uses the default ping interval defined in <see cref="DefaultPingInterval"/>, no masking key,
/// and automatically responds to pings.
/// </remarks>
public static readonly WebSocketOptions DefaultUnmasked = new WebSocketOptions()
{
PingInterval = DefaultPingInterval,
MaskingKeyGenerator = null,
FixedMaskingKey = null
};
/// <summary>
/// Gets the default <see cref="WebSocketOptions"/> for an unmasked sender.
/// </summary>
/// <remarks>
/// Uses the default ping interval defined in <see cref="DefaultPingInterval"/>, the system random
/// key generator, and automatically responds to pings.
/// </remarks>
public static readonly WebSocketOptions DefaultMasked = new WebSocketOptions()
{
PingInterval = DefaultPingInterval,
MaskingKeyGenerator = RandomNumberGenerator.Create(),
FixedMaskingKey = null
};
/// <summary>
/// Gets or sets a boolean indicating if all frames, even those automatically handled (<see cref="WebSocketOpcode.Ping"/> and <see cref="WebSocketOpcode.Pong"/> frames),
/// should be passed to the <see cref="WebSocketConnection.ExecuteAsync"/> callback. NOTE: The frames will STILL be automatically handled, they are
/// only passed along for diagnostic purposes.
/// </summary>
public bool PassAllFramesThrough { get; private set; }
/// <summary>
/// Gets or sets the time between pings sent from the local endpoint
/// </summary>
public TimeSpan PingInterval { get; private set; }
/// <summary>
/// Gets or sets the <see cref="RandomNumberGenerator"/> used to generate masking keys used to mask outgoing frames.
/// If <see cref="FixedMaskingKey"/> is set, this value is ignored. If neither this value nor
/// <see cref="FixedMaskingKey"/> is set, no masking will be performed.
/// </summary>
public RandomNumberGenerator MaskingKeyGenerator { get; internal set; }
/// <summary>
/// Gets or sets a fixed masking key used to mask outgoing frames. If this value is set, <see cref="MaskingKeyGenerator"/>
/// is ignored. If neither this value nor <see cref="MaskingKeyGenerator"/> is set, no masking will be performed.
/// </summary>
public byte[] FixedMaskingKey { get; private set; }
/// <summary>
/// Sets the ping interval for this <see cref="WebSocketOptions"/>.
/// </summary>
/// <param name="pingInterval">The interval at which ping frames will be sent</param>
/// <returns>A new <see cref="WebSocketOptions"/> with the specified ping interval</returns>
public WebSocketOptions WithPingInterval(TimeSpan pingInterval)
{
return new WebSocketOptions()
{
PingInterval = pingInterval,
FixedMaskingKey = FixedMaskingKey,
MaskingKeyGenerator = MaskingKeyGenerator
};
}
/// <summary>
/// Enables frame pass-through in this <see cref="WebSocketOptions"/>. Generally for diagnostic or testing purposes only.
/// </summary>
/// <returns>A new <see cref="WebSocketOptions"/> with <see cref="PassAllFramesThrough"/> set to true</returns>
public WebSocketOptions WithAllFramesPassedThrough()
{
return new WebSocketOptions()
{
PassAllFramesThrough = true,
PingInterval = PingInterval,
FixedMaskingKey = FixedMaskingKey,
MaskingKeyGenerator = MaskingKeyGenerator
};
}
/// <summary>
/// Enables random masking in this <see cref="WebSocketOptions"/>, using the system random number generator.
/// </summary>
/// <returns>A new <see cref="WebSocketOptions"/> with random masking enabled</returns>
public WebSocketOptions WithRandomMasking() => WithRandomMasking(RandomNumberGenerator.Create());
/// <summary>
/// Enables random masking in this <see cref="WebSocketOptions"/>, using the provided random number generator.
/// </summary>
/// <param name="rng">The <see cref="RandomNumberGenerator"/> to use to generate masking keys</param>
/// <returns>A new <see cref="WebSocketOptions"/> with random masking enabled</returns>
public WebSocketOptions WithRandomMasking(RandomNumberGenerator rng)
{
return new WebSocketOptions()
{
PingInterval = PingInterval,
FixedMaskingKey = null,
MaskingKeyGenerator = rng
};
}
/// <summary>
/// Enables fixed masking in this <see cref="WebSocketOptions"/>. FOR DEVELOPMENT PURPOSES ONLY.
/// </summary>
/// <param name="maskingKey">The masking key to use for all outgoing frames.</param>
/// <returns>A new <see cref="WebSocketOptions"/> with fixed masking enabled</returns>
public WebSocketOptions WithFixedMaskingKey(byte[] maskingKey)
{
if (maskingKey.Length != 4)
{
throw new ArgumentException("Masking Key must be exactly 4 bytes", nameof(maskingKey));
}
return new WebSocketOptions()
{
PingInterval = PingInterval,
FixedMaskingKey = maskingKey,
MaskingKeyGenerator = null
};
}
}
}

View File

@ -4,20 +4,20 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Net.WebSockets;
using System.Security.Claims;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Http.Internal;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.AspNetCore.WebSockets.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
using Microsoft.Extensions.WebSockets.Internal;
using Xunit;
namespace Microsoft.AspNetCore.Sockets.Tests
@ -368,17 +368,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var webSocketTask = Task.CompletedTask;
var ws = (TestWebSocketConnectionFeature)context1.Features.Get<IHttpWebSocketConnectionFeature>();
var ws = (TestWebSocketConnectionFeature)context1.Features.Get<IHttpWebSocketFeature>();
if (ws != null)
{
webSocketTask = ws.Client.ExecuteAsync(frame => Task.CompletedTask);
await ws.Client.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure), CancellationToken.None);
await ws.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
}
manager.CloseConnections();
await webSocketTask.OrTimeout();
await request1.OrTimeout();
}
@ -975,7 +972,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
switch (transportType)
{
case TransportType.WebSockets:
context.Features.Set<IHttpWebSocketConnectionFeature>(new TestWebSocketConnectionFeature());
context.Features.Set<IHttpWebSocketFeature>(new TestWebSocketConnectionFeature());
break;
case TransportType.ServerSentEvents:
context.Request.Headers["Accept"] = "text/event-stream";

View File

@ -9,7 +9,6 @@
<ItemGroup>
<Compile Include="..\Common\TaskExtensions.cs" Link="TaskExtensions.cs" />
<Compile Include="..\Microsoft.Extensions.WebSockets.Internal.Tests\WebSocketConnectionExtensions.cs;..\Microsoft.Extensions.WebSockets.Internal.Tests\WebSocketConnectionSummary.cs;..\Microsoft.Extensions.WebSockets.Internal.Tests\WebSocketPair.cs" />
</ItemGroup>
<ItemGroup>

View File

@ -1,36 +1,177 @@
using System;
using System.Buffers;
using System.IO.Pipelines;
using System.Collections.Generic;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.WebSockets.Internal;
using Microsoft.Extensions.WebSockets.Internal;
using Microsoft.AspNetCore.Http.Features;
namespace Microsoft.AspNetCore.Sockets.Tests
{
internal class TestWebSocketConnectionFeature : IHttpWebSocketConnectionFeature, IDisposable
internal class TestWebSocketConnectionFeature : IHttpWebSocketFeature, IDisposable
{
private PipeFactory _factory = new PipeFactory(BufferPool.Default);
public bool IsWebSocketRequest => true;
public WebSocketConnection Client { get; private set; }
public WebSocketChannel Client { get; private set; }
public ValueTask<IWebSocketConnection> AcceptWebSocketConnectionAsync(WebSocketAcceptContext context)
public Task<WebSocket> AcceptAsync() => AcceptAsync(new WebSocketAcceptContext());
public Task<WebSocket> AcceptAsync(WebSocketAcceptContext context)
{
var clientToServer = _factory.Create();
var serverToClient = _factory.Create();
var clientToServer = Channel.CreateUnbounded<WebSocketMessage>();
var serverToClient = Channel.CreateUnbounded<WebSocketMessage>();
var clientSocket = new WebSocketConnection(serverToClient.Reader, clientToServer.Writer);
var serverSocket = new WebSocketConnection(clientToServer.Reader, serverToClient.Writer);
var clientSocket = new WebSocketChannel(serverToClient.In, clientToServer.Out);
var serverSocket = new WebSocketChannel(clientToServer.In, serverToClient.Out);
Client = clientSocket;
return new ValueTask<IWebSocketConnection>(serverSocket);
return Task.FromResult<WebSocket>(serverSocket);
}
public void Dispose()
{
_factory.Dispose();
}
public class WebSocketChannel : WebSocket
{
private readonly ReadableChannel<WebSocketMessage> _input;
private readonly WritableChannel<WebSocketMessage> _output;
private WebSocketCloseStatus? _closeStatus;
private string _closeStatusDescription;
private WebSocketState _state;
public WebSocketChannel(ReadableChannel<WebSocketMessage> input, WritableChannel<WebSocketMessage> output)
{
_input = input;
_output = output;
}
public override WebSocketCloseStatus? CloseStatus => _closeStatus;
public override string CloseStatusDescription => _closeStatusDescription;
public override WebSocketState State => _state;
public override string SubProtocol => null;
public override void Abort()
{
_output.TryComplete(new OperationCanceledException());
_state = WebSocketState.Aborted;
}
public override async Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
{
await _output.WriteAsync(new WebSocketMessage
{
CloseStatus = closeStatus,
CloseStatusDescription = statusDescription,
MessageType = WebSocketMessageType.Close,
},
cancellationToken);
_state = WebSocketState.CloseSent;
_output.TryComplete();
}
public override async Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
{
await _output.WriteAsync(new WebSocketMessage
{
CloseStatus = closeStatus,
CloseStatusDescription = statusDescription,
MessageType = WebSocketMessageType.Close,
},
cancellationToken);
_state = WebSocketState.CloseSent;
_output.TryComplete();
}
public override void Dispose()
{
_state = WebSocketState.Closed;
_output.TryComplete();
}
public override async Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
{
var message = await _input.ReadAsync();
if (message.MessageType == WebSocketMessageType.Close)
{
_state = WebSocketState.CloseReceived;
_closeStatus = message.CloseStatus;
_closeStatusDescription = message.CloseStatusDescription;
return new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, message.CloseStatus, message.CloseStatusDescription);
}
// REVIEW: This assumes the buffer passed in is > the buffer received
Buffer.BlockCopy(message.Buffer, 0, buffer.Array, buffer.Offset, message.Buffer.Length);
return new WebSocketReceiveResult(message.Buffer.Length, message.MessageType, message.EndOfMessage);
}
public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
{
var copy = new byte[buffer.Count];
Buffer.BlockCopy(buffer.Array, buffer.Offset, copy, 0, buffer.Count);
return _output.WriteAsync(new WebSocketMessage
{
Buffer = copy,
MessageType = messageType,
EndOfMessage = endOfMessage
},
cancellationToken);
}
public async Task<WebSocketConnectionSummary> ExecuteAndCaptureFramesAsync()
{
var frames = new List<WebSocketMessage>();
while (await _input.WaitToReadAsync())
{
while (_input.TryRead(out var message))
{
if (message.MessageType == WebSocketMessageType.Close)
{
_state = WebSocketState.CloseReceived;
_closeStatus = message.CloseStatus;
_closeStatusDescription = message.CloseStatusDescription;
return new WebSocketConnectionSummary(frames, new WebSocketReceiveResult(0, message.MessageType, message.EndOfMessage, message.CloseStatus, message.CloseStatusDescription));
}
frames.Add(message);
}
}
_state = WebSocketState.Closed;
_closeStatus = WebSocketCloseStatus.InternalServerError;
return new WebSocketConnectionSummary(frames, new WebSocketReceiveResult(0, WebSocketMessageType.Close, endOfMessage: true, closeStatus: WebSocketCloseStatus.InternalServerError, closeStatusDescription: ""));
}
}
public class WebSocketConnectionSummary
{
public IList<WebSocketMessage> Received { get; }
public WebSocketReceiveResult CloseResult { get; }
public WebSocketConnectionSummary(IList<WebSocketMessage> received, WebSocketReceiveResult closeResult)
{
Received = received;
CloseResult = closeResult;
}
}
public class WebSocketMessage
{
public byte[] Buffer { get; set; }
public WebSocketMessageType MessageType { get; set; }
public bool EndOfMessage { get; set; }
public WebSocketCloseStatus? CloseStatus { get; set; }
public string CloseStatusDescription { get; set; }
}
}
}

View File

@ -2,16 +2,15 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.AspNetCore.Sockets.Transports;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.WebSockets.Internal;
using Microsoft.Extensions.WebSockets.Internal.Tests;
using Xunit;
namespace Microsoft.AspNetCore.Sockets.Tests
@ -19,9 +18,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public class WebSocketsTests
{
[Theory]
[InlineData(MessageType.Text, WebSocketOpcode.Text)]
[InlineData(MessageType.Binary, WebSocketOpcode.Binary)]
public async Task ReceivedFramesAreWrittenToChannel(MessageType format, WebSocketOpcode opcode)
[InlineData(MessageType.Text, WebSocketMessageType.Text)]
[InlineData(MessageType.Binary, WebSocketMessageType.Binary)]
public async Task ReceivedFramesAreWrittenToChannel(MessageType format, WebSocketMessageType webSocketMessageType)
{
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
@ -29,23 +28,23 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipeFactory())
using (var pair = WebSocketPair.Create(factory))
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var client = feature.Client.ExecuteAndCaptureFramesAsync();
// Send a frame, then close
await pair.ClientSocket.SendAsync(new WebSocketFrame(
await feature.Client.SendAsync(
buffer: new ArraySegment<byte>(Encoding.UTF8.GetBytes("Hello")),
messageType: webSocketMessageType,
endOfMessage: true,
opcode: opcode,
payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello"))));
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
cancellationToken: CancellationToken.None);
await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
var message = await applicationSide.Input.In.ReadAsync();
Assert.True(message.EndOfMessage);
@ -60,14 +59,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
// The connection should close after this, which means the client will get a close frame.
var clientSummary = await client;
Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.Status);
Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.CloseStatus);
}
}
[Theory]
[InlineData(MessageType.Text, WebSocketOpcode.Text)]
[InlineData(MessageType.Binary, WebSocketOpcode.Binary)]
public async Task MultiFrameMessagesArePropagatedToTheChannel(MessageType format, WebSocketOpcode opcode)
[InlineData(MessageType.Text, WebSocketMessageType.Text)]
[InlineData(MessageType.Binary, WebSocketMessageType.Binary)]
public async Task IncompleteMessagesAreWrittenAsMultiFrameWebSocketMessages(MessageType format, WebSocketMessageType webSocketMessageType)
{
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
@ -75,71 +74,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipeFactory())
using (var pair = WebSocketPair.Create(factory))
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
// Send a frame, then close
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: false,
opcode: opcode,
payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello"))));
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Continuation,
payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("World"))));
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
var message1 = await applicationSide.Input.In.ReadAsync();
Assert.False(message1.EndOfMessage);
Assert.Equal(format, message1.Type);
Assert.Equal("Hello", Encoding.UTF8.GetString(message1.Payload));
var message2 = await applicationSide.Input.In.ReadAsync();
Assert.True(message2.EndOfMessage);
Assert.Equal(format, message2.Type);
Assert.Equal("World", Encoding.UTF8.GetString(message2.Payload));
Assert.True(applicationSide.Output.Out.TryComplete());
// The transport should finish now
await transport;
// The connection should close after this, which means the client will get a close frame.
var clientSummary = await client;
Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.Status);
}
}
[Theory]
[InlineData(MessageType.Text, WebSocketOpcode.Text)]
[InlineData(MessageType.Binary, WebSocketOpcode.Binary)]
public async Task IncompleteMessagesAreWrittenAsMultiFrameWebSocketMessages(MessageType format, WebSocketOpcode opcode)
{
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipeFactory())
using (var pair = WebSocketPair.Create(factory))
{
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var client = feature.Client.ExecuteAndCaptureFramesAsync();
// Write multi-frame message to the output channel, and then complete it
await applicationSide.Output.Out.WriteAsync(new Message(
@ -154,23 +97,23 @@ namespace Microsoft.AspNetCore.Sockets.Tests
// The client should finish now, as should the server
var clientSummary = await client;
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
await transport;
Assert.Equal(2, clientSummary.Received.Count);
Assert.False(clientSummary.Received[0].EndOfMessage);
Assert.Equal(opcode, clientSummary.Received[0].Opcode);
Assert.Equal("Hello", Encoding.UTF8.GetString(clientSummary.Received[0].Payload.ToArray()));
Assert.Equal(webSocketMessageType, clientSummary.Received[0].MessageType);
Assert.Equal("Hello", Encoding.UTF8.GetString(clientSummary.Received[0].Buffer));
Assert.True(clientSummary.Received[1].EndOfMessage);
Assert.Equal(WebSocketOpcode.Continuation, clientSummary.Received[1].Opcode);
Assert.Equal("World", Encoding.UTF8.GetString(clientSummary.Received[1].Payload.ToArray()));
Assert.Equal(webSocketMessageType, clientSummary.Received[1].MessageType);
Assert.Equal("World", Encoding.UTF8.GetString(clientSummary.Received[1].Buffer));
}
}
[Theory]
[InlineData(MessageType.Text, WebSocketOpcode.Text)]
[InlineData(MessageType.Binary, WebSocketOpcode.Binary)]
public async Task DataWrittenToOutputPipelineAreSentAsFrames(MessageType format, WebSocketOpcode opcode)
[InlineData(MessageType.Text, WebSocketMessageType.Text)]
[InlineData(MessageType.Binary, WebSocketMessageType.Binary)]
public async Task DataWrittenToOutputPipelineAreSentAsFrames(MessageType format, WebSocketMessageType webSocketMessageType)
{
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
@ -178,16 +121,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipeFactory())
using (var pair = WebSocketPair.Create(factory))
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var client = feature.Client.ExecuteAndCaptureFramesAsync();
// Write to the output channel, and then complete it
await applicationSide.Output.Out.WriteAsync(new Message(
@ -198,20 +140,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests
// The client should finish now, as should the server
var clientSummary = await client;
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
await transport;
Assert.Equal(1, clientSummary.Received.Count);
Assert.True(clientSummary.Received[0].EndOfMessage);
Assert.Equal(opcode, clientSummary.Received[0].Opcode);
Assert.Equal("Hello", Encoding.UTF8.GetString(clientSummary.Received[0].Payload.ToArray()));
Assert.Equal(webSocketMessageType, clientSummary.Received[0].MessageType);
Assert.Equal("Hello", Encoding.UTF8.GetString(clientSummary.Received[0].Buffer));
}
}
[Theory]
[InlineData(MessageType.Text, WebSocketOpcode.Text)]
[InlineData(MessageType.Binary, WebSocketOpcode.Binary)]
public async Task FrameReceivedAfterServerCloseSent(MessageType format, WebSocketOpcode opcode)
[InlineData(MessageType.Text, WebSocketMessageType.Text)]
[InlineData(MessageType.Binary, WebSocketMessageType.Binary)]
public async Task FrameReceivedAfterServerCloseSent(MessageType format, WebSocketMessageType webSocketMessageType)
{
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
@ -219,27 +161,27 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipeFactory())
using (var pair = WebSocketPair.Create(factory))
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var client = feature.Client.ExecuteAndCaptureFramesAsync();
// Close the output and wait for the close frame
Assert.True(applicationSide.Output.Out.TryComplete());
await client;
// Send another frame. Then close
await pair.ClientSocket.SendAsync(new WebSocketFrame(
await feature.Client.SendAsync(
buffer: new ArraySegment<byte>(Encoding.UTF8.GetBytes("Hello")),
endOfMessage: true,
opcode: opcode,
payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello"))));
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
messageType: webSocketMessageType,
cancellationToken: CancellationToken.None);
await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
// Read that frame from the input
var message = await applicationSide.Input.In.ReadAsync();
@ -260,22 +202,21 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipeFactory())
using (var pair = WebSocketPair.Create(factory))
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var client = feature.Client.ExecuteAndCaptureFramesAsync();
// Terminate the client to server channel with an exception
pair.TerminateFromClient(new InvalidOperationException());
feature.Client.Abort();
// Wait for the transport
await Assert.ThrowsAsync<InvalidOperationException>(() => transport);
await Assert.ThrowsAsync<OperationCanceledException>(() => transport);
}
}
@ -288,24 +229,23 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipeFactory())
using (var pair = WebSocketPair.Create(factory))
using (var feature = new TestWebSocketConnectionFeature())
{
var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
var transport = ws.ProcessSocketAsync(await feature.AcceptAsync());
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var client = feature.Client.ExecuteAndCaptureFramesAsync();
// Fail in the app
Assert.True(applicationSide.Output.Out.TryComplete(new InvalidOperationException()));
var clientSummary = await client;
Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.Status);
Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.CloseStatus);
// Close from the client
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
await transport.OrTimeout();
}
@ -320,8 +260,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
using (var factory = new PipeFactory())
using (var pair = WebSocketPair.Create(factory))
using (var feature = new TestWebSocketConnectionFeature())
{
var options = new WebSocketOptions()
{
@ -330,21 +269,19 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var ws = new WebSocketsTransport(options, transportSide, new LoggerFactory());
var serverSocket = await feature.AcceptAsync();
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
var transport = ws.ProcessSocketAsync(serverSocket);
// End the app
applicationSide.Dispose();
await transport.OrTimeout(TimeSpan.FromSeconds(10));
// We're still in the closed sent state since the client never sent the close frame
Assert.Equal(WebSocketConnectionState.CloseSent, pair.ServerSocket.State);
pair.ServerSocket.Dispose();
// Now we're closed
Assert.Equal(WebSocketConnectionState.Closed, pair.ServerSocket.State);
Assert.Equal(WebSocketState.Aborted, serverSocket.State);
serverSocket.Dispose();
}
}
}

View File

@ -1,33 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Linq;
using Newtonsoft.Json.Linq;
namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn
{
public class AutobahnCaseResult
{
public string Name { get; }
public string ActualBehavior { get; }
public AutobahnCaseResult(string name, string actualBehavior)
{
Name = name;
ActualBehavior = actualBehavior;
}
public static AutobahnCaseResult FromJson(JProperty prop)
{
var caseObj = (JObject)prop.Value;
var actualBehavior = (string)caseObj["behavior"];
return new AutobahnCaseResult(prop.Name, actualBehavior);
}
public bool BehaviorIs(params string[] behaviors)
{
return behaviors.Any(b => string.Equals(b, ActualBehavior, StringComparison.Ordinal));
}
}
}

View File

@ -1,82 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.AspNetCore.Server.IntegrationTesting;
namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn
{
public class AutobahnExpectations
{
private Dictionary<string, Expectation> _expectations = new Dictionary<string, Expectation>();
public bool Ssl { get; }
public ServerType Server { get; }
public AutobahnExpectations(ServerType server, bool ssl)
{
Server = server;
Ssl = ssl;
}
public AutobahnExpectations Fail(params string[] caseSpecs) => Expect(Expectation.Fail, caseSpecs);
public AutobahnExpectations NonStrict(params string[] caseSpecs) => Expect(Expectation.NonStrict, caseSpecs);
public AutobahnExpectations OkOrFail(params string[] caseSpecs) => Expect(Expectation.OkOrFail, caseSpecs);
public AutobahnExpectations Expect(Expectation expectation, params string[] caseSpecs)
{
foreach (var caseSpec in caseSpecs)
{
_expectations[caseSpec] = expectation;
}
return this;
}
internal void Verify(AutobahnServerResult serverResult, StringBuilder failures)
{
foreach (var caseResult in serverResult.Cases)
{
// If this is an informational test result, we can't compare it to anything
if (!string.Equals(caseResult.ActualBehavior, "INFORMATIONAL", StringComparison.Ordinal))
{
Expectation expectation;
if (!_expectations.TryGetValue(caseResult.Name, out expectation))
{
expectation = Expectation.Ok;
}
switch (expectation)
{
case Expectation.Fail:
if (!caseResult.BehaviorIs("FAILED"))
{
failures.AppendLine($"Case {serverResult.Name}:{caseResult.Name}. Expected 'FAILED', but got '{caseResult.ActualBehavior}'");
}
break;
case Expectation.NonStrict:
if (!caseResult.BehaviorIs("NON-STRICT"))
{
failures.AppendLine($"Case {serverResult.Name}:{caseResult.Name}. Expected 'NON-STRICT', but got '{caseResult.ActualBehavior}'");
}
break;
case Expectation.Ok:
if (!caseResult.BehaviorIs("NON-STRICT") && !caseResult.BehaviorIs("OK"))
{
failures.AppendLine($"Case {serverResult.Name}:{caseResult.Name}. Expected 'NON-STRICT' or 'OK', but got '{caseResult.ActualBehavior}'");
}
break;
case Expectation.OkOrFail:
if (!caseResult.BehaviorIs("NON-STRICT") && !caseResult.BehaviorIs("FAILED") && !caseResult.BehaviorIs("OK"))
{
failures.AppendLine($"Case {serverResult.Name}:{caseResult.Name}. Expected 'FAILED', 'NON-STRICT' or 'OK', but got '{caseResult.ActualBehavior}'");
}
break;
default:
break;
}
}
}
}
}
}

View File

@ -1,25 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.Linq;
using Newtonsoft.Json.Linq;
namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn
{
public class AutobahnResult
{
public IEnumerable<AutobahnServerResult> Servers { get; }
public AutobahnResult(IEnumerable<AutobahnServerResult> servers)
{
Servers = servers;
}
public static AutobahnResult FromReportJson(JObject indexJson)
{
// Load the report
return new AutobahnResult(indexJson.Properties().Select(AutobahnServerResult.FromJson));
}
}
}

View File

@ -1,40 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.AspNetCore.Server.IntegrationTesting;
using Newtonsoft.Json.Linq;
namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn
{
public class AutobahnServerResult
{
public ServerType Server { get; }
public bool Ssl { get; }
public string Name { get; }
public IEnumerable<AutobahnCaseResult> Cases { get; }
public AutobahnServerResult(string name, IEnumerable<AutobahnCaseResult> cases)
{
Name = name;
var splat = name.Split('|');
if (splat.Length < 2)
{
throw new FormatException("Results incorrectly formatted");
}
Server = (ServerType)Enum.Parse(typeof(ServerType), splat[0]);
Ssl = string.Equals(splat[1], "SSL", StringComparison.Ordinal);
Cases = cases;
}
public static AutobahnServerResult FromJson(JProperty prop)
{
var valueObj = ((JObject)prop.Value);
return new AutobahnServerResult(prop.Name, valueObj.Properties().Select(AutobahnCaseResult.FromJson));
}
}
}

View File

@ -1,62 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn
{
public class AutobahnSpec
{
public string OutputDirectory { get; }
public IList<ServerSpec> Servers { get; } = new List<ServerSpec>();
public IList<string> Cases { get; } = new List<string>();
public IList<string> ExcludedCases { get; } = new List<string>();
public AutobahnSpec(string outputDirectory)
{
OutputDirectory = outputDirectory;
}
public AutobahnSpec WithServer(string name, string url)
{
Servers.Add(new ServerSpec(name, url));
return this;
}
public AutobahnSpec IncludeCase(params string[] caseSpecs)
{
foreach (var caseSpec in caseSpecs)
{
Cases.Add(caseSpec);
}
return this;
}
public AutobahnSpec ExcludeCase(params string[] caseSpecs)
{
foreach (var caseSpec in caseSpecs)
{
ExcludedCases.Add(caseSpec);
}
return this;
}
public void WriteJson(string file)
{
File.WriteAllText(file, GetJson().ToString(Formatting.Indented));
}
public JObject GetJson() => new JObject(
new JProperty("options", new JObject(
new JProperty("failByDrop", false))),
new JProperty("outdir", OutputDirectory),
new JProperty("servers", new JArray(Servers.Select(s => s.GetJson()).ToArray())),
new JProperty("cases", new JArray(Cases.ToArray())),
new JProperty("exclude-cases", new JArray(ExcludedCases.ToArray())),
new JProperty("exclude-agent-cases", new JObject()));
}
}

View File

@ -1,147 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.IntegrationTesting;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json.Linq;
using Xunit;
namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn
{
public class AutobahnTester : IDisposable
{
private int _nextPort;
private readonly List<IApplicationDeployer> _deployers = new List<IApplicationDeployer>();
private readonly List<AutobahnExpectations> _expectations = new List<AutobahnExpectations>();
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
public AutobahnSpec Spec { get; }
public AutobahnTester(ILoggerFactory loggerFactory, AutobahnSpec baseSpec) : this(7000, loggerFactory, baseSpec) { }
public AutobahnTester(int startPort, ILoggerFactory loggerFactory, AutobahnSpec baseSpec)
{
_nextPort = startPort;
_loggerFactory = loggerFactory;
_logger = _loggerFactory.CreateLogger("AutobahnTester");
Spec = baseSpec;
}
public async Task<AutobahnResult> Run()
{
var specFile = Path.GetTempFileName();
try
{
Spec.WriteJson(specFile);
// Run the test (write something to the console so people know this will take a while...)
_logger.LogInformation("Now launching Autobahn Test Suite. This will take a while.");
var exitCode = await Wstest.Default.ExecAsync("-m fuzzingclient -s " + specFile);
if (exitCode != 0)
{
throw new Exception("wstest failed");
}
}
finally
{
if (File.Exists(specFile))
{
File.Delete(specFile);
}
}
// Parse the output.
var outputFile = Path.Combine(Directory.GetCurrentDirectory(), Spec.OutputDirectory, "index.json");
using (var reader = new StreamReader(File.OpenRead(outputFile)))
{
return AutobahnResult.FromReportJson(JObject.Parse(await reader.ReadToEndAsync()));
}
}
public void Verify(AutobahnResult result)
{
var failures = new StringBuilder();
foreach (var serverResult in result.Servers)
{
var serverExpectation = _expectations.FirstOrDefault(e => e.Server == serverResult.Server && e.Ssl == serverResult.Ssl);
if (serverExpectation == null)
{
failures.AppendLine($"Expected no results for server: {serverResult.Name} but found results!");
}
else
{
serverExpectation.Verify(serverResult, failures);
}
}
Assert.True(failures.Length == 0, "Autobahn results did not meet expectations:" + Environment.NewLine + failures.ToString());
}
public async Task DeployTestAndAddToSpec(ServerType server, bool ssl, Action<AutobahnExpectations> expectationConfig = null)
{
var port = Interlocked.Increment(ref _nextPort);
var baseUrl = ssl ? $"https://localhost:{port}" : $"http://localhost:{port}";
var sslNamePart = ssl ? "SSL" : "NoSSL";
var name = $"{server}|{sslNamePart}";
var logger = _loggerFactory.CreateLogger($"AutobahnTestApp:{server}:{sslNamePart}");
var appPath = Helpers.GetApplicationPath("WebSocketsTestApp");
var parameters = new DeploymentParameters(appPath, server, RuntimeFlavor.CoreClr, RuntimeArchitecture.x64)
{
ApplicationBaseUriHint = baseUrl,
ApplicationType = ApplicationType.Portable,
TargetFramework = "netcoreapp2.0",
EnvironmentName = "Development"
};
var deployer = ApplicationDeployerFactory.Create(parameters, _loggerFactory);
var result = await deployer.DeployAsync();
result.HostShutdownToken.ThrowIfCancellationRequested();
var handler = new HttpClientHandler();
if (ssl)
{
// Don't take this out of the "if(ssl)". If we set it on some platforms, it crashes
// So we avoid running SSL tests on those platforms (for now).
// See https://github.com/dotnet/corefx/issues/9728
handler.ServerCertificateCustomValidationCallback = (_, __, ___, ____) => true;
}
var client = new HttpClient(handler);
// Make sure the server works
var resp = await RetryHelper.RetryRequest(() =>
{
return client.GetAsync(result.ApplicationBaseUri);
}, logger, result.HostShutdownToken); // High retry count because Travis macOS is slow
resp.EnsureSuccessStatusCode();
// Add to the current spec
var wsUrl = result.ApplicationBaseUri.Replace("https://", "wss://").Replace("http://", "ws://");
Spec.WithServer(name, wsUrl);
_deployers.Add(deployer);
var expectations = new AutobahnExpectations(server, ssl);
expectationConfig?.Invoke(expectations);
_expectations.Add(expectations);
}
public void Dispose()
{
foreach (var deployer in _deployers)
{
deployer.Dispose();
}
}
}
}

View File

@ -1,57 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using System.IO;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn
{
public class Executable
{
private static readonly string _exeSuffix = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? ".exe" : string.Empty;
private readonly string _path;
protected Executable(string path)
{
_path = path;
}
public static string Locate(string name)
{
foreach (var dir in Environment.GetEnvironmentVariable("PATH").Split(Path.PathSeparator))
{
var candidate = Path.Combine(dir, name + _exeSuffix);
if (File.Exists(candidate))
{
return candidate;
}
}
return null;
}
public Task<int> ExecAsync(string args)
{
var process = new Process()
{
StartInfo = new ProcessStartInfo()
{
FileName = _path,
Arguments = args,
UseShellExecute = false,
},
EnableRaisingEvents = true
};
var tcs = new TaskCompletionSource<int>();
process.Exited += (_, __) => tcs.TrySetResult(process.ExitCode);
process.Start();
return tcs.Task;
}
}
}

View File

@ -1,13 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn
{
public enum Expectation
{
Fail,
NonStrict,
OkOrFail,
Ok
}
}

View File

@ -1,25 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Newtonsoft.Json.Linq;
namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn
{
public class ServerSpec
{
public string Name { get; }
public string Url { get; }
public ServerSpec(string name, string url)
{
Name = name;
Url = url;
}
public JObject GetJson() => new JObject(
new JProperty("agent", Name),
new JProperty("url", Url),
new JProperty("options", new JObject(
new JProperty("version", 18))));
}
}

View File

@ -1,25 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn
{
/// <summary>
/// Wrapper around the Autobahn Test Suite's "wstest" app.
/// </summary>
public class Wstest : Executable
{
private static Lazy<Wstest> _instance = new Lazy<Wstest>(Create);
public static Wstest Default => _instance.Value;
public Wstest(string path) : base(path) { }
private static Wstest Create()
{
var location = Locate("wstest");
return location == null ? null : new Wstest(location);
}
}
}

View File

@ -1,92 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using System.IO;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.IntegrationTesting;
using Microsoft.AspNetCore.Testing.xunit;
using Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn;
using Microsoft.Extensions.Logging;
using Xunit.Abstractions;
namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest
{
public class AutobahnTests
{
private readonly ITestOutputHelper _output;
public AutobahnTests(ITestOutputHelper output)
{
_output = output;
}
[ConditionalFact(Skip = "Too flaky. See https://github.com/aspnet/SignalR/issues/336")]
[SkipIfWsTestNotPresent]
public async Task AutobahnTestSuite()
{
var reportDir = Environment.GetEnvironmentVariable("AUTOBAHN_SUITES_REPORT_DIR");
var outDir = !string.IsNullOrEmpty(reportDir) ?
reportDir :
Path.Combine(AppContext.BaseDirectory, "autobahnreports");
if (Directory.Exists(outDir))
{
Directory.Delete(outDir, recursive: true);
}
outDir = outDir.Replace("\\", "\\\\");
// 9.* is Limits/Performance which is VERY SLOW; 12.*/13.* are compression which we don't implement
var spec = new AutobahnSpec(outDir)
.IncludeCase("*")
.ExcludeCase("9.*", "12.*", "13.*");
var loggerFactory = new LoggerFactory(); // No logging by default! It's very loud...
if (string.Equals(Environment.GetEnvironmentVariable("AUTOBAHN_SUITES_LOG"), "1", StringComparison.Ordinal))
{
loggerFactory.AddXunit(_output);
loggerFactory.AddConsole();
_output.WriteLine("Logging enabled");
}
AutobahnResult result;
using (var tester = new AutobahnTester(loggerFactory, spec))
{
await tester.DeployTestAndAddToSpec(ServerType.Kestrel, ssl: false, expectationConfig: expect => expect
.NonStrict("6.4.3", "6.4.4"));
result = await tester.Run();
tester.Verify(result);
}
}
private bool IsWindows8OrHigher()
{
const string WindowsName = "Microsoft Windows ";
const int VersionOffset = 18;
if (RuntimeInformation.OSDescription.StartsWith(WindowsName))
{
var versionStr = RuntimeInformation.OSDescription.Substring(VersionOffset);
Version version;
if (Version.TryParse(versionStr, out version))
{
return version.Major > 6 || (version.Major == 6 && version.Minor >= 2);
}
}
return false;
}
private bool IsIISExpress10Installed()
{
var pf = Environment.GetEnvironmentVariable("PROGRAMFILES");
var iisExpressExe = Path.Combine(pf, "IIS Express", "iisexpress.exe");
return File.Exists(iisExpressExe) && FileVersionInfo.GetVersionInfo(iisExpressExe).FileMajorPart >= 10;
}
}
}

View File

@ -1,31 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest
{
public class Helpers
{
public static string GetApplicationPath(string projectName)
{
var applicationBasePath = AppContext.BaseDirectory;
var directoryInfo = new DirectoryInfo(applicationBasePath);
do
{
var solutionFileInfo = new FileInfo(Path.Combine(directoryInfo.FullName, "SignalR.sln"));
if (solutionFileInfo.Exists)
{
return Path.GetFullPath(Path.Combine(directoryInfo.FullName, "test", projectName));
}
directoryInfo = directoryInfo.Parent;
}
while (directoryInfo.Parent != null);
throw new Exception($"Solution root could not be found using {applicationBasePath}");
}
}
}

View File

@ -1,22 +0,0 @@
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="..\..\build\common.props" />
<PropertyGroup>
<TargetFrameworks>netcoreapp2.0;net461</TargetFrameworks>
<TargetFrameworks Condition="'$(OS)' != 'Windows_NT'">netcoreapp2.0</TargetFrameworks>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Server.IntegrationTesting" Version="$(AspNetCoreIntegrationTestingVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Testing" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Testing" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(TestSdkVersion)" />
<PackageReference Include="Newtonsoft.Json" Version="$(JsonNetVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XunitVersion)" />
<PackageReference Include="xunit" Version="$(XunitVersion)" />
</ItemGroup>
</Project>

View File

@ -1,16 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNetCore.Testing.xunit;
using Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest.Autobahn;
namespace Microsoft.AspNetCore.WebSockets.Internal.ConformanceTest
{
[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)]
public class SkipIfWsTestNotPresentAttribute : Attribute, ITestCondition
{
public bool IsMet => Wstest.Default != null;
public string SkipReason => "Autobahn Test Suite is not installed on the host machine.";
}
}

View File

@ -1,23 +0,0 @@
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="..\..\build\common.props" />
<PropertyGroup>
<TargetFrameworks>netcoreapp2.0;net461</TargetFrameworks>
<TargetFrameworks Condition="'$(OS)' != 'Windows_NT'">netcoreapp2.0</TargetFrameworks>
</PropertyGroup>
<ItemGroup>
<Compile Include="..\Common\TaskExtensions.cs" Link="TaskExtensions.cs" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.Extensions.WebSockets.Internal\Microsoft.Extensions.WebSockets.Internal.csproj" />
<PackageReference Include="Microsoft.AspNetCore.Testing" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(TestSdkVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XunitVersion)" />
<PackageReference Include="xunit" Version="$(XunitVersion)" />
</ItemGroup>
</Project>

View File

@ -1,22 +0,0 @@
using System;
using Microsoft.AspNetCore.Testing.xunit;
namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
[AttributeUsage(AttributeTargets.Method, AllowMultiple = true)]
public class SkipIfEnvVarPresentAttribute : Attribute, ITestCondition
{
private readonly string _environmentVariable;
private readonly string _skipReason;
public bool IsMet => string.IsNullOrEmpty(Environment.GetEnvironmentVariable(_environmentVariable));
public string SkipReason => _skipReason;
public SkipIfEnvVarPresentAttribute(string environmentVariable, string skipReason)
{
_environmentVariable = environmentVariable;
_skipReason = skipReason;
}
}
}

View File

@ -1,32 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
internal static class TestUtil
{
private static readonly TimeSpan DefaultTimeout = TimeSpan.FromSeconds(1);
public static CancellationToken CreateTimeoutToken() => CreateTimeoutToken(DefaultTimeout);
public static CancellationToken CreateTimeoutToken(TimeSpan timeout)
{
if (Debugger.IsAttached)
{
return CancellationToken.None;
}
else
{
var cts = new CancellationTokenSource();
cts.CancelAfter(timeout);
return cts.Token;
}
}
}
}

View File

@ -1,148 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using Xunit;
namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public class Utf8ValidatorTests
{
[Theory]
[InlineData(new byte[] { 0x48, 0x65, 0x6C, 0x6C, 0x6F }, "Hello")]
[InlineData(new byte[] { 0xC2, 0xA7, 0x31, 0x2C, 0x20, 0x39, 0x35, 0xC2, 0xA2 }, "§1, 95¢")]
[InlineData(new byte[] { 0xE0, 0xA0, 0x80, 0xE0, 0xA4, 0x80 }, "\u0800\u0900")]
[InlineData(new byte[] { 0xF0, 0x90, 0x80, 0x80 }, "\U00010000")]
public void ValidSingleFramePayloads(byte[] payload, string decoded)
{
var validator = new Utf8Validator();
Assert.True(validator.ValidateUtf8Frame(ReadableBuffer.Create(payload), fin: true));
// Not really part of the test, but it ensures that the "decoded" string matches the "payload",
// so that the "decoded" string can be used as a human-readable explanation of the string in question
Assert.Equal(decoded, Encoding.UTF8.GetString(payload));
}
[Theory]
[InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0x6C, 0x6C, 0x6F }, "Hello")]
[InlineData(new byte[0], new byte[] { 0xC2, 0xA7 }, "§")]
[InlineData(new byte[] { 0xC2 }, new byte[] { 0xA7 }, "§")]
[InlineData(new byte[] { 0xC2, 0xA7 }, new byte[0], "§")]
[InlineData(new byte[0], new byte[] { 0xC2, 0xA2 }, "¢")]
[InlineData(new byte[] { 0xC2 }, new byte[] { 0xA2 }, "¢")]
[InlineData(new byte[] { 0xC2, 0xA2 }, new byte[0], "¢")]
[InlineData(new byte[0], new byte[] { 0xE0, 0xA0, 0x80 }, "\u0800")]
[InlineData(new byte[] { 0xE0 }, new byte[] { 0xA0, 0x80 }, "\u0800")]
[InlineData(new byte[] { 0xE0, 0xA0 }, new byte[] { 0x80 }, "\u0800")]
[InlineData(new byte[] { 0xE0, 0xA0, 0x80 }, new byte[0], "\u0800")]
[InlineData(new byte[0], new byte[] { 0xE0, 0xA4, 0x80 }, "\u0900")]
[InlineData(new byte[] { 0xE0 }, new byte[] { 0xA4, 0x80 }, "\u0900")]
[InlineData(new byte[] { 0xE0, 0xA4 }, new byte[] { 0x80 }, "\u0900")]
[InlineData(new byte[] { 0xE0, 0xA4, 0x80 }, new byte[0], "\u0900")]
[InlineData(new byte[0], new byte[] { 0xF0, 0x90, 0x80, 0x80 }, "\U00010000")]
[InlineData(new byte[] { 0xF0 }, new byte[] { 0x90, 0x80, 0x80 }, "\U00010000")]
[InlineData(new byte[] { 0xF0, 0x90 }, new byte[] { 0x80, 0x80 }, "\U00010000")]
[InlineData(new byte[] { 0xF0, 0x90, 0x80 }, new byte[] { 0x80 }, "\U00010000")]
[InlineData(new byte[] { 0xF0, 0x90, 0x80, 0x80 }, new byte[0], "\U00010000")]
public void ValidMultiFramePayloads(byte[] payload1, byte[] payload2, string decoded)
{
var validator = new Utf8Validator();
Assert.True(validator.ValidateUtf8Frame(ReadableBuffer.Create(payload1), fin: false));
Assert.True(validator.ValidateUtf8Frame(ReadableBuffer.Create(payload2), fin: true));
// Not really part of the test, but it ensures that the "decoded" string matches the "payload",
// so that the "decoded" string can be used as a human-readable explanation of the string in question
Assert.Equal(decoded, Encoding.UTF8.GetString(Enumerable.Concat(payload1, payload2).ToArray()));
}
[Theory]
// Continuation byte as first byte of code point
[InlineData(new byte[] { 0x48, 0x65, 0x80, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x48, 0x65, 0x99, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x48, 0x65, 0xAB, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x48, 0x65, 0xB0, 0x6C, 0x6F })]
// Incomplete Code Point
[InlineData(new byte[] { 0xC2 })]
[InlineData(new byte[] { 0xE0 })]
[InlineData(new byte[] { 0xE0, 0xA0 })]
[InlineData(new byte[] { 0xE0, 0xA4 })]
[InlineData(new byte[] { 0xF0, 0x90, 0x80 })]
// Overlong Encoding
// 'H' (1 byte char) encoded with 2, 3 and 4 bytes
[InlineData(new byte[] { 0xC1, 0x88 })]
[InlineData(new byte[] { 0xE0, 0x81, 0x88 })]
[InlineData(new byte[] { 0xF0, 0x80, 0x81, 0x88 })]
// '§' (2 byte char) encoded with 3 and 4 bytes
[InlineData(new byte[] { 0xE0, 0x82, 0xA7 })]
[InlineData(new byte[] { 0xF0, 0x80, 0x82, 0xA7 })]
// '\u0800' (3 byte char) encoded with 4 bytes
[InlineData(new byte[] { 0xF0, 0x80, 0xA0, 0x80 })]
// Code point larger than what is allowed
[InlineData(new byte[] { 0xF5, 0x80, 0x80, 0x80 })]
public void InvalidSingleFramePayloads(byte[] payload)
{
var validator = new Utf8Validator();
Assert.False(validator.ValidateUtf8Frame(ReadableBuffer.Create(payload), fin: true));
}
[Theory]
[InlineData(new byte[] { 0xC0 })] // overlong encoding of ASCII
[InlineData(new byte[] { 0xC1 })] // overlong encoding of ASCII
[InlineData(new byte[] { 0xF5 })] // larger than the unicode limit
public void InvalidMultiByteSequencesByFirstByte(byte[] payload)
{
var validator = new Utf8Validator();
Assert.False(validator.ValidateUtf8Frame(ReadableBuffer.Create(payload), fin: false));
}
[Theory]
// Continuation byte as first byte of code point
[InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0x80, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0x99, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0xAB, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0xB0, 0x6C, 0x6F })]
// Incomplete Code Point
[InlineData(new byte[] { 0xC2 }, new byte[0])]
[InlineData(new byte[] { 0xE0 }, new byte[0])]
[InlineData(new byte[] { 0xE0, 0xA0 }, new byte[0])]
[InlineData(new byte[] { 0xE0, 0xA4 }, new byte[0])]
[InlineData(new byte[] { 0xF0, 0x90, 0x80 }, new byte[0])]
// Overlong Encoding
// 'H' (1 byte char) encoded with 3 and 4 bytes
[InlineData(new byte[] { 0xE0 }, new byte[] { 0x81, 0x88 })]
[InlineData(new byte[] { 0xF0 }, new byte[] { 0x80, 0x81, 0x88 })]
// '§' (2 byte char) encoded with 3 and 4 bytes
[InlineData(new byte[] { 0xE0, 0x82 }, new byte[] { 0xA7 })]
[InlineData(new byte[] { 0xF0, 0x80 }, new byte[] { 0x82, 0xA7 })]
// '\u0800' (3 byte char) encoded with 4 bytes
[InlineData(new byte[] { 0xF0, 0x80 }, new byte[] { 0xA0, 0x80 })]
public void InvalidMultiFramePayloads(byte[] payload1, byte[] payload2)
{
var validator = new Utf8Validator();
Assert.True(validator.ValidateUtf8Frame(ReadableBuffer.Create(payload1), fin: false));
Assert.False(validator.ValidateUtf8Frame(ReadableBuffer.Create(payload2), fin: true));
}
}
}

View File

@ -1,18 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.Threading.Tasks;
namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public static class WebSocketConnectionExtensions
{
public static async Task<WebSocketConnectionSummary> ExecuteAndCaptureFramesAsync(this IWebSocketConnection connection)
{
var frames = new List<WebSocketFrame>();
var closeResult = await connection.ExecuteAsync(frame => frames.Add(frame.Copy()));
return new WebSocketConnectionSummary(frames, closeResult);
}
}
}

View File

@ -1,19 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public class WebSocketConnectionSummary
{
public IList<WebSocketFrame> Received { get; }
public WebSocketCloseResult CloseResult { get; }
public WebSocketConnectionSummary(IList<WebSocketFrame> received, WebSocketCloseResult closeResult)
{
Received = received;
CloseResult = closeResult;
}
}
}

View File

@ -1,168 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.Extensions.Internal;
using System;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public partial class WebSocketConnectionTests
{
[Fact]
public async Task SendReceiveFrames()
{
using (var pair = WebSocketPair.Create())
{
var client = pair.ClientSocket.ExecuteAsync(_ =>
{
Assert.False(true, "did not expect the client to receive any frames!");
return TaskCache.CompletedTask;
});
// Send Frames
await pair.ClientSocket.SendAsync(CreateTextFrame("Hello")).OrTimeout();
await pair.ClientSocket.SendAsync(CreateTextFrame("World")).OrTimeout();
await pair.ClientSocket.SendAsync(CreateBinaryFrame(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF })).OrTimeout();
await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)).OrTimeout();
var summary = await pair.ServerSocket.ExecuteAndCaptureFramesAsync().OrTimeout();
Assert.Equal(3, summary.Received.Count);
Assert.Equal("Hello", Encoding.UTF8.GetString(summary.Received[0].Payload.ToArray()));
Assert.Equal("World", Encoding.UTF8.GetString(summary.Received[1].Payload.ToArray()));
Assert.Equal(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }, summary.Received[2].Payload.ToArray());
await pair.ServerSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)).OrTimeout();
await client.OrTimeout();
}
}
[Fact]
public async Task ExecuteReturnsWhenCloseFrameReceived()
{
using (var pair = WebSocketPair.Create())
{
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.InvalidMessageType, "Abc")).OrTimeout();
var serverSummary = await pair.ServerSocket.ExecuteAndCaptureFramesAsync().OrTimeout();
await pair.ServerSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure, "Ok")).OrTimeout();
var clientSummary = await client.OrTimeout();
Assert.Equal(0, serverSummary.Received.Count);
Assert.Equal(WebSocketCloseStatus.InvalidMessageType, serverSummary.CloseResult.Status);
Assert.Equal("Abc", serverSummary.CloseResult.Description);
Assert.Equal(0, clientSummary.Received.Count);
Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.Status);
Assert.Equal("Ok", clientSummary.CloseResult.Description);
}
}
[Fact]
public async Task AbnormalTerminationOfInboundChannelCausesExecuteToThrow()
{
using (var pair = WebSocketPair.Create())
{
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync();
pair.TerminateFromClient(new InvalidOperationException("It broke!"));
await Assert.ThrowsAsync<InvalidOperationException>(() => server);
}
}
[Fact]
public async Task StateTransitions()
{
using (var pair = WebSocketPair.Create())
{
// Initial State
Assert.Equal(WebSocketConnectionState.Created, pair.ServerSocket.State);
Assert.Equal(WebSocketConnectionState.Created, pair.ClientSocket.State);
// Start the sockets
var serverReceiving = new TaskCompletionSource<object>();
var clientReceiving = new TaskCompletionSource<object>();
var server = pair.ServerSocket.ExecuteAsync(frame => serverReceiving.TrySetResult(null));
var client = pair.ClientSocket.ExecuteAsync(frame => clientReceiving.TrySetResult(null));
// Send a frame from each and verify that the state transitioned.
// We need to do this because it's the only way to correctly wait for the state transition (which happens asynchronously in ExecuteAsync)
await pair.ClientSocket.SendAsync(CreateTextFrame("Hello")).OrTimeout();
await pair.ServerSocket.SendAsync(CreateTextFrame("Hello")).OrTimeout();
await Task.WhenAll(serverReceiving.Task, clientReceiving.Task).OrTimeout();
// Check state
Assert.Equal(WebSocketConnectionState.Connected, pair.ServerSocket.State);
Assert.Equal(WebSocketConnectionState.Connected, pair.ClientSocket.State);
// Close the server socket
await pair.ServerSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)).OrTimeout();
await client.OrTimeout();
// Check state
Assert.Equal(WebSocketConnectionState.CloseSent, pair.ServerSocket.State);
Assert.Equal(WebSocketConnectionState.CloseReceived, pair.ClientSocket.State);
// Close the client socket
await pair.ClientSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure)).OrTimeout();
await server.OrTimeout();
// Check state
Assert.Equal(WebSocketConnectionState.Closed, pair.ServerSocket.State);
Assert.Equal(WebSocketConnectionState.Closed, pair.ClientSocket.State);
// Verify we can't restart the connection or send a message
await Assert.ThrowsAsync<ObjectDisposedException>(async () => await pair.ServerSocket.ExecuteAsync(f => { }));
await Assert.ThrowsAsync<ObjectDisposedException>(async () => await pair.ClientSocket.SendAsync(CreateTextFrame("Nope")));
}
}
[Fact]
public async Task CanReceiveControlFrameInTheMiddleOfFragmentedMessage()
{
using (var pair = WebSocketPair.Create())
{
// Start the sockets
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync();
// Send (Fin=false, "Hello"), (Ping), (Fin=true, "World")
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: false,
opcode: WebSocketOpcode.Text,
payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello"))));
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Ping,
payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("ping"))));
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Continuation,
payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("World"))));
// Close the socket
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
var serverSummary = await server;
await pair.ServerSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
var clientSummary = await client;
// Assert
var nonControlFrames = serverSummary.Received.Where(f => f.Opcode < WebSocketOpcode.Close).ToList();
Assert.Equal(2, nonControlFrames.Count);
Assert.False(nonControlFrames[0].EndOfMessage);
Assert.True(nonControlFrames[1].EndOfMessage);
Assert.Equal(WebSocketOpcode.Text, nonControlFrames[0].Opcode);
Assert.Equal(WebSocketOpcode.Continuation, nonControlFrames[1].Opcode);
Assert.Equal("Hello", Encoding.UTF8.GetString(nonControlFrames[0].Payload.ToArray()));
Assert.Equal("World", Encoding.UTF8.GetString(nonControlFrames[1].Payload.ToArray()));
}
}
}
}

View File

@ -1,102 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Globalization;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Xunit;
namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public partial class WebSocketConnectionTests
{
[Fact]
public async Task AutomaticPingTransmission()
{
var startTime = DateTime.UtcNow;
// Arrange
using (var pair = WebSocketPair.Create(
serverOptions: new WebSocketOptions().WithAllFramesPassedThrough().WithPingInterval(TimeSpan.FromMilliseconds(10)),
clientOptions: new WebSocketOptions().WithAllFramesPassedThrough()))
{
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync();
// Act
// Wait for pings to be sent
await Task.Delay(500);
await pair.ServerSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout();
var clientSummary = await client.OrTimeout();
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout();
var serverSummary = await server.OrTimeout();
// Assert
Assert.NotEqual(0, clientSummary.Received.Count);
Assert.True(clientSummary.Received.All(f => f.EndOfMessage));
Assert.True(clientSummary.Received.All(f => f.Opcode == WebSocketOpcode.Ping));
Assert.True(clientSummary.Received.All(f =>
{
var str = Encoding.UTF8.GetString(f.Payload.ToArray());
// We can't verify the exact timestamp, but we can verify that it is a timestamp created after we started.
if (DateTime.TryParseExact(str, "O", CultureInfo.InvariantCulture, DateTimeStyles.AdjustToUniversal, out var dt))
{
return dt >= startTime;
}
return false;
}));
}
}
[Fact]
public async Task AutomaticPingResponse()
{
// Arrange
using (var pair = WebSocketPair.Create(
serverOptions: new WebSocketOptions().WithAllFramesPassedThrough(),
clientOptions: new WebSocketOptions().WithAllFramesPassedThrough()))
{
var payload = Encoding.UTF8.GetBytes("ping payload");
var pongTcs = new TaskCompletionSource<WebSocketFrame>();
var client = pair.ClientSocket.ExecuteAsync(f =>
{
if (f.Opcode == WebSocketOpcode.Pong)
{
pongTcs.TrySetResult(f.Copy());
}
else
{
Assert.False(true, "Received non-pong frame from server!");
}
});
var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync();
// Act
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Ping,
payload: ReadableBuffer.Create(payload)));
var pongFrame = await pongTcs.Task.OrTimeout();
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout();
await server.OrTimeout();
await pair.ServerSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout();
await client.OrTimeout();
// Assert
Assert.True(pongFrame.EndOfMessage);
Assert.Equal(WebSocketOpcode.Pong, pongFrame.Opcode);
Assert.Equal(payload, pongFrame.Payload.ToArray());
}
}
}
}

View File

@ -1,246 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.SignalR.Tests.Common;
using System;
using System.IO.Pipelines;
using System.Text;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public partial class WebSocketConnectionTests
{
[Theory]
[InlineData(new byte[] { 0x11, 0x00 })]
[InlineData(new byte[] { 0x21, 0x00 })]
[InlineData(new byte[] { 0x31, 0x00 })]
[InlineData(new byte[] { 0x41, 0x00 })]
[InlineData(new byte[] { 0x51, 0x00 })]
[InlineData(new byte[] { 0x61, 0x00 })]
[InlineData(new byte[] { 0x71, 0x00 })]
public Task TerminatesConnectionOnReservedBitSet(byte[] rawFrame)
{
return WriteFrameAndExpectClose(rawFrame, WebSocketCloseStatus.ProtocolError, "Reserved bits, which are required to be zero, were set.");
}
[Theory]
[InlineData(0x03)]
[InlineData(0x04)]
[InlineData(0x05)]
[InlineData(0x06)]
[InlineData(0x07)]
[InlineData(0x0B)]
[InlineData(0x0C)]
[InlineData(0x0D)]
[InlineData(0x0E)]
[InlineData(0x0F)]
public Task ReservedOpcodes(byte opcode)
{
var payload = Encoding.UTF8.GetBytes("hello");
var frame = new WebSocketFrame(
endOfMessage: true,
opcode: (WebSocketOpcode)opcode,
payload: ReadableBuffer.Create(payload));
return SendFrameAndExpectClose(frame, WebSocketCloseStatus.ProtocolError, $"Received frame using reserved opcode: 0x{opcode:X}");
}
[Theory]
[InlineData(new byte[] { 0x88, 0x01, 0xAB })]
// Invalid UTF-8 reason
[InlineData(new byte[] { 0x88, 0x07, 0x03, 0xE8, 0x48, 0x65, 0x80, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x88, 0x07, 0x03, 0xE8, 0x48, 0x65, 0x99, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x88, 0x07, 0x03, 0xE8, 0x48, 0x65, 0xAB, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x88, 0x07, 0x03, 0xE8, 0x48, 0x65, 0xB0, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x88, 0x03, 0x03, 0xE8, 0xC2 })]
[InlineData(new byte[] { 0x88, 0x03, 0x03, 0xE8, 0xE0 })]
[InlineData(new byte[] { 0x88, 0x04, 0x03, 0xE8, 0xE0, 0xA0 })]
[InlineData(new byte[] { 0x88, 0x04, 0x03, 0xE8, 0xE0, 0xA4 })]
[InlineData(new byte[] { 0x88, 0x05, 0x03, 0xE8, 0xF0, 0x90, 0x80 })]
[InlineData(new byte[] { 0x88, 0x04, 0x03, 0xE8, 0xC1, 0x88 })]
[InlineData(new byte[] { 0x88, 0x05, 0x03, 0xE8, 0xE0, 0x81, 0x88 })]
[InlineData(new byte[] { 0x88, 0x06, 0x03, 0xE8, 0xF0, 0x80, 0x81, 0x88 })]
[InlineData(new byte[] { 0x88, 0x05, 0x03, 0xE8, 0xE0, 0x82, 0xA7 })]
[InlineData(new byte[] { 0x88, 0x06, 0x03, 0xE8, 0xF0, 0x80, 0x82, 0xA7 })]
[InlineData(new byte[] { 0x88, 0x06, 0x03, 0xE8, 0xF0, 0x80, 0xA0, 0x80 })]
public Task InvalidCloseFrames(byte[] rawFrame)
{
return WriteFrameAndExpectClose(rawFrame, WebSocketCloseStatus.ProtocolError, "Close frame payload invalid");
}
[Fact]
public Task CloseFrameTooLong()
{
var rawFrame = new byte[256];
new Random().NextBytes(rawFrame);
// Put a WebSocket frame header in front
rawFrame[0] = 0x88; // Close frame, FIN=true
rawFrame[1] = 0x7E; // Mask=false, LEN=126
rawFrame[2] = 0x00; // Extended Len = 252 (256 - 4 bytes for header)
rawFrame[3] = 0xFC;
return WriteFrameAndExpectClose(rawFrame, WebSocketCloseStatus.ProtocolError, "Close frame payload too long. Maximum size is 125 bytes");
}
[Theory]
// 0-999 reserved
[InlineData(0)]
[InlineData(999)]
// Specifically reserved status codes, or codes that should not be sent in frames.
[InlineData(1004)]
[InlineData(1005)]
[InlineData(1006)]
[InlineData(1012)]
[InlineData(1013)]
[InlineData(1014)]
[InlineData(1015)]
// Undefined status codes
[InlineData(1016)]
[InlineData(1100)]
[InlineData(2000)]
[InlineData(2999)]
public Task InvalidCloseStatuses(ushort status)
{
var rawFrame = new byte[] { 0x88, 0x02, (byte)(status >> 8), (byte)(status) };
return WriteFrameAndExpectClose(rawFrame, WebSocketCloseStatus.ProtocolError, $"Invalid close status: {status}.");
}
[Theory]
[InlineData(new byte[] { 0x08, 0x00 })]
[InlineData(new byte[] { 0x09, 0x00 })]
[InlineData(new byte[] { 0x0A, 0x00 })]
public Task TerminatesConnectionOnFragmentedControlFrame(byte[] rawFrame)
{
return WriteFrameAndExpectClose(rawFrame, WebSocketCloseStatus.ProtocolError, "Control frames may not be fragmented");
}
[Fact]
public async Task TerminatesConnectionOnNonContinuationFrameFollowingFragmentedMessageStart()
{
// Arrange
using (var pair = WebSocketPair.Create(
serverOptions: new WebSocketOptions().WithAllFramesPassedThrough(),
clientOptions: new WebSocketOptions().WithAllFramesPassedThrough()))
{
var payload = Encoding.UTF8.GetBytes("hello");
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync();
// Act
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: false,
opcode: WebSocketOpcode.Text,
payload: ReadableBuffer.Create(payload)));
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Text,
payload: ReadableBuffer.Create(payload)));
// Server should terminate
var clientSummary = await client.OrTimeout();
Assert.Equal(WebSocketCloseStatus.ProtocolError, clientSummary.CloseResult.Status);
Assert.Equal("Received non-continuation frame during a fragmented message", clientSummary.CloseResult.Description);
await server.OrTimeout();
}
}
[Fact]
public async Task TerminatesConnectionOnUnsolicitedContinuationFrame()
{
// Arrange
using (var pair = WebSocketPair.Create(
serverOptions: new WebSocketOptions().WithAllFramesPassedThrough(),
clientOptions: new WebSocketOptions().WithAllFramesPassedThrough()))
{
var payload = Encoding.UTF8.GetBytes("hello");
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync();
// Act
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Text,
payload: ReadableBuffer.Create(payload)));
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Continuation,
payload: ReadableBuffer.Create(payload)));
// Server should terminate
var clientSummary = await client.OrTimeout();
Assert.Equal(WebSocketCloseStatus.ProtocolError, clientSummary.CloseResult.Status);
Assert.Equal("Continuation Frame was received when expecting a new message", clientSummary.CloseResult.Description);
await server.OrTimeout();
}
}
[Fact]
public Task TerminatesConnectionOnPingFrameLargerThan125Bytes()
{
var payload = new byte[126];
new Random().NextBytes(payload);
return SendFrameAndExpectClose(
new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Ping,
payload: ReadableBuffer.Create(payload)),
WebSocketCloseStatus.ProtocolError,
"Ping frame exceeded maximum size of 125 bytes");
}
private static async Task SendFrameAndExpectClose(WebSocketFrame frame, WebSocketCloseStatus closeStatus, string closeReason)
{
// Arrange
using (var pair = WebSocketPair.Create(
serverOptions: new WebSocketOptions().WithAllFramesPassedThrough(),
clientOptions: new WebSocketOptions().WithAllFramesPassedThrough()))
{
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync();
// Act
await pair.ClientSocket.SendAsync(frame);
// Server should terminate
var clientSummary = await client.OrTimeout();
Assert.Equal(closeStatus, clientSummary.CloseResult.Status);
Assert.Equal(closeReason, clientSummary.CloseResult.Description);
await server.OrTimeout();
}
}
private static async Task WriteFrameAndExpectClose(byte[] rawFrame, WebSocketCloseStatus closeStatus, string closeReason)
{
// Arrange
using (var pair = WebSocketPair.Create(
serverOptions: new WebSocketOptions().WithAllFramesPassedThrough(),
clientOptions: new WebSocketOptions().WithAllFramesPassedThrough()))
{
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync();
// Act
await pair.ClientToServer.Writer.WriteAsync(rawFrame);
// Server should terminate
var clientSummary = await client.OrTimeout();
Assert.Equal(closeStatus, clientSummary.CloseResult.Status);
Assert.Equal(closeReason, clientSummary.CloseResult.Description);
await server.OrTimeout();
}
}
}
}

View File

@ -1,240 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Xunit;
namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public partial class WebSocketConnectionTests
{
[Theory]
[InlineData(new byte[] { 0x81, 0x00 }, "", true)]
[InlineData(new byte[] { 0x81, 0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F }, "Hello", true)]
[InlineData(new byte[] { 0x81, 0x85, 0x1, 0x2, 0x3, 0x4, 0x48 ^ 0x1, 0x65 ^ 0x2, 0x6C ^ 0x3, 0x6C ^ 0x4, 0x6F ^ 0x1 }, "Hello", true)]
[InlineData(new byte[] { 0x01, 0x00 }, "", false)]
[InlineData(new byte[] { 0x01, 0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F }, "Hello", false)]
[InlineData(new byte[] { 0x01, 0x85, 0x1, 0x2, 0x3, 0x4, 0x48 ^ 0x1, 0x65 ^ 0x2, 0x6C ^ 0x3, 0x6C ^ 0x4, 0x6F ^ 0x1 }, "Hello", false)]
public Task ReadTextFrames(byte[] rawFrame, string message, bool endOfMessage)
{
return RunSingleFrameTest(
rawFrame,
endOfMessage,
WebSocketOpcode.Text,
b => Assert.Equal(message, Encoding.UTF8.GetString(b)));
}
[Theory]
// Opcode = Binary
[InlineData(new byte[] { 0x82, 0x00 }, new byte[0], WebSocketOpcode.Binary, true)]
[InlineData(new byte[] { 0x82, 0x05, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, true)]
[InlineData(new byte[] { 0x82, 0x85, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, true)]
[InlineData(new byte[] { 0x02, 0x00 }, new byte[0], WebSocketOpcode.Binary, false)]
[InlineData(new byte[] { 0x02, 0x05, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, false)]
[InlineData(new byte[] { 0x02, 0x85, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Binary, false)]
// Opcode = Ping
[InlineData(new byte[] { 0x89, 0x00 }, new byte[0], WebSocketOpcode.Ping, true)]
[InlineData(new byte[] { 0x89, 0x05, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Ping, true)]
[InlineData(new byte[] { 0x89, 0x85, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Ping, true)]
// Control frames can't have fin=false
// Opcode = Pong
[InlineData(new byte[] { 0x8A, 0x00 }, new byte[0], WebSocketOpcode.Pong, true)]
[InlineData(new byte[] { 0x8A, 0x05, 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Pong, true)]
[InlineData(new byte[] { 0x8A, 0x85, 0x1, 0x2, 0x3, 0x4, 0xDE ^ 0x1, 0xAD ^ 0x2, 0xBE ^ 0x3, 0xEF ^ 0x4, 0xAB ^ 0x1 }, new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, WebSocketOpcode.Pong, true)]
// Control frames can't have fin=false
public Task ReadBinaryFormattedFrames(byte[] rawFrame, byte[] payload, WebSocketOpcode opcode, bool endOfMessage)
{
return RunSingleFrameTest(
rawFrame,
endOfMessage,
opcode,
b => Assert.Equal(payload, b));
}
[Fact]
public async Task ReadMultipleFramesAcrossMultipleBuffers()
{
var result = await RunReceiveTest(
producer: async (channel, cancellationToken) =>
{
await channel.WriteAsync(new byte[] { 0x02, 0x05 }).OrTimeout();
await Task.Yield();
await channel.WriteAsync(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB, 0x80, 0x05 }).OrTimeout();
await Task.Yield();
await channel.WriteAsync(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }).OrTimeout();
await Task.Yield();
await channel.WriteAsync(new byte[] { 0xAB }).OrTimeout();
await Task.Yield();
});
Assert.Equal(2, result.Received.Count);
Assert.False(result.Received[0].EndOfMessage);
Assert.Equal(WebSocketOpcode.Binary, result.Received[0].Opcode);
Assert.Equal(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, result.Received[0].Payload.ToArray());
Assert.True(result.Received[1].EndOfMessage);
Assert.Equal(WebSocketOpcode.Continuation, result.Received[1].Opcode);
Assert.Equal(new byte[] { 0xDE, 0xAD, 0xBE, 0xEF, 0xAB }, result.Received[1].Payload.ToArray());
}
[Fact]
public async Task ReadLargeMaskedPayload()
{
// This test was added to ensure we don't break a behavior discovered while running the Autobahn test suite.
// Larger than one page, which means it will span blocks in the memory pool.
var expectedPayload = new byte[4192];
for (int i = 0; i < expectedPayload.Length; i++)
{
expectedPayload[i] = (byte)(i % byte.MaxValue);
}
var maskingKey = new byte[] { 0x01, 0x02, 0x03, 0x04 };
var sendPayload = new byte[4192];
for (int i = 0; i < expectedPayload.Length; i++)
{
sendPayload[i] = (byte)(expectedPayload[i] ^ maskingKey[i % 4]);
}
var result = await RunReceiveTest(
producer: async (channel, cancellationToken) =>
{
// We use a 64-bit length because we want to ensure that the first page of data ends at an
// offset within the frame that is NOT divisible by 4. This ensures that when the unmasking
// moves from one buffer to the other, we are at a non-zero position within the masking key.
// This ensures that we're tracking the masking key offset properly.
// Header: (Opcode=Binary, Fin=true), (Mask=false, Len=126), (64-bit big endian length)
await channel.WriteAsync(new byte[] { 0x82, 0xFF, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x60 }).OrTimeout();
await channel.WriteAsync(maskingKey).OrTimeout();
await Task.Yield();
await channel.WriteAsync(sendPayload).OrTimeout();
});
Assert.Equal(1, result.Received.Count);
var frame = result.Received[0];
Assert.True(frame.EndOfMessage);
Assert.Equal(WebSocketOpcode.Binary, frame.Opcode);
Assert.Equal(expectedPayload, frame.Payload.ToArray());
}
[Fact]
public async Task Read16BitPayloadLength()
{
var expectedPayload = new byte[1024];
new Random().NextBytes(expectedPayload);
var result = await RunReceiveTest(
producer: async (channel, cancellationToken) =>
{
// Header: (Opcode=Binary, Fin=true), (Mask=false, Len=126), (16-bit big endian length)
await channel.WriteAsync(new byte[] { 0x82, 0x7E, 0x04, 0x00 }).OrTimeout();
await Task.Yield();
await channel.WriteAsync(expectedPayload).OrTimeout();
});
Assert.Equal(1, result.Received.Count);
var frame = result.Received[0];
Assert.True(frame.EndOfMessage);
Assert.Equal(WebSocketOpcode.Binary, frame.Opcode);
Assert.Equal(expectedPayload, frame.Payload.ToArray());
}
[Fact]
public async Task Read64bitPayloadLength()
{
// Allocating an actual (2^32 + 1) byte payload is crazy for this test. We just need to test that we can USE a 64-bit length
var expectedPayload = new byte[1024];
new Random().NextBytes(expectedPayload);
var result = await RunReceiveTest(
producer: async (channel, cancellationToken) =>
{
// Header: (Opcode=Binary, Fin=true), (Mask=false, Len=127), (64-bit big endian length)
await channel.WriteAsync(new byte[] { 0x82, 0x7F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00 }).OrTimeout();
await Task.Yield();
await channel.WriteAsync(expectedPayload).OrTimeout();
});
Assert.Equal(1, result.Received.Count);
var frame = result.Received[0];
Assert.True(frame.EndOfMessage);
Assert.Equal(WebSocketOpcode.Binary, frame.Opcode);
Assert.Equal(expectedPayload, frame.Payload.ToArray());
}
private static async Task RunSingleFrameTest(byte[] rawFrame, bool endOfMessage, WebSocketOpcode expectedOpcode, Action<byte[]> payloadAssert)
{
var result = await RunReceiveTest(
producer: async (channel, cancellationToken) =>
{
await channel.WriteAsync(rawFrame).OrTimeout();
});
var frames = result.Received;
Assert.Equal(1, frames.Count);
var frame = frames[0];
Assert.Equal(endOfMessage, frame.EndOfMessage);
Assert.Equal(expectedOpcode, frame.Opcode);
payloadAssert(frame.Payload.ToArray());
}
private static async Task<WebSocketConnectionSummary> RunReceiveTest(Func<IPipeWriter, CancellationToken, Task> producer)
{
using (var factory = new PipeFactory())
{
var outbound = factory.Create();
var inbound = factory.Create();
var timeoutToken = TestUtil.CreateTimeoutToken();
var producerTask = Task.Run(async () =>
{
await producer(inbound.Writer, timeoutToken).OrTimeout();
inbound.Writer.Complete();
}, timeoutToken);
var consumerTask = Task.Run(async () =>
{
var connection = new WebSocketConnection(inbound.Reader, outbound.Writer, options: new WebSocketOptions().WithAllFramesPassedThrough());
using (timeoutToken.Register(() => connection.Dispose()))
using (connection)
{
// Receive frames until we're closed
return await connection.ExecuteAndCaptureFramesAsync().OrTimeout();
}
}, timeoutToken);
await Task.WhenAll(producerTask, consumerTask);
return consumerTask.Result;
}
}
private static WebSocketFrame CreateTextFrame(string message)
{
var payload = Encoding.UTF8.GetBytes(message);
return CreateFrame(endOfMessage: true, opcode: WebSocketOpcode.Text, payload: payload);
}
private static WebSocketFrame CreateBinaryFrame(byte[] payload)
{
return CreateFrame(endOfMessage: true, opcode: WebSocketOpcode.Binary, payload: payload);
}
private static WebSocketFrame CreateFrame(bool endOfMessage, WebSocketOpcode opcode, byte[] payload)
{
return new WebSocketFrame(endOfMessage, opcode, payload: ReadableBuffer.Create(payload, 0, payload.Length));
}
}
}

View File

@ -1,238 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.Extensions.Internal;
using System;
using System.IO.Pipelines;
using System.Text;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public partial class WebSocketConnectionTests
{
// No auto-pinging for us!
private readonly static WebSocketOptions DefaultTestOptions = new WebSocketOptions().WithAllFramesPassedThrough();
[Theory]
[InlineData("", true, new byte[] { 0x81, 0x00 })]
[InlineData("Hello", true, new byte[] { 0x81, 0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F })]
[InlineData("", false, new byte[] { 0x01, 0x00 })]
[InlineData("Hello", false, new byte[] { 0x01, 0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F })]
public async Task WriteTextFrames(string message, bool endOfMessage, byte[] expectedRawFrame)
{
var data = await RunSendTest(
producer: async (socket) =>
{
var payload = Encoding.UTF8.GetBytes(message);
await socket.SendAsync(CreateFrame(
endOfMessage,
opcode: WebSocketOpcode.Text,
payload: payload)).OrTimeout();
}, options: DefaultTestOptions);
Assert.Equal(expectedRawFrame, data);
}
[Theory]
// Opcode = Binary
[InlineData(new byte[0], WebSocketOpcode.Binary, true, new byte[] { 0x82, 0x00 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, true, new byte[] { 0x82, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })]
[InlineData(new byte[0], WebSocketOpcode.Binary, false, new byte[] { 0x02, 0x00 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, false, new byte[] { 0x02, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })]
// Opcode = Continuation
[InlineData(new byte[0], WebSocketOpcode.Continuation, true, new byte[] { 0x80, 0x00 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, true, new byte[] { 0x80, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })]
[InlineData(new byte[0], WebSocketOpcode.Continuation, false, new byte[] { 0x00, 0x00 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, false, new byte[] { 0x00, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })]
// Opcode = Ping
[InlineData(new byte[0], WebSocketOpcode.Ping, true, new byte[] { 0x89, 0x00 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, true, new byte[] { 0x89, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })]
[InlineData(new byte[0], WebSocketOpcode.Ping, false, new byte[] { 0x09, 0x00 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, false, new byte[] { 0x09, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })]
// Opcode = Pong
[InlineData(new byte[0], WebSocketOpcode.Pong, true, new byte[] { 0x8A, 0x00 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, true, new byte[] { 0x8A, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })]
[InlineData(new byte[0], WebSocketOpcode.Pong, false, new byte[] { 0x0A, 0x00 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, false, new byte[] { 0x0A, 0x05, 0xA, 0xB, 0xC, 0xD, 0xE })]
public async Task WriteBinaryFormattedFrames(byte[] payload, WebSocketOpcode opcode, bool endOfMessage, byte[] expectedRawFrame)
{
var data = await RunSendTest(
producer: async (socket) =>
{
await socket.SendAsync(CreateFrame(
endOfMessage,
opcode,
payload: payload)).OrTimeout();
}, options: DefaultTestOptions);
Assert.Equal(expectedRawFrame, data);
}
[Theory]
[InlineData("", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x81, 0x80, 0x01, 0x02, 0x03, 0x04 })]
[InlineData("Hello", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x81, 0x85, 0x01, 0x02, 0x03, 0x04, 0x48 ^ 0x01, 0x65 ^ 0x02, 0x6C ^ 0x03, 0x6C ^ 0x04, 0x6F ^ 0x01 })]
public async Task WriteMaskedTextFrames(string message, byte[] maskingKey, byte[] expectedRawFrame)
{
var data = await RunSendTest(
producer: async (socket) =>
{
var payload = Encoding.UTF8.GetBytes(message);
await socket.SendAsync(CreateFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Text,
payload: payload)).OrTimeout();
}, options: DefaultTestOptions.WithFixedMaskingKey(maskingKey));
Assert.Equal(expectedRawFrame, data);
}
[Theory]
// Opcode = Binary
[InlineData(new byte[0], WebSocketOpcode.Binary, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x82, 0x80, 0x01, 0x02, 0x03, 0x04 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x82, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })]
[InlineData(new byte[0], WebSocketOpcode.Binary, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x02, 0x80, 0x01, 0x02, 0x03, 0x04 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Binary, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x02, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })]
// Opcode = Continuation
[InlineData(new byte[0], WebSocketOpcode.Continuation, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x80, 0x80, 0x01, 0x02, 0x03, 0x04 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x80, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })]
[InlineData(new byte[0], WebSocketOpcode.Continuation, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x00, 0x80, 0x01, 0x02, 0x03, 0x04 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Continuation, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x00, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })]
// Opcode = Ping
[InlineData(new byte[0], WebSocketOpcode.Ping, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x89, 0x80, 0x01, 0x02, 0x03, 0x04 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x89, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })]
[InlineData(new byte[0], WebSocketOpcode.Ping, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x09, 0x80, 0x01, 0x02, 0x03, 0x04 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Ping, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x09, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })]
// Opcode = Pong
[InlineData(new byte[0], WebSocketOpcode.Pong, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x8A, 0x80, 0x01, 0x02, 0x03, 0x04 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, true, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x8A, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })]
[InlineData(new byte[0], WebSocketOpcode.Pong, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x0A, 0x80, 0x01, 0x02, 0x03, 0x04 })]
[InlineData(new byte[] { 0xA, 0xB, 0xC, 0xD, 0xE }, WebSocketOpcode.Pong, false, new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x0A, 0x85, 0x01, 0x02, 0x03, 0x04, 0x0A ^ 0x01, 0x0B ^ 0x02, 0x0C ^ 0x03, 0x0D ^ 0x04, 0x0E ^ 0x01 })]
public async Task WriteMaskedBinaryFormattedFrames(byte[] payload, WebSocketOpcode opcode, bool endOfMessage, byte[] maskingKey, byte[] expectedRawFrame)
{
var data = await RunSendTest(
producer: async (socket) =>
{
await socket.SendAsync(CreateFrame(
endOfMessage,
opcode,
payload: payload)).OrTimeout();
}, options: DefaultTestOptions.WithFixedMaskingKey(maskingKey));
Assert.Equal(expectedRawFrame, data);
}
[Fact]
public async Task WriteRandomMaskedFrame()
{
var data = await RunSendTest(
producer: async (socket) =>
{
await socket.SendAsync(CreateFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Binary,
payload: new byte[] { 0x0A, 0x0B, 0x0C, 0x0D, 0x0E })).OrTimeout();
}, options: DefaultTestOptions.WithRandomMasking());
// Verify the header
Assert.Equal(0x82, data[0]);
Assert.Equal(0x85, data[1]);
// We don't know the mask, so we have to read it in order to verify this frame
var mask = new byte[] { data[2], data[3], data[4], data[5] };
var actualPayload = new byte[data.Length - 6];
// Unmask the payload
for (int i = 0; i < actualPayload.Length; i++)
{
actualPayload[i] = (byte)(mask[i % 4] ^ data[i + 6]);
}
Assert.Equal(new byte[] { 0x0A, 0x0B, 0x0C, 0x0D, 0x0E }, actualPayload);
}
[Theory]
[InlineData(WebSocketCloseStatus.MandatoryExtension, "Hi", null, new byte[] { 0x88, 0x04, 0x03, 0xF2, (byte)'H', (byte)'i' })]
[InlineData(WebSocketCloseStatus.PolicyViolation, "", null, new byte[] { 0x88, 0x02, 0x03, 0xF0 })]
[InlineData(WebSocketCloseStatus.MandatoryExtension, "Hi", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x88, 0x84, 0x01, 0x02, 0x03, 0x04, 0x03 ^ 0x01, 0xF2 ^ 0x02, (byte)'H' ^ 0x03, (byte)'i' ^ 0x04 })]
[InlineData(WebSocketCloseStatus.PolicyViolation, "", new byte[] { 0x01, 0x02, 0x03, 0x04 }, new byte[] { 0x88, 0x82, 0x01, 0x02, 0x03, 0x04, 0x03 ^ 0x01, 0xF0 ^ 0x02 })]
public async Task WriteCloseFrames(WebSocketCloseStatus status, string description, byte[] maskingKey, byte[] expectedRawFrame)
{
var data = await RunSendTest(
producer: async (socket) =>
{
await socket.CloseAsync(new WebSocketCloseResult(status, description)).OrTimeout();
}, options: maskingKey == null ? DefaultTestOptions : DefaultTestOptions.WithFixedMaskingKey(maskingKey));
Assert.Equal(expectedRawFrame, data);
}
[Fact]
public async Task WriteMultipleFrames()
{
var data = await RunSendTest(
producer: async (socket) =>
{
await socket.SendAsync(CreateFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Binary,
payload: new byte[0])).OrTimeout();
await socket.SendAsync(CreateFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Binary,
payload: new byte[] { 0x01 })).OrTimeout();
await socket.SendAsync(CreateFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Text,
payload: new byte[0])).OrTimeout();
await socket.SendAsync(CreateFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Text,
payload: Encoding.UTF8.GetBytes("Hello"))).OrTimeout();
}, options: DefaultTestOptions);
Assert.Equal(new byte[]
{
0x82, 0x00, // Frame 1
0x82, 0x01, 0x01, // Frame 2
0x81, 0x00, // Frame 3
0x81, 0x05, (byte)'H', (byte)'e', (byte)'l', (byte)'l', (byte)'o' // Frame 4
}, data);
}
private static async Task<byte[]> RunSendTest(Func<WebSocketConnection, Task> producer, WebSocketOptions options)
{
using (var factory = new PipeFactory())
{
var outbound = factory.Create();
var inbound = factory.Create();
using (var connection = new WebSocketConnection(inbound.Reader, outbound.Writer, options))
{
var executeTask = connection.ExecuteAndCaptureFramesAsync();
await producer(connection).OrTimeout();
connection.Abort();
inbound.Writer.Complete();
await executeTask.OrTimeout();
}
var buffer = await outbound.Reader.ReadToEndAsync();
var data = buffer.ToArray();
outbound.Reader.Advance(buffer.End);
inbound.Reader.Complete();
CompleteChannels(outbound);
return data;
}
}
private static void CompleteChannels(params IPipe[] readerWriters)
{
foreach (var readerWriter in readerWriters)
{
readerWriter.Reader.Complete();
readerWriter.Writer.Complete();
}
}
}
}

View File

@ -1,223 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.SignalR.Tests.Common;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public partial class WebSocketConnectionTests
{
[Theory]
[InlineData(new byte[] { 0x48, 0x65, 0x6C, 0x6C, 0x6F }, "Hello")]
[InlineData(new byte[] { 0xC2, 0xA7, 0x31, 0x2C, 0x20, 0x39, 0x35, 0xC2, 0xA2 }, "§1, 95¢")]
[InlineData(new byte[] { 0xE0, 0xA0, 0x80, 0xE0, 0xA4, 0x80 }, "\u0800\u0900")]
[InlineData(new byte[] { 0xF0, 0x90, 0x80, 0x80 }, "\U00010000")]
public async Task ValidSingleFramePayloads(byte[] payload, string decoded)
{
using (var pair = WebSocketPair.Create())
{
var timeoutToken = TestUtil.CreateTimeoutToken();
using (timeoutToken.Register(() => pair.Dispose()))
{
var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync();
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var frame = new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Text,
payload: ReadableBuffer.Create(payload));
await pair.ClientSocket.SendAsync(frame).OrTimeout();
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout();
var serverSummary = await server.OrTimeout();
await pair.ServerSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout();
var clientSummary = await client.OrTimeout();
Assert.Equal(0, clientSummary.Received.Count);
Assert.Equal(1, serverSummary.Received.Count);
Assert.True(serverSummary.Received[0].EndOfMessage);
Assert.Equal(WebSocketOpcode.Text, serverSummary.Received[0].Opcode);
Assert.Equal(decoded, Encoding.UTF8.GetString(serverSummary.Received[0].Payload.ToArray()));
}
}
}
[Theory]
[InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0x6C, 0x6C, 0x6F }, "Hello")]
[InlineData(new byte[0], new byte[] { 0xC2, 0xA7 }, "§")]
[InlineData(new byte[] { 0xC2 }, new byte[] { 0xA7 }, "§")]
[InlineData(new byte[] { 0xC2, 0xA7 }, new byte[0], "§")]
[InlineData(new byte[0], new byte[] { 0xC2, 0xA2 }, "¢")]
[InlineData(new byte[] { 0xC2 }, new byte[] { 0xA2 }, "¢")]
[InlineData(new byte[] { 0xC2, 0xA2 }, new byte[0], "¢")]
[InlineData(new byte[0], new byte[] { 0xE0, 0xA0, 0x80 }, "\u0800")]
[InlineData(new byte[] { 0xE0 }, new byte[] { 0xA0, 0x80 }, "\u0800")]
[InlineData(new byte[] { 0xE0, 0xA0 }, new byte[] { 0x80 }, "\u0800")]
[InlineData(new byte[] { 0xE0, 0xA0, 0x80 }, new byte[0], "\u0800")]
[InlineData(new byte[0], new byte[] { 0xE0, 0xA4, 0x80 }, "\u0900")]
[InlineData(new byte[] { 0xE0 }, new byte[] { 0xA4, 0x80 }, "\u0900")]
[InlineData(new byte[] { 0xE0, 0xA4 }, new byte[] { 0x80 }, "\u0900")]
[InlineData(new byte[] { 0xE0, 0xA4, 0x80 }, new byte[0], "\u0900")]
[InlineData(new byte[0], new byte[] { 0xF0, 0x90, 0x80, 0x80 }, "\U00010000")]
[InlineData(new byte[] { 0xF0 }, new byte[] { 0x90, 0x80, 0x80 }, "\U00010000")]
[InlineData(new byte[] { 0xF0, 0x90 }, new byte[] { 0x80, 0x80 }, "\U00010000")]
[InlineData(new byte[] { 0xF0, 0x90, 0x80 }, new byte[] { 0x80 }, "\U00010000")]
[InlineData(new byte[] { 0xF0, 0x90, 0x80, 0x80 }, new byte[0], "\U00010000")]
public async Task ValidMultiFramePayloads(byte[] payload1, byte[] payload2, string decoded)
{
using (var pair = WebSocketPair.Create())
{
var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync();
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var frame = new WebSocketFrame(
endOfMessage: false,
opcode: WebSocketOpcode.Text,
payload: ReadableBuffer.Create(payload1));
await pair.ClientSocket.SendAsync(frame).OrTimeout();
frame = new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Continuation,
payload: ReadableBuffer.Create(payload2));
await pair.ClientSocket.SendAsync(frame).OrTimeout();
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout();
var serverSummary = await server.OrTimeout();
await pair.ServerSocket.CloseAsync(WebSocketCloseStatus.NormalClosure).OrTimeout();
var clientSummary = await client.OrTimeout();
Assert.Equal(0, clientSummary.Received.Count);
Assert.Equal(2, serverSummary.Received.Count);
Assert.False(serverSummary.Received[0].EndOfMessage);
Assert.Equal(WebSocketOpcode.Text, serverSummary.Received[0].Opcode);
Assert.True(serverSummary.Received[1].EndOfMessage);
Assert.Equal(WebSocketOpcode.Continuation, serverSummary.Received[1].Opcode);
var finalPayload = serverSummary.Received.SelectMany(f => f.Payload.ToArray()).ToArray();
Assert.Equal(decoded, Encoding.UTF8.GetString(finalPayload));
}
}
[Theory]
// Continuation byte as first byte of code point
[InlineData(new byte[] { 0x48, 0x65, 0x80, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x48, 0x65, 0x99, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x48, 0x65, 0xAB, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x48, 0x65, 0xB0, 0x6C, 0x6F })]
// Incomplete Code Point
[InlineData(new byte[] { 0xC2 })]
[InlineData(new byte[] { 0xE0 })]
[InlineData(new byte[] { 0xE0, 0xA0 })]
[InlineData(new byte[] { 0xE0, 0xA4 })]
[InlineData(new byte[] { 0xF0, 0x90, 0x80 })]
// Overlong Encoding
// 'H' (1 byte char) encoded with 2, 3 and 4 bytes
[InlineData(new byte[] { 0xC1, 0x88 })]
[InlineData(new byte[] { 0xE0, 0x81, 0x88 })]
[InlineData(new byte[] { 0xF0, 0x80, 0x81, 0x88 })]
// '§' (2 byte char) encoded with 3 and 4 bytes
[InlineData(new byte[] { 0xE0, 0x82, 0xA7 })]
[InlineData(new byte[] { 0xF0, 0x80, 0x82, 0xA7 })]
// '\u0800' (3 byte char) encoded with 4 bytes
[InlineData(new byte[] { 0xF0, 0x80, 0xA0, 0x80 })]
public async Task InvalidSingleFramePayloads(byte[] payload)
{
using (var pair = WebSocketPair.Create())
{
var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync();
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var frame = new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Text,
payload: ReadableBuffer.Create(payload));
await pair.ClientSocket.SendAsync(frame).OrTimeout();
var clientSummary = await client.OrTimeout();
var serverSummary = await server.OrTimeout();
Assert.Equal(0, serverSummary.Received.Count);
Assert.Equal(0, clientSummary.Received.Count);
Assert.Equal(WebSocketCloseStatus.InvalidPayloadData, clientSummary.CloseResult.Status);
Assert.Equal("An invalid Text frame payload was received", clientSummary.CloseResult.Description);
}
}
[Theory]
// Continuation byte as first byte of code point
[InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0x80, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0x99, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0xAB, 0x6C, 0x6F })]
[InlineData(new byte[] { 0x48, 0x65 }, new byte[] { 0xB0, 0x6C, 0x6F })]
// Incomplete Code Point
[InlineData(new byte[] { 0xC2 }, new byte[0])]
[InlineData(new byte[] { 0xE0 }, new byte[0])]
[InlineData(new byte[] { 0xE0, 0xA0 }, new byte[0])]
[InlineData(new byte[] { 0xE0, 0xA4 }, new byte[0])]
[InlineData(new byte[] { 0xF0, 0x90, 0x80 }, new byte[0])]
// Overlong Encoding
// 'H' (1 byte char) encoded with 3 and 4 bytes
[InlineData(new byte[] { 0xE0 }, new byte[] { 0x81, 0x88 })]
[InlineData(new byte[] { 0xF0 }, new byte[] { 0x80, 0x81, 0x88 })]
// '§' (2 byte char) encoded with 3 and 4 bytes
[InlineData(new byte[] { 0xE0, 0x82 }, new byte[] { 0xA7 })]
[InlineData(new byte[] { 0xF0, 0x80 }, new byte[] { 0x82, 0xA7 })]
// '\u0800' (3 byte char) encoded with 4 bytes
[InlineData(new byte[] { 0xF0, 0x80 }, new byte[] { 0xA0, 0x80 })]
public async Task InvalidMultiFramePayloads(byte[] payload1, byte[] payload2)
{
using (var pair = WebSocketPair.Create())
{
var timeoutToken = TestUtil.CreateTimeoutToken();
using (timeoutToken.Register(() => pair.Dispose()))
{
var server = pair.ServerSocket.ExecuteAndCaptureFramesAsync();
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
var frame = new WebSocketFrame(
endOfMessage: false,
opcode: WebSocketOpcode.Text,
payload: ReadableBuffer.Create(payload1));
await pair.ClientSocket.SendAsync(frame).OrTimeout();
frame = new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Continuation,
payload: ReadableBuffer.Create(payload2));
await pair.ClientSocket.SendAsync(frame).OrTimeout();
var clientSummary = await client.OrTimeout();
var serverSummary = await server.OrTimeout();
Assert.Equal(1, serverSummary.Received.Count);
Assert.False(serverSummary.Received[0].EndOfMessage);
Assert.Equal(WebSocketOpcode.Text, serverSummary.Received[0].Opcode);
Assert.Equal(payload1, serverSummary.Received[0].Payload.ToArray());
Assert.Equal(0, clientSummary.Received.Count);
Assert.Equal(WebSocketCloseStatus.InvalidPayloadData, clientSummary.CloseResult.Status);
Assert.Equal("An invalid Text frame payload was received", clientSummary.CloseResult.Description);
}
}
}
}
}

View File

@ -1,66 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
internal class WebSocketPair : IDisposable
{
private static readonly WebSocketOptions DefaultServerOptions = new WebSocketOptions().WithAllFramesPassedThrough().WithRandomMasking();
private static readonly WebSocketOptions DefaultClientOptions = new WebSocketOptions().WithAllFramesPassedThrough();
private PipeFactory _factory;
private readonly bool _ownFactory;
public IPipe ServerToClient { get; }
public IPipe ClientToServer { get; }
public IWebSocketConnection ClientSocket { get; }
public IWebSocketConnection ServerSocket { get; }
public WebSocketPair(bool ownFactory, PipeFactory factory, IPipe serverToClient, IPipe clientToServer, IWebSocketConnection clientSocket, IWebSocketConnection serverSocket)
{
_ownFactory = ownFactory;
_factory = factory;
ServerToClient = serverToClient;
ClientToServer = clientToServer;
ClientSocket = clientSocket;
ServerSocket = serverSocket;
}
public static WebSocketPair Create() => Create(new PipeFactory(), DefaultServerOptions, DefaultClientOptions, ownFactory: true);
public static WebSocketPair Create(PipeFactory factory) => Create(factory, DefaultServerOptions, DefaultClientOptions, ownFactory: false);
public static WebSocketPair Create(WebSocketOptions serverOptions, WebSocketOptions clientOptions) => Create(new PipeFactory(), serverOptions, clientOptions, ownFactory: true);
public static WebSocketPair Create(PipeFactory factory, WebSocketOptions serverOptions, WebSocketOptions clientOptions) => Create(factory, serverOptions, clientOptions, ownFactory: false);
private static WebSocketPair Create(PipeFactory factory, WebSocketOptions serverOptions, WebSocketOptions clientOptions, bool ownFactory)
{
// Create channels
var serverToClient = factory.Create();
var clientToServer = factory.Create();
var serverSocket = new WebSocketConnection(clientToServer.Reader, serverToClient.Writer, options: serverOptions);
var clientSocket = new WebSocketConnection(serverToClient.Reader, clientToServer.Writer, options: clientOptions);
return new WebSocketPair(ownFactory, factory, serverToClient, clientToServer, clientSocket, serverSocket);
}
public void Dispose()
{
ServerSocket.Dispose();
ClientSocket.Dispose();
if (_ownFactory)
{
_factory.Dispose();
}
}
public void TerminateFromClient(Exception ex = null)
{
ClientToServer.Writer.Complete(ex);
}
}
}

View File

@ -1,36 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Console;
namespace WebSocketsTestApp
{
public class Program
{
public static void Main(string[] args)
{
var config = new ConfigurationBuilder()
.AddCommandLine(args)
.Build();
var host = new WebHostBuilder()
.UseConfiguration(config)
.ConfigureLogging(factory =>
{
factory.AddConsole();
factory.AddFilter<ConsoleLoggerProvider>(level => level >= LogLevel.Debug);
})
.UseKestrel()
.UseContentRoot(Directory.GetCurrentDirectory())
.UseIISIntegration()
.UseStartup<Startup>()
.Build();
host.Run();
}
}
}

View File

@ -1,116 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.WebSockets.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.WebSockets.Internal;
namespace WebSocketsTestApp
{
public class Startup
{
// This method gets called by the runtime. Use this method to add services to the container.
// For more information on how to configure your application, visit http://go.microsoft.com/fwlink/?LinkID=398940
public void ConfigureServices(IServiceCollection services)
{
services.AddSingleton<PipeFactory>();
}
// This method gets called by the runtime. Use this method to configure the HTTP request pipeline.
public void Configure(IApplicationBuilder app, IHostingEnvironment env, ILoggerFactory loggerFactory, PipeFactory PipeFactory)
{
if (env.IsDevelopment())
{
app.UseDeveloperExceptionPage();
}
app.UseWebSocketConnections(PipeFactory);
app.Use(async (context, next) =>
{
var webSocketConnectionFeature = context.Features.Get<IHttpWebSocketConnectionFeature>();
if (webSocketConnectionFeature != null && webSocketConnectionFeature.IsWebSocketRequest)
{
using (var webSocket = await webSocketConnectionFeature.AcceptWebSocketConnectionAsync(new WebSocketAcceptContext()))
{
await Echo(context, webSocket, loggerFactory.CreateLogger("Echo"));
}
}
else
{
await next();
}
});
app.UseFileServer();
}
private async Task Echo(HttpContext context, IWebSocketConnection webSocket, ILogger logger)
{
var lastFrameOpcode = WebSocketOpcode.Continuation;
var closeResult = await webSocket.ExecuteAsync(frame =>
{
if (frame.Opcode == WebSocketOpcode.Ping || frame.Opcode == WebSocketOpcode.Pong)
{
// Already handled
return Task.CompletedTask;
}
LogFrame(logger, lastFrameOpcode, ref frame);
// If the client send "ServerClose", then they want a server-originated close to occur
string content = "<<binary>>";
if (frame.Opcode == WebSocketOpcode.Text)
{
// Slooooow
content = Encoding.UTF8.GetString(frame.Payload.ToArray());
if (content.Equals("ServerClose"))
{
logger.LogDebug($"Sending Frame Close: {WebSocketCloseStatus.NormalClosure} Closing from Server");
return webSocket.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure, "Closing from Server"));
}
else if (content.Equals("ServerAbort"))
{
context.Abort();
}
}
if (frame.Opcode != WebSocketOpcode.Continuation)
{
lastFrameOpcode = frame.Opcode;
}
logger.LogDebug($"Sending {frame.Opcode}: Len={frame.Payload.Length}, Fin={frame.EndOfMessage}: {content}");
return webSocket.SendAsync(frame);
});
if (webSocket.State == WebSocketConnectionState.CloseReceived)
{
// Close the connection from our end
await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
logger.LogDebug("Socket closed");
}
else if (webSocket.State != WebSocketConnectionState.Closed)
{
logger.LogError("WebSocket closed but not closed?");
}
}
private void LogFrame(ILogger logger, WebSocketOpcode lastFrameOpcode, ref WebSocketFrame frame)
{
var opcode = frame.Opcode;
if (opcode == WebSocketOpcode.Continuation)
{
opcode = lastFrameOpcode;
}
logger.LogDebug($"Received {frame.Opcode} frame (FIN={frame.EndOfMessage}, LEN={frame.Payload.Length})");
}
}
}

View File

@ -1,24 +0,0 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<Import Project="..\..\build\common.props" />
<PropertyGroup>
<TargetFrameworks>netcoreapp2.0;net461</TargetFrameworks>
<TargetFrameworks Condition="'$(OS)' != 'Windows_NT'">netcoreapp2.0</TargetFrameworks>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.WebSockets.Internal\Microsoft.AspNetCore.WebSockets.Internal.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Diagnostics" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Server.IISIntegration" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Server.Kestrel" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.StaticFiles" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.Configuration.CommandLine" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.Configuration" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="$(AspNetCoreVersion)" />
</ItemGroup>
</Project>

View File

@ -1,43 +0,0 @@
#
# RunAutobahnTests.ps1
#
param([Parameter(Mandatory=$true)][string]$ServerUrl, [string[]]$Cases = @("*"), [string]$OutputDir, [int]$Iterations = 1)
if(!(Get-Command wstest -ErrorAction SilentlyContinue)) {
throw "Missing required command 'wstest'. See README.md in Microsoft.AspNetCore.WebSockets.Server.Test project for information on installing Autobahn Test Suite."
}
if(!$OutputDir) {
$OutputDir = Convert-Path "."
$OutputDir = Join-Path $OutputDir "autobahnreports"
}
Write-Host "Launching Autobahn Test Suite ($Iterations iteration(s))..."
0..($Iterations-1) | % {
$iteration = $_
$Spec = Convert-Path (Join-Path $PSScriptRoot "autobahn.spec.json")
$CasesArray = [string]::Join(",", @($Cases | ForEach-Object { "`"$_`"" }))
$SpecJson = [IO.File]::ReadAllText($Spec).Replace("OUTPUTDIR", $OutputDir.Replace("\", "\\")).Replace("WEBSOCKETURL", $ServerUrl).Replace("`"CASES`"", $CasesArray)
$TempFile = [IO.Path]::GetTempFileName()
try {
[IO.File]::WriteAllText($TempFile, $SpecJson)
$wstestOutput = & wstest -m fuzzingclient -s $TempFile
} finally {
if(Test-Path $TempFile) {
rm $TempFile
}
}
$report = ConvertFrom-Json ([IO.File]::ReadAllText((Convert-Path (Join-Path $OutputDir "index.json"))))
$report.Server | gm | ? { $_.MemberType -eq "NoteProperty" } | % {
$case = $report.Server."$($_.Name)"
Write-Host "[#$($iteration.ToString().PadRight(2))] [$($case.behavior.PadRight(6))] Case $($_.Name)"
}
}

View File

@ -1,14 +0,0 @@
{
"options": { "failByDrop": false },
"outdir": "OUTPUTDIR",
"servers": [
{
"agent": "Server",
"url": "WEBSOCKETURL",
"options": { "version": 18 }
}
],
"cases": ["CASES"],
"exclude-cases": ["12.*", "13.*"],
"exclude-agent-cases": {}
}

View File

@ -1,14 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<configuration>
<!--
Configure your application settings in appsettings.json. Learn more at http://go.microsoft.com/fwlink/?LinkId=786380
-->
<system.webServer>
<handlers>
<add name="aspNetCore" path="*" verb="*" modules="AspNetCoreModule" resourceType="Unspecified"/>
</handlers>
<aspNetCore processPath="%LAUNCHER_PATH%" arguments="%LAUNCHER_ARGS%" stdoutLogEnabled="false" stdoutLogFile=".\logs\stdout" forwardWindowsAuthToken="false"/>
</system.webServer>
</configuration>

View File

@ -1,151 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title></title>
<style>
table { border: 0 }
.commslog-data { font-family: Consolas, Courier New, Courier, monospace; }
.commslog-server { background-color: red; color: white }
.commslog-client { background-color: green; color: white }
</style>
</head>
<body>
<h1>WebSocket Test Page</h1>
<p id="stateLabel">Ready to connect...</p>
<div>
<label for="connectionUrl">WebSocket Server URL:</label>
<input id="connectionUrl" />
<button id="connectButton" type="submit">Connect</button>
</div>
<div>
<label for="sendMessage">Message to send:</label>
<input id="sendMessage" disabled />
<button id="sendButton" type="submit" disabled>Send</button>
<button id="closeButton" disabled>Close Socket</button>
</div>
<p>Note: When connected to the default server (i.e. the server in the address bar ;)), the message "ServerClose" will cause the server to close the connection. Similarly, the message "ServerAbort" will cause the server to forcibly terminate the connection without a closing handshake</p>
<h2>Communication Log</h2>
<table style="width: 800px">
<thead>
<tr>
<td style="width: 100px">From</td>
<td style="width: 100px">To</td>
<td>Data</td>
</tr>
</thead>
<tbody id="commsLog">
</tbody>
</table>
<script>
var connectionForm = document.getElementById("connectionForm");
var connectionUrl = document.getElementById("connectionUrl");
var connectButton = document.getElementById("connectButton");
var stateLabel = document.getElementById("stateLabel");
var sendMessage = document.getElementById("sendMessage");
var sendButton = document.getElementById("sendButton");
var sendForm = document.getElementById("sendForm");
var commsLog = document.getElementById("commsLog");
var socket;
var scheme = document.location.protocol == "https:" ? "wss" : "ws";
var port = document.location.port ? (":" + document.location.port) : "";
connectionUrl.value = scheme + "://" + document.location.hostname + port;
function updateState() {
function disable() {
sendMessage.disabled = true;
sendButton.disabled = true;
closeButton.disabled = true;
}
function enable() {
sendMessage.disabled = false;
sendButton.disabled = false;
closeButton.disabled = false;
}
connectionUrl.disabled = true;
connectButton.disabled = true;
if (!socket) {
disable();
} else {
switch (socket.readyState) {
case WebSocket.CLOSED:
stateLabel.innerHTML = "Closed";
disable();
connectionUrl.disabled = false;
connectButton.disabled = false;
break;
case WebSocket.CLOSING:
stateLabel.innerHTML = "Closing...";
disable();
break;
case WebSocket.CONNECTING:
stateLabel.innerHTML = "Connecting...";
disable();
break;
case WebSocket.OPEN:
stateLabel.innerHTML = "Open";
enable();
break;
default:
stateLabel.innerHTML = "Unknown WebSocket State: " + socket.readyState;
disable();
break;
}
}
}
closeButton.onclick = function () {
if (!socket || socket.readyState != WebSocket.OPEN) {
alert("socket not connected");
}
socket.close(1000, "Closing from client");
}
sendButton.onclick = function () {
if (!socket || socket.readyState != WebSocket.OPEN) {
alert("socket not connected");
}
var data = sendMessage.value;
socket.send(data);
commsLog.innerHTML += '<tr>' +
'<td class="commslog-client">Client</td>' +
'<td class="commslog-server">Server</td>' +
'<td class="commslog-data">' + data + '</td>'
'</tr>';
}
connectButton.onclick = function() {
stateLabel.innerHTML = "Connecting";
socket = new WebSocket(connectionUrl.value);
socket.onopen = function (event) {
updateState();
commsLog.innerHTML += '<tr>' +
'<td colspan="3" class="commslog-data">Connection opened</td>' +
'</tr>';
};
socket.onclose = function (event) {
updateState();
commsLog.innerHTML += '<tr>' +
'<td colspan="3" class="commslog-data">Connection closed. Code: ' + event.code + '. Reason: ' + event.reason + '</td>' +
'</tr>';
};
socket.onerror = updateState;
socket.onmessage = function (event) {
commsLog.innerHTML += '<tr>' +
'<td class="commslog-server">Server</td>' +
'<td class="commslog-client">Client</td>' +
'<td class="commslog-data">' + event.data + '</td>'
'</tr>';
};
};
</script>
</body>
</html>