diff --git a/samples/EchoApp/Startup.cs b/samples/EchoApp/Startup.cs index ee3fb0ee91..729716693c 100644 --- a/samples/EchoApp/Startup.cs +++ b/samples/EchoApp/Startup.cs @@ -24,7 +24,7 @@ namespace EchoApp // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. public void Configure(IApplicationBuilder app, IHostingEnvironment env, ILoggerFactory loggerFactory) { - loggerFactory.AddConsole(); + loggerFactory.AddConsole(LogLevel.Debug); if (env.IsDevelopment()) { @@ -38,7 +38,7 @@ namespace EchoApp if (context.WebSockets.IsWebSocketRequest) { var webSocket = await context.WebSockets.AcceptWebSocketAsync(); - await Echo(webSocket); + await Echo(context, webSocket, loggerFactory.CreateLogger("Echo")); } else { @@ -49,27 +49,57 @@ namespace EchoApp app.UseFileServer(); } - private async Task Echo(WebSocket webSocket) + private async Task Echo(HttpContext context, WebSocket webSocket, ILogger logger) { var buffer = new byte[1024 * 4]; var result = await webSocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + LogFrame(logger, result, buffer); while (!result.CloseStatus.HasValue) { // If the client send "ServerClose", then they want a server-originated close to occur - if(result.MessageType == WebSocketMessageType.Text) + string content = "<>"; + if (result.MessageType == WebSocketMessageType.Text) { - var str = Encoding.UTF8.GetString(buffer, 0, result.Count); - if(str.Equals("ServerClose")) + content = Encoding.UTF8.GetString(buffer, 0, result.Count); + if (content.Equals("ServerClose")) { await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing from Server", CancellationToken.None); + logger.LogDebug($"Sent Frame Close: {WebSocketCloseStatus.NormalClosure} Closing from Server"); return; } + else if (content.Equals("ServerAbort")) + { + context.Abort(); + } } await webSocket.SendAsync(new ArraySegment(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, CancellationToken.None); + logger.LogDebug($"Sent Frame {result.MessageType}: Len={result.Count}, Fin={result.EndOfMessage}: {content}"); + result = await webSocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + LogFrame(logger, result, buffer); } await webSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None); } + + private void LogFrame(ILogger logger, WebSocketReceiveResult frame, byte[] buffer) + { + var close = frame.CloseStatus != null; + string message; + if (close) + { + message = $"Close: {frame.CloseStatus.Value} {frame.CloseStatusDescription}"; + } + else + { + string content = "<>"; + if (frame.MessageType == WebSocketMessageType.Text) + { + content = Encoding.UTF8.GetString(buffer, 0, frame.Count); + } + message = $"{frame.MessageType}: Len={frame.Count}, Fin={frame.EndOfMessage}: {content}"; + } + logger.LogDebug("Received Frame " + message); + } } } diff --git a/samples/EchoApp/wwwroot/index.html b/samples/EchoApp/wwwroot/index.html index d80fd3b907..1663600a5e 100644 --- a/samples/EchoApp/wwwroot/index.html +++ b/samples/EchoApp/wwwroot/index.html @@ -25,7 +25,7 @@ -

Note: When connected to the default server (i.e. the server in the address bar ;)), the message "ServerClose" will cause the server to close the connection

+

Note: When connected to the default server (i.e. the server in the address bar ;)), the message "ServerClose" will cause the server to close the connection. Similarly, the message "ServerAbort" will cause the server to forcibly terminate the connection without a closing handshake

Communication Log

diff --git a/src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ManagedWebSocket.cs index 322997f1d1..8253aebde5 100644 --- a/src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -615,12 +615,7 @@ namespace System.Net.WebSockets // Make sure we have the first two bytes, which includes the start of the payload length. if (_receiveBufferCount < 2) { - await EnsureBufferContainsAsync(2, cancellationToken, throwOnPrematureClosure: false).ConfigureAwait(false); - if (_receiveBufferCount < 2) - { - // The connection closed; nothing more to read. - return new WebSocketReceiveResult(0, WebSocketMessageType.Text, true); - } + await EnsureBufferContainsAsync(2, cancellationToken, throwOnPrematureClosure: true).ConfigureAwait(false); } // Then make sure we have the full header based on the payload length. diff --git a/test/Microsoft.AspNetCore.WebSockets.Test/BufferStream.cs b/test/Microsoft.AspNetCore.WebSockets.Test/BufferStream.cs index 61a6699034..acbaa723cc 100644 --- a/test/Microsoft.AspNetCore.WebSockets.Test/BufferStream.cs +++ b/test/Microsoft.AspNetCore.WebSockets.Test/BufferStream.cs @@ -16,6 +16,7 @@ namespace Microsoft.AspNetCore.WebSockets.Test { private bool _disposed; private bool _aborted; + private bool _terminated; private Exception _abortException; private ConcurrentQueue _bufferedData; private ArraySegment _topBuffer; @@ -71,6 +72,14 @@ namespace Microsoft.AspNetCore.WebSockets.Test #endregion NotSupported + /// + /// Ends the stream, meaning all future reads will return '0'. + /// + public void End() + { + _terminated = true; + } + public override void Flush() { CheckDisposed(); @@ -95,6 +104,11 @@ namespace Microsoft.AspNetCore.WebSockets.Test public override int Read(byte[] buffer, int offset, int count) { + if(_terminated) + { + return 0; + } + VerifyBuffer(buffer, offset, count, allowEmpty: false); _readLock.Wait(); try @@ -154,6 +168,11 @@ namespace Microsoft.AspNetCore.WebSockets.Test public async override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { + if(_terminated) + { + return 0; + } + VerifyBuffer(buffer, offset, count, allowEmpty: false); CancellationTokenRegistration registration = cancellationToken.Register(Abort); await _readLock.WaitAsync(cancellationToken); diff --git a/test/Microsoft.AspNetCore.WebSockets.Test/DuplexStream.cs b/test/Microsoft.AspNetCore.WebSockets.Test/DuplexStream.cs index 528b20466e..0c3c4e0877 100644 --- a/test/Microsoft.AspNetCore.WebSockets.Test/DuplexStream.cs +++ b/test/Microsoft.AspNetCore.WebSockets.Test/DuplexStream.cs @@ -10,30 +10,31 @@ namespace Microsoft.AspNetCore.WebSockets.Test // A duplex wrapper around a read and write stream. public class DuplexStream : Stream { - private readonly Stream _readStream; - private readonly Stream _writeStream; + public BufferStream ReadStream { get; } + public BufferStream WriteStream { get; } public DuplexStream() : this (new BufferStream(), new BufferStream()) { } - public DuplexStream(Stream readStream, Stream writeStream) + public DuplexStream(BufferStream readStream, BufferStream writeStream) { - _readStream = readStream; - _writeStream = writeStream; + ReadStream = readStream; + WriteStream = writeStream; } public DuplexStream CreateReverseDuplexStream() { - return new DuplexStream(_writeStream, _readStream); + return new DuplexStream(WriteStream, ReadStream); } + #region Properties public override bool CanRead { - get { return _readStream.CanRead; } + get { return ReadStream.CanRead; } } public override bool CanSeek @@ -43,12 +44,12 @@ namespace Microsoft.AspNetCore.WebSockets.Test public override bool CanTimeout { - get { return _readStream.CanTimeout || _writeStream.CanTimeout; } + get { return ReadStream.CanTimeout || WriteStream.CanTimeout; } } public override bool CanWrite { - get { return _writeStream.CanWrite; } + get { return WriteStream.CanWrite; } } public override long Length @@ -64,14 +65,14 @@ namespace Microsoft.AspNetCore.WebSockets.Test public override int ReadTimeout { - get { return _readStream.ReadTimeout; } - set { _readStream.ReadTimeout = value; } + get { return ReadStream.ReadTimeout; } + set { ReadStream.ReadTimeout = value; } } public override int WriteTimeout { - get { return _writeStream.WriteTimeout; } - set { _writeStream.WriteTimeout = value; } + get { return WriteStream.WriteTimeout; } + set { WriteStream.WriteTimeout = value; } } #endregion Properties @@ -90,33 +91,33 @@ namespace Microsoft.AspNetCore.WebSockets.Test public override int Read(byte[] buffer, int offset, int count) { - return _readStream.Read(buffer, offset, count); + return ReadStream.Read(buffer, offset, count); } #if !NETCOREAPP1_0 public override int ReadByte() { - return _readStream.ReadByte(); + return ReadStream.ReadByte(); } public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) { - return _readStream.BeginRead(buffer, offset, count, callback, state); + return ReadStream.BeginRead(buffer, offset, count, callback, state); } public override int EndRead(IAsyncResult asyncResult) { - return _readStream.EndRead(asyncResult); + return ReadStream.EndRead(asyncResult); } public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - return _readStream.ReadAsync(buffer, offset, count, cancellationToken); + return ReadStream.ReadAsync(buffer, offset, count, cancellationToken); } public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) { - return _readStream.CopyToAsync(destination, bufferSize, cancellationToken); + return ReadStream.CopyToAsync(destination, bufferSize, cancellationToken); } #endif @@ -126,39 +127,39 @@ namespace Microsoft.AspNetCore.WebSockets.Test public override void Write(byte[] buffer, int offset, int count) { - _writeStream.Write(buffer, offset, count); + WriteStream.Write(buffer, offset, count); } #if !NETCOREAPP1_0 public override void WriteByte(byte value) { - _writeStream.WriteByte(value); + WriteStream.WriteByte(value); } public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) { - return _writeStream.BeginWrite(buffer, offset, count, callback, state); + return WriteStream.BeginWrite(buffer, offset, count, callback, state); } public override void EndWrite(IAsyncResult asyncResult) { - _writeStream.EndWrite(asyncResult); + WriteStream.EndWrite(asyncResult); } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - return _writeStream.WriteAsync(buffer, offset, count, cancellationToken); + return WriteStream.WriteAsync(buffer, offset, count, cancellationToken); } public override Task FlushAsync(CancellationToken cancellationToken) { - return _writeStream.FlushAsync(cancellationToken); + return WriteStream.FlushAsync(cancellationToken); } #endif public override void Flush() { - _writeStream.Flush(); + WriteStream.Flush(); } #endregion Write @@ -167,8 +168,8 @@ namespace Microsoft.AspNetCore.WebSockets.Test { if (disposing) { - _readStream.Dispose(); - _writeStream.Dispose(); + ReadStream.Dispose(); + WriteStream.Dispose(); } base.Dispose(disposing); } diff --git a/test/Microsoft.AspNetCore.WebSockets.Test/SendReceiveTests.cs b/test/Microsoft.AspNetCore.WebSockets.Test/SendReceiveTests.cs index 9580b7faad..afa9f70741 100644 --- a/test/Microsoft.AspNetCore.WebSockets.Test/SendReceiveTests.cs +++ b/test/Microsoft.AspNetCore.WebSockets.Test/SendReceiveTests.cs @@ -73,5 +73,35 @@ namespace Microsoft.AspNetCore.WebSockets.Test Assert.Equal(WebSocketMessageType.Binary, result.MessageType); Assert.Equal(sendBuffer, receiveBuffer.Take(result.Count).ToArray()); } + + [Fact] + public async Task ThrowsWhenUnderlyingStreamClosed() + { + var pair = WebSocketPair.Create(); + var sendBuffer = new byte[] { 0xde, 0xad, 0xbe, 0xef }; + + await pair.ServerSocket.SendAsync(new ArraySegment(sendBuffer), WebSocketMessageType.Binary, endOfMessage: true, cancellationToken: CancellationToken.None); + + var receiveBuffer = new byte[32]; + var result = await pair.ClientSocket.ReceiveAsync(new ArraySegment(receiveBuffer), CancellationToken.None); + + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + + // Close the client socket's read end + pair.ClientStream.ReadStream.End(); + + // Assert.Throws doesn't support async :( + try + { + await pair.ClientSocket.ReceiveAsync(new ArraySegment(receiveBuffer), CancellationToken.None); + + // The exception should prevent this line from running + Assert.False(true, "Expected an exception to be thrown!"); + } + catch (WebSocketException ex) + { + Assert.Equal(WebSocketError.ConnectionClosedPrematurely, ex.WebSocketErrorCode); + } + } } } \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.WebSockets.Test/WebSocketPair.cs b/test/Microsoft.AspNetCore.WebSockets.Test/WebSocketPair.cs index 936371acdd..aded688e42 100644 --- a/test/Microsoft.AspNetCore.WebSockets.Test/WebSocketPair.cs +++ b/test/Microsoft.AspNetCore.WebSockets.Test/WebSocketPair.cs @@ -8,9 +8,13 @@ namespace Microsoft.AspNetCore.WebSockets.Test { public WebSocket ClientSocket { get; } public WebSocket ServerSocket { get; } + public DuplexStream ServerStream { get; } + public DuplexStream ClientStream { get; } - public WebSocketPair(WebSocket clientSocket, WebSocket serverSocket) + public WebSocketPair(DuplexStream serverStream, DuplexStream clientStream, WebSocket clientSocket, WebSocket serverSocket) { + ClientStream = clientStream; + ServerStream = serverStream; ClientSocket = clientSocket; ServerSocket = serverSocket; } @@ -22,6 +26,8 @@ namespace Microsoft.AspNetCore.WebSockets.Test var clientStream = serverStream.CreateReverseDuplexStream(); return new WebSocketPair( + serverStream, + clientStream, clientSocket: WebSocketFactory.CreateClientWebSocket(clientStream, null, TimeSpan.FromMinutes(2), 1024), serverSocket: WebSocketFactory.CreateServerWebSocket(serverStream, null, TimeSpan.FromMinutes(2), 1024)); }