Lock when sending data to connection (#1876)

This commit is contained in:
James Newton-King 2018-04-06 16:25:47 +12:00 committed by GitHub
parent cb5ece8a24
commit 36edadabb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 100 additions and 1 deletions

View File

@ -450,7 +450,19 @@ namespace Microsoft.AspNetCore.Http.Connections
}
var pipeWriterStream = new PipeWriterStream(connection.Application.Output);
await context.Request.Body.CopyToAsync(pipeWriterStream);
// REVIEW: Consider spliting the connection lock into a read lock and a write lock
// Need to think about HttpConnectionContext.DisposeAsync and whether one or both locks would be needed
await connection.Lock.WaitAsync();
try
{
await context.Request.Body.CopyToAsync(pipeWriterStream);
}
finally
{
connection.Lock.Release();
}
Log.ReceivedBytes(_logger, pipeWriterStream.Length);
}

View File

@ -338,6 +338,93 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
}
}
[Theory]
[InlineData(HttpTransportType.LongPolling)]
[InlineData(HttpTransportType.ServerSentEvents)]
public async Task PostSendsToConnectionInParallel(HttpTransportType transportType)
{
using (StartLog(out var loggerFactory, LogLevel.Debug))
{
var manager = CreateConnectionManager(loggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory);
var connection = manager.CreateConnection();
connection.Items[ConnectionMetadataNames.Transport] = transportType;
// Allow a maximum of one caller to use code at one time
var callerTracker = new SemaphoreSlim(1, 1);
var waitTcs = new TaskCompletionSource<bool>();
// This tests thread safety of sending multiple pieces of data to a connection at once
var executeTask1 = DispatcherExecuteAsync(dispatcher, connection, callerTracker, waitTcs.Task);
var executeTask2 = DispatcherExecuteAsync(dispatcher, connection, callerTracker, waitTcs.Task);
waitTcs.SetResult(true);
await Task.WhenAll(executeTask1, executeTask2);
}
async Task DispatcherExecuteAsync(HttpConnectionDispatcher dispatcher, HttpConnectionContext connection, SemaphoreSlim callerTracker, Task waitTask)
{
using (var requestBody = new TrackingMemoryStream(callerTracker, waitTask))
{
var bytes = Encoding.UTF8.GetBytes("Hello World");
requestBody.Write(bytes, 0, bytes.Length);
requestBody.Seek(0, SeekOrigin.Begin);
var context = new DefaultHttpContext();
context.Request.Body = requestBody;
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
context.Request.Path = "/foo";
context.Request.Method = "POST";
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app);
}
}
}
private class TrackingMemoryStream : MemoryStream
{
private readonly SemaphoreSlim _callerTracker;
private readonly Task _waitTask;
public TrackingMemoryStream(SemaphoreSlim callerTracker, Task waitTask)
{
_callerTracker = callerTracker;
_waitTask = waitTask;
}
public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
// Will return false if all available locks from semaphore are taken
if (!_callerTracker.Wait(0))
{
throw new Exception("Too many callers.");
}
try
{
await _waitTask;
await base.CopyToAsync(destination, bufferSize, cancellationToken);
}
finally
{
_callerTracker.Release();
}
}
}
[Fact]
public async Task HttpContextFeatureForLongpollingWorksBetweenPolls()
{