Changing the Closed Event to be a Task (#1080)

This commit is contained in:
Mikael Mengistu 2017-11-09 17:51:13 -08:00 committed by GitHub
parent 06475270ec
commit 1a21fd49b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 80 additions and 180 deletions

View File

@ -36,31 +36,23 @@ namespace ClientSample
try
{
var sendCts = new CancellationTokenSource();
var cts = new CancellationTokenSource();
Console.CancelKeyPress += (sender, a) =>
Console.CancelKeyPress += async (sender, a) =>
{
a.Cancel = true;
Console.WriteLine("Stopping loops...");
cts.Cancel();
sendCts.Cancel();
await connection.DisposeAsync();
};
// Set up handler
connection.On<string>("Send", Console.WriteLine);
connection.Closed += e =>
while (!connection.Closed.IsCompleted)
{
Console.WriteLine("Connection closed.");
cts.Cancel();
return Task.CompletedTask;
};
var ctsTask = Task.Delay(-1, cts.Token);
while (!cts.Token.IsCancellationRequested)
{
var completedTask = await Task.WhenAny(Task.Run(() => Console.ReadLine(), cts.Token), ctsTask);
if (completedTask == ctsTask)
var completedTask = await Task.WhenAny(Task.Run(() => Console.ReadLine()), connection.Closed);
if (completedTask == connection.Closed)
{
break;
}
@ -72,7 +64,7 @@ namespace ClientSample
break;
}
await connection.InvokeAsync<object>("Send", line, cts.Token);
await connection.InvokeAsync<object>("Send", line, sendCts.Token);
}
}
catch (AggregateException aex) when (aex.InnerExceptions.All(e => e is OperationCanceledException))

View File

@ -39,25 +39,19 @@ namespace ClientSample
var connection = new HttpConnection(new Uri(baseUrl), loggerFactory);
try
{
var cts = new CancellationTokenSource();
connection.OnReceived(data => Console.Out.WriteLineAsync($"{Encoding.UTF8.GetString(data)}"));
connection.Closed += e =>
{
cts.Cancel();
return Task.CompletedTask;
};
await connection.StartAsync();
Console.WriteLine($"Connected to {baseUrl}");
Console.CancelKeyPress += (sender, a) =>
var cts = new CancellationTokenSource();
Console.CancelKeyPress += async (sender, a) =>
{
a.Cancel = true;
cts.Cancel();
await connection.DisposeAsync();
};
while (!cts.Token.IsCancellationRequested)
while (!connection.Closed.IsCompleted)
{
var line = await Task.Run(() => Console.ReadLine(), cts.Token);

View File

@ -39,11 +39,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
private int _nextId = 0;
private volatile bool _startCalled;
public event Func<Exception, Task> Closed
{
add { _connection.Closed += value; }
remove { _connection.Closed -= value; }
}
public Task Closed { get; }
public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFactory loggerFactory)
{
@ -63,7 +59,11 @@ namespace Microsoft.AspNetCore.SignalR.Client
_loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
_logger = _loggerFactory.CreateLogger<HubConnection>();
_connection.OnReceived((data, state) => ((HubConnection)state).OnDataReceivedAsync(data), this);
_connection.Closed += Shutdown;
Closed = _connection.Closed.ContinueWith(task =>
{
Shutdown(task.Exception);
return task;
}).Unwrap();
}
public async Task StartAsync()

View File

@ -16,7 +16,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
IDisposable OnReceived(Func<byte[], object, Task> callback, object state);
event Func<Exception, Task> Closed;
Task Closed { get; }
IFeatureCollection Features { get; }
}

View File

@ -34,6 +34,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
private volatile ITransport _transport;
private volatile Task _receiveLoopTask;
private TaskCompletionSource<object> _startTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly TaskCompletionSource<object> _closedTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
private TaskQueue _eventQueue = new TaskQueue();
private readonly ITransportFactory _transportFactory;
private string _connectionId;
@ -47,7 +48,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
public IFeatureCollection Features { get; } = new FeatureCollection();
public event Func<Exception, Task> Closed;
public Task Closed => _closedTcs.Task;
public HttpConnection(Uri url)
: this(url, TransportType.All)
@ -116,10 +117,12 @@ namespace Microsoft.AspNetCore.Sockets.Client
if (t.IsFaulted)
{
_startTcs.SetException(t.Exception.InnerException);
_closedTcs.TrySetException(t.Exception.InnerException);
}
else if (t.IsCanceled)
{
_startTcs.SetCanceled();
_closedTcs.SetCanceled();
}
else
{
@ -190,19 +193,18 @@ namespace Microsoft.AspNetCore.Sockets.Client
await Task.WhenAny(_eventQueue.Drain().NoThrow(), Task.Delay(_eventQueueDrainTimeout));
_httpClient?.Dispose();
_logger.RaiseClosed(_connectionId);
var closedEventHandler = Closed;
if (closedEventHandler != null)
_logger.CompleteClosed(_connectionId);
if (t.IsFaulted)
{
try
{
await closedEventHandler.Invoke(t.IsFaulted ? t.Exception.InnerException : null);
}
catch (Exception ex)
{
_logger.ExceptionThrownFromCallback(_connectionId, nameof(Closed), ex);
}
_closedTcs.TrySetException(t.Exception.InnerException);
}
if (t.IsCanceled)
{
_closedTcs.TrySetCanceled();
}
else
{
_closedTcs.TrySetResult(null);
}
});
@ -463,6 +465,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
await _receiveLoopTask;
}
_closedTcs.TrySetResult(null);
_httpClient?.Dispose();
}

View File

@ -111,8 +111,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal
private static readonly Action<ILogger, DateTime, string, Exception> _drainEvents =
LoggerMessage.Define<DateTime, string>(LogLevel.Debug, new EventId(5, nameof(DrainEvents)), "{time}: Connection Id {connectionId}: Draining event queue.");
private static readonly Action<ILogger, DateTime, string, Exception> _raiseClosed =
LoggerMessage.Define<DateTime, string>(LogLevel.Debug, new EventId(6, nameof(RaiseClosed)), "{time}: Connection Id {connectionId}: Raising Closed event.");
private static readonly Action<ILogger, DateTime, string, Exception> _completeClosed =
LoggerMessage.Define<DateTime, string>(LogLevel.Debug, new EventId(6, nameof(CompleteClosed)), "{time}: Connection Id {connectionId}: Completing Closed task.");
private static readonly Action<ILogger, DateTime, Uri, Exception> _establishingConnection =
LoggerMessage.Define<DateTime, Uri>(LogLevel.Debug, new EventId(7, nameof(EstablishingConnection)), "{time}: Establishing Connection at: {url}.");
@ -410,11 +410,11 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal
}
}
public static void RaiseClosed(this ILogger logger, string connectionId)
public static void CompleteClosed(this ILogger logger, string connectionId)
{
if (logger.IsEnabled(LogLevel.Debug))
{
_raiseClosed(logger, DateTime.Now, connectionId, null);
_completeClosed(logger, DateTime.Now, connectionId, null);
}
}

View File

@ -174,22 +174,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
try
{
await connection.StartAsync().OrTimeout();
var closeTcs = new TaskCompletionSource<object>();
connection.Closed += ex =>
{
if (ex != null)
{
closeTcs.SetException(ex);
}
else
{
closeTcs.SetResult(null);
}
return Task.CompletedTask;
};
await connection.InvokeAsync("CallHandlerThatDoesntExist").OrTimeout();
await connection.DisposeAsync().OrTimeout();
await closeTcs.Task.OrTimeout();
await connection.Closed.OrTimeout();
}
finally
{

View File

@ -204,18 +204,11 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object);
var closedEventTcs = new TaskCompletionSource<Exception>();
connection.Closed += e =>
{
closedEventTcs.SetResult(e);
return Task.CompletedTask;
};
await connection.StartAsync();
await connection.DisposeAsync();
await connection.StartAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
await connection.Closed.OrTimeout();
// in case of clean disconnect error should be null
Assert.Null(await closedEventTcs.Task.OrTimeout());
}
[Fact]
@ -236,17 +229,11 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
});
var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object);
var closedEventTcs = new TaskCompletionSource<Exception>();
connection.Closed += e =>
{
closedEventTcs.TrySetResult(e);
return Task.CompletedTask;
};
try
{
await connection.StartAsync();
Assert.IsType<HttpRequestException>(await closedEventTcs.Task.OrTimeout());
await connection.StartAsync().OrTimeout();
await Assert.ThrowsAsync<HttpRequestException>(() => connection.Closed.OrTimeout());
}
finally
{
@ -388,16 +375,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text);
var blockReceiveCallbackTcs = new TaskCompletionSource<object>();
var closedTcs = new TaskCompletionSource<object>();
var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object);
connection.OnReceived(_ => blockReceiveCallbackTcs.Task);
connection.Closed += _ => {
closedTcs.SetResult(null);
return Task.CompletedTask;
};
await connection.StartAsync();
channel.Out.TryWrite(Array.Empty<byte>());
@ -438,7 +419,6 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text);
var callbackInvokedTcs = new TaskCompletionSource<object>();
var closedTcs = new TaskCompletionSource<object>();
var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object);
connection.OnReceived( _ =>
@ -460,15 +440,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
{
var connection = new HttpConnection(new Uri("http://fakeuri.org/"));
bool closedEventRaised = false;
connection.Closed += e =>
{
closedEventRaised = true;
return Task.CompletedTask;
};
await connection.DisposeAsync();
Assert.False(closedEventRaised);
Assert.False(connection.Closed.IsCompleted);
}
[Fact]
@ -644,21 +617,20 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
return Task.CompletedTask;
}, receiveTcs);
connection.Closed += e =>
_ = connection.Closed.ContinueWith(task =>
{
if (task.Exception != null)
{
if (e != null)
{
receiveTcs.TrySetException(e);
}
else
{
receiveTcs.TrySetCanceled();
}
return Task.CompletedTask;
};
await connection.StartAsync();
receiveTcs.TrySetException(task.Exception);
}
else
{
receiveTcs.TrySetCanceled();
}
return Task.CompletedTask;
});
await connection.StartAsync().OrTimeout();
Assert.Equal("42", await receiveTcs.Task.OrTimeout());
}
finally
@ -707,18 +679,18 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
return Task.CompletedTask;
});
connection.Closed += e =>
_ = connection.Closed.ContinueWith(task =>
{
if (e != null)
if (task.Exception != null)
{
receiveTcs.TrySetException(e);
receiveTcs.TrySetException(task.Exception);
}
else
{
receiveTcs.TrySetCanceled();
}
return Task.CompletedTask;
};
});
await connection.StartAsync();
@ -770,18 +742,18 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
return Task.CompletedTask;
});
connection.Closed += e =>
_ = connection.Closed.ContinueWith(task =>
{
if (e != null)
if (task.Exception != null)
{
receiveTcs.TrySetException(e);
receiveTcs.TrySetException(task.Exception);
}
else
{
receiveTcs.TrySetCanceled();
}
return Task.CompletedTask;
};
});
await connection.StartAsync();
@ -813,20 +785,13 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object);
try
{
var closeTcs = new TaskCompletionSource<Exception>();
connection.Closed += e =>
{
closeTcs.TrySetResult(e);
return Task.CompletedTask;
};
await connection.StartAsync();
await connection.StartAsync().OrTimeout();
// Exception in send should shutdown the connection
await closeTcs.Task.OrTimeout();
await Assert.ThrowsAsync<HttpRequestException>(() => connection.Closed.OrTimeout());
var exception = await Assert.ThrowsAsync<InvalidOperationException>(
async () => await connection.SendAsync(new byte[0]));
var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => connection.SendAsync(new byte[0]));
Assert.Equal("Cannot send messages when the connection is not in the Connected state.", exception.Message);
}

View File

@ -72,16 +72,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
var hubConnection = new HubConnection(new TestConnection(), Mock.Of<IHubProtocol>(), null);
var closedEventTcs = new TaskCompletionSource<Exception>();
hubConnection.Closed += e =>
{
closedEventTcs.SetResult(e);
return Task.CompletedTask;
};
await hubConnection.StartAsync();
await hubConnection.DisposeAsync();
Assert.Null(await closedEventTcs.Task.OrTimeout());
await hubConnection.StartAsync().OrTimeout();
await hubConnection.DisposeAsync().OrTimeout();
await hubConnection.Closed.OrTimeout();
}
[Fact]
@ -177,12 +171,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
[Fact]
public async Task PendingInvocationsAreTerminatedWithExceptionWhenConnectionClosesDueToError()
{
var exception = new InvalidOperationException();
var mockConnection = new Mock<IConnection>();
mockConnection.SetupGet(p => p.Features).Returns(new FeatureCollection());
mockConnection
.Setup(m => m.DisposeAsync())
.Callback(() => mockConnection.Raise(c => c.Closed += null, exception))
.Returns(Task.FromResult<object>(null));
var hubConnection = new HubConnection(mockConnection.Object, Mock.Of<IHubProtocol>(), new LoggerFactory());
@ -191,8 +183,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
var invokeTask = hubConnection.InvokeAsync<int>("testMethod");
await hubConnection.DisposeAsync();
var thrown = await Assert.ThrowsAsync(exception.GetType(), async () => await invokeTask);
Assert.Same(exception, thrown);
await Assert.ThrowsAsync<InvalidOperationException>(async () => await invokeTask);
}
// Moq really doesn't handle out parameters well, so to make these tests work I added a manual mock -anurse

View File

@ -29,9 +29,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
private Task _receiveLoop;
private TransferMode? _transferMode;
private readonly TaskCompletionSource<object> _closeTcs = new TaskCompletionSource<object>();
public event Func<Exception, Task> Closed;
public Task Closed => _closeTcs.Task;
public Task Started => _started.Task;
public Task Disposed => _disposed.Task;
public ReadableChannel<byte[]> SentMessages => _sentMessages.In;
@ -133,16 +133,16 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}
}
}
Closed?.Invoke(null);
_closeTcs.TrySetResult(null);
}
catch (OperationCanceledException)
{
// Do nothing, we were just asked to shut down.
Closed?.Invoke(null);
_closeTcs.TrySetResult(null);
}
catch (Exception ex)
{
Closed?.Invoke(ex);
_closeTcs.TrySetException(ex);
}
}

View File

@ -165,7 +165,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests
try
{
var receiveTcs = new TaskCompletionSource<string>();
var closeTcs = new TaskCompletionSource<object>();
connection.OnReceived((data, state) =>
{
logger.LogInformation("Received {length} byte message", data.Length);
@ -179,22 +178,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return Task.CompletedTask;
}, receiveTcs);
connection.Closed += e =>
{
logger.LogInformation("Connection closed");
if (e != null)
{
receiveTcs.TrySetException(e);
closeTcs.TrySetException(e);
}
else
{
receiveTcs.TrySetResult(null);
closeTcs.TrySetResult(null);
}
return Task.CompletedTask;
};
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync().OrTimeout();
logger.LogInformation("Started connection to {url}", url);
@ -225,8 +208,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
logger.LogInformation("Receiving message");
Assert.Equal(message, await receiveTcs.Task.OrTimeout());
logger.LogInformation("Completed receive");
await closeTcs.Task.OrTimeout();
await connection.Closed.OrTimeout();
}
catch (Exception ex)
{
@ -344,24 +326,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var closeTcs = new TaskCompletionSource<object>();
connection.Closed += e =>
{
logger.LogInformation("Connection closed");
if (e != null)
{
closeTcs.TrySetException(e);
}
else
{
closeTcs.TrySetResult(null);
}
return Task.CompletedTask;
};
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync().OrTimeout();
await closeTcs.Task.OrTimeout();
await connection.Closed.OrTimeout();
}
catch (Exception ex)
{