diff --git a/samples/ClientSample/HubSample.cs b/samples/ClientSample/HubSample.cs index 1dd478122e..5f5697915d 100644 --- a/samples/ClientSample/HubSample.cs +++ b/samples/ClientSample/HubSample.cs @@ -72,7 +72,7 @@ namespace ClientSample break; } - await connection.InvokeAsync("Send", cts.Token, line); + await connection.InvokeAsync("Send", line, cts.Token); } } catch (AggregateException aex) when (aex.InnerExceptions.All(e => e is OperationCanceledException)) diff --git a/specs/HubProtocol.md b/specs/HubProtocol.md index 811c2d5610..8cf20a9a6d 100644 --- a/specs/HubProtocol.md +++ b/specs/HubProtocol.md @@ -20,10 +20,11 @@ This document describes three encodings of the SignalR protocol: [JSON](http://w In the SignalR protocol, the following types of messages can be sent: -* `Negotiation` Message - Sent by the client to negotiate the message format +* `Negotiation` Message - Sent by the client to negotiate the message format. * `Invocation` Message - Indicates a request to invoke a particular method (the Target) with provided Arguments on the remote endpoint. * `StreamItem` Message - Indicates individual items of streamed response data from a previous Invocation message. * `Completion` Message - Indicates a previous Invocation has completed, and no further `StreamItem` messages will be received. Contains an error if the invocation concluded with an error, or the result if the invocation is not a streaming invocation. +* `CancelInvocation` Message - Sent by the client to cancel a streaming invocation on the server. After opening a connection to the server the client must send a `Negotiation` message to the server as its first message. The negotiation message is **always** a JSON message and contains the name of the format (protocol) that will be used for the duration of the connection. If the server does not support the protocol requested by the client or the first message received from the client is not a `Negotiation` message the server must close the connection. @@ -41,9 +42,9 @@ Example: ## Communication between the Caller and the Callee -There a three kinds of interactions between the Caller and the Calle: +There are three kinds of interactions between the Caller and the Callee: -* Invocations - the Caller sends a message to the Calle and expects a message indicating that the invocation has been completed and optionally a result of the invocation +* Invocations - the Caller sends a message to the Callee and expects a message indicating that the invocation has been completed and optionally a result of the invocation * Non-Blocking Invocations - the Caller sends a message to the Callee and does not expect any further messages for this invocation * Streaming Invocations - the Caller sends a message to the Callee and expects one or more results returned by the Callee followed by a message indicating the end of invocation @@ -74,7 +75,7 @@ The SignalR protocol allows for multiple `StreamItem` messages to be transmitted On the Callee side, it is up to the Callee's Binder to determine if a method call will yield multiple results. For example, in .NET certain return types may indicate multiple results, while others may indicate a single result. Even then, applications may wish for multiple results to be buffered and returned in a single `Completion` frame. It is up to the Binder to decide how to map this. The Callee's Binder must encode each result in separate `StreamItem` messages, indicating the end of results by sending a `Completion` message. -On the Caller side, the user code which performs the invocation indicates how it would like to receive the results and it is up the Caller's Binder to determine how to handle the result. If the Caller expects only a single result, but multiple results are returned, the Caller's Binder should yield an error indicating that multiple results were returned. However, if a Caller expects multiple results, but only a single result is returned, the Caller's Binder should yield that single result and indicate there are no further results. +On the Caller side, the user code which performs the invocation indicates how it would like to receive the results and it is up the Caller's Binder to determine how to handle the result. If the Caller expects only a single result, but multiple results are returned, the Caller's Binder should yield an error indicating that multiple results were returned. However, if a Caller expects multiple results, but only a single result is returned, the Caller's Binder should yield that single result and indicate there are no further results. If the Caller wants to stop receiving `StreamItem` messages before the Callee sends a `Completion` message, the Caller can send a `CancelInvocation` message with the same `Invocation ID` used for the `Invocation` message that started the stream. It is possible to receive `StreamItem` messages or a `Completion` message after a `CancelInvocation` message has been sent, these can be ignored. ## Completion and results @@ -223,6 +224,16 @@ S->C: Completion { Id = 42, Error = "Ran out of data!" } This should manifest to the Calling code as a sequence which emits `0`, `1`, `2`, `3`, `4`, but then fails with the error `Ran out of data!`. +### Streamed Result closed early (`Stream` example above) + +``` +C->S: Invocation { Id = 42, Target = "Stream", Arguments = [ 5 ] } +S->C: StreamItem { Id = 42, Item = 0 } +S->C: StreamItem { Id = 42, Item = 1 } +C->S: CancelInvocation { Id = 42 } +S->C: StreamItem { Id = 42, Item = 2} // This can be ignored +``` + ### Non-Blocking Call (`NonBlocking` example above) ``` @@ -248,7 +259,7 @@ Example: ```json { "type": 1, - "invocationId": 123, + "invocationId": "123", "target": "Send", "arguments": [ 42, @@ -261,7 +272,7 @@ Example (Non-Blocking): ```json { "type": 1, - "invocationId": 123, + "invocationId": "123", "nonblocking": true, "target": "Send", "arguments": [ @@ -284,7 +295,7 @@ Example ```json { "type": 2, - "invocationId": 123, + "invocationId": "123", "item": 42 } ``` @@ -305,7 +316,7 @@ Example - A `Completion` message with no result or error ```json { "type": 3, - "invocationId": 123 + "invocationId": "123" } ``` @@ -314,7 +325,7 @@ Example - A `Completion` message with a result ```json { "type": 3, - "invocationId": 123, + "invocationId": "123", "result": 42 } ``` @@ -324,7 +335,7 @@ Example - A `Completion` message with an error ```json { "type": 3, - "invocationId": 123, + "invocationId": "123", "error": "It didn't work!" } ``` @@ -334,12 +345,26 @@ Example - The following `Completion` message is a protocol error because it has ```json { "type": 3, - "invocationId": 123, + "invocationId": "123", "result": 42, "error": "It didn't work!" } ``` +### CancelInvocation Message Encoding +A `CancelInvocation` message is a JSON object with the following properties + +* `type` - A `Number` with the literal value `5`, indicationg that this is a `CancelInvocation`. +* `invocationId` - A `String` encoding the `Invocation ID` for a message. + +Example +```json +{ + "type": 5, + "invocationId": "123" +} +``` + ### JSON Payload Encoding Items in the arguments array within the `Invocation` message type, as well as the `item` value of the `StreamItem` message and the `result` value of the `Completion` message, encode values which have meaning to each particular Binder. A general guideline for encoding/decoding these values is provided in the "Type Mapping" section at the end of this document, but Binders should provide configuration to applications to allow them to customize these mappings. These mappings need not be self-describing, because when decoding the value, the Binder is expected to know the destination type (by looking up the definition of the method indicated by the Target). @@ -378,7 +403,7 @@ is decoded as follows: * `0x95` - 5-element array * `0x01` - `1` (Message Type - `Invocation` message) -* `0xa3` - string of length 3 (Target) +* `0xa3` - string of length 3 (InvocationId) * `0x78` - `x` * `0x79` - `y` * `0x7a` - `z` @@ -397,7 +422,9 @@ is decoded as follows: `StreamItem` messages have the following structure: +``` [2, InvocationId, Item] +``` * `2` - Message Type - `2` indicates this is a `StreamItem` message * InvocationId - A `String` encoding the Invocation ID for the message @@ -414,7 +441,7 @@ is decoded as follows: * `0x93` - 3-element array * `0x02` - `2` (Message Type - `StreamItem` message) -* `0xa3` - string of length 3 (Target) +* `0xa3` - string of length 3 (InvocationId) * `0x78` - `x` * `0x79` - `y` * `0x7a` - `z` @@ -449,7 +476,7 @@ is decoded as follows: * `0x94` - 4-element array * `0x03` - `3` (Message Type - `Result` message) -* `0xa3` - string of length 3 (Target) +* `0xa3` - string of length 3 (InvocationId) * `0x78` - `x` * `0x79` - `y` * `0x7a` - `z` @@ -472,7 +499,7 @@ is decoded as follows: * `0x93` - 3-element array * `0x03` - `3` (Message Type - `Result` message) -* `0xa3` - string of length 3 (Target) +* `0xa3` - string of length 3 (InvocationId) * `0x78` - `x` * `0x79` - `y` * `0x7a` - `z` @@ -489,13 +516,40 @@ is decoded as follows: * `0x94` - 4-element array * `0x03` - `3` (Message Type - `Result` message) -* `0xa3` - string of length 3 (Target) +* `0xa3` - string of length 3 (InvocationId) * `0x78` - `x` * `0x79` - `y` * `0x7a` - `z` * `0x03` - `3` (ResultKind - Non-Void result) * `0x2a` - `42` (Result) +### CancelInvocation Message Encoding + +`CancelInvocation` messages have the following structure + +``` +[5, InvocationId] +``` + +* `5` - Message Type - `5` indicates this is a `CancelInvocation` message +* InvocationId - A `String` encoding the Invocation ID for the message + +Example: + +The following payload: +``` +0x92 0x05 0xa3 0x78 0x79 0x7a +``` + +is decoded as follows: + +* `0x92` - 2-element array +* `0x05` - `5` (Message Type `CancelInvocation` message) +* `0xa3` - string of length 3 (InvocationId) +* `0x78` - `x` +* `0x79` - `y` +* `0x7a` - `z` + ## Protocol Buffers (ProtoBuf) Encoding **Protobuf encoding is currently not implemented** diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index 31297b17ca..d20df8ac2c 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -132,8 +132,35 @@ namespace Microsoft.AspNetCore.SignalR.Client private async Task> StreamAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken) { - var irq = InvocationRequest.Stream(cancellationToken, returnType, GetNextId(), _loggerFactory, out var channel); - await InvokeCore(methodName, irq, args, nonBlocking: false); + var invokeCts = new CancellationTokenSource(); + var irq = InvocationRequest.Stream(invokeCts.Token, returnType, GetNextId(), _loggerFactory, this, out var channel); + // After InvokeCore we don't want the irq cancellation token to be triggered. + // The stream invocation will be canceled by the CancelInvocationMessage, connection closing, or channel finishing. + using (cancellationToken.Register(token => ((CancellationTokenSource)token).Cancel(), invokeCts)) + { + await InvokeCore(methodName, irq, args, nonBlocking: false); + } + + if (cancellationToken.CanBeCanceled) + { + cancellationToken.Register(state => + { + var invocationReq = (InvocationRequest)state; + if (!invocationReq.HubConnection._connectionActive.IsCancellationRequested) + { + // Fire and forget, if it fails that means we aren't connected anymore. + _ = invocationReq.HubConnection.SendHubMessage(new CancelInvocationMessage(invocationReq.InvocationId), invocationReq); + + if (invocationReq.HubConnection.TryRemoveInvocation(invocationReq.InvocationId, out _)) + { + invocationReq.Complete(null); + } + + invocationReq.Dispose(); + } + }, irq); + } + return channel; } @@ -142,7 +169,7 @@ namespace Microsoft.AspNetCore.SignalR.Client private async Task InvokeAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken) { - var irq = InvocationRequest.Invoke(cancellationToken, returnType, GetNextId(), _loggerFactory, out var task); + var irq = InvocationRequest.Invoke(cancellationToken, returnType, GetNextId(), _loggerFactory, this, out var task); await InvokeCore(methodName, irq, args, nonBlocking: false); return await task; } @@ -152,7 +179,7 @@ namespace Microsoft.AspNetCore.SignalR.Client private Task SendAsyncCore(string methodName, object[] args, CancellationToken cancellationToken) { - var irq = InvocationRequest.Invoke(cancellationToken, typeof(void), GetNextId(), _loggerFactory, out _); + var irq = InvocationRequest.Invoke(cancellationToken, typeof(void), GetNextId(), _loggerFactory, this, out _); return InvokeCore(methodName, irq, args, nonBlocking: true); } @@ -184,24 +211,24 @@ namespace Microsoft.AspNetCore.SignalR.Client _logger.IssueInvocation(invocationMessage.InvocationId, irq.ResultType.FullName, methodName, args); // We don't need to wait for this to complete. It will signal back to the invocation request. - return SendInvocation(invocationMessage, irq); + return SendHubMessage(invocationMessage, irq); } - private async Task SendInvocation(InvocationMessage invocationMessage, InvocationRequest irq) + private async Task SendHubMessage(HubMessage hubMessage, InvocationRequest irq) { try { - var payload = _protocolReaderWriter.WriteMessage(invocationMessage); - _logger.SendInvocation(invocationMessage.InvocationId); + var payload = _protocolReaderWriter.WriteMessage(hubMessage); + _logger.SendInvocation(hubMessage.InvocationId); await _connection.SendAsync(payload, irq.CancellationToken); - _logger.SendInvocationCompleted(invocationMessage.InvocationId); + _logger.SendInvocationCompleted(hubMessage.InvocationId); } catch (Exception ex) { - _logger.SendInvocationFailed(invocationMessage.InvocationId, ex); + _logger.SendInvocationFailed(hubMessage.InvocationId, ex); irq.Fail(ex); - TryRemoveInvocation(invocationMessage.InvocationId, out _); + TryRemoveInvocation(hubMessage.InvocationId, out _); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs index 61f4b12b44..dfe02a5c51 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs @@ -19,8 +19,9 @@ namespace Microsoft.AspNetCore.SignalR.Client public Type ResultType { get; } public CancellationToken CancellationToken { get; } public string InvocationId { get; } + public HubConnection HubConnection { get; private set; } - protected InvocationRequest(CancellationToken cancellationToken, Type resultType, string invocationId, ILogger logger) + protected InvocationRequest(CancellationToken cancellationToken, Type resultType, string invocationId, ILogger logger, HubConnection hubConnection) { _cancellationTokenRegistration = cancellationToken.Register(self => ((InvocationRequest)self).Cancel(), this); @@ -28,21 +29,23 @@ namespace Microsoft.AspNetCore.SignalR.Client CancellationToken = cancellationToken; ResultType = resultType; Logger = logger; + HubConnection = hubConnection; Logger.InvocationCreated(InvocationId); } - public static InvocationRequest Invoke(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, out Task result) + public static InvocationRequest Invoke(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, HubConnection hubConnection, out Task result) { - var req = new NonStreaming(cancellationToken, resultType, invocationId, loggerFactory); + var req = new NonStreaming(cancellationToken, resultType, invocationId, loggerFactory, hubConnection); result = req.Result; return req; } - public static InvocationRequest Stream(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, out ReadableChannel result) + public static InvocationRequest Stream(CancellationToken cancellationToken, Type resultType, string invocationId, + ILoggerFactory loggerFactory, HubConnection hubConnection, out ReadableChannel result) { - var req = new Streaming(cancellationToken, resultType, invocationId, loggerFactory); + var req = new Streaming(cancellationToken, resultType, invocationId, loggerFactory, hubConnection); result = req.Result; return req; } @@ -67,8 +70,8 @@ namespace Microsoft.AspNetCore.SignalR.Client { private readonly Channel _channel = Channel.CreateUnbounded(); - public Streaming(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory) - : base(cancellationToken, resultType, invocationId, loggerFactory.CreateLogger()) + public Streaming(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, HubConnection hubConnection) + : base(cancellationToken, resultType, invocationId, loggerFactory.CreateLogger(), hubConnection) { } @@ -115,7 +118,7 @@ namespace Microsoft.AspNetCore.SignalR.Client protected override void Cancel() { - _channel.Out.TryComplete(new OperationCanceledException("Connection terminated")); + _channel.Out.TryComplete(new OperationCanceledException("Invocation terminated")); } } @@ -123,8 +126,8 @@ namespace Microsoft.AspNetCore.SignalR.Client { private readonly TaskCompletionSource _completionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - public NonStreaming(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory) - : base(cancellationToken, resultType, invocationId, loggerFactory.CreateLogger()) + public NonStreaming(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, HubConnection hubConnection) + : base(cancellationToken, resultType, invocationId, loggerFactory.CreateLogger(), hubConnection) { } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CancelInvocationMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CancelInvocationMessage.cs new file mode 100644 index 0000000000..2240d8f569 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CancelInvocationMessage.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public class CancelInvocationMessage : HubMessage + { + public CancelInvocationMessage(string invocationId) : base(invocationId) + { + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/InvocationMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/InvocationMessage.cs index fa941b3fe4..4f8a4c738f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/InvocationMessage.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/InvocationMessage.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; using System.Linq; namespace Microsoft.AspNetCore.SignalR.Internal.Protocol diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs index fb5aea9f8f..e45d5ff471 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs @@ -25,6 +25,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private const int InvocationMessageType = 1; private const int ResultMessageType = 2; private const int CompletionMessageType = 3; + private const int CancelInvocationMessageType = 5; // ONLY to be used for application payloads (args, return values, etc.) private JsonSerializer _payloadSerializer; @@ -111,6 +112,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol return BindResultMessage(json, binder); case CompletionMessageType: return BindCompletionMessage(json, binder); + case CancelInvocationMessageType: + return BindCancelInvocationMessage(json); default: throw new FormatException($"Unknown message type: {type}"); } @@ -137,6 +140,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol case CompletionMessage m: WriteCompletionMessage(m, writer); break; + case CancelInvocationMessage m: + WriteCancelInvocationMessage(m, writer); + break; default: throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}"); } @@ -160,6 +166,13 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol writer.WriteEndObject(); } + private void WriteCancelInvocationMessage(CancelInvocationMessage message, JsonTextWriter writer) + { + writer.WriteStartObject(); + WriteHubMessageCommon(message, writer, CancelInvocationMessageType); + writer.WriteEndObject(); + } + private void WriteStreamItemMessage(StreamItemMessage message, JsonTextWriter writer) { writer.WriteStartObject(); @@ -260,6 +273,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol } } + private CancelInvocationMessage BindCancelInvocationMessage(JObject json) + { + var invocationId = JsonUtils.GetRequiredProperty(json, InvocationIdPropertyName, JTokenType.String); + return new CancelInvocationMessage(invocationId); + } + public static JsonSerializerSettings CreateDefaultSerializerSettings() { return new JsonSerializerSettings { ContractResolver = new CamelCasePropertyNamesContractResolver() }; diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs index 92fd694839..d06ef8346b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs @@ -15,6 +15,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private const int InvocationMessageType = 1; private const int StreamItemMessageType = 2; private const int CompletionMessageType = 3; + private const int CancelInvocationMessageType = 5; private const int ErrorResult = 1; private const int VoidResult = 2; @@ -64,6 +65,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol return CreateStreamItemMessage(unpacker, binder); case CompletionMessageType: return CreateCompletionMessage(unpacker, binder); + case CancelInvocationMessageType: + return CreateCancelInvocationMessage(unpacker); default: throw new FormatException($"Invalid message type: {messageType}."); } @@ -129,6 +132,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol return new CompletionMessage(invocationId, error, result, hasResult); } + private static CancelInvocationMessage CreateCancelInvocationMessage(Unpacker unpacker) + { + var invocationId = ReadInvocationId(unpacker); + return new CancelInvocationMessage(invocationId); + } + public void WriteMessage(HubMessage message, Stream output) { using (var memoryStream = new MemoryStream()) @@ -146,20 +155,23 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol switch (message) { case InvocationMessage invocationMessage: - WriteInvocationMessage(invocationMessage, packer, output); + WriteInvocationMessage(invocationMessage, packer); break; case StreamItemMessage streamItemMessage: - WriteStreamingItemMessage(streamItemMessage, packer, output); + WriteStreamingItemMessage(streamItemMessage, packer); break; case CompletionMessage completionMessage: - WriteCompletionMessage(completionMessage, packer, output); + WriteCompletionMessage(completionMessage, packer); + break; + case CancelInvocationMessage cancelInvocationMessage: + WriteCancelInvocationMessage(cancelInvocationMessage, packer); break; default: throw new FormatException($"Unexpected message type: {message.GetType().Name}"); } } - private void WriteInvocationMessage(InvocationMessage invocationMessage, Packer packer, Stream output) + private void WriteInvocationMessage(InvocationMessage invocationMessage, Packer packer) { packer.PackArrayHeader(5); packer.Pack(InvocationMessageType); @@ -169,7 +181,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol packer.PackObject(invocationMessage.Arguments, _serializationContext); } - private void WriteStreamingItemMessage(StreamItemMessage streamItemMessage, Packer packer, Stream output) + private void WriteStreamingItemMessage(StreamItemMessage streamItemMessage, Packer packer) { packer.PackArrayHeader(3); packer.Pack(StreamItemMessageType); @@ -177,7 +189,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol packer.PackObject(streamItemMessage.Item, _serializationContext); } - private void WriteCompletionMessage(CompletionMessage completionMessage, Packer packer, Stream output) + private void WriteCompletionMessage(CompletionMessage completionMessage, Packer packer) { var resultKind = completionMessage.Error != null ? ErrorResult : @@ -199,6 +211,13 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol } } + private void WriteCancelInvocationMessage(CancelInvocationMessage cancelInvocationMessage, Packer packer) + { + packer.PackArrayHeader(2); + packer.Pack(CancelInvocationMessageType); + packer.PackString(cancelInvocationMessage.InvocationId); + } + private static string ReadInvocationId(Unpacker unpacker) { return ReadString(unpacker, "invocationId"); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index f016986d34..56f118dd49 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Runtime.ExceptionServices; using System.Security.Claims; @@ -54,6 +55,9 @@ namespace Microsoft.AspNetCore.SignalR public virtual WritableChannel Output => _output; + // Currently used only for streaming methods + internal ConcurrentDictionary ActiveRequestCancellationSources { get; } = new ConcurrentDictionary(); + public virtual void Abort() { // If we already triggered the token then noop, this isn't thread safe but it's good enough diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs index fbf9151e41..653fae4c8f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs @@ -281,6 +281,20 @@ namespace Microsoft.AspNetCore.SignalR _ = ProcessInvocation(connection, invocationMessage); break; + case CancelInvocationMessage cancelInvocationMessage: + // Check if there is an associated active stream and cancel it if it exists. + if (connection.ActiveRequestCancellationSources.TryRemove(cancelInvocationMessage.InvocationId, out var cts)) + { + _logger.CancelStream(cancelInvocationMessage.InvocationId); + cts.Cancel(); + } + else + { + // Stream can be canceled on the server while client is canceling stream. + _logger.UnexpectedCancel(); + } + break; + // Other kind of message we weren't expecting default: _logger.UnsupportedMessageReceived(hubMessage.GetType().FullName); @@ -384,7 +398,7 @@ namespace Microsoft.AspNetCore.SignalR result = methodExecutor.Execute(hub, invocationMessage.Arguments); } - if (IsStreamed(connection, methodExecutor, result, methodExecutor.MethodReturnType, out var enumerator)) + if (IsStreamed(connection, invocationMessage.InvocationId, methodExecutor, result, methodExecutor.MethodReturnType, out var enumerator)) { _logger.StreamingResult(invocationMessage.InvocationId, methodExecutor.MethodReturnType.FullName); await StreamResultsAsync(invocationMessage.InvocationId, connection, enumerator); @@ -456,9 +470,16 @@ namespace Microsoft.AspNetCore.SignalR { await SendMessageAsync(connection, CompletionMessage.WithError(invocationId, ex.Message)); } + finally + { + if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts)) + { + cts.Dispose(); + } + } } - private bool IsStreamed(HubConnectionContext connection, ObjectMethodExecutor methodExecutor, object result, Type resultType, out IAsyncEnumerator enumerator) + private bool IsStreamed(HubConnectionContext connection, string invocationId, ObjectMethodExecutor methodExecutor, object result, Type resultType, out IAsyncEnumerator enumerator) { if (result == null) { @@ -466,20 +487,17 @@ namespace Microsoft.AspNetCore.SignalR return false; } - - // TODO: We need to support cancelling the stream without a client disconnect as well. - var observableInterface = IsIObservable(resultType) ? resultType : resultType.GetInterfaces().FirstOrDefault(IsIObservable); if (observableInterface != null) { - enumerator = AsyncEnumeratorAdapters.FromObservable(result, observableInterface, connection.ConnectionAbortedToken); + enumerator = AsyncEnumeratorAdapters.FromObservable(result, observableInterface, CreateCancellation()); return true; } else if (IsChannel(resultType, out var payloadType)) { - enumerator = AsyncEnumeratorAdapters.FromChannel(result, payloadType, connection.ConnectionAbortedToken); + enumerator = AsyncEnumeratorAdapters.FromChannel(result, payloadType, CreateCancellation()); return true; } else @@ -488,6 +506,13 @@ namespace Microsoft.AspNetCore.SignalR enumerator = null; return false; } + + CancellationToken CreateCancellation() + { + var streamCts = new CancellationTokenSource(); + connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts); + return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAbortedToken, streamCts.Token).Token; + } } private static bool IsIObservable(Type iface) diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs index f387b22329..f066e18048 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs @@ -89,7 +89,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal public void OnNext(T value) { - _cancellationToken.ThrowIfCancellationRequested(); + if (_cancellationToken.IsCancellationRequested) + { + // Noop, someone else is handling the cancellation + return; + } // This will block the thread emitting the object if the channel is bounded and full // I think this is OK, since we want to push the backpressure up. However, we may need diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/SignalRCoreLoggerExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/SignalRCoreLoggerExtensions.cs index ee41f75f82..1dbdc07c48 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/SignalRCoreLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/SignalRCoreLoggerExtensions.cs @@ -49,6 +49,12 @@ namespace Microsoft.AspNetCore.SignalR.Core.Internal private static readonly Action _hubMethodBound = LoggerMessage.Define(LogLevel.Trace, new EventId(12, nameof(HubMethodBound)), "Hub method '{hubMethod}' is bound."); + private static readonly Action _cancelStream = + LoggerMessage.Define(LogLevel.Debug, new EventId(13, nameof(CancelStream)), "Canceling stream for invocation {invocationId}."); + + private static readonly Action _unexpectedCancel = + LoggerMessage.Define(LogLevel.Debug, new EventId(14, nameof(UnexpectedCancel)), "CancelInvocationMessage received unexpectedly."); + public static void UsingHubProtocol(this ILogger logger, string hubProtocol) { _usingHubProtocol(logger, hubProtocol, null); @@ -113,5 +119,15 @@ namespace Microsoft.AspNetCore.SignalR.Core.Internal { _hubMethodBound(logger, hubMethod, null); } + + public static void CancelStream(this ILogger logger, string invocationId) + { + _cancelStream(logger, invocationId, null); + } + + public static void UnexpectedCancel(this ILogger logger) + { + _unexpectedCancel(logger, null); + } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs index 58b5132f07..a95c59a353 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs @@ -27,11 +27,12 @@ namespace Microsoft.AspNetCore.Sockets.Client { // Grab as many messages as we can from the channel messages = new List(); - while (!transportCts.Token.IsCancellationRequested && application.In.TryRead(out SendMessage message)) + while (!transportCts.IsCancellationRequested && application.In.TryRead(out SendMessage message)) { messages.Add(message); } + transportCts.Token.ThrowIfCancellationRequested(); if (messages.Count > 0) { logger.SendingMessages(connectionId, messages.Count, sendUrl); @@ -57,7 +58,7 @@ namespace Microsoft.AspNetCore.Sockets.Client // Set the, now filled, stream as the content request.Content = new StreamContent(memoryStream); - var response = await httpClient.SendAsync(request); + var response = await httpClient.SendAsync(request, transportCts.Token); response.EnsureSuccessStatusCode(); logger.SentSuccessfully(connectionId); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index a9a02a0f72..26ca8eefa0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Internal.Protocol; @@ -175,12 +176,76 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { await connection.StartAsync().OrTimeout(); - var tcs = new TaskCompletionSource(); - - var channel = await connection.StreamAsync("Stream"); + var channel = await connection.StreamAsync("Stream", 5).OrTimeout(); var results = await channel.ReadAllAsync().OrTimeout(); - Assert.Equal(new[] { "a", "b", "c" }, results.ToArray()); + Assert.Equal(new[] { 0, 1, 2, 3, 4 }, results.ToArray()); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + public async Task CanCloseStreamMethodEarly(IHubProtocol protocol, TransportType transportType, string path) + { + using (StartLog(out var loggerFactory)) + { + var httpConnection = new HttpConnection(new Uri(_serverFixture.BaseUrl + path), transportType, loggerFactory); + var connection = new HubConnection(httpConnection, protocol, loggerFactory); + try + { + await connection.StartAsync().OrTimeout(); + + var cts = new CancellationTokenSource(); + + var channel = await connection.StreamAsync("Stream", 1000, cts.Token).OrTimeout(); + + await channel.WaitToReadAsync().OrTimeout(); + cts.Cancel(); + + var results = await channel.ReadAllAsync().OrTimeout(); + + Assert.True(results.Count > 0 && results.Count < 1000); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + public async Task StreamDoesNotStartIfTokenAlreadyCanceled(IHubProtocol protocol, TransportType transportType, string path) + { + using (StartLog(out var loggerFactory)) + { + var httpConnection = new HttpConnection(new Uri(_serverFixture.BaseUrl + path), transportType, loggerFactory); + var connection = new HubConnection(httpConnection, protocol, loggerFactory); + try + { + await connection.StartAsync().OrTimeout(); + + var cts = new CancellationTokenSource(); + cts.Cancel(); + + var channel = await connection.StreamAsync("Stream", 5, cts.Token).OrTimeout(); + + await Assert.ThrowsAnyAsync(() => channel.WaitToReadAsync().OrTimeout()); } catch (Exception ex) { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs index b35ae46d7e..2c02e941fe 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs @@ -24,9 +24,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests await Clients.Client(Context.ConnectionId).InvokeAsync("Echo", message); } - public IObservable Stream() + public IObservable Stream(int count) { - return new[] { "a", "b", "c" }.ToObservable(); + return Observable.Interval(TimeSpan.FromMilliseconds(1)) + .Select((_, index) => index) + .Take(count); } } @@ -47,9 +49,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests await Clients.Client(Context.ConnectionId).Echo(message); } - public IObservable Stream() + public IObservable Stream(int count) { - return new[] { "a", "b", "c" }.ToObservable(); + return Observable.Interval(TimeSpan.FromMilliseconds(1)) + .Select((_, index) => index) + .Take(count); } public Task SendMessage(string message) @@ -75,9 +79,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests await Clients.Client(Context.ConnectionId).Echo(message); } - public IObservable Stream() + public IObservable Stream(int count) { - return new[] { "a", "b", "c" }.ToObservable(); + return Observable.Interval(TimeSpan.FromMilliseconds(1)) + .Select((_, index) => index) + .Take(count); } public Task SendMessage(string message) diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs index fe2a5e4d07..ecf971530b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs @@ -112,7 +112,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [InlineData("{'type':3,'invocationId':42}", "Expected 'invocationId' to be of type String.")] [InlineData("{'type':3,'invocationId':'42','error':[]}", "Expected 'error' to be of type String.")] - [InlineData("{'type':4}", "Unknown message type: 4")] + [InlineData("{'type':9}", "Unknown message type: 9")] [InlineData("{'type':'foo'}", "Expected 'type' to be of type Integer.")] public void InvalidMessages(string input, string expectedMessage) { diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 04ef10118b..b147de72a7 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -210,6 +210,47 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task ObservableHubRemovesSubscriptionWhenCanceledFromClient() + { + var observable = new Observable(); + var serviceProvider = CreateServiceProvider(s => s.AddSingleton(observable)); + var endPoint = serviceProvider.GetService>(); + + var waitForSubscribe = new TaskCompletionSource(); + observable.OnSubscribe = o => + { + waitForSubscribe.TrySetResult(null); + }; + + var waitForDispose = new TaskCompletionSource(); + observable.OnDispose = o => + { + waitForDispose.TrySetResult(null); + }; + + using (var client = new TestClient()) + { + var endPointTask = endPoint.OnConnectedAsync(client.Connection); + + var invocationId = await client.SendInvocationAsync(nameof(ObservableHub.Subscribe), nonBlocking: false).OrTimeout(); + + await waitForSubscribe.Task.OrTimeout(); + + observable.OnNext(1); + + await client.SendHubMessageAsync(new CancelInvocationMessage(invocationId)).OrTimeout(); + + await waitForDispose.Task.OrTimeout(); + + Assert.Equal(1L, ((StreamItemMessage)await client.ReadAsync().OrTimeout()).Item); + + client.Dispose(); + + await endPointTask.OrTimeout(); + } + } + [Fact] public async Task MissingNegotiateAndMessageSentFromHubConnectionCanBeDisposedCleanly() { diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs index 7835863279..b50a1e8fe8 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs @@ -129,13 +129,17 @@ namespace Microsoft.AspNetCore.SignalR.Tests return SendInvocationAsync(methodName, nonBlocking: false, args: args); } - public async Task SendInvocationAsync(string methodName, bool nonBlocking, params object[] args) + public Task SendInvocationAsync(string methodName, bool nonBlocking, params object[] args) { var invocationId = GetInvocationId(); - var payload = _protocolReaderWriter.WriteMessage(new InvocationMessage(invocationId, nonBlocking, methodName, args)); - await Application.Out.WriteAsync(payload); + return SendHubMessageAsync(new InvocationMessage(invocationId, nonBlocking, methodName, args)); + } - return invocationId; + public async Task SendHubMessageAsync(HubMessage message) + { + var payload = _protocolReaderWriter.WriteMessage(message); + await Application.Out.WriteAsync(payload); + return message.InvocationId; } public async Task ReadAsync()