port dotnet/corefx#11348 (#114)

also adds some tests and extra features to the EchoApp test sample
This commit is contained in:
Andrew Stanton-Nurse 2016-09-02 14:02:06 -07:00 committed by GitHub
parent c51aec5292
commit b996ee39a4
7 changed files with 123 additions and 42 deletions

View File

@ -24,7 +24,7 @@ namespace EchoApp
// This method gets called by the runtime. Use this method to configure the HTTP request pipeline.
public void Configure(IApplicationBuilder app, IHostingEnvironment env, ILoggerFactory loggerFactory)
{
loggerFactory.AddConsole();
loggerFactory.AddConsole(LogLevel.Debug);
if (env.IsDevelopment())
{
@ -38,7 +38,7 @@ namespace EchoApp
if (context.WebSockets.IsWebSocketRequest)
{
var webSocket = await context.WebSockets.AcceptWebSocketAsync();
await Echo(webSocket);
await Echo(context, webSocket, loggerFactory.CreateLogger("Echo"));
}
else
{
@ -49,27 +49,57 @@ namespace EchoApp
app.UseFileServer();
}
private async Task Echo(WebSocket webSocket)
private async Task Echo(HttpContext context, WebSocket webSocket, ILogger logger)
{
var buffer = new byte[1024 * 4];
var result = await webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), CancellationToken.None);
LogFrame(logger, result, buffer);
while (!result.CloseStatus.HasValue)
{
// If the client send "ServerClose", then they want a server-originated close to occur
if(result.MessageType == WebSocketMessageType.Text)
string content = "<<binary>>";
if (result.MessageType == WebSocketMessageType.Text)
{
var str = Encoding.UTF8.GetString(buffer, 0, result.Count);
if(str.Equals("ServerClose"))
content = Encoding.UTF8.GetString(buffer, 0, result.Count);
if (content.Equals("ServerClose"))
{
await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing from Server", CancellationToken.None);
logger.LogDebug($"Sent Frame Close: {WebSocketCloseStatus.NormalClosure} Closing from Server");
return;
}
else if (content.Equals("ServerAbort"))
{
context.Abort();
}
}
await webSocket.SendAsync(new ArraySegment<byte>(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, CancellationToken.None);
logger.LogDebug($"Sent Frame {result.MessageType}: Len={result.Count}, Fin={result.EndOfMessage}: {content}");
result = await webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), CancellationToken.None);
LogFrame(logger, result, buffer);
}
await webSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None);
}
private void LogFrame(ILogger logger, WebSocketReceiveResult frame, byte[] buffer)
{
var close = frame.CloseStatus != null;
string message;
if (close)
{
message = $"Close: {frame.CloseStatus.Value} {frame.CloseStatusDescription}";
}
else
{
string content = "<<binary>>";
if (frame.MessageType == WebSocketMessageType.Text)
{
content = Encoding.UTF8.GetString(buffer, 0, frame.Count);
}
message = $"{frame.MessageType}: Len={frame.Count}, Fin={frame.EndOfMessage}: {content}";
}
logger.LogDebug("Received Frame " + message);
}
}
}

View File

@ -25,7 +25,7 @@
<button id="closeButton" disabled>Close Socket</button>
</div>
<p>Note: When connected to the default server (i.e. the server in the address bar ;)), the message "ServerClose" will cause the server to close the connection</p>
<p>Note: When connected to the default server (i.e. the server in the address bar ;)), the message "ServerClose" will cause the server to close the connection. Similarly, the message "ServerAbort" will cause the server to forcibly terminate the connection without a closing handshake</p>
<h2>Communication Log</h2>
<table style="width: 800px">

View File

@ -615,12 +615,7 @@ namespace System.Net.WebSockets
// Make sure we have the first two bytes, which includes the start of the payload length.
if (_receiveBufferCount < 2)
{
await EnsureBufferContainsAsync(2, cancellationToken, throwOnPrematureClosure: false).ConfigureAwait(false);
if (_receiveBufferCount < 2)
{
// The connection closed; nothing more to read.
return new WebSocketReceiveResult(0, WebSocketMessageType.Text, true);
}
await EnsureBufferContainsAsync(2, cancellationToken, throwOnPrematureClosure: true).ConfigureAwait(false);
}
// Then make sure we have the full header based on the payload length.

