From ee9feedc278e68239502ec32c821ebcf42631bbb Mon Sep 17 00:00:00 2001 From: Nate McMaster Date: Mon, 17 Apr 2017 17:09:48 -0700 Subject: [PATCH] Improve implementation of IHttpUpgradeFeature After upgrade has been accepted by the server: - Reads to HttpRequest.Body always return 0 - Writes to HttpResponse.Body always throw - The only valid way to communicate is to use the stream returned by IHttpUpgradeFeature.UpgradeAsync() Also, Kestrel returns HTTP 400 if requests attempt to send a request body along with Connection: Upgrade --- .gitignore | 1 + .../BadHttpRequestException.cs | 3 + .../CoreStrings.resx | 123 +++++++++++++ .../Internal/Http/Frame.FeatureCollection.cs | 2 +- .../Internal/Http/Frame.cs | 30 +-- .../Internal/Http/FrameDuplexStream.cs | 2 +- .../Internal/Http/FrameRequestStream.cs | 29 +-- .../Internal/Http/FrameResponseStream.cs | 29 +-- .../Internal/Http/MessageBody.cs | 30 ++- .../Internal/Http/RequestRejectionReason.cs | 1 + .../Internal/Infrastructure/ReadOnlyStream.cs | 29 +++ .../Internal/Infrastructure/Streams.cs | 77 +++++++- .../Infrastructure/ThrowingWriteOnlyStream.cs | 45 +++++ .../Internal/Infrastructure/WrappingStream.cs | 140 ++++++++++++++ .../Infrastructure/WriteOnlyStream.cs | 29 +++ ...soft.AspNetCore.Server.Kestrel.Core.csproj | 7 + .../Properties/CoreStrings.Designer.cs | 44 +++++ .../FrameRequestStreamTests.cs | 1 + .../FrameTests.cs | 3 - .../StreamsTests.cs | 89 +++++++++ .../ThrowingWriteOnlyStreamTests.cs | 29 +++ .../UpgradeTests.cs | 174 ++++++++++++++++++ 22 files changed, 824 insertions(+), 93 deletions(-) create mode 100644 src/Microsoft.AspNetCore.Server.Kestrel.Core/CoreStrings.resx create mode 100644 src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/ReadOnlyStream.cs create mode 100644 src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/ThrowingWriteOnlyStream.cs create mode 100644 src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/WrappingStream.cs create mode 100644 src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/WriteOnlyStream.cs create mode 100644 src/Microsoft.AspNetCore.Server.Kestrel.Core/Properties/CoreStrings.Designer.cs create mode 100644 test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/StreamsTests.cs create mode 100644 test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/ThrowingWriteOnlyStreamTests.cs create mode 100644 test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/UpgradeTests.cs diff --git a/.gitignore b/.gitignore index 7a34bbac8a..708c4155fa 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ TestResults/ .nuget/ *.sln.ide/ _ReSharper.*/ +.idea/ packages/ artifacts/ PublishProfiles/ diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/BadHttpRequestException.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/BadHttpRequestException.cs index 6bc10f460d..c7d2e69042 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/BadHttpRequestException.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/BadHttpRequestException.cs @@ -89,6 +89,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core case RequestRejectionReason.InvalidHostHeader: ex = new BadHttpRequestException("Invalid Host header.", StatusCodes.Status400BadRequest); break; + case RequestRejectionReason.UpgradeRequestCannotHavePayload: + ex = new BadHttpRequestException("Requests with 'Connection: Upgrade' cannot have content in the request body.", StatusCodes.Status400BadRequest); + break; default: ex = new BadHttpRequestException("Bad request.", StatusCodes.Status400BadRequest); break; diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/CoreStrings.resx b/src/Microsoft.AspNetCore.Server.Kestrel.Core/CoreStrings.resx new file mode 100644 index 0000000000..f1605f6630 --- /dev/null +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/CoreStrings.resx @@ -0,0 +1,123 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Cannot write to response body after connection has been upgraded. + + \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/Frame.FeatureCollection.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/Frame.FeatureCollection.cs index 58180c0f69..bd1d8e2a56 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/Frame.FeatureCollection.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/Frame.FeatureCollection.cs @@ -244,7 +244,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http await FlushAsync(default(CancellationToken)); - return DuplexStream; + return _frameStreams.Upgrade(); } IEnumerator> IEnumerable>.GetEnumerator() => FastEnumerable().GetEnumerator(); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/Frame.cs index 5353364724..a63fec9407 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/Frame.cs @@ -241,8 +241,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http public IHeaderDictionary ResponseHeaders { get; set; } public Stream ResponseBody { get; set; } - public Stream DuplexStream { get; set; } - public CancellationToken RequestAborted { get @@ -323,31 +321,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http _frameStreams = new Streams(this); } - RequestBody = _frameStreams.RequestBody; - ResponseBody = _frameStreams.ResponseBody; - DuplexStream = _frameStreams.DuplexStream; - - _frameStreams.RequestBody.StartAcceptingReads(messageBody); - _frameStreams.ResponseBody.StartAcceptingWrites(); + (RequestBody, ResponseBody) = _frameStreams.Start(messageBody); } - public void PauseStreams() - { - _frameStreams.RequestBody.PauseAcceptingReads(); - _frameStreams.ResponseBody.PauseAcceptingWrites(); - } + public void PauseStreams() => _frameStreams.Pause(); - public void ResumeStreams() - { - _frameStreams.RequestBody.ResumeAcceptingReads(); - _frameStreams.ResponseBody.ResumeAcceptingWrites(); - } + public void ResumeStreams() => _frameStreams.Resume(); - public void StopStreams() - { - _frameStreams.RequestBody.StopAcceptingReads(); - _frameStreams.ResponseBody.StopAcceptingWrites(); - } + public void StopStreams() => _frameStreams.Stop(); public void Reset() { @@ -455,8 +436,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { _requestProcessingStopping = true; - _frameStreams?.RequestBody.Abort(error); - _frameStreams?.ResponseBody.Abort(); + _frameStreams?.Abort(error); LifetimeControl.End(ProduceEndType.SocketDisconnect); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameDuplexStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameDuplexStream.cs index b427651405..d1b7cab14d 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameDuplexStream.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameDuplexStream.cs @@ -10,7 +10,7 @@ using System.Threading.Tasks; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { - class FrameDuplexStream : Stream + internal class FrameDuplexStream : Stream { private readonly Stream _requestStream; private readonly Stream _responseStream; diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameRequestStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameRequestStream.cs index 93f0db1018..35bbf73302 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameRequestStream.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameRequestStream.cs @@ -5,11 +5,12 @@ using System; using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.Extensions.Internal; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { - class FrameRequestStream : Stream + internal class FrameRequestStream : ReadOnlyStream { private MessageBody _body; private FrameStreamState _state; @@ -20,30 +21,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http _state = FrameStreamState.Closed; } - public override bool CanRead => true; - public override bool CanSeek => false; - public override bool CanWrite => false; - public override long Length - { - get - { - throw new NotSupportedException(); - } - } + => throw new NotSupportedException(); public override long Position { - get - { - throw new NotSupportedException(); - } - set - { - throw new NotSupportedException(); - } + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); } public override void Flush() @@ -145,11 +131,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return task; } - public override void Write(byte[] buffer, int offset, int count) - { - throw new NotSupportedException(); - } - public void StartAcceptingReads(MessageBody body) { // Only start if not aborted diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameResponseStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameResponseStream.cs index 699aa90f92..5f0931759b 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameResponseStream.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameResponseStream.cs @@ -5,10 +5,11 @@ using System; using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { - class FrameResponseStream : Stream + internal class FrameResponseStream : WriteOnlyStream { private IFrameControl _frameControl; private FrameStreamState _state; @@ -19,30 +20,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http _state = FrameStreamState.Closed; } - public override bool CanRead => false; - public override bool CanSeek => false; - public override bool CanWrite => true; - public override long Length - { - get - { - throw new NotSupportedException(); - } - } + => throw new NotSupportedException(); public override long Position { - get - { - throw new NotSupportedException(); - } - set - { - throw new NotSupportedException(); - } + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); } public override void Flush() @@ -72,11 +58,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http throw new NotSupportedException(); } - public override int Read(byte[] buffer, int offset, int count) - { - throw new NotSupportedException(); - } - public override void Write(byte[] buffer, int offset, int count) { ValidateState(default(CancellationToken)); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/MessageBody.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/MessageBody.cs index 97fff771fa..ca26fe9095 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/MessageBody.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/MessageBody.cs @@ -24,6 +24,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http _context = context; } + public static MessageBody ZeroContentLengthClose => _zeroContentLengthClose; + public bool RequestKeepAlive { get; protected set; } public bool RequestUpgrade { get; protected set; } @@ -237,15 +239,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http var keepAlive = httpVersion != HttpVersion.Http10; var connection = headers.HeaderConnection; + var upgrade = false; if (connection.Count > 0) { var connectionOptions = FrameHeaders.ParseConnection(connection); - if ((connectionOptions & ConnectionOptions.Upgrade) == ConnectionOptions.Upgrade) - { - return new ForRemainingData(true, context); - } - + upgrade = (connectionOptions & ConnectionOptions.Upgrade) == ConnectionOptions.Upgrade; keepAlive = (connectionOptions & ConnectionOptions.KeepAlive) == ConnectionOptions.KeepAlive; } @@ -265,16 +264,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http context.RejectRequest(RequestRejectionReason.FinalTransferCodingNotChunked, transferEncoding.ToString()); } + if (upgrade) + { + context.RejectRequest(RequestRejectionReason.UpgradeRequestCannotHavePayload); + } + return new ForChunkedEncoding(keepAlive, headers, context); } if (headers.ContentLength.HasValue) { var contentLength = headers.ContentLength.Value; + if (contentLength == 0) { return keepAlive ? _zeroContentLengthKeepAlive : _zeroContentLengthClose; } + else if (upgrade) + { + context.RejectRequest(RequestRejectionReason.UpgradeRequestCannotHavePayload); + } return new ForContentLength(keepAlive, contentLength, context); } @@ -291,15 +300,20 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } } + if (upgrade) + { + return new ForUpgrade(context); + } + return keepAlive ? _zeroContentLengthKeepAlive : _zeroContentLengthClose; } - private class ForRemainingData : MessageBody + private class ForUpgrade : MessageBody { - public ForRemainingData(bool upgrade, Frame context) + public ForUpgrade(Frame context) : base(context) { - RequestUpgrade = upgrade; + RequestUpgrade = true; } protected override ValueTask> PeekAsync(CancellationToken cancellationToken) diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/RequestRejectionReason.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/RequestRejectionReason.cs index ea03aab6fa..730badcff8 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/RequestRejectionReason.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/RequestRejectionReason.cs @@ -30,5 +30,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http MissingHostHeader, MultipleHostHeaders, InvalidHostHeader, + UpgradeRequestCannotHavePayload, } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/ReadOnlyStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/ReadOnlyStream.cs new file mode 100644 index 0000000000..cf4b02a41f --- /dev/null +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/ReadOnlyStream.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public abstract class ReadOnlyStream : Stream + { + public override bool CanRead => true; + + public override bool CanWrite => false; + + public override int WriteTimeout + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + => throw new NotSupportedException(); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => throw new NotSupportedException(); + } +} diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/Streams.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/Streams.cs index 61654b3e56..c81f9c3985 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/Streams.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/Streams.cs @@ -1,21 +1,84 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.IO; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure { - class Streams + internal class Streams { + private static readonly ThrowingWriteOnlyStream _throwingResponseStream + = new ThrowingWriteOnlyStream(new InvalidOperationException(CoreStrings.ResponseStreamWasUpgraded)); + private readonly FrameResponseStream _response; + private readonly FrameRequestStream _request; + private readonly WrappingStream _upgradeableResponse; + private readonly FrameRequestStream _emptyRequest; + private readonly Stream _upgradeStream; + public Streams(IFrameControl frameControl) { - RequestBody = new FrameRequestStream(); - ResponseBody = new FrameResponseStream(frameControl); - DuplexStream = new FrameDuplexStream(RequestBody, ResponseBody); + _request = new FrameRequestStream(); + _emptyRequest = new FrameRequestStream(); + _response = new FrameResponseStream(frameControl); + _upgradeableResponse = new WrappingStream(_response); + _upgradeStream = new FrameDuplexStream(_request, _response); } - public FrameRequestStream RequestBody { get; } - public FrameResponseStream ResponseBody { get; } - public FrameDuplexStream DuplexStream { get; } + public Stream Upgrade() + { + // causes writes to context.Response.Body to throw + _upgradeableResponse.SetInnerStream(_throwingResponseStream); + // _upgradeStream always uses _response + return _upgradeStream; + } + + public (Stream request, Stream response) Start(MessageBody body) + { + _request.StartAcceptingReads(body); + _emptyRequest.StartAcceptingReads(MessageBody.ZeroContentLengthClose); + _response.StartAcceptingWrites(); + + if (body.RequestUpgrade) + { + // until Upgrade() is called, context.Response.Body should use the normal output stream + _upgradeableResponse.SetInnerStream(_response); + // upgradeable requests should never have a request body + return (_emptyRequest, _upgradeableResponse); + } + else + { + return (_request, _response); + } + } + + public void Pause() + { + _request.PauseAcceptingReads(); + _emptyRequest.PauseAcceptingReads(); + _response.PauseAcceptingWrites(); + } + + public void Resume() + { + _request.ResumeAcceptingReads(); + _emptyRequest.ResumeAcceptingReads(); + _response.ResumeAcceptingWrites(); + } + + public void Stop() + { + _request.StopAcceptingReads(); + _emptyRequest.StopAcceptingReads(); + _response.StopAcceptingWrites(); + } + + public void Abort(Exception error) + { + _request.Abort(error); + _emptyRequest.Abort(error); + _response.Abort(); + } } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/ThrowingWriteOnlyStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/ThrowingWriteOnlyStream.cs new file mode 100644 index 0000000000..aff2b7a6d4 --- /dev/null +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/ThrowingWriteOnlyStream.cs @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public class ThrowingWriteOnlyStream : WriteOnlyStream + { + private readonly Exception _exception; + + public ThrowingWriteOnlyStream(Exception exception) + { + _exception = exception; + } + + public override bool CanSeek => false; + + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + => throw _exception; + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => throw _exception; + + public override void Flush() + => throw _exception; + + public override long Seek(long offset, SeekOrigin origin) + => throw new NotSupportedException(); + + public override void SetLength(long value) + => throw new NotSupportedException(); + } +} diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/WrappingStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/WrappingStream.cs new file mode 100644 index 0000000000..64fbd85526 --- /dev/null +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/WrappingStream.cs @@ -0,0 +1,140 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +#if NET46 +using System.Runtime.Remoting; +#endif +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + internal class WrappingStream : Stream + { + private Stream _inner; + private bool _disposed; + + public WrappingStream(Stream inner) + { + _inner = inner; + } + + public void SetInnerStream(Stream inner) + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(WrappingStream)); + } + + _inner = inner; + } + + public override bool CanRead => _inner.CanRead; + + public override bool CanSeek => _inner.CanSeek; + + public override bool CanWrite => _inner.CanWrite; + + public override bool CanTimeout => _inner.CanTimeout; + + public override long Length => _inner.Length; + + public override long Position + { + get => _inner.Position; + set => _inner.Position = value; + } + + public override int ReadTimeout + { + get => _inner.ReadTimeout; + set => _inner.ReadTimeout = value; + } + + public override int WriteTimeout + { + get => _inner.WriteTimeout; + set => _inner.WriteTimeout = value; + } + + public override void Flush() + => _inner.Flush(); + + public override Task FlushAsync(CancellationToken cancellationToken) + => _inner.FlushAsync(cancellationToken); + + public override int Read(byte[] buffer, int offset, int count) + => _inner.Read(buffer, offset, count); + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => _inner.ReadAsync(buffer, offset, count, cancellationToken); + + public override int ReadByte() + => _inner.ReadByte(); + + public override long Seek(long offset, SeekOrigin origin) + => _inner.Seek(offset, origin); + + public override void SetLength(long value) + => _inner.SetLength(value); + + public override void Write(byte[] buffer, int offset, int count) + => _inner.Write(buffer, offset, count); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => _inner.WriteAsync(buffer, offset, count, cancellationToken); + + public override void WriteByte(byte value) + => _inner.WriteByte(value); + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + => _inner.CopyToAsync(destination, bufferSize, cancellationToken); + +#if NET46 + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + => _inner.BeginRead(buffer, offset, count, callback, state); + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + => _inner.BeginWrite(buffer, offset, count, callback, state); + + public override int EndRead(IAsyncResult asyncResult) + => _inner.EndRead(asyncResult); + + public override void EndWrite(IAsyncResult asyncResult) + => _inner.EndWrite(asyncResult); + + public override ObjRef CreateObjRef(Type requestedType) + => _inner.CreateObjRef(requestedType); + + public override object InitializeLifetimeService() + => _inner.InitializeLifetimeService(); + + public override void Close() + => _inner.Close(); + +#elif NETSTANDARD1_3 +#else +#error Target framework should be updated +#endif + + public override bool Equals(object obj) + => _inner.Equals(obj); + + public override int GetHashCode() + => _inner.GetHashCode(); + + public override string ToString() + => _inner.ToString(); + + protected override void Dispose(bool disposing) + { + if (disposing) + { + _disposed = true; + _inner.Dispose(); + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/WriteOnlyStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/WriteOnlyStream.cs new file mode 100644 index 0000000000..c7042e2bb0 --- /dev/null +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Infrastructure/WriteOnlyStream.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure +{ + public abstract class WriteOnlyStream : Stream + { + public override bool CanRead => false; + + public override bool CanWrite => true; + + public override int ReadTimeout + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + => throw new NotSupportedException(); + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => throw new NotSupportedException(); + } +} diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Microsoft.AspNetCore.Server.Kestrel.Core.csproj b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Microsoft.AspNetCore.Server.Kestrel.Core.csproj index 20102b9aba..3e9d1e703d 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Microsoft.AspNetCore.Server.Kestrel.Core.csproj +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Microsoft.AspNetCore.Server.Kestrel.Core.csproj @@ -20,6 +20,7 @@ + @@ -33,4 +34,10 @@ + + + + + + diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Properties/CoreStrings.Designer.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Properties/CoreStrings.Designer.cs new file mode 100644 index 0000000000..0deecaa6d1 --- /dev/null +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Properties/CoreStrings.Designer.cs @@ -0,0 +1,44 @@ +// +namespace Microsoft.AspNetCore.Server.Kestrel.Core +{ + using System.Globalization; + using System.Reflection; + using System.Resources; + + internal static class CoreStrings + { + private static readonly ResourceManager _resourceManager + = new ResourceManager("Microsoft.AspNetCore.Server.Kestrel.Core.CoreStrings", typeof(CoreStrings).GetTypeInfo().Assembly); + + /// + /// Cannot write to response body after connection has been upgraded. + /// + internal static string ResponseStreamWasUpgraded + { + get => GetString("ResponseStreamWasUpgraded"); + } + + /// + /// Cannot write to response body after connection has been upgraded. + /// + internal static string FormatResponseStreamWasUpgraded() + => GetString("ResponseStreamWasUpgraded"); + + private static string GetString(string name, params string[] formatterNames) + { + var value = _resourceManager.GetString(name); + + System.Diagnostics.Debug.Assert(value != null); + + if (formatterNames != null) + { + for (var i = 0; i < formatterNames.Length; i++) + { + value = value.Replace("{" + formatterNames[i] + "}", "{" + i + "}"); + } + } + + return value; + } + } +} diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameRequestStreamTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameRequestStreamTests.cs index 9f723aa903..8adad18cca 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameRequestStreamTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameRequestStreamTests.cs @@ -3,6 +3,7 @@ using System; using System.IO; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Moq; diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameTests.cs index b6471a0228..d050e7f43d 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameTests.cs @@ -256,10 +256,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var originalRequestBody = _frame.RequestBody; var originalResponseBody = _frame.ResponseBody; - var originalDuplexStream = _frame.DuplexStream; _frame.RequestBody = new MemoryStream(); _frame.ResponseBody = new MemoryStream(); - _frame.DuplexStream = new MemoryStream(); // Act _frame.InitializeStreams(messageBody); @@ -267,7 +265,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests // Assert Assert.Same(originalRequestBody, _frame.RequestBody); Assert.Same(originalResponseBody, _frame.ResponseBody); - Assert.Same(originalDuplexStream, _frame.DuplexStream); } [Theory] diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/StreamsTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/StreamsTests.cs new file mode 100644 index 0000000000..d9f8854a14 --- /dev/null +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/StreamsTests.cs @@ -0,0 +1,89 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class StreamsTests + { + [Fact] + public async Task StreamsThrowAfterAbort() + { + var streams = new Streams(Mock.Of()); + var (request, response) = streams.Start(new MockMessageBody()); + + var ex = new Exception("My error"); + streams.Abort(ex); + + await response.WriteAsync(new byte[1], 0, 1); + Assert.Same(ex, + await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); + } + + [Fact] + public async Task StreamsThrowOnAbortAfterUpgrade() + { + var streams = new Streams(Mock.Of()); + var (request, response) = streams.Start(new MockMessageBody(upgradeable: true)); + + var upgrade = streams.Upgrade(); + var ex = new Exception("My error"); + streams.Abort(ex); + + var writeEx = await Assert.ThrowsAsync(() => response.WriteAsync(new byte[1], 0, 1)); + Assert.Equal(CoreStrings.ResponseStreamWasUpgraded, writeEx.Message); + + Assert.Same(ex, + await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); + + Assert.Same(ex, + await Assert.ThrowsAsync(() => upgrade.ReadAsync(new byte[1], 0, 1))); + + await upgrade.WriteAsync(new byte[1], 0, 1); + } + + [Fact] + public async Task StreamsThrowOnUpgradeAfterAbort() + { + var streams = new Streams(Mock.Of()); + + var (request, response) = streams.Start(new MockMessageBody(upgradeable: true)); + var ex = new Exception("My error"); + streams.Abort(ex); + + var upgrade = streams.Upgrade(); + + var writeEx = await Assert.ThrowsAsync(() => response.WriteAsync(new byte[1], 0, 1)); + Assert.Equal(CoreStrings.ResponseStreamWasUpgraded, writeEx.Message); + + Assert.Same(ex, + await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); + + Assert.Same(ex, + await Assert.ThrowsAsync(() => upgrade.ReadAsync(new byte[1], 0, 1))); + + await upgrade.WriteAsync(new byte[1], 0, 1); + } + + private class MockMessageBody : MessageBody + { + public MockMessageBody(bool upgradeable = false) + : base(null) + { + RequestUpgrade = upgradeable; + } + + protected override ValueTask> PeekAsync(CancellationToken cancellationToken) + { + return new ValueTask>(new ArraySegment(new byte[1])); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/ThrowingWriteOnlyStreamTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/ThrowingWriteOnlyStreamTests.cs new file mode 100644 index 0000000000..7e6490e4f3 --- /dev/null +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/ThrowingWriteOnlyStreamTests.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests +{ + public class ThrowingWriteOnlyStreamTests + { + [Fact] + public async Task ThrowsOnWrite() + { + var ex = new Exception("my error"); + var stream = new ThrowingWriteOnlyStream(ex); + + Assert.True(stream.CanWrite); + Assert.False(stream.CanRead); + Assert.False(stream.CanSeek); + Assert.False(stream.CanTimeout); + Assert.Same(ex, Assert.Throws(() => stream.Write(new byte[1], 0, 1))); + Assert.Same(ex, await Assert.ThrowsAsync(() => stream.WriteAsync(new byte[1], 0, 1))); + Assert.Same(ex, Assert.Throws(() => stream.Flush())); + Assert.Same(ex, await Assert.ThrowsAsync(() => stream.FlushAsync())); + } + } +} diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/UpgradeTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/UpgradeTests.cs new file mode 100644 index 0000000000..d56ecf4d00 --- /dev/null +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/UpgradeTests.cs @@ -0,0 +1,174 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Internal; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests +{ + public class UpgradeTests + { + [Fact] + public async Task ResponseThrowsAfterUpgrade() + { + var upgrade = new TaskCompletionSource(); + using (var server = new TestServer(async context => + { + var feature = context.Features.Get(); + var stream = await feature.UpgradeAsync(); + + var ex = Assert.Throws(() => context.Response.Body.WriteByte((byte)' ')); + Assert.Equal(CoreStrings.ResponseStreamWasUpgraded, ex.Message); + + using (var writer = new StreamWriter(stream)) + { + writer.WriteLine("New protocol data"); + } + + upgrade.TrySetResult(true); + })) + { + using (var connection = server.CreateConnection()) + { + await connection.Send("GET / HTTP/1.1", + "Host:", + "Connection: Upgrade", + "", + ""); + await connection.Receive("HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {server.Context.DateHeaderValue}", + "", + ""); + + await connection.Receive("New protocol data"); + await upgrade.Task.TimeoutAfter(TimeSpan.FromSeconds(30)); + } + } + } + + [Fact] + public async Task RequestBodyAlwaysEmptyAfterUpgrade() + { + const string send = "Custom protocol send"; + const string recv = "Custom protocol recv"; + + var upgrade = new TaskCompletionSource(); + using (var server = new TestServer(async context => + { + try + { + var feature = context.Features.Get(); + var stream = await feature.UpgradeAsync(); + + var buffer = new byte[128]; + var read = await context.Request.Body.ReadAsync(buffer, 0, 128).TimeoutAfter(TimeSpan.FromSeconds(10)); + Assert.Equal(0, read); + + using (var reader = new StreamReader(stream)) + using (var writer = new StreamWriter(stream)) + { + var line = await reader.ReadLineAsync(); + Assert.Equal(send, line); + await writer.WriteLineAsync(recv); + } + + upgrade.TrySetResult(true); + } + catch (Exception ex) + { + upgrade.SetException(ex); + throw; + } + })) + { + using (var connection = server.CreateConnection()) + { + await connection.Send("GET / HTTP/1.1", + "Host:", + "Connection: Upgrade", + "", + ""); + + await connection.Receive("HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {server.Context.DateHeaderValue}", + "", + ""); + + await connection.Send(send + "\r\n"); + await connection.Receive(recv); + + await upgrade.Task.TimeoutAfter(TimeSpan.FromSeconds(30)); + } + } + } + + [Fact] + public async Task RejectsRequestWithContentLengthAndUpgrade() + { + using (var server = new TestServer(context => TaskCache.CompletedTask)) + using (var connection = server.CreateConnection()) + { + await connection.Send("POST / HTTP/1.1", + "Host:", + "Content-Length: 1", + "Connection: Upgrade", + "", + "1"); + + await connection.Receive("HTTP/1.1 400 Bad Request"); + } + } + + [Fact] + public async Task AcceptsRequestWithNoContentLengthAndUpgrade() + { + using (var server = new TestServer(context => TaskCache.CompletedTask)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send("POST / HTTP/1.1", + "Host:", + "Content-Length: 0", + "Connection: Upgrade, keep-alive", + "", + ""); + await connection.Receive("HTTP/1.1 200 OK"); + } + + using (var connection = server.CreateConnection()) + { + await connection.Send("GET / HTTP/1.1", + "Host:", + "Connection: Upgrade", + "", + ""); + await connection.Receive("HTTP/1.1 200 OK"); + } + } + } + + [Fact] + public async Task RejectsRequestWithChunkedEncodingAndUpgrade() + { + using (var server = new TestServer(context => TaskCache.CompletedTask)) + using (var connection = server.CreateConnection()) + { + await connection.Send("POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "Connection: Upgrade", + "", + ""); + await connection.Receive("HTTP/1.1 400 Bad Request"); + } + } + } +}