From aec52670b43a9340ed739a0748a07dc985c5f185 Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Fri, 13 Jan 2017 11:09:20 -0800 Subject: [PATCH] React to Channel API changes --- .../IChannelConnection.cs | 4 +- .../Internal/ChannelConnection.cs | 22 ++-- .../Transports/LongPollingTransport.cs | 4 +- .../Transports/ServerSentEventsTransport.cs | 4 +- .../HubEndpointTests.cs | 106 ++---------------- .../LongPollingTests.cs | 6 +- .../ServerSentEventsTests.cs | 6 +- .../WebSocketsTests.cs | 26 ++--- 8 files changed, 50 insertions(+), 128 deletions(-) diff --git a/src/Microsoft.AspNetCore.Sockets/IChannelConnection.cs b/src/Microsoft.AspNetCore.Sockets/IChannelConnection.cs index bf65bb9afe..3e918d5898 100644 --- a/src/Microsoft.AspNetCore.Sockets/IChannelConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets/IChannelConnection.cs @@ -10,7 +10,7 @@ namespace Microsoft.AspNetCore.Sockets // access to two separate channels, the read end for one and the write end for the other. public interface IChannelConnection : IDisposable { - IReadableChannel Input { get; } - IWritableChannel Output { get; } + ReadableChannel Input { get; } + WritableChannel Output { get; } } } diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/ChannelConnection.cs b/src/Microsoft.AspNetCore.Sockets/Internal/ChannelConnection.cs index 07eb6260ff..35f9a126d9 100644 --- a/src/Microsoft.AspNetCore.Sockets/Internal/ChannelConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets/Internal/ChannelConnection.cs @@ -6,15 +6,23 @@ using System.Threading.Tasks.Channels; namespace Microsoft.AspNetCore.Sockets.Internal { + public static class ChannelConnection + { + public static ChannelConnection Create(Channel input, Channel output) + { + return new ChannelConnection(input, output); + } + } + public class ChannelConnection : IChannelConnection { - public IChannel Input { get; } - public IChannel Output { get; } + public Channel Input { get; } + public Channel Output { get; } - IReadableChannel IChannelConnection.Input => Input; - IWritableChannel IChannelConnection.Output => Output; + ReadableChannel IChannelConnection.Input => Input; + WritableChannel IChannelConnection.Output => Output; - public ChannelConnection(IChannel input, IChannel output) + public ChannelConnection(Channel input, Channel output) { Input = input; Output = output; @@ -22,8 +30,8 @@ namespace Microsoft.AspNetCore.Sockets.Internal public void Dispose() { - Output.TryComplete(); - Input.TryComplete(); + Output.Out.TryComplete(); + Input.Out.TryComplete(); } } } diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs index da9875138f..379c9b3aee 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs @@ -14,10 +14,10 @@ namespace Microsoft.AspNetCore.Sockets.Transports public class LongPollingTransport : IHttpTransport { public static readonly string Name = "longPolling"; - private readonly IReadableChannel _application; + private readonly ReadableChannel _application; private readonly ILogger _logger; - public LongPollingTransport(IReadableChannel application, ILoggerFactory loggerFactory) + public LongPollingTransport(ReadableChannel application, ILoggerFactory loggerFactory) { _application = application; _logger = loggerFactory.CreateLogger(); diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs index 221cb704c8..ec34631695 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs @@ -12,10 +12,10 @@ namespace Microsoft.AspNetCore.Sockets.Transports public class ServerSentEventsTransport : IHttpTransport { public static readonly string Name = "serverSentEvents"; - private readonly IReadableChannel _application; + private readonly ReadableChannel _application; private readonly ILogger _logger; - public ServerSentEventsTransport(IReadableChannel application, ILoggerFactory loggerFactory) + public ServerSentEventsTransport(ReadableChannel application, ILoggerFactory loggerFactory) { _application = application; _logger = loggerFactory.CreateLogger(); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 084fa15a8c..a29bf360d5 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -4,7 +4,6 @@ using System; using System.IO; using System.IO.Pipelines; -using System.Runtime.CompilerServices; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; @@ -32,8 +31,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); - await connectionWrapper.ApplicationStartedReading; - // kill the connection connectionWrapper.Dispose(); @@ -136,8 +133,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); - await connectionWrapper.ApplicationStartedReading; - var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); @@ -165,8 +160,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); - await connectionWrapper.ApplicationStartedReading; - var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); @@ -194,8 +187,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); - await connectionWrapper.ApplicationStartedReading; - var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); @@ -222,8 +213,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); - await connectionWrapper.ApplicationStartedReading; - var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); @@ -250,8 +239,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); - await connectionWrapper.ApplicationStartedReading; - var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); @@ -277,8 +264,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); - await connectionWrapper.ApplicationStartedReading; - var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); @@ -302,8 +287,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests var firstEndPointTask = endPoint.OnConnectedAsync(firstConnection.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondConnection.Connection); - await Task.WhenAll(firstConnection.ApplicationStartedReading, secondConnection.ApplicationStartedReading); - var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); @@ -339,22 +322,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests var firstEndPointTask = endPoint.OnConnectedAsync(firstConnection.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondConnection.Connection); - await Task.WhenAll(firstConnection.ApplicationStartedReading, secondConnection.ApplicationStartedReading); - var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); await SendRequest_IgnoreReceive(firstConnection, adapter, "GroupSendMethod", "testGroup", "test"); // check that 'secondConnection' hasn't received the group send Message message; - Assert.False(secondConnection.Transport.Output.TryRead(out message)); + Assert.False(secondConnection.Application.Input.TryRead(out message)); await SendRequest_IgnoreReceive(secondConnection, adapter, "GroupAddMethod", "testGroup"); await SendRequest(firstConnection, adapter, "GroupSendMethod", "testGroup", "test"); // check that 'firstConnection' hasn't received the group send - Assert.False(firstConnection.Transport.Output.TryRead(out message)); + Assert.False(firstConnection.Application.Input.TryRead(out message)); // check that 'secondConnection' has received the group send var res = await ReadConnectionOutputAsync(secondConnection); @@ -381,8 +362,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var endPointTask = endPoint.OnConnectedAsync(connection.Connection); - await connection.ApplicationStartedReading; - var invocationAdapter = serviceProvider.GetService(); var writer = invocationAdapter.GetInvocationAdapter("json"); @@ -408,8 +387,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests var firstEndPointTask = endPoint.OnConnectedAsync(firstConnection.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondConnection.Connection); - await Task.WhenAll(firstConnection.ApplicationStartedReading, secondConnection.ApplicationStartedReading); - var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); @@ -442,8 +419,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests var firstEndPointTask = endPoint.OnConnectedAsync(firstConnection.Connection); var secondEndPointTask = endPoint.OnConnectedAsync(secondConnection.Connection); - await Task.WhenAll(firstConnection.ApplicationStartedReading, secondConnection.ApplicationStartedReading); - var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); @@ -490,7 +465,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests stream); var buffer = ReadableBuffer.Create(stream.ToArray()).Preserve(); - await connection.Transport.Input.WriteAsync(new Message(buffer, Format.Binary, endOfMessage: true)); + await connection.Application.Output.WriteAsync(new Message(buffer, Format.Binary, endOfMessage: true)); } public async Task SendRequest_IgnoreReceive(ConnectionWrapper connection, IInvocationAdapter writer, string method, params object[] args) @@ -498,13 +473,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests await SendRequest(connection, writer, method, args); // Consume the result - await connection.Transport.Output.ReadAsync(); + await connection.Application.Input.ReadAsync(); } private async Task ReadConnectionOutputAsync(ConnectionWrapper connection) { // TODO: other formats? - var message = await connection.Transport.Output.ReadAsync(); + var message = await connection.Application.Input.ReadAsync(); var serializer = new JsonSerializer(); return serializer.Deserialize(new JsonTextReader(new StreamReader(new MemoryStream(message.Payload.Buffer.ToArray())))); } @@ -629,24 +604,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests public class ConnectionWrapper : IDisposable { private static int _id; - private readonly TestChannel _input; - + public Connection Connection { get; } - public ChannelConnection Transport { get; } - - public Task ApplicationStartedReading => _input.ReadingStarted; + public IChannelConnection Application { get; } public ConnectionWrapper(string format = "json") { var transportToApplication = Channel.CreateUnbounded(); var applicationToTransport = Channel.CreateUnbounded(); - _input = new TestChannel(transportToApplication); + Application = ChannelConnection.Create(input: applicationToTransport, output: transportToApplication); + var transport = ChannelConnection.Create(input: transportToApplication, output: applicationToTransport); - Transport = new ChannelConnection(_input, applicationToTransport); - - Connection = new Connection(Guid.NewGuid().ToString(), Transport); + Connection = new Connection(Guid.NewGuid().ToString(), transport); Connection.Metadata["formatType"] = format; Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) })); } @@ -655,63 +626,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests { Connection.Dispose(); } - - private class TestChannel : IChannel - { - private IChannel _channel; - private TaskCompletionSource _tcs = new TaskCompletionSource(); - - public TestChannel(IChannel channel) - { - _channel = channel; - } - - public Task Completion => _channel.Completion; - - public Task ReadingStarted => _tcs.Task; - - public ValueAwaiter GetAwaiter() - { - return _channel.GetAwaiter(); - } - - public ValueTask ReadAsync(CancellationToken cancellationToken = default(CancellationToken)) - { - _tcs.TrySetResult(null); - return _channel.ReadAsync(cancellationToken); - } - - public bool TryComplete(Exception error = null) - { - return _channel.TryComplete(error); - } - - public bool TryRead(out T item) - { - return _channel.TryRead(out item); - } - - public bool TryWrite(T item) - { - return _channel.TryWrite(item); - } - - public Task WaitToReadAsync(CancellationToken cancellationToken = default(CancellationToken)) - { - _tcs.TrySetResult(null); - return _channel.WaitToReadAsync(cancellationToken); - } - - public Task WaitToWriteAsync(CancellationToken cancellationToken = default(CancellationToken)) - { - return _channel.WaitToWriteAsync(cancellationToken); - } - - public Task WriteAsync(T item, CancellationToken cancellationToken = default(CancellationToken)) - { - return _channel.WriteAsync(item, cancellationToken); - } - } } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs index 145744a937..f85cc529f8 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs @@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var poll = new LongPollingTransport(channel, new LoggerFactory()); - Assert.True(channel.TryComplete()); + Assert.True(channel.Out.TryComplete()); await poll.ProcessRequestAsync(context); @@ -41,12 +41,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests var ms = new MemoryStream(); context.Response.Body = ms; - await channel.WriteAsync(new Message( + await channel.Out.WriteAsync(new Message( ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve(), Format.Text, endOfMessage: true)); - Assert.True(channel.TryComplete()); + Assert.True(channel.Out.TryComplete()); await poll.ProcessRequestAsync(context); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs index 8be741d912..04c1a67245 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs @@ -22,7 +22,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var context = new DefaultHttpContext(); var sse = new ServerSentEventsTransport(channel, new LoggerFactory()); - Assert.True(channel.TryComplete()); + Assert.True(channel.Out.TryComplete()); await sse.ProcessRequestAsync(context); @@ -39,12 +39,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests var ms = new MemoryStream(); context.Response.Body = ms; - await channel.WriteAsync(new Message( + await channel.Out.WriteAsync(new Message( ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello World")).Preserve(), Format.Text, endOfMessage: true)); - Assert.True(channel.TryComplete()); + Assert.True(channel.Out.TryComplete()); await sse.ProcessRequestAsync(context); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs index eca7c3c9d3..8e29d113f5 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs @@ -46,14 +46,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello")))); await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure); - using (var message = await applicationSide.Input.ReadAsync()) + using (var message = await applicationSide.Input.In.ReadAsync()) { Assert.True(message.EndOfMessage); Assert.Equal(format, message.MessageFormat); Assert.Equal("Hello", Encoding.UTF8.GetString(message.Payload.Buffer.ToArray())); } - Assert.True(applicationSide.Output.TryComplete()); + Assert.True(applicationSide.Output.Out.TryComplete()); // The transport should finish now await transport; @@ -98,21 +98,21 @@ namespace Microsoft.AspNetCore.Sockets.Tests payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("World")))); await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure); - using (var message1 = await applicationSide.Input.ReadAsync()) + using (var message1 = await applicationSide.Input.In.ReadAsync()) { Assert.False(message1.EndOfMessage); Assert.Equal(format, message1.MessageFormat); Assert.Equal("Hello", Encoding.UTF8.GetString(message1.Payload.Buffer.ToArray())); } - using (var message2 = await applicationSide.Input.ReadAsync()) + using (var message2 = await applicationSide.Input.In.ReadAsync()) { Assert.True(message2.EndOfMessage); Assert.Equal(format, message2.MessageFormat); Assert.Equal("World", Encoding.UTF8.GetString(message2.Payload.Buffer.ToArray())); } - Assert.True(applicationSide.Output.TryComplete()); + Assert.True(applicationSide.Output.Out.TryComplete()); // The transport should finish now await transport; @@ -147,15 +147,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); // Write multi-frame message to the output channel, and then complete it - await applicationSide.Output.WriteAsync(new Message( + await applicationSide.Output.Out.WriteAsync(new Message( ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello")).Preserve(), format, endOfMessage: false)); - await applicationSide.Output.WriteAsync(new Message( + await applicationSide.Output.Out.WriteAsync(new Message( ReadableBuffer.Create(Encoding.UTF8.GetBytes("World")).Preserve(), format, endOfMessage: true)); - Assert.True(applicationSide.Output.TryComplete()); + Assert.True(applicationSide.Output.Out.TryComplete()); // The client should finish now, as should the server var clientSummary = await client; @@ -195,11 +195,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); // Write to the output channel, and then complete it - await applicationSide.Output.WriteAsync(new Message( + await applicationSide.Output.Out.WriteAsync(new Message( ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello")).Preserve(), format, endOfMessage: true)); - Assert.True(applicationSide.Output.TryComplete()); + Assert.True(applicationSide.Output.Out.TryComplete()); // The client should finish now, as should the server var clientSummary = await client; @@ -236,7 +236,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); // Close the output and wait for the close frame - Assert.True(applicationSide.Output.TryComplete()); + Assert.True(applicationSide.Output.Out.TryComplete()); await client; // Send another frame. Then close @@ -247,7 +247,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure); // Read that frame from the input - using (var message = await applicationSide.Input.ReadAsync()) + using (var message = await applicationSide.Input.In.ReadAsync()) { Assert.True(message.EndOfMessage); Assert.Equal(format, message.MessageFormat); @@ -307,7 +307,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); // Fail in the app - Assert.True(applicationSide.Output.TryComplete(new InvalidOperationException())); + Assert.True(applicationSide.Output.Out.TryComplete(new InvalidOperationException())); var clientSummary = await client; Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.Status);