View File

@ -16,6 +16,7 @@ namespace Microsoft.AspNetCore.WebSockets.Test
{
private bool _disposed;
private bool _aborted;
private bool _terminated;
private Exception _abortException;
private ConcurrentQueue<byte[]> _bufferedData;
private ArraySegment<byte> _topBuffer;
@ -71,6 +72,14 @@ namespace Microsoft.AspNetCore.WebSockets.Test
#endregion NotSupported
/// <summary>
/// Ends the stream, meaning all future reads will return '0'.
/// </summary>
public void End()
{
_terminated = true;
}
public override void Flush()
{
CheckDisposed();
@ -95,6 +104,11 @@ namespace Microsoft.AspNetCore.WebSockets.Test
public override int Read(byte[] buffer, int offset, int count)
{
if(_terminated)
{
return 0;
}
VerifyBuffer(buffer, offset, count, allowEmpty: false);
_readLock.Wait();
try
@ -154,6 +168,11 @@ namespace Microsoft.AspNetCore.WebSockets.Test
public async override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
if(_terminated)
{
return 0;
}
VerifyBuffer(buffer, offset, count, allowEmpty: false);
CancellationTokenRegistration registration = cancellationToken.Register(Abort);
await _readLock.WaitAsync(cancellationToken);

View File

@ -10,30 +10,31 @@ namespace Microsoft.AspNetCore.WebSockets.Test
// A duplex wrapper around a read and write stream.
public class DuplexStream : Stream
{
private readonly Stream _readStream;
private readonly Stream _writeStream;
public BufferStream ReadStream { get; }
public BufferStream WriteStream { get; }
public DuplexStream()
: this (new BufferStream(), new BufferStream())
{
}
public DuplexStream(Stream readStream, Stream writeStream)
public DuplexStream(BufferStream readStream, BufferStream writeStream)
{
_readStream = readStream;
_writeStream = writeStream;
ReadStream = readStream;
WriteStream = writeStream;
}
public DuplexStream CreateReverseDuplexStream()
{
return new DuplexStream(_writeStream, _readStream);
return new DuplexStream(WriteStream, ReadStream);
}
#region Properties
public override bool CanRead
{
get { return _readStream.CanRead; }
get { return ReadStream.CanRead; }
}
public override bool CanSeek
@ -43,12 +44,12 @@ namespace Microsoft.AspNetCore.WebSockets.Test
public override bool CanTimeout
{
get { return _readStream.CanTimeout || _writeStream.CanTimeout; }
get { return ReadStream.CanTimeout || WriteStream.CanTimeout; }
}
public override bool CanWrite
{
get { return _writeStream.CanWrite; }
get { return WriteStream.CanWrite; }
}
public override long Length
@ -64,14 +65,14 @@ namespace Microsoft.AspNetCore.WebSockets.Test
public override int ReadTimeout
{
get { return _readStream.ReadTimeout; }
set { _readStream.ReadTimeout = value; }
get { return ReadStream.ReadTimeout; }
set { ReadStream.ReadTimeout = value; }
}
public override int WriteTimeout
{
get { return _writeStream.WriteTimeout; }
set { _writeStream.WriteTimeout = value; }
get { return WriteStream.WriteTimeout; }
set { WriteStream.WriteTimeout = value; }
}
#endregion Properties
@ -90,33 +91,33 @@ namespace Microsoft.AspNetCore.WebSockets.Test
public override int Read(byte[] buffer, int offset, int count)
{
return _readStream.Read(buffer, offset, count);
return ReadStream.Read(buffer, offset, count);
}
#if !NETCOREAPP1_0
public override int ReadByte()
{
return _readStream.ReadByte();
return ReadStream.ReadByte();
}
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
return _readStream.BeginRead(buffer, offset, count, callback, state);
return ReadStream.BeginRead(buffer, offset, count, callback, state);
}
public override int EndRead(IAsyncResult asyncResult)
{
return _readStream.EndRead(asyncResult);
return ReadStream.EndRead(asyncResult);
}
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _readStream.ReadAsync(buffer, offset, count, cancellationToken);
return ReadStream.ReadAsync(buffer, offset, count, cancellationToken);
}
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
return _readStream.CopyToAsync(destination, bufferSize, cancellationToken);
return ReadStream.CopyToAsync(destination, bufferSize, cancellationToken);
}
#endif
@ -126,39 +127,39 @@ namespace Microsoft.AspNetCore.WebSockets.Test
public override void Write(byte[] buffer, int offset, int count)
{
_writeStream.Write(buffer, offset, count);
WriteStream.Write(buffer, offset, count);
}
#if !NETCOREAPP1_0
public override void WriteByte(byte value)
{
_writeStream.WriteByte(value);
WriteStream.WriteByte(value);
}
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
return _writeStream.BeginWrite(buffer, offset, count, callback, state);
return WriteStream.BeginWrite(buffer, offset, count, callback, state);
}
public override void EndWrite(IAsyncResult asyncResult)
{
_writeStream.EndWrite(asyncResult);
WriteStream.EndWrite(asyncResult);
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _writeStream.WriteAsync(buffer, offset, count, cancellationToken);
return WriteStream.WriteAsync(buffer, offset, count, cancellationToken);
}
public override Task FlushAsync(CancellationToken cancellationToken)
{
return _writeStream.FlushAsync(cancellationToken);
return WriteStream.FlushAsync(cancellationToken);
}
#endif
public override void Flush()
{
_writeStream.Flush();
WriteStream.Flush();
}
#endregion Write
@ -167,8 +168,8 @@ namespace Microsoft.AspNetCore.WebSockets.Test
{
if (disposing)
{
_readStream.Dispose();
_writeStream.Dispose();
ReadStream.Dispose();
WriteStream.Dispose();
}
base.Dispose(disposing);
}

View File

@ -73,5 +73,35 @@ namespace Microsoft.AspNetCore.WebSockets.Test
Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
Assert.Equal(sendBuffer, receiveBuffer.Take(result.Count).ToArray());
}
[Fact]
public async Task ThrowsWhenUnderlyingStreamClosed()
{
var pair = WebSocketPair.Create();
var sendBuffer = new byte[] { 0xde, 0xad, 0xbe, 0xef };
await pair.ServerSocket.SendAsync(new ArraySegment<byte>(sendBuffer), WebSocketMessageType.Binary, endOfMessage: true, cancellationToken: CancellationToken.None);
var receiveBuffer = new byte[32];
var result = await pair.ClientSocket.ReceiveAsync(new ArraySegment<byte>(receiveBuffer), CancellationToken.None);
Assert.Equal(WebSocketMessageType.Binary, result.MessageType);
// Close the client socket's read end
pair.ClientStream.ReadStream.End();
// Assert.Throws doesn't support async :(
try
{
await pair.ClientSocket.ReceiveAsync(new ArraySegment<byte>(receiveBuffer), CancellationToken.None);
// The exception should prevent this line from running
Assert.False(true, "Expected an exception to be thrown!");
}
catch (WebSocketException ex)
{
Assert.Equal(WebSocketError.ConnectionClosedPrematurely, ex.WebSocketErrorCode);
}
}
}
}

View File

@ -8,9 +8,13 @@ namespace Microsoft.AspNetCore.WebSockets.Test
{
public WebSocket ClientSocket { get; }
public WebSocket ServerSocket { get; }
public DuplexStream ServerStream { get; }
public DuplexStream ClientStream { get; }
public WebSocketPair(WebSocket clientSocket, WebSocket serverSocket)
public WebSocketPair(DuplexStream serverStream, DuplexStream clientStream, WebSocket clientSocket, WebSocket serverSocket)
{
ClientStream = clientStream;
ServerStream = serverStream;
ClientSocket = clientSocket;
ServerSocket = serverSocket;
}
@ -22,6 +26,8 @@ namespace Microsoft.AspNetCore.WebSockets.Test
var clientStream = serverStream.CreateReverseDuplexStream();
return new WebSocketPair(
serverStream,
clientStream,
clientSocket: WebSocketFactory.CreateClientWebSocket(clientStream, null, TimeSpan.FromMinutes(2), 1024),
serverSocket: WebSocketFactory.CreateServerWebSocket(serverStream, null, TimeSpan.FromMinutes(2), 1024));
}