TestHost: Add WebSocket support.

This commit is contained in:
Master T 2015-08-22 13:44:27 +02:00 committed by Chris R
parent 25b5a42ca6
commit 2ee7384400
5 changed files with 688 additions and 15 deletions

View File

@ -67,22 +67,22 @@ namespace Microsoft.AspNet.TestHost
// Async offload, don't let the test code block the caller.
var offload = Task.Factory.StartNew(async () =>
{
try
{
try
{
await _next(state.HttpContext.Features);
state.CompleteResponse();
}
catch (Exception ex)
{
state.Abort(ex);
}
finally
{
registration.Dispose();
state.Dispose();
}
});
await _next(state.HttpContext.Features);
state.CompleteResponse();
}
catch (Exception ex)
{
state.Abort(ex);
}
finally
{
registration.Dispose();
state.Dispose();
}
});
return await state.ResponseTask;
}

View File

@ -97,6 +97,12 @@ namespace Microsoft.AspNet.TestHost
return new HttpClient(CreateHandler()) { BaseAddress = BaseAddress };
}
public WebSocketClient CreateWebSocketClient()
{
var pathBase = BaseAddress == null ? PathString.Empty : PathString.FromUriComponent(BaseAddress);
return new WebSocketClient(Invoke, pathBase);
}
/// <summary>
/// Begins constructing a request message for submission.
/// </summary>

View File

@ -0,0 +1,354 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNet.TestHost
{
internal class TestWebSocket : WebSocket
{
private ReceiverSenderBuffer _receiveBuffer;
private ReceiverSenderBuffer _sendBuffer;
private readonly string _subProtocol;
private WebSocketState _state;
private WebSocketCloseStatus? _closeStatus;
private string _closeStatusDescription;
private Message _receiveMessage;
public static Tuple<TestWebSocket, TestWebSocket> CreatePair(string subProtocol)
{
var buffers = new[] { new ReceiverSenderBuffer(), new ReceiverSenderBuffer() };
return Tuple.Create(
new TestWebSocket(subProtocol, buffers[0], buffers[1]),
new TestWebSocket(subProtocol, buffers[1], buffers[0]));
}
public override WebSocketCloseStatus? CloseStatus
{
get { return _closeStatus; }
}
public override string CloseStatusDescription
{
get { return _closeStatusDescription; }
}
public override WebSocketState State
{
get { return _state; }
}
public override string SubProtocol
{
get { return _subProtocol; }
}
public async override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
{
ThrowIfDisposed();
if (State == WebSocketState.Open || State == WebSocketState.CloseReceived)
{
// Send a close message.
await CloseOutputAsync(closeStatus, statusDescription, cancellationToken);
}
if (State == WebSocketState.CloseSent)
{
// Do a receiving drain
var data = new byte[1024];
WebSocketReceiveResult result;
do
{
result = await ReceiveAsync(new ArraySegment<byte>(data), cancellationToken);
}
while (result.MessageType != WebSocketMessageType.Close);
}
}
public async override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
{
ThrowIfDisposed();
ThrowIfOutputClosed();
var message = new Message(closeStatus, statusDescription);
await _sendBuffer.SendAsync(message, cancellationToken);
if (State == WebSocketState.Open)
{
_state = WebSocketState.CloseSent;
}
else if (State == WebSocketState.CloseReceived)
{
_state = WebSocketState.Closed;
Close();
}
}
public override void Abort()
{
if (_state >= WebSocketState.Closed) // or Aborted
{
return;
}
_state = WebSocketState.Aborted;
Close();
}
public override void Dispose()
{
if (_state >= WebSocketState.Closed) // or Aborted
{
return;
}
_state = WebSocketState.Closed;
Close();
}
public override async Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
{
ThrowIfDisposed();
ThrowIfInputClosed();
ValidateSegment(buffer);
// TODO: InvalidOperationException if any receives are currently in progress.
Message receiveMessage = _receiveMessage;
_receiveMessage = null;
if (receiveMessage == null)
{
receiveMessage = await _receiveBuffer.ReceiveAsync(cancellationToken);
}
if (receiveMessage.MessageType == WebSocketMessageType.Close)
{
_closeStatus = receiveMessage.CloseStatus;
_closeStatusDescription = receiveMessage.CloseStatusDescription ?? string.Empty;
var result = new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, _closeStatus, _closeStatusDescription);
if (_state == WebSocketState.Open)
{
_state = WebSocketState.CloseReceived;
}
else if (_state == WebSocketState.CloseSent)
{
_state = WebSocketState.Closed;
Close();
}
return result;
}
else
{
int count = Math.Min(buffer.Count, receiveMessage.Buffer.Count);
bool endOfMessage = count == receiveMessage.Buffer.Count;
Array.Copy(receiveMessage.Buffer.Array, receiveMessage.Buffer.Offset, buffer.Array, buffer.Offset, count);
if (!endOfMessage)
{
receiveMessage.Buffer = new ArraySegment<byte>(receiveMessage.Buffer.Array, receiveMessage.Buffer.Offset + count, receiveMessage.Buffer.Count - count);
_receiveMessage = receiveMessage;
}
endOfMessage = endOfMessage && receiveMessage.EndOfMessage;
return new WebSocketReceiveResult(count, receiveMessage.MessageType, endOfMessage);
}
}
public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
{
ValidateSegment(buffer);
if (messageType != WebSocketMessageType.Binary && messageType != WebSocketMessageType.Text)
{
// Block control frames
throw new ArgumentOutOfRangeException(nameof(messageType), messageType, string.Empty);
}
var message = new Message(buffer, messageType, endOfMessage, cancellationToken);
return _sendBuffer.SendAsync(message, cancellationToken);
}
private void Close()
{
_receiveBuffer.SetReceiverClosed();
_sendBuffer.SetSenderClosed();
}
private void ThrowIfDisposed()
{
if (_state >= WebSocketState.Closed) // or Aborted
{
throw new ObjectDisposedException(typeof(TestWebSocket).FullName);
}
}
private void ThrowIfOutputClosed()
{
if (State == WebSocketState.CloseSent)
{
throw new InvalidOperationException("Close already sent.");
}
}
private void ThrowIfInputClosed()
{
if (State == WebSocketState.CloseReceived)
{
throw new InvalidOperationException("Close already received.");
}
}
private void ValidateSegment(ArraySegment<byte> buffer)
{
if (buffer.Array == null)
{
throw new ArgumentNullException(nameof(buffer));
}
if (buffer.Offset < 0 || buffer.Offset > buffer.Array.Length)
{
throw new ArgumentOutOfRangeException(nameof(buffer.Offset), buffer.Offset, string.Empty);
}
if (buffer.Count < 0 || buffer.Count > buffer.Array.Length - buffer.Offset)
{
throw new ArgumentOutOfRangeException(nameof(buffer.Count), buffer.Count, string.Empty);
}
}
private TestWebSocket(string subProtocol, ReceiverSenderBuffer readBuffer, ReceiverSenderBuffer writeBuffer)
{
_state = WebSocketState.Open;
_subProtocol = subProtocol;
_receiveBuffer = readBuffer;
_sendBuffer = writeBuffer;
}
private class Message
{
public Message(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken token)
{
Buffer = buffer;
CloseStatus = null;
CloseStatusDescription = null;
EndOfMessage = endOfMessage;
MessageType = messageType;
}
public Message(WebSocketCloseStatus? closeStatus, string closeStatusDescription)
{
Buffer = new ArraySegment<byte>(new byte[0]);
CloseStatus = closeStatus;
CloseStatusDescription = closeStatusDescription;
MessageType = WebSocketMessageType.Close;
EndOfMessage = true;
}
public WebSocketCloseStatus? CloseStatus { get; set; }
public string CloseStatusDescription { get; set; }
public ArraySegment<byte> Buffer { get; set; }
public bool EndOfMessage { get; set; }
public WebSocketMessageType MessageType { get; set; }
}
private class ReceiverSenderBuffer
{
private bool _receiverClosed;
private bool _senderClosed;
private bool _disposed;
private SemaphoreSlim _sem;
private Queue<Message> _messageQueue;
public ReceiverSenderBuffer()
{
_sem = new SemaphoreSlim(0);
_messageQueue = new Queue<Message>();
}
public async virtual Task<Message> ReceiveAsync(CancellationToken cancellationToken)
{
if (_disposed)
{
ThrowNoReceive();
}
await _sem.WaitAsync(cancellationToken);
lock (_messageQueue)
{
if (_messageQueue.Count == 0)
{
_disposed = true;
_sem.Dispose();
ThrowNoReceive();
}
return _messageQueue.Dequeue();
}
}
public virtual Task SendAsync(Message message, CancellationToken cancellationToken)
{
lock (_messageQueue)
{
if (_senderClosed)
{
throw new ObjectDisposedException(typeof(TestWebSocket).FullName);
}
if (_receiverClosed)
{
throw new IOException("The remote end closed the connection.", new ObjectDisposedException(typeof(TestWebSocket).FullName));
}
// we return immediately so we need to copy the buffer since the sender can re-use it
var array = new byte[message.Buffer.Count];
Array.Copy(message.Buffer.Array, message.Buffer.Offset, array, 0, message.Buffer.Count);
message.Buffer = new ArraySegment<byte>(array);
_messageQueue.Enqueue(message);
_sem.Release();
return Task.FromResult(true);
}
}
public void SetReceiverClosed()
{
lock (_messageQueue)
{
if (!_receiverClosed)
{
_receiverClosed = true;
if (!_disposed)
{
_sem.Release();
}
}
}
}
public void SetSenderClosed()
{
lock (_messageQueue)
{
if (!_senderClosed)
{
_senderClosed = true;
if (!_disposed)
{
_sem.Release();
}
}
}
}
private void ThrowNoReceive()
{
if (_receiverClosed)
{
throw new ObjectDisposedException(typeof(TestWebSocket).FullName);
}
else // _senderClosed
{
throw new IOException("The remote end closed the connection.", new ObjectDisposedException(typeof(TestWebSocket).FullName));
}
}
}
}
}

View File

@ -0,0 +1,179 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Net.WebSockets;
using System.Security.Cryptography;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNet.Http.Features;
using Microsoft.AspNet.Http;
using Microsoft.AspNet.Http.Internal;
using Microsoft.Framework.Internal;
namespace Microsoft.AspNet.TestHost
{
public class WebSocketClient
{
private readonly Func<IFeatureCollection, Task> _next;
private readonly PathString _pathBase;
internal WebSocketClient([NotNull] Func<IFeatureCollection, Task> next, PathString pathBase)
{
_next = next;
// PathString.StartsWithSegments that we use below requires the base path to not end in a slash.
if (pathBase.HasValue && pathBase.Value.EndsWith("/"))
{
pathBase = new PathString(pathBase.Value.Substring(0, pathBase.Value.Length - 1));
}
_pathBase = pathBase;
SubProtocols = new List<string>();
}
public IList<string> SubProtocols
{
get;
private set;
}
public Action<HttpRequest> ConfigureRequest
{
get;
set;
}
public async Task<WebSocket> ConnectAsync(Uri uri, CancellationToken cancellationToken)
{
var state = new RequestState(uri, _pathBase, cancellationToken);
if (ConfigureRequest != null)
{
ConfigureRequest(state.HttpContext.Request);
}
// Async offload, don't let the test code block the caller.
var offload = Task.Factory.StartNew(async () =>
{
try
{
await _next(state.FeatureCollection);
state.PipelineComplete();
}
catch (Exception ex)
{
state.PipelineFailed(ex);
}
finally
{
state.Dispose();
}
});
return await state.WebSocketTask;
}
private class RequestState : IDisposable, IHttpWebSocketFeature
{
private TaskCompletionSource<WebSocket> _clientWebSocketTcs;
private WebSocket _serverWebSocket;
public IFeatureCollection FeatureCollection { get; private set; }
public HttpContext HttpContext { get; private set; }
public Task<WebSocket> WebSocketTask { get { return _clientWebSocketTcs.Task; } }
public RequestState(Uri uri, PathString pathBase, CancellationToken cancellationToken)
{
_clientWebSocketTcs = new TaskCompletionSource<WebSocket>();
// HttpContext
FeatureCollection = new FeatureCollection();
HttpContext = new DefaultHttpContext(FeatureCollection);
// Request
HttpContext.SetFeature<IHttpRequestFeature>(new RequestFeature());
var request = HttpContext.Request;
request.Protocol = "HTTP/1.1";
var scheme = uri.Scheme;
scheme = (scheme == "ws") ? "http" : scheme;
scheme = (scheme == "wss") ? "https" : scheme;
request.Scheme = scheme;
request.Method = "GET";
var fullPath = PathString.FromUriComponent(uri);
PathString remainder;
if (fullPath.StartsWithSegments(pathBase, out remainder))
{
request.PathBase = pathBase;
request.Path = remainder;
}
else
{
request.PathBase = PathString.Empty;
request.Path = fullPath;
}
request.QueryString = QueryString.FromUriComponent(uri);
request.Headers.Add("Connection", new string[] { "Upgrade" });
request.Headers.Add("Upgrade", new string[] { "websocket" });
request.Headers.Add("Sec-WebSocket-Version", new string[] { "13" });
request.Headers.Add("Sec-WebSocket-Key", new string[] { CreateRequestKey() });
request.Body = Stream.Null;
// Response
HttpContext.SetFeature<IHttpResponseFeature>(new ResponseFeature());
var response = HttpContext.Response;
response.Body = Stream.Null;
response.StatusCode = 200;
// WebSocket
HttpContext.SetFeature<IHttpWebSocketFeature>(this);
}
public void PipelineComplete()
{
PipelineFailed(new InvalidOperationException("Incomplete handshake, status code: " + HttpContext.Response.StatusCode));
}
public void PipelineFailed(Exception ex)
{
_clientWebSocketTcs.TrySetException(new InvalidOperationException("The websocket was not accepted.", ex));
}
public void Dispose()
{
if (_serverWebSocket != null)
{
_serverWebSocket.Dispose();
}
}
private string CreateRequestKey()
{
byte[] data = new byte[16];
var rng = RandomNumberGenerator.Create();
rng.GetBytes(data);
return Convert.ToBase64String(data);
}
bool IHttpWebSocketFeature.IsWebSocketRequest
{
get
{
return true;
}
}
Task<WebSocket> IHttpWebSocketFeature.AcceptAsync(WebSocketAcceptContext context)
{
HttpContext.Response.StatusCode = 101; // Switching Protocols
var websockets = TestWebSocket.CreatePair(context.SubProtocol);
_clientWebSocketTcs.SetResult(websockets.Item1);
_serverWebSocket = websockets.Item2;
return Task.FromResult<WebSocket>(_serverWebSocket);
}
}
}
}

View File

@ -2,7 +2,11 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNet.Builder;
using Microsoft.AspNet.Http;
@ -111,5 +115,135 @@ namespace Microsoft.AspNet.TestHost
// Assert
Assert.Equal("Hello world POST Response", await response.Content.ReadAsStringAsync());
}
[Fact]
public async Task WebSocketWorks()
{
// Arrange
RequestDelegate appDelegate = async ctx =>
{
if (ctx.WebSockets.IsWebSocketRequest)
{
var websocket = await ctx.WebSockets.AcceptWebSocketAsync();
var receiveArray = new byte[1024];
while (true)
{
var receiveResult = await websocket.ReceiveAsync(new System.ArraySegment<byte>(receiveArray), CancellationToken.None);
if (receiveResult.MessageType == WebSocketMessageType.Close)
{
await websocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Normal Closure", CancellationToken.None);
break;
}
else
{
var sendBuffer = new System.ArraySegment<byte>(receiveArray, 0, receiveResult.Count);
await websocket.SendAsync(sendBuffer, receiveResult.MessageType, receiveResult.EndOfMessage, CancellationToken.None);
}
}
}
};
var server = TestServer.Create(app =>
{
app.Run(appDelegate);
});
// Act
var client = server.CreateWebSocketClient();
var clientSocket = await client.ConnectAsync(new System.Uri("http://localhost"), CancellationToken.None);
var hello = Encoding.UTF8.GetBytes("hello");
await clientSocket.SendAsync(new System.ArraySegment<byte>(hello), WebSocketMessageType.Text, true, CancellationToken.None);
var world = Encoding.UTF8.GetBytes("world!");
await clientSocket.SendAsync(new System.ArraySegment<byte>(world), WebSocketMessageType.Binary, true, CancellationToken.None);
await clientSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Normal Closure", CancellationToken.None);
// Assert
Assert.Equal(WebSocketState.CloseSent, clientSocket.State);
var buffer = new byte[1024];
var result = await clientSocket.ReceiveAsync(new System.ArraySegment<byte>(buffer), CancellationToken.None);
Assert.Equal(hello.Length, result.Count);
Assert.True(hello.SequenceEqual(buffer.Take(hello.Length)));
Assert.Equal(WebSocketMessageType.Text, result.MessageType);
result = await clientSocket.ReceiveAsync(new System.ArraySegment<byte>(buffer), CancellationToken.None);
Assert.Equal(world.Length, result.Count);
Assert.True(world.SequenceEqual(buffer.Take(world.Length)));
Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
result = await clientSocket.ReceiveAsync(new System.ArraySegment<byte>(buffer), CancellationToken.None);
Assert.Equal(WebSocketMessageType.Close, result.MessageType);
Assert.Equal(WebSocketState.Closed, clientSocket.State);
clientSocket.Dispose();
}
[Fact]
public async Task WebSocketDisposalThrowsOnPeer()
{
// Arrange
RequestDelegate appDelegate = async ctx =>
{
if (ctx.WebSockets.IsWebSocketRequest)
{
var websocket = await ctx.WebSockets.AcceptWebSocketAsync();
websocket.Dispose();
}
};
var server = TestServer.Create(app =>
{
app.Run(appDelegate);
});
// Act
var client = server.CreateWebSocketClient();
var clientSocket = await client.ConnectAsync(new System.Uri("http://localhost"), CancellationToken.None);
var buffer = new byte[1024];
await Assert.ThrowsAsync<IOException>(async () => await clientSocket.ReceiveAsync(new System.ArraySegment<byte>(buffer), CancellationToken.None));
clientSocket.Dispose();
}
[Fact]
public async Task WebSocketTinyReceiveGeneratesEndOfMessage()
{
// Arrange
RequestDelegate appDelegate = async ctx =>
{
if (ctx.WebSockets.IsWebSocketRequest)
{
var websocket = await ctx.WebSockets.AcceptWebSocketAsync();
var receiveArray = new byte[1024];
while (true)
{
var receiveResult = await websocket.ReceiveAsync(new System.ArraySegment<byte>(receiveArray), CancellationToken.None);
var sendBuffer = new System.ArraySegment<byte>(receiveArray, 0, receiveResult.Count);
await websocket.SendAsync(sendBuffer, receiveResult.MessageType, receiveResult.EndOfMessage, CancellationToken.None);
}
}
};
var server = TestServer.Create(app =>
{
app.Run(appDelegate);
});
// Act
var client = server.CreateWebSocketClient();
var clientSocket = await client.ConnectAsync(new System.Uri("http://localhost"), CancellationToken.None);
var hello = Encoding.UTF8.GetBytes("hello");
await clientSocket.SendAsync(new System.ArraySegment<byte>(hello), WebSocketMessageType.Text, true, CancellationToken.None);
// Assert
var buffer = new byte[1];
for (int i = 0; i < hello.Length; i++)
{
bool last = i == (hello.Length - 1);
var result = await clientSocket.ReceiveAsync(new System.ArraySegment<byte>(buffer), CancellationToken.None);
Assert.Equal(buffer.Length, result.Count);
Assert.Equal(buffer[0], hello[i]);
Assert.Equal(last, result.EndOfMessage);
}
clientSocket.Dispose();
}
}
}