port dotnet/corefx#11348 (#114)

also adds some tests and extra features to the EchoApp test sample
This commit is contained in:
Andrew Stanton-Nurse 2016-09-02 14:02:06 -07:00 committed by GitHub
parent c51aec5292
commit b996ee39a4
7 changed files with 123 additions and 42 deletions

View File

@ -24,7 +24,7 @@ namespace EchoApp
// This method gets called by the runtime. Use this method to configure the HTTP request pipeline. // This method gets called by the runtime. Use this method to configure the HTTP request pipeline.
public void Configure(IApplicationBuilder app, IHostingEnvironment env, ILoggerFactory loggerFactory) public void Configure(IApplicationBuilder app, IHostingEnvironment env, ILoggerFactory loggerFactory)
{ {
loggerFactory.AddConsole(); loggerFactory.AddConsole(LogLevel.Debug);
if (env.IsDevelopment()) if (env.IsDevelopment())
{ {
@ -38,7 +38,7 @@ namespace EchoApp
if (context.WebSockets.IsWebSocketRequest) if (context.WebSockets.IsWebSocketRequest)
{ {
var webSocket = await context.WebSockets.AcceptWebSocketAsync(); var webSocket = await context.WebSockets.AcceptWebSocketAsync();
await Echo(webSocket); await Echo(context, webSocket, loggerFactory.CreateLogger("Echo"));
} }
else else
{ {
@ -49,27 +49,57 @@ namespace EchoApp
app.UseFileServer(); app.UseFileServer();
} }
private async Task Echo(WebSocket webSocket) private async Task Echo(HttpContext context, WebSocket webSocket, ILogger logger)
{ {
var buffer = new byte[1024 * 4]; var buffer = new byte[1024 * 4];
var result = await webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), CancellationToken.None); var result = await webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), CancellationToken.None);
LogFrame(logger, result, buffer);
while (!result.CloseStatus.HasValue) while (!result.CloseStatus.HasValue)
{ {
// If the client send "ServerClose", then they want a server-originated close to occur // If the client send "ServerClose", then they want a server-originated close to occur
if(result.MessageType == WebSocketMessageType.Text) string content = "<<binary>>";
if (result.MessageType == WebSocketMessageType.Text)
{ {
var str = Encoding.UTF8.GetString(buffer, 0, result.Count); content = Encoding.UTF8.GetString(buffer, 0, result.Count);
if(str.Equals("ServerClose")) if (content.Equals("ServerClose"))
{ {
await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing from Server", CancellationToken.None); await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing from Server", CancellationToken.None);
logger.LogDebug($"Sent Frame Close: {WebSocketCloseStatus.NormalClosure} Closing from Server");
return; return;
} }
else if (content.Equals("ServerAbort"))
{
context.Abort();
}
} }
await webSocket.SendAsync(new ArraySegment<byte>(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, CancellationToken.None); await webSocket.SendAsync(new ArraySegment<byte>(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, CancellationToken.None);
logger.LogDebug($"Sent Frame {result.MessageType}: Len={result.Count}, Fin={result.EndOfMessage}: {content}");
result = await webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), CancellationToken.None); result = await webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), CancellationToken.None);
LogFrame(logger, result, buffer);
} }
await webSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None); await webSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None);
} }
private void LogFrame(ILogger logger, WebSocketReceiveResult frame, byte[] buffer)
{
var close = frame.CloseStatus != null;
string message;
if (close)
{
message = $"Close: {frame.CloseStatus.Value} {frame.CloseStatusDescription}";
}
else
{
string content = "<<binary>>";
if (frame.MessageType == WebSocketMessageType.Text)
{
content = Encoding.UTF8.GetString(buffer, 0, frame.Count);
}
message = $"{frame.MessageType}: Len={frame.Count}, Fin={frame.EndOfMessage}: {content}";
}
logger.LogDebug("Received Frame " + message);
}
} }
} }

View File

@ -25,7 +25,7 @@
<button id="closeButton" disabled>Close Socket</button> <button id="closeButton" disabled>Close Socket</button>
</div> </div>
<p>Note: When connected to the default server (i.e. the server in the address bar ;)), the message "ServerClose" will cause the server to close the connection</p> <p>Note: When connected to the default server (i.e. the server in the address bar ;)), the message "ServerClose" will cause the server to close the connection. Similarly, the message "ServerAbort" will cause the server to forcibly terminate the connection without a closing handshake</p>
<h2>Communication Log</h2> <h2>Communication Log</h2>
<table style="width: 800px"> <table style="width: 800px">

View File

@ -615,12 +615,7 @@ namespace System.Net.WebSockets
// Make sure we have the first two bytes, which includes the start of the payload length. // Make sure we have the first two bytes, which includes the start of the payload length.
if (_receiveBufferCount < 2) if (_receiveBufferCount < 2)
{ {
await EnsureBufferContainsAsync(2, cancellationToken, throwOnPrematureClosure: false).ConfigureAwait(false); await EnsureBufferContainsAsync(2, cancellationToken, throwOnPrematureClosure: true).ConfigureAwait(false);
if (_receiveBufferCount < 2)
{
// The connection closed; nothing more to read.
return new WebSocketReceiveResult(0, WebSocketMessageType.Text, true);
}
} }
// Then make sure we have the full header based on the payload length. // Then make sure we have the full header based on the payload length.

View File

@ -16,6 +16,7 @@ namespace Microsoft.AspNetCore.WebSockets.Test
{ {
private bool _disposed; private bool _disposed;
private bool _aborted; private bool _aborted;
private bool _terminated;
private Exception _abortException; private Exception _abortException;
private ConcurrentQueue<byte[]> _bufferedData; private ConcurrentQueue<byte[]> _bufferedData;
private ArraySegment<byte> _topBuffer; private ArraySegment<byte> _topBuffer;
@ -71,6 +72,14 @@ namespace Microsoft.AspNetCore.WebSockets.Test
#endregion NotSupported #endregion NotSupported
/// <summary>
/// Ends the stream, meaning all future reads will return '0'.
/// </summary>
public void End()
{
_terminated = true;
}
public override void Flush() public override void Flush()
{ {
CheckDisposed(); CheckDisposed();
@ -95,6 +104,11 @@ namespace Microsoft.AspNetCore.WebSockets.Test
public override int Read(byte[] buffer, int offset, int count) public override int Read(byte[] buffer, int offset, int count)
{ {
if(_terminated)
{
return 0;
}
VerifyBuffer(buffer, offset, count, allowEmpty: false); VerifyBuffer(buffer, offset, count, allowEmpty: false);
_readLock.Wait(); _readLock.Wait();
try try
@ -154,6 +168,11 @@ namespace Microsoft.AspNetCore.WebSockets.Test
public async override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) public async override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{ {
if(_terminated)
{
return 0;
}
VerifyBuffer(buffer, offset, count, allowEmpty: false); VerifyBuffer(buffer, offset, count, allowEmpty: false);
CancellationTokenRegistration registration = cancellationToken.Register(Abort); CancellationTokenRegistration registration = cancellationToken.Register(Abort);
await _readLock.WaitAsync(cancellationToken); await _readLock.WaitAsync(cancellationToken);

View File

@ -10,30 +10,31 @@ namespace Microsoft.AspNetCore.WebSockets.Test
// A duplex wrapper around a read and write stream. // A duplex wrapper around a read and write stream.
public class DuplexStream : Stream public class DuplexStream : Stream
{ {
private readonly Stream _readStream; public BufferStream ReadStream { get; }
private readonly Stream _writeStream; public BufferStream WriteStream { get; }
public DuplexStream() public DuplexStream()
: this (new BufferStream(), new BufferStream()) : this (new BufferStream(), new BufferStream())
{ {
} }
public DuplexStream(Stream readStream, Stream writeStream) public DuplexStream(BufferStream readStream, BufferStream writeStream)
{ {
_readStream = readStream; ReadStream = readStream;
_writeStream = writeStream; WriteStream = writeStream;
} }
public DuplexStream CreateReverseDuplexStream() public DuplexStream CreateReverseDuplexStream()
{ {
return new DuplexStream(_writeStream, _readStream); return new DuplexStream(WriteStream, ReadStream);
} }
#region Properties #region Properties
public override bool CanRead public override bool CanRead
{ {
get { return _readStream.CanRead; } get { return ReadStream.CanRead; }
} }
public override bool CanSeek public override bool CanSeek
@ -43,12 +44,12 @@ namespace Microsoft.AspNetCore.WebSockets.Test
public override bool CanTimeout public override bool CanTimeout
{ {
get { return _readStream.CanTimeout || _writeStream.CanTimeout; } get { return ReadStream.CanTimeout || WriteStream.CanTimeout; }
} }
public override bool CanWrite public override bool CanWrite
{ {
get { return _writeStream.CanWrite; } get { return WriteStream.CanWrite; }
} }
public override long Length public override long Length
@ -64,14 +65,14 @@ namespace Microsoft.AspNetCore.WebSockets.Test
public override int ReadTimeout public override int ReadTimeout
{ {
get { return _readStream.ReadTimeout; } get { return ReadStream.ReadTimeout; }
set { _readStream.ReadTimeout = value; } set { ReadStream.ReadTimeout = value; }
} }
public override int WriteTimeout public override int WriteTimeout
{ {
get { return _writeStream.WriteTimeout; } get { return WriteStream.WriteTimeout; }
set { _writeStream.WriteTimeout = value; } set { WriteStream.WriteTimeout = value; }
} }
#endregion Properties #endregion Properties
@ -90,33 +91,33 @@ namespace Microsoft.AspNetCore.WebSockets.Test
public override int Read(byte[] buffer, int offset, int count) public override int Read(byte[] buffer, int offset, int count)
{ {
return _readStream.Read(buffer, offset, count); return ReadStream.Read(buffer, offset, count);
} }
#if !NETCOREAPP1_0 #if !NETCOREAPP1_0
public override int ReadByte() public override int ReadByte()
{ {
return _readStream.ReadByte(); return ReadStream.ReadByte();
} }
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{ {
return _readStream.BeginRead(buffer, offset, count, callback, state); return ReadStream.BeginRead(buffer, offset, count, callback, state);
} }
public override int EndRead(IAsyncResult asyncResult) public override int EndRead(IAsyncResult asyncResult)
{ {
return _readStream.EndRead(asyncResult); return ReadStream.EndRead(asyncResult);
} }
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{ {
return _readStream.ReadAsync(buffer, offset, count, cancellationToken); return ReadStream.ReadAsync(buffer, offset, count, cancellationToken);
} }
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{ {
return _readStream.CopyToAsync(destination, bufferSize, cancellationToken); return ReadStream.CopyToAsync(destination, bufferSize, cancellationToken);
} }
#endif #endif
@ -126,39 +127,39 @@ namespace Microsoft.AspNetCore.WebSockets.Test
public override void Write(byte[] buffer, int offset, int count) public override void Write(byte[] buffer, int offset, int count)
{ {
_writeStream.Write(buffer, offset, count); WriteStream.Write(buffer, offset, count);
} }
#if !NETCOREAPP1_0 #if !NETCOREAPP1_0
public override void WriteByte(byte value) public override void WriteByte(byte value)
{ {
_writeStream.WriteByte(value); WriteStream.WriteByte(value);
} }
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{ {
return _writeStream.BeginWrite(buffer, offset, count, callback, state); return WriteStream.BeginWrite(buffer, offset, count, callback, state);
} }
public override void EndWrite(IAsyncResult asyncResult) public override void EndWrite(IAsyncResult asyncResult)
{ {
_writeStream.EndWrite(asyncResult); WriteStream.EndWrite(asyncResult);
} }
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{ {
return _writeStream.WriteAsync(buffer, offset, count, cancellationToken); return WriteStream.WriteAsync(buffer, offset, count, cancellationToken);
} }
public override Task FlushAsync(CancellationToken cancellationToken) public override Task FlushAsync(CancellationToken cancellationToken)
{ {
return _writeStream.FlushAsync(cancellationToken); return WriteStream.FlushAsync(cancellationToken);
} }
#endif #endif
public override void Flush() public override void Flush()
{ {
_writeStream.Flush(); WriteStream.Flush();
} }
#endregion Write #endregion Write
@ -167,8 +168,8 @@ namespace Microsoft.AspNetCore.WebSockets.Test
{ {
if (disposing) if (disposing)
{ {
_readStream.Dispose(); ReadStream.Dispose();
_writeStream.Dispose(); WriteStream.Dispose();
} }
base.Dispose(disposing); base.Dispose(disposing);
} }

View File

@ -73,5 +73,35 @@ namespace Microsoft.AspNetCore.WebSockets.Test
Assert.Equal(WebSocketMessageType.Binary, result.MessageType); Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
Assert.Equal(sendBuffer, receiveBuffer.Take(result.Count).ToArray()); Assert.Equal(sendBuffer, receiveBuffer.Take(result.Count).ToArray());
} }
[Fact]
public async Task ThrowsWhenUnderlyingStreamClosed()
{
var pair = WebSocketPair.Create();
var sendBuffer = new byte[] { 0xde, 0xad, 0xbe, 0xef };
await pair.ServerSocket.SendAsync(new ArraySegment<byte>(sendBuffer), WebSocketMessageType.Binary, endOfMessage: true, cancellationToken: CancellationToken.None);
var receiveBuffer = new byte[32];
var result = await pair.ClientSocket.ReceiveAsync(new ArraySegment<byte>(receiveBuffer), CancellationToken.None);
Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
// Close the client socket's read end
pair.ClientStream.ReadStream.End();
// Assert.Throws doesn't support async :(
try
{
await pair.ClientSocket.ReceiveAsync(new ArraySegment<byte>(receiveBuffer), CancellationToken.None);
// The exception should prevent this line from running
Assert.False(true, "Expected an exception to be thrown!");
}
catch (WebSocketException ex)
{
Assert.Equal(WebSocketError.ConnectionClosedPrematurely, ex.WebSocketErrorCode);
}
}
} }
} }

View File

@ -8,9 +8,13 @@ namespace Microsoft.AspNetCore.WebSockets.Test
{ {
public WebSocket ClientSocket { get; } public WebSocket ClientSocket { get; }
public WebSocket ServerSocket { get; } public WebSocket ServerSocket { get; }
public DuplexStream ServerStream { get; }
public DuplexStream ClientStream { get; }
public WebSocketPair(WebSocket clientSocket, WebSocket serverSocket) public WebSocketPair(DuplexStream serverStream, DuplexStream clientStream, WebSocket clientSocket, WebSocket serverSocket)
{ {
ClientStream = clientStream;
ServerStream = serverStream;
ClientSocket = clientSocket; ClientSocket = clientSocket;
ServerSocket = serverSocket; ServerSocket = serverSocket;
} }
@ -22,6 +26,8 @@ namespace Microsoft.AspNetCore.WebSockets.Test
var clientStream = serverStream.CreateReverseDuplexStream(); var clientStream = serverStream.CreateReverseDuplexStream();
return new WebSocketPair( return new WebSocketPair(
serverStream,
clientStream,
clientSocket: WebSocketFactory.CreateClientWebSocket(clientStream, null, TimeSpan.FromMinutes(2), 1024), clientSocket: WebSocketFactory.CreateClientWebSocket(clientStream, null, TimeSpan.FromMinutes(2), 1024),
serverSocket: WebSocketFactory.CreateServerWebSocket(serverStream, null, TimeSpan.FromMinutes(2), 1024)); serverSocket: WebSocketFactory.CreateServerWebSocket(serverStream, null, TimeSpan.FromMinutes(2), 1024));
} }