Fix #215 and restore tests (#218)

* fix #215 by properly handling pipe closure

* pr feedback

* pr feedback
This commit is contained in:
Andrew Stanton-Nurse 2017-02-21 15:27:52 -08:00 committed by GitHub
parent 9709139a27
commit 755ba7613e
12 changed files with 304 additions and 257 deletions

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
@ -75,7 +75,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
{
if (receiving.IsCanceled || receiving.IsFaulted)
{
// The receiver faulted or cancelled. This means the client is probably broken. Just propagate the exception and exit
// The receiver faulted or cancelled. This means the socket is probably broken. Abort the socket and propagate the exception
receiving.GetAwaiter().GetResult();
// Should never get here because GetResult above will throw

View File

@ -64,6 +64,11 @@ namespace Microsoft.Extensions.WebSockets.Internal
/// <param name="state">A state parameter that will be passed to each invocation of <paramref name="messageHandler"/></param>
/// <returns>A <see cref="Task{WebSocketCloseResult}"/> that will complete when the client has sent a close frame, or the connection has been terminated</returns>
Task<WebSocketCloseResult> ExecuteAsync(Func<WebSocketFrame, object, Task> messageHandler, object state);
/// <summary>
/// Forcibly terminates the socket, cleaning up the necessary resources.
/// </summary>
void Abort();
}
public static class WebSocketConnectionExtensions

View File

@ -13,9 +13,10 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="System.ValueTuple" Version="$(CoreFxVersion)" />
<PackageReference Include="System.IO.Pipelines" Version="$(CoreFxLabsVersion)" />
<PackageReference Include="System.IO.Pipelines.Text.Primitives" Version="$(CoreFxLabsVersion)" />
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="1.2.0-*" PrivateAssets="All"/>
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="1.2.0-*" PrivateAssets="All" />
</ItemGroup>
</Project>

View File

@ -1,31 +1,29 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading;
using System.Threading.Tasks;
using System.IO.Pipelines;
using System.Threading.Tasks;
namespace Microsoft.Extensions.WebSockets.Internal
{
public static class PipeReaderExtensions
{
public static ValueTask<ReadResult> ReadAtLeastAsync(this IPipeReader input, int minimumRequiredBytes) => ReadAtLeastAsync(input, minimumRequiredBytes, CancellationToken.None);
// TODO: Pull this up to Channels. We should be able to do it there without allocating a Task<T> in any case (rather than here where we can avoid allocation
// only if the buffer is already ready and has enough data)
public static ValueTask<ReadResult> ReadAtLeastAsync(this IPipeReader input, int minimumRequiredBytes, CancellationToken cancellationToken)
public static async ValueTask<ReadResult> ReadAtLeastAsync(this IPipeReader input, int minimumRequiredBytes)
{
var awaiter = input.ReadAsync(/* cancellationToken */);
// Short-cut path!
ReadResult result;
if (awaiter.IsCompleted)
{
// We have a buffer, is it big enough?
var result = awaiter.GetResult();
result = awaiter.GetResult();
if (result.IsCompleted || result.Buffer.Length >= minimumRequiredBytes)
{
return new ValueTask<ReadResult>(result);
return result;
}
// Buffer wasn't big enough, mark it as examined and continue to the "slow" path below
@ -33,15 +31,9 @@ namespace Microsoft.Extensions.WebSockets.Internal
consumed: result.Buffer.Start,
examined: result.Buffer.End);
}
return new ValueTask<ReadResult>(ReadAtLeastSlowAsync(awaiter, input, minimumRequiredBytes, cancellationToken));
}
private static async Task<ReadResult> ReadAtLeastSlowAsync(ReadableBufferAwaitable awaitable, IPipeReader input, int minimumRequiredBytes, CancellationToken cancellationToken)
{
var result = await awaitable;
while (!result.IsCompleted && result.Buffer.Length < minimumRequiredBytes)
result = await awaiter;
while (!result.IsCancelled && !result.IsCompleted && result.Buffer.Length < minimumRequiredBytes)
{
cancellationToken.ThrowIfCancellationRequested();
input.Advance(
consumed: result.Buffer.Start,
examined: result.Buffer.End);

View File

@ -31,7 +31,6 @@ namespace Microsoft.Extensions.WebSockets.Internal
private readonly byte[] _maskingKeyBuffer;
private readonly IPipeReader _inbound;
private readonly IPipeWriter _outbound;
private readonly CancellationTokenSource _terminateReceiveCts = new CancellationTokenSource();
private readonly Timer _pinger;
private readonly CancellationTokenSource _timerCts = new CancellationTokenSource();
private Utf8Validator _validator = new Utf8Validator();
@ -113,8 +112,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
{
// We don't need to wait for this task to complete, we're "tail calling" and
// we are in a Timer thread-pool thread.
#pragma warning disable 4014
connection.SendCoreLockAcquiredAsync(
var ignore = connection.SendCoreLockAcquiredAsync(
fin: true,
opcode: WebSocketOpcode.Ping,
payloadAllocLength: 28,
@ -122,7 +120,6 @@ namespace Microsoft.Extensions.WebSockets.Internal
payloadWriter: PingPayloadWriter,
payload: DateTime.UtcNow,
cancellationToken: connection._timerCts.Token);
#pragma warning restore 4014
}
}
@ -131,7 +128,6 @@ namespace Microsoft.Extensions.WebSockets.Internal
State = WebSocketConnectionState.Closed;
_pinger?.Dispose();
_timerCts.Cancel();
_terminateReceiveCts.Cancel();
_inbound.Complete();
_outbound.Complete();
}
@ -148,7 +144,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
throw new InvalidOperationException("Connection is already running.");
}
State = WebSocketConnectionState.Connected;
return ReceiveLoop(messageHandler, state, _terminateReceiveCts.Token);
return ReceiveLoop(messageHandler, state);
}
/// <summary>
@ -249,269 +245,322 @@ namespace Microsoft.Extensions.WebSockets.Internal
_maskingKeyBuffer.CopyTo(buffer);
}
private async Task<WebSocketCloseResult> ReceiveLoop(Func<WebSocketFrame, object, Task> messageHandler, object state, CancellationToken cancellationToken)
/// <summary>
/// Terminates the socket abruptly.
/// </summary>
public void Abort()
{
while (!cancellationToken.IsCancellationRequested)
// We duplicate some work from Dispose here, but that's OK.
_timerCts.Cancel();
_inbound.CancelPendingRead();
_outbound.Complete();
}
private async ValueTask<(bool Success, byte OpcodeByte, bool Masked, bool Fin, int Length, uint MaskingKey)> ReadHeaderAsync()
{
// Read at least 2 bytes
var readResult = await _inbound.ReadAtLeastAsync(2);
if (readResult.IsCancelled || (readResult.IsCompleted && readResult.Buffer.Length < 2))
{
// WebSocket Frame layout (https://tools.ietf.org/html/rfc6455#section-5.2):
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-------+-+-------------+-------------------------------+
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
// |I|S|S|S| (4) |A| (7) | (16/64) |
// |N|V|V|V| |S| | (if payload len==126/127) |
// | |1|2|3| |K| | |
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
// | Extended payload length continued, if payload len == 127 |
// + - - - - - - - - - - - - - - - +-------------------------------+
// | |Masking-key, if MASK set to 1 |
// +-------------------------------+-------------------------------+
// | Masking-key (continued) | Payload Data |
// +-------------------------------- - - - - - - - - - - - - - - - +
// : Payload Data continued ... :
// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
// | Payload Data continued ... |
// +---------------------------------------------------------------+
_inbound.Advance(readResult.Buffer.End);
return (Success: false, OpcodeByte: 0, Masked: false, Fin: false, Length: 0, MaskingKey: 0);
}
var buffer = readResult.Buffer;
// Read at least 2 bytes
var result = await _inbound.ReadAtLeastAsync(2, cancellationToken);
cancellationToken.ThrowIfCancellationRequested();
if (result.IsCompleted && result.Buffer.Length < 2)
// Read the opcode and length
var opcodeByte = buffer.ReadBigEndian<byte>();
buffer = buffer.Slice(1);
// Read the first byte of the payload length
var lengthByte = buffer.ReadBigEndian<byte>();
buffer = buffer.Slice(1);
_inbound.Advance(buffer.Start);
// Determine how much header there still is to read
var fin = (opcodeByte & 0x80) != 0;
var masked = (lengthByte & 0x80) != 0;
var length = lengthByte & 0x7F;
// Calculate the rest of the header length
var headerLength = masked ? 4 : 0;
if (length == 126)
{
headerLength += 2;
}
else if (length == 127)
{
headerLength += 8;
}
// Read the next set of header data
uint maskingKey = 0;
if (headerLength > 0)
{
readResult = await _inbound.ReadAtLeastAsync(headerLength);
if (readResult.IsCancelled || (readResult.IsCompleted && readResult.Buffer.Length < headerLength))
{
return WebSocketCloseResult.AbnormalClosure;
_inbound.Advance(readResult.Buffer.End);
return (Success: false, OpcodeByte: 0, Masked: false, Fin: false, Length: 0, MaskingKey: 0);
}
var buffer = result.Buffer;
buffer = readResult.Buffer;
// Read the opcode
var opcodeByte = buffer.ReadBigEndian<byte>();
buffer = buffer.Slice(1);
var fin = (opcodeByte & 0x80) != 0;
var opcodeNum = opcodeByte & 0x0F;
var opcode = (WebSocketOpcode)opcodeNum;
if ((opcodeByte & 0x70) != 0)
// Read extended payload length (if any)
if (length == 126)
{
// Reserved bits set, this frame is invalid, close our side and terminate immediately
return await CloseFromProtocolError(cancellationToken, 0, default(ReadableBuffer), "Reserved bits, which are required to be zero, were set.");
length = buffer.ReadBigEndian<ushort>();
buffer = buffer.Slice(sizeof(ushort));
}
else if ((opcodeNum >= 0x03 && opcodeNum <= 0x07) || (opcodeNum >= 0x0B && opcodeNum <= 0x0F))
else if (length == 127)
{
// Reserved opcode
return await CloseFromProtocolError(cancellationToken, 0, default(ReadableBuffer), $"Received frame using reserved opcode: 0x{opcodeNum:X}");
var longLen = buffer.ReadBigEndian<ulong>();
buffer = buffer.Slice(sizeof(ulong));
if (longLen > int.MaxValue)
{
throw new WebSocketException($"Frame is too large. Maximum frame size is {int.MaxValue} bytes");
}
length = (int)longLen;
}
// Read the first byte of the payload length
var lenByte = buffer.ReadBigEndian<byte>();
buffer = buffer.Slice(1);
// Read masking key
if (masked)
{
var maskingKeyStart = buffer.Start;
maskingKey = buffer.Slice(0, sizeof(uint)).ReadBigEndian<uint>();
buffer = buffer.Slice(sizeof(uint));
}
var masked = (lenByte & 0x80) != 0;
var payloadLen = (lenByte & 0x7F);
// Mark what we've got so far as consumed
// Mark the length and masking key consumed
_inbound.Advance(buffer.Start);
}
// Calculate the rest of the header length
var headerLength = masked ? 4 : 0;
if (payloadLen == 126)
return (Success: true, opcodeByte, masked, fin, length, maskingKey);
}
private async ValueTask<(bool Success, ReadableBuffer Buffer)> ReadPayloadAsync(int length, bool masked, uint maskingKey)
{
var payload = default(ReadableBuffer);
if (length > 0)
{
var readResult = await _inbound.ReadAtLeastAsync(length);
if (readResult.IsCancelled || (readResult.IsCompleted && readResult.Buffer.Length < length))
{
headerLength += 2;
return (Success: false, Buffer: readResult.Buffer);
}
else if (payloadLen == 127)
var buffer = readResult.Buffer;
payload = buffer.Slice(0, length);
if (masked)
{
headerLength += 8;
// Unmask
MaskingUtilities.ApplyMask(ref payload, maskingKey);
}
}
return (Success: true, Buffer: payload);
}
uint maskingKey = 0;
if (headerLength > 0)
private async Task<WebSocketCloseResult> ReceiveLoop(Func<WebSocketFrame, object, Task> messageHandler, object state)
{
try
{
while (true)
{
result = await _inbound.ReadAtLeastAsync(headerLength, cancellationToken);
cancellationToken.ThrowIfCancellationRequested();
if (result.IsCompleted && result.Buffer.Length < headerLength)
{
return WebSocketCloseResult.AbnormalClosure;
}
buffer = result.Buffer;
// WebSocket Frame layout (https://tools.ietf.org/html/rfc6455#section-5.2):
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-------+-+-------------+-------------------------------+
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
// |I|S|S|S| (4) |A| (7) | (16/64) |
// |N|V|V|V| |S| | (if payload len==126/127) |
// | |1|2|3| |K| | |
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
// | Extended payload length continued, if payload len == 127 |
// + - - - - - - - - - - - - - - - +-------------------------------+
// | |Masking-key, if MASK set to 1 |
// +-------------------------------+-------------------------------+
// | Masking-key (continued) | Payload Data |
// +-------------------------------- - - - - - - - - - - - - - - - +
// : Payload Data continued ... :
// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
// | Payload Data continued ... |
// +---------------------------------------------------------------+
// Read extended payload length (if any)
if (payloadLen == 126)
var header = await ReadHeaderAsync();
if (!header.Success)
{
payloadLen = buffer.ReadBigEndian<ushort>();
buffer = buffer.Slice(sizeof(ushort));
break;
}
else if (payloadLen == 127)
// Validate Opcode
var opcodeNum = header.OpcodeByte & 0x0F;
if ((header.OpcodeByte & 0x70) != 0)
{
var longLen = buffer.ReadBigEndian<ulong>();
buffer = buffer.Slice(sizeof(ulong));
if (longLen > int.MaxValue)
// Reserved bits set, this frame is invalid, close our side and terminate immediately
await CloseFromProtocolError("Reserved bits, which are required to be zero, were set.");
break;
}
else if ((opcodeNum >= 0x03 && opcodeNum <= 0x07) || (opcodeNum >= 0x0B && opcodeNum <= 0x0F))
{
// Reserved opcode
await CloseFromProtocolError($"Received frame using reserved opcode: 0x{opcodeNum:X}");
break;
}
var opcode = (WebSocketOpcode)opcodeNum;
var payload = await ReadPayloadAsync(header.Length, header.Masked, header.MaskingKey);
if (!payload.Success)
{
_inbound.Advance(payload.Buffer.End);
break;
}
var frame = new WebSocketFrame(header.Fin, opcode, payload.Buffer);
// Start a try-finally because we may get an exception while closing, if there's an error
// And we need to advance the buffer even if that happens. It wasn't needed above because
// we had already parsed the buffer before we verified it, so we had already advanced the
// buffer, if we encountered an error while closing we didn't have to advance the buffer.
// Side Note: Look at this gloriously aligned comment. You have anurse and brecon to thank for it. Oh wait, I ruined it.
try
{
if (frame.Opcode.IsControl() && !frame.EndOfMessage)
{
throw new WebSocketException($"Frame is too large. Maximum frame size is {int.MaxValue} bytes");
// Control frames cannot be fragmented.
await CloseFromProtocolError("Control frames may not be fragmented");
break;
}
payloadLen = (int)longLen;
}
// Read masking key
if (masked)
{
var maskingKeyStart = buffer.Start;
maskingKey = buffer.Slice(0, 4).ReadBigEndian<uint>();
buffer = buffer.Slice(4);
}
// Mark the length and masking key consumed
_inbound.Advance(buffer.Start);
}
var payload = default(ReadableBuffer);
if (payloadLen > 0)
{
result = await _inbound.ReadAtLeastAsync(payloadLen, cancellationToken);
cancellationToken.ThrowIfCancellationRequested();
if (result.IsCompleted && result.Buffer.Length < payloadLen)
{
return WebSocketCloseResult.AbnormalClosure;
}
buffer = result.Buffer;
payload = buffer.Slice(0, payloadLen);
if (masked)
{
// Unmask
MaskingUtilities.ApplyMask(ref payload, maskingKey);
}
}
// Run the callback, if we're not cancelled.
cancellationToken.ThrowIfCancellationRequested();
var frame = new WebSocketFrame(fin, opcode, payload);
if (frame.Opcode.IsControl() && !frame.EndOfMessage)
{
// Control frames cannot be fragmented.
return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Control frames may not be fragmented");
}
else if (_currentMessageType != WebSocketOpcode.Continuation && opcode.IsMessage() && opcode != 0)
{
return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Received non-continuation frame during a fragmented message");
}
else if (_currentMessageType == WebSocketOpcode.Continuation && frame.Opcode == WebSocketOpcode.Continuation)
{
return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Continuation Frame was received when expecting a new message");
}
if (frame.Opcode == WebSocketOpcode.Close)
{
// Allowed frame lengths:
// 0 - No body
// 2 - Code with no reason phrase
// >2 - Code and reason phrase (must be valid UTF-8)
if (frame.Payload.Length > 125)
{
return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Close frame payload too long. Maximum size is 125 bytes");
}
else if ((frame.Payload.Length == 1) || (frame.Payload.Length > 2 && !Utf8Validator.ValidateUtf8(payload.Slice(2))))
{
return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Close frame payload invalid");
}
ushort? actualStatusCode;
var closeResult = HandleCloseFrame(payload, frame, out actualStatusCode);
// Verify the close result
if (actualStatusCode != null)
{
var statusCode = actualStatusCode.Value;
if (statusCode < 1000 || statusCode == 1004 || statusCode == 1005 || statusCode == 1006 || (statusCode > 1011 && statusCode < 3000))
else if (_currentMessageType != WebSocketOpcode.Continuation && opcode.IsMessage() && opcode != 0)
{
return await CloseFromProtocolError(cancellationToken, payloadLen, payload, $"Invalid close status: {statusCode}.");
await CloseFromProtocolError("Received non-continuation frame during a fragmented message");
break;
}
else if (_currentMessageType == WebSocketOpcode.Continuation && frame.Opcode == WebSocketOpcode.Continuation)
{
await CloseFromProtocolError("Continuation Frame was received when expecting a new message");
break;
}
if (frame.Opcode == WebSocketOpcode.Close)
{
return await ProcessCloseFrameAsync(frame);
}
else
{
if (frame.Opcode == WebSocketOpcode.Ping)
{
// Check the ping payload length
if (frame.Payload.Length > 125)
{
// Payload too long
await CloseFromProtocolError("Ping frame exceeded maximum size of 125 bytes");
break;
}
await SendCoreAsync(
frame.EndOfMessage,
WebSocketOpcode.Pong,
payloadAllocLength: 0,
payloadLength: frame.Payload.Length,
payloadWriter: AppendPayloadWriter,
payload: frame.Payload,
cancellationToken: CancellationToken.None);
}
var effectiveOpcode = opcode == WebSocketOpcode.Continuation ? _currentMessageType : opcode;
if (effectiveOpcode == WebSocketOpcode.Text && !_validator.ValidateUtf8Frame(frame.Payload, frame.EndOfMessage))
{
// Drop the frame and immediately close with InvalidPayload
await CloseFromProtocolError("An invalid Text frame payload was received", statusCode: WebSocketCloseStatus.InvalidPayloadData);
break;
}
else if (_options.PassAllFramesThrough || (frame.Opcode != WebSocketOpcode.Ping && frame.Opcode != WebSocketOpcode.Pong))
{
await messageHandler(frame, state);
}
}
}
finally
{
if (frame.Payload.Length > 0)
{
_inbound.Advance(frame.Payload.End);
}
}
// Make the payload as consumed
if (payloadLen > 0)
if (header.Fin)
{
_inbound.Advance(payload.End);
}
// Reset the UTF8 validator
_validator.Reset();
return closeResult;
}
else
{
if (frame.Opcode == WebSocketOpcode.Ping)
{
// Check the ping payload length
if (frame.Payload.Length > 125)
// If it's a non-control frame, reset the message type tracker
if (opcode.IsMessage())
{
// Payload too long
return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Ping frame exceeded maximum size of 125 bytes");
_currentMessageType = WebSocketOpcode.Continuation;
}
await SendCoreAsync(
frame.EndOfMessage,
WebSocketOpcode.Pong,
payloadAllocLength: 0,
payloadLength: payload.Length,
payloadWriter: AppendPayloadWriter,
payload: payload,
cancellationToken: cancellationToken);
}
var effectiveOpcode = opcode == WebSocketOpcode.Continuation ? _currentMessageType : opcode;
if (effectiveOpcode == WebSocketOpcode.Text && !_validator.ValidateUtf8Frame(frame.Payload, frame.EndOfMessage))
// If there isn't a current message type, and this was a fragmented message frame, set the current message type
else if (!header.Fin && _currentMessageType == WebSocketOpcode.Continuation && opcode.IsMessage())
{
// Drop the frame and immediately close with InvalidPayload
return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "An invalid Text frame payload was received", statusCode: WebSocketCloseStatus.InvalidPayloadData);
}
else if (_options.PassAllFramesThrough || (frame.Opcode != WebSocketOpcode.Ping && frame.Opcode != WebSocketOpcode.Pong))
{
await messageHandler(frame, state);
_currentMessageType = opcode;
}
}
if (fin)
{
// Reset the UTF8 validator
_validator.Reset();
// If it's a non-control frame, reset the message type tracker
if (opcode.IsMessage())
{
_currentMessageType = WebSocketOpcode.Continuation;
}
}
// If there isn't a current message type, and this was a fragmented message frame, set the current message type
else if (!fin && _currentMessageType == WebSocketOpcode.Continuation && opcode.IsMessage())
{
_currentMessageType = opcode;
}
// Mark the payload as consumed
if (payloadLen > 0)
{
_inbound.Advance(payload.End);
}
}
catch
{
// Abort the socket and rethrow
Abort();
throw;
}
return WebSocketCloseResult.AbnormalClosure;
}
private async Task<WebSocketCloseResult> CloseFromProtocolError(CancellationToken cancellationToken, int payloadLen, ReadableBuffer payload, string reason, WebSocketCloseStatus statusCode = WebSocketCloseStatus.ProtocolError)
private async ValueTask<WebSocketCloseResult> ProcessCloseFrameAsync(WebSocketFrame frame)
{
// Non-continuation non-control message during fragmented message
if (payloadLen > 0)
// Allowed frame lengths:
// 0 - No body
// 2 - Code with no reason phrase
// >2 - Code and reason phrase (must be valid UTF-8)
if (frame.Payload.Length > 125)
{
_inbound.Advance(payload.End);
await CloseFromProtocolError("Close frame payload too long. Maximum size is 125 bytes");
return WebSocketCloseResult.AbnormalClosure;
}
var closeResult = new WebSocketCloseResult(
statusCode,
reason);
await CloseAsync(closeResult, cancellationToken);
Dispose();
else if ((frame.Payload.Length == 1) || (frame.Payload.Length > 2 && !Utf8Validator.ValidateUtf8(frame.Payload.Slice(2))))
{
await CloseFromProtocolError("Close frame payload invalid");
return WebSocketCloseResult.AbnormalClosure;
}
ushort? actualStatusCode;
var closeResult = ParseCloseFrame(frame.Payload, frame, out actualStatusCode);
// Verify the close result
if (actualStatusCode != null)
{
var statusCode = actualStatusCode.Value;
if (statusCode < 1000 || statusCode == 1004 || statusCode == 1005 || statusCode == 1006 || (statusCode > 1011 && statusCode < 3000))
{
await CloseFromProtocolError($"Invalid close status: {statusCode}.");
return WebSocketCloseResult.AbnormalClosure;
}
}
return closeResult;
}
private WebSocketCloseResult HandleCloseFrame(ReadableBuffer payload, WebSocketFrame frame, out ushort? actualStatusCode)
private async Task CloseFromProtocolError(string reason, WebSocketCloseStatus statusCode = WebSocketCloseStatus.ProtocolError)
{
var closeResult = new WebSocketCloseResult(
statusCode,
reason);
await CloseAsync(closeResult, CancellationToken.None);
// We can now terminate our connection, according to the spec.
Abort();
}
private WebSocketCloseResult ParseCloseFrame(ReadableBuffer payload, WebSocketFrame frame, out ushort? actualStatusCode)
{
// Update state
if (State == WebSocketConnectionState.CloseSent)
@ -529,6 +578,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
{
closeResult = WebSocketCloseResult.Empty;
}
return closeResult;
}

View File

@ -12,6 +12,7 @@
</PropertyGroup>
<ItemGroup>
<Compile Include="..\Common\TaskExtensions.cs" Link="TaskExtensions.cs" />
<Compile Include="..\Microsoft.Extensions.WebSockets.Internal.Tests\WebSocketConnectionExtensions.cs;..\Microsoft.Extensions.WebSockets.Internal.Tests\WebSocketConnectionSummary.cs;..\Microsoft.Extensions.WebSockets.Internal.Tests\WebSocketPair.cs" />
</ItemGroup>

View File

@ -6,6 +6,7 @@ using System.IO.Pipelines;
using System.Text;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.AspNetCore.Sockets.Transports;
using Microsoft.Extensions.Logging;
@ -258,7 +259,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
}
[Fact(Skip="Fails after updating to new Pipelines")]
[Fact]
public async Task TransportFailsWhenClientDisconnectsAbnormally()
{
var transportToApplication = Channel.CreateUnbounded<Message>();
@ -286,7 +287,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
}
[Fact(Skip="Fails after updating to new Pipelines")]
[Fact]
public async Task ClientReceivesInternalServerErrorWhenTheApplicationFails()
{
var transportToApplication = Channel.CreateUnbounded<Message>();
@ -313,6 +314,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
// Close from the client
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
await transport.OrTimeout();
}
}
}

View File

@ -65,7 +65,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
}
}
[Fact(Skip="Fails after updating to new Pipelines")]
[Fact]
public async Task AbnormalTerminationOfInboundChannelCausesExecuteToThrow()
{
using (var pair = WebSocketPair.Create())

View File

@ -12,8 +12,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public partial class WebSocketConnectionTests
{
// Skipping tests after failures caused by updating to newer Pipelines
private class ProtocolErrors
public class ProtocolErrors
{
[Theory]
[InlineData(new byte[] { 0x11, 0x00 })]

View File

@ -13,8 +13,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public partial class WebSocketConnectionTests
{
// Skipping tests after failures caused by updating to newer Pipelines
private class TheReceiveAsyncMethod
public class TheReceiveAsyncMethod
{
[Theory]
[InlineData(new byte[] { 0x81, 0x00 }, "", true)]

View File

@ -13,8 +13,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public partial class WebSocketConnectionTests
{
// Skipping tests after failures caused by updating to newer Pipelines
private class TheSendAsyncMethod
public class TheSendAsyncMethod
{
// No auto-pinging for us!
private readonly static WebSocketOptions DefaultTestOptions = new WebSocketOptions().WithAllFramesPassedThrough();
@ -179,20 +178,18 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
var outbound = factory.Create();
var inbound = factory.Create();
Task executeTask;
using (var connection = new WebSocketConnection(inbound.Reader, outbound.Writer, options))
{
executeTask = connection.ExecuteAsync(f =>
{
Assert.False(true, "Did not expect to receive any messages");
return TaskCache.CompletedTask;
});
var executeTask = connection.ExecuteAndCaptureFramesAsync();
await producer(connection).OrTimeout();
connection.Abort();
inbound.Writer.Complete();
await executeTask.OrTimeout();
}
var data = (await outbound.Reader.ReadToEndAsync()).ToArray();
var buffer = await outbound.Reader.ReadToEndAsync();
var data = buffer.ToArray();
outbound.Reader.Advance(buffer.End);
inbound.Reader.Complete();
CompleteChannels(outbound);
return data;

View File

@ -8,8 +8,8 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
internal class WebSocketPair : IDisposable
{
private static readonly WebSocketOptions DefaultServerOptions = new WebSocketOptions().WithAllFramesPassedThrough().WithRandomMasking();
private static readonly WebSocketOptions DefaultClientOptions = new WebSocketOptions().WithAllFramesPassedThrough();
private static readonly WebSocketOptions DefaultServerOptions = new WebSocketOptions().WithAllFramesPassedThrough().WithRandomMasking();
private static readonly WebSocketOptions DefaultClientOptions = new WebSocketOptions().WithAllFramesPassedThrough();
private PipeFactory _factory;
private readonly bool _ownFactory;