From 36edadabb4a825bb429553d441d6094debbdbfab Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Fri, 6 Apr 2018 16:25:47 +1200 Subject: [PATCH] Lock when sending data to connection (#1876) --- .../HttpConnectionDispatcher.cs | 14 ++- .../HttpConnectionDispatcherTests.cs | 87 +++++++++++++++++++ 2 files changed, 100 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs index c04a805f08..239b793db2 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs @@ -450,7 +450,19 @@ namespace Microsoft.AspNetCore.Http.Connections } var pipeWriterStream = new PipeWriterStream(connection.Application.Output); - await context.Request.Body.CopyToAsync(pipeWriterStream); + + // REVIEW: Consider spliting the connection lock into a read lock and a write lock + // Need to think about HttpConnectionContext.DisposeAsync and whether one or both locks would be needed + await connection.Lock.WaitAsync(); + + try + { + await context.Request.Body.CopyToAsync(pipeWriterStream); + } + finally + { + connection.Lock.Release(); + } Log.ReceivedBytes(_logger, pipeWriterStream.Length); } diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs index 0318b236aa..f7ab74e33d 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs @@ -338,6 +338,93 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } + [Theory] + [InlineData(HttpTransportType.LongPolling)] + [InlineData(HttpTransportType.ServerSentEvents)] + public async Task PostSendsToConnectionInParallel(HttpTransportType transportType) + { + using (StartLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + var connection = manager.CreateConnection(); + connection.Items[ConnectionMetadataNames.Transport] = transportType; + + // Allow a maximum of one caller to use code at one time + var callerTracker = new SemaphoreSlim(1, 1); + var waitTcs = new TaskCompletionSource(); + + // This tests thread safety of sending multiple pieces of data to a connection at once + var executeTask1 = DispatcherExecuteAsync(dispatcher, connection, callerTracker, waitTcs.Task); + var executeTask2 = DispatcherExecuteAsync(dispatcher, connection, callerTracker, waitTcs.Task); + + waitTcs.SetResult(true); + + await Task.WhenAll(executeTask1, executeTask2); + } + + async Task DispatcherExecuteAsync(HttpConnectionDispatcher dispatcher, HttpConnectionContext connection, SemaphoreSlim callerTracker, Task waitTask) + { + using (var requestBody = new TrackingMemoryStream(callerTracker, waitTask)) + { + var bytes = Encoding.UTF8.GetBytes("Hello World"); + requestBody.Write(bytes, 0, bytes.Length); + requestBody.Seek(0, SeekOrigin.Begin); + + var context = new DefaultHttpContext(); + context.Request.Body = requestBody; + + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddOptions(); + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + var values = new Dictionary(); + values["id"] = connection.ConnectionId; + var qs = new QueryCollection(values); + context.Request.Query = qs; + + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + + await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + } + } + } + + private class TrackingMemoryStream : MemoryStream + { + private readonly SemaphoreSlim _callerTracker; + private readonly Task _waitTask; + + public TrackingMemoryStream(SemaphoreSlim callerTracker, Task waitTask) + { + _callerTracker = callerTracker; + _waitTask = waitTask; + } + + public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + // Will return false if all available locks from semaphore are taken + if (!_callerTracker.Wait(0)) + { + throw new Exception("Too many callers."); + } + + try + { + await _waitTask; + + await base.CopyToAsync(destination, bufferSize, cancellationToken); + } + finally + { + _callerTracker.Release(); + } + } + } + [Fact] public async Task HttpContextFeatureForLongpollingWorksBetweenPolls() {