Fix issue where multiple calls to dispose don't wait properly (#360)

* Fix issue where multiple calls to dispose don't wait properly
- DisposeAsync returned immediately to anyone but the first caller.
This means that it was possible to end the request before properly
waiting on the transport task which means writing after dispose was possible.
- Added a test
This commit is contained in:
David Fowler 2017-04-03 15:25:45 -07:00 committed by GitHub
parent f6f0007c12
commit 8da2dddd49
3 changed files with 69 additions and 25 deletions

View File

@ -10,6 +10,10 @@ namespace Microsoft.AspNetCore.Sockets.Internal
{
public class ConnectionState
{
// This tcs exists so that multiple calls to DisposeAsync all wait asynchronously
// on the same task
private TaskCompletionSource<object> _disposeTcs = new TaskCompletionSource<object>();
public Connection Connection { get; set; }
public IChannelConnection<Message> Application { get; }
@ -34,8 +38,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal
public async Task DisposeAsync()
{
Task applicationTask = TaskCache.CompletedTask;
Task transportTask = TaskCache.CompletedTask;
Task disposeTask = TaskCache.CompletedTask;
try
{
@ -43,30 +46,34 @@ namespace Microsoft.AspNetCore.Sockets.Internal
if (Status == ConnectionStatus.Disposed)
{
return;
disposeTask = _disposeTcs.Task;
}
Status = ConnectionStatus.Disposed;
RequestId = null;
// If the application task is faulted, propagate the error to the transport
if (ApplicationTask?.IsFaulted == true)
else
{
Connection.Transport.Output.TryComplete(ApplicationTask.Exception.InnerException);
Status = ConnectionStatus.Disposed;
RequestId = null;
// If the application task is faulted, propagate the error to the transport
if (ApplicationTask?.IsFaulted == true)
{
Connection.Transport.Output.TryComplete(ApplicationTask.Exception.InnerException);
}
// If the transport task is faulted, propagate the error to the application
if (TransportTask?.IsFaulted == true)
{
Application.Output.TryComplete(TransportTask.Exception.InnerException);
}
Connection.Dispose();
Application.Dispose();
var applicationTask = ApplicationTask ?? TaskCache.CompletedTask;
var transportTask = TransportTask ?? TaskCache.CompletedTask;
disposeTask = Task.WhenAll(applicationTask, transportTask);
}
// If the transport task is faulted, propagate the error to the application
if (TransportTask?.IsFaulted == true)
{
Application.Output.TryComplete(TransportTask.Exception.InnerException);
}
Connection.Dispose();
Application.Dispose();
applicationTask = ApplicationTask ?? applicationTask;
transportTask = TransportTask ?? transportTask;
}
finally
{
@ -74,7 +81,10 @@ namespace Microsoft.AspNetCore.Sockets.Internal
}
// REVIEW: Add a timeout so we don't wait forever
await Task.WhenAll(applicationTask, transportTask);
await disposeTask;
// Notify all waiters that we're done disposing
_disposeTcs.TrySetResult(null);
}
public enum ConnectionStatus

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.Extensions.Logging;
using Xunit;
@ -100,6 +101,26 @@ namespace Microsoft.AspNetCore.Sockets.Tests
await state.DisposeAsync();
}
[Fact]
public async Task DisposingConnectionMultipleTimesWaitsOnConnectionClose()
{
var connectionManager = CreateConnectionManager();
var state = connectionManager.CreateConnection();
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
state.ApplicationTask = tcs.Task;
state.TransportTask = tcs.Task;
var firstTask = state.DisposeAsync();
var secondTask = state.DisposeAsync();
Assert.False(firstTask.IsCompleted);
Assert.False(secondTask.IsCompleted);
tcs.TrySetResult(null);
await Task.WhenAll(firstTask, secondTask).OrTimeout();
}
[Fact]
public async Task DisposeInactiveConnection()
{

View File

@ -5,6 +5,7 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Internal;
@ -14,6 +15,7 @@ using Microsoft.AspNetCore.WebSockets.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
using Microsoft.Extensions.WebSockets.Internal;
using Xunit;
namespace Microsoft.AspNetCore.Sockets.Tests
@ -234,9 +236,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode);
var webSocketTask = Task.CompletedTask;
if (isWebSocketRequest)
{
var ws = (TestWebSocketConnectionFeature)context1.Features.Get<IHttpWebSocketConnectionFeature>();
webSocketTask = ws.Client.ExecuteAsync(frame => Task.CompletedTask);
await ws.Client.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure), CancellationToken.None);
}
manager.CloseConnections();
await request1;
await webSocketTask.OrTimeout();
await request1.OrTimeout();
}
[Fact]