diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Adapter/Internal/StreamSocketOutput.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Adapter/Internal/StreamSocketOutput.cs index a5b618c8d4..92a32c2a28 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Adapter/Internal/StreamSocketOutput.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Adapter/Internal/StreamSocketOutput.cs @@ -83,9 +83,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Adapter.Internal return WriteAsync(default(ArraySegment), chunk: false, cancellationToken: cancellationToken); } - public WritableBuffer Alloc() + public void Write(Action callback, T state) { - return _pipe.Writer.Alloc(); + lock (_sync) + { + if (_completed) + { + return; + } + + var buffer = _pipe.Writer.Alloc(); + callback(buffer, state); + buffer.Commit(); + } } public async Task WriteOutputAsync() diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs index 8c5f032a1e..d900924e41 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs @@ -32,6 +32,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http private static readonly ArraySegment _endChunkedResponseBytes = CreateAsciiByteArraySegment("0\r\n\r\n"); private static readonly ArraySegment _continueBytes = CreateAsciiByteArraySegment("HTTP/1.1 100 Continue\r\n\r\n"); + private static readonly Action _writeHeaders = WriteResponseHeaders; private static readonly byte[] _bytesConnectionClose = Encoding.ASCII.GetBytes("\r\nConnection: close"); private static readonly byte[] _bytesConnectionKeepAlive = Encoding.ASCII.GetBytes("\r\nConnection: keep-alive"); @@ -784,9 +785,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http _requestProcessingStatus = RequestProcessingStatus.ResponseStarted; - var statusBytes = ReasonPhrases.ToStatusBytes(StatusCode, ReasonPhrase); - - CreateResponseHeader(statusBytes, appCompleted); + CreateResponseHeader(appCompleted); } protected Task TryProduceInvalidRequestResponse() @@ -881,9 +880,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } } - private void CreateResponseHeader( - byte[] statusBytes, - bool appCompleted) + private void CreateResponseHeader(bool appCompleted) { var responseHeaders = FrameResponseHeaders; var hasConnection = responseHeaders.HasConnection; @@ -972,12 +969,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http responseHeaders.SetRawDate(dateHeaderValues.String, dateHeaderValues.Bytes); } - var writableBuffer = Output.Alloc(); + Output.Write(_writeHeaders, this); + } + + private static void WriteResponseHeaders(WritableBuffer writableBuffer, Frame frame) + { + var responseHeaders = frame.FrameResponseHeaders; writableBuffer.WriteFast(_bytesHttpVersion11); + var statusBytes = ReasonPhrases.ToStatusBytes(frame.StatusCode, frame.ReasonPhrase); writableBuffer.WriteFast(statusBytes); responseHeaders.CopyTo(ref writableBuffer); writableBuffer.WriteFast(_bytesEndHeaders); - writableBuffer.Commit(); } public void ParseRequest(ReadableBuffer buffer, out ReadCursor consumed, out ReadCursor examined) diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ISocketOutput.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ISocketOutput.cs index fbefd3fa44..7298b42163 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ISocketOutput.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/ISocketOutput.cs @@ -5,7 +5,6 @@ using System; using System.Threading; using System.Threading.Tasks; using System.IO.Pipelines; -using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { @@ -18,6 +17,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http Task WriteAsync(ArraySegment buffer, bool chunk = false, CancellationToken cancellationToken = default(CancellationToken)); void Flush(); Task FlushAsync(CancellationToken cancellationToken = default(CancellationToken)); - WritableBuffer Alloc(); + void Write(Action write, T state); } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketOutput.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketOutput.cs index 7c4412ef23..ef74217d12 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketOutput.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketOutput.cs @@ -186,17 +186,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http return WriteAsync(_emptyData, cancellationToken); } - public WritableBuffer Alloc() + public void Write(Action callback, T state) { lock (_contextLock) { if (_completed) { - // This is broken - return default(WritableBuffer); + return; } - return _pipe.Writer.Alloc(); + var buffer = _pipe.Writer.Alloc(); + callback(buffer, state); + buffer.Commit(); } } diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs index eabb1a5385..b0ae4f6c3e 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs @@ -5,6 +5,7 @@ using System; using System.Linq; using System.Net; using System.Net.Http; +using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -988,6 +989,47 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } } + [Fact] + public async Task WriteAfterConnectionCloseNoops() + { + var connectionClosed = new ManualResetEventSlim(); + var requestStarted = new ManualResetEventSlim(); + var tcs = new TaskCompletionSource(); + + using (var server = new TestServer(async httpContext => + { + try + { + requestStarted.Set(); + connectionClosed.Wait(); + httpContext.Response.ContentLength = 12; + await httpContext.Response.WriteAsync("hello, world"); + tcs.TrySetResult(null); + } + catch (Exception ex) + { + tcs.TrySetException(ex); + } + }, new TestServiceContext())) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + + requestStarted.Wait(); + connection.Shutdown(SocketShutdown.Send); + await connection.WaitForConnectionClose(); + } + + connectionClosed.Set(); + + await tcs.Task; + } + } + [Fact] public async Task AppCanWriteOwnBadRequestResponse() { diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs index 544c9af39b..e1257a8dd8 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs @@ -320,9 +320,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Assert.NotEmpty(completeQueue); // Add more bytes to the write-behind buffer to prevent the next write from - var writableBuffer = socketOutput.Alloc(); - writableBuffer.Write(halfWriteBehindBuffer); - writableBuffer.Commit(); + socketOutput.Write((writableBuffer, state) => + { + writableBuffer.Write(state); + }, + halfWriteBehindBuffer); // Act var writeTask2 = socketOutput.WriteAsync(halfWriteBehindBuffer, default(CancellationToken)); @@ -579,7 +581,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } - [Fact(Skip = "Commit throws with a non channel backed writable buffer")] + [Fact] public async Task AllocCommitCanBeCalledAfterConnectionClose() { var mockLibuv = new MockLibuv(); @@ -608,8 +610,15 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Assert.Equal(TaskStatus.RanToCompletion, connection.SocketClosed.Status); - var start = socketOutput.Alloc(); - start.Commit(); + var called = false; + + socketOutput.Write((buffer, state) => + { + called = true; + }, + null); + + Assert.False(called); } } } diff --git a/test/shared/MockSocketOutput.cs b/test/shared/MockSocketOutput.cs index fa7f8836d6..a506b388fd 100644 --- a/test/shared/MockSocketOutput.cs +++ b/test/shared/MockSocketOutput.cs @@ -12,12 +12,8 @@ namespace Microsoft.AspNetCore.Testing { public class MockSocketOutput : ISocketOutput { - private PipeFactory _factory = new PipeFactory(); - private IPipeWriter _writer; - public MockSocketOutput() { - _writer = _factory.Create().Writer; } public void Write(ArraySegment buffer, bool chunk = false) @@ -38,9 +34,9 @@ namespace Microsoft.AspNetCore.Testing return TaskCache.CompletedTask; } - public WritableBuffer Alloc() + public void Write(Action write, T state) { - return _writer.Alloc(); + } } }