diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index 9642ceaa97..f8d1f927e0 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -163,7 +163,7 @@ namespace Microsoft.AspNetCore.SignalR.Client if (invocationReq.HubConnection.TryRemoveInvocation(invocationReq.InvocationId, out _)) { - invocationReq.Complete(null); + invocationReq.Complete(new StreamCompletionMessage(irq.InvocationId, error: null)); } invocationReq.Dispose(); @@ -273,6 +273,15 @@ namespace Microsoft.AspNetCore.SignalR.Client } DispatchInvocationStreamItemAsync(streamItem, irq); break; + case StreamCompletionMessage streamCompletion: + if (!TryRemoveInvocation(streamCompletion.InvocationId, out irq)) + { + _logger.DropStreamCompletionMessage(streamCompletion.InvocationId); + return; + } + DispatchStreamCompletion(streamCompletion, irq); + irq.Dispose(); + break; default: throw new InvalidOperationException($"Unknown message type: {message.GetType().FullName}"); } @@ -362,18 +371,25 @@ namespace Microsoft.AspNetCore.SignalR.Client if (irq.CancellationToken.IsCancellationRequested) { - _logger.CancelingCompletion(irq.InvocationId); + _logger.CancelingInvocationCompletion(irq.InvocationId); } else { - if (!string.IsNullOrEmpty(completion.Error)) - { - irq.Fail(new HubException(completion.Error)); - } - else - { - irq.Complete(completion.Result); - } + irq.Complete(completion); + } + } + + private void DispatchStreamCompletion(StreamCompletionMessage completion, InvocationRequest irq) + { + _logger.ReceivedStreamCompletion(completion.InvocationId); + + if (irq.CancellationToken.IsCancellationRequested) + { + _logger.CancelingStreamCompletion(irq.InvocationId); + } + else + { + irq.Complete(completion); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.StreamAsync.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.StreamAsync.cs index 267ea0e001..c62c814051 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.StreamAsync.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.StreamAsync.cs @@ -2,8 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; -using System.Text; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs index af32121b5f..5966fa78cd 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs @@ -40,44 +40,53 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal private static readonly Action _dropStreamMessage = LoggerMessage.Define(LogLevel.Warning, new EventId(9, nameof(DropStreamMessage)), "Dropped unsolicited StreamItem message for invocation '{invocationId}'."); + private static readonly Action _dropStreamCompletionMessage = + LoggerMessage.Define(LogLevel.Warning, new EventId(10, nameof(DropStreamCompletionMessage)), "Dropped unsolicited Stream Completion message for invocation '{invocationId}'."); + private static readonly Action _shutdownConnection = - LoggerMessage.Define(LogLevel.Trace, new EventId(10, nameof(ShutdownConnection)), "Shutting down connection."); + LoggerMessage.Define(LogLevel.Trace, new EventId(11, nameof(ShutdownConnection)), "Shutting down connection."); private static readonly Action _shutdownWithError = - LoggerMessage.Define(LogLevel.Error, new EventId(11, nameof(ShutdownWithError)), "Connection is shutting down due to an error."); + LoggerMessage.Define(LogLevel.Error, new EventId(12, nameof(ShutdownWithError)), "Connection is shutting down due to an error."); private static readonly Action _removeInvocation = - LoggerMessage.Define(LogLevel.Trace, new EventId(12, nameof(RemoveInvocation)), "Removing pending invocation {invocationId}."); + LoggerMessage.Define(LogLevel.Trace, new EventId(13, nameof(RemoveInvocation)), "Removing pending invocation {invocationId}."); private static readonly Action _missingHandler = - LoggerMessage.Define(LogLevel.Warning, new EventId(13, nameof(MissingHandler)), "Failed to find handler for '{target}' method."); + LoggerMessage.Define(LogLevel.Warning, new EventId(14, nameof(MissingHandler)), "Failed to find handler for '{target}' method."); private static readonly Action _receivedStreamItem = - LoggerMessage.Define(LogLevel.Trace, new EventId(14, nameof(ReceivedStreamItem)), "Received StreamItem for Invocation {invocationId}."); + LoggerMessage.Define(LogLevel.Trace, new EventId(15, nameof(ReceivedStreamItem)), "Received StreamItem for Invocation {invocationId}."); private static readonly Action _cancelingStreamItem = - LoggerMessage.Define(LogLevel.Trace, new EventId(15, nameof(CancelingStreamItem)), "Canceling dispatch of StreamItem message for Invocation {invocationId}. The invocation was canceled."); + LoggerMessage.Define(LogLevel.Trace, new EventId(16, nameof(CancelingStreamItem)), "Canceling dispatch of StreamItem message for Invocation {invocationId}. The invocation was canceled."); private static readonly Action _receivedStreamItemAfterClose = - LoggerMessage.Define(LogLevel.Warning, new EventId(16, nameof(ReceivedStreamItemAfterClose)), "Invocation {invocationId} received stream item after channel was closed."); + LoggerMessage.Define(LogLevel.Warning, new EventId(17, nameof(ReceivedStreamItemAfterClose)), "Invocation {invocationId} received stream item after channel was closed."); private static readonly Action _receivedInvocationCompletion = - LoggerMessage.Define(LogLevel.Trace, new EventId(17, nameof(ReceivedInvocationCompletion)), "Received Completion for Invocation {invocationId}."); + LoggerMessage.Define(LogLevel.Trace, new EventId(18, nameof(ReceivedInvocationCompletion)), "Received Completion for Invocation {invocationId}."); - private static readonly Action _cancelingCompletion = - LoggerMessage.Define(LogLevel.Trace, new EventId(18, nameof(CancelingCompletion)), "Canceling dispatch of Completion message for Invocation {invocationId}. The invocation was canceled."); + private static readonly Action _cancelingInvocationCompletion = + LoggerMessage.Define(LogLevel.Trace, new EventId(19, nameof(CancelingInvocationCompletion)), "Canceling dispatch of Completion message for Invocation {invocationId}. The invocation was canceled."); + + private static readonly Action _receivedStreamCompletion = + LoggerMessage.Define(LogLevel.Trace, new EventId(20, nameof(ReceivedStreamCompletion)), "Received StreamCompletion for Invocation {invocationId}."); + + private static readonly Action _cancelingStreamCompletion = + LoggerMessage.Define(LogLevel.Trace, new EventId(21, nameof(CancelingStreamCompletion)), "Canceling dispatch of StreamCompletion message for Invocation {invocationId}. The invocation was canceled."); private static readonly Action _invokeAfterTermination = - LoggerMessage.Define(LogLevel.Error, new EventId(19, nameof(InvokeAfterTermination)), "Invoke for Invocation '{invocationId}' was called after the connection was terminated."); + LoggerMessage.Define(LogLevel.Error, new EventId(22, nameof(InvokeAfterTermination)), "Invoke for Invocation '{invocationId}' was called after the connection was terminated."); private static readonly Action _invocationAlreadyInUse = - LoggerMessage.Define(LogLevel.Critical, new EventId(20, nameof(InvocationAlreadyInUse)), "Invocation ID '{invocationId}' is already in use."); + LoggerMessage.Define(LogLevel.Critical, new EventId(23, nameof(InvocationAlreadyInUse)), "Invocation ID '{invocationId}' is already in use."); private static readonly Action _receivedUnexpectedResponse = - LoggerMessage.Define(LogLevel.Error, new EventId(21, nameof(ReceivedUnexpectedResponse)), "Unsolicited response received for invocation '{invocationId}'."); + LoggerMessage.Define(LogLevel.Error, new EventId(24, nameof(ReceivedUnexpectedResponse)), "Unsolicited response received for invocation '{invocationId}'."); private static readonly Action _hubProtocol = - LoggerMessage.Define(LogLevel.Information, new EventId(22, nameof(HubProtocol)), "Using HubProtocol '{protocol}'."); + LoggerMessage.Define(LogLevel.Information, new EventId(25, nameof(HubProtocol)), "Using HubProtocol '{protocol}'."); // Category: Streaming and NonStreaming private static readonly Action _invocationCreated = @@ -93,11 +102,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal LoggerMessage.Define(LogLevel.Trace, new EventId(3, nameof(InvocationFailed)), "Invocation {invocationId} marked as failed."); // Category: Streaming - private static readonly Action _receivedUnexpectedComplete = - LoggerMessage.Define(LogLevel.Error, new EventId(4, nameof(ReceivedUnexpectedComplete)), "Invocation {invocationId} received a completion result, but was invoked as a streaming invocation."); - private static readonly Action _errorWritingStreamItem = - LoggerMessage.Define(LogLevel.Error, new EventId(5, nameof(ErrorWritingStreamItem)), "Invocation {invocationId} caused an error trying to write a stream item."); + LoggerMessage.Define(LogLevel.Error, new EventId(4, nameof(ErrorWritingStreamItem)), "Invocation {invocationId} caused an error trying to write a stream item."); + + private static readonly Action _receivedUnexpectedMessageTypeForStreamCompletion = + LoggerMessage.Define(LogLevel.Error, new EventId(5, nameof(ReceivedUnexpectedMessageTypeForStreamCompletion)), "Invocation {invocationId} was invoked as a streaming hub method but completed with '{messageType}' message."); // Category: NonStreaming private static readonly Action _streamItemOnNonStreamInvocation = @@ -106,6 +115,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal private static readonly Action _exceptionThrownFromCallback = LoggerMessage.Define(LogLevel.Error, new EventId(5, nameof(ExceptionThrownFromCallback)), "An exception was thrown from the '{callback}' callback"); + private static readonly Action _receivedUnexpectedMessageTypeForInvokeCompletion = + LoggerMessage.Define(LogLevel.Error, new EventId(6, nameof(ReceivedUnexpectedMessageTypeForInvokeCompletion)), "Invocation {invocationId} was invoked as a non-streaming hub method but completed with '{messageType}' message."); + public static void PreparingNonBlockingInvocation(this ILogger logger, string invocationId, string target, int count) { _preparingNonBlockingInvocation(logger, invocationId, target, count, null); @@ -164,6 +176,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal _dropStreamMessage(logger, invocationId, null); } + public static void DropStreamCompletionMessage(this ILogger logger, string invocationId) + { + _dropStreamCompletionMessage(logger, invocationId, null); + } + public static void ShutdownConnection(this ILogger logger) { _shutdownConnection(logger, null); @@ -204,9 +221,19 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal _receivedInvocationCompletion(logger, invocationId, null); } - public static void CancelingCompletion(this ILogger logger, string invocationId) + public static void CancelingInvocationCompletion(this ILogger logger, string invocationId) { - _cancelingCompletion(logger, invocationId, null); + _cancelingInvocationCompletion(logger, invocationId, null); + } + + public static void ReceivedStreamCompletion(this ILogger logger, string invocationId) + { + _receivedStreamCompletion(logger, invocationId, null); + } + + public static void CancelingStreamCompletion(this ILogger logger, string invocationId) + { + _cancelingStreamCompletion(logger, invocationId, null); } public static void InvokeAfterTermination(this ILogger logger, string invocationId) @@ -249,11 +276,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal _invocationFailed(logger, invocationId, null); } - public static void ReceivedUnexpectedComplete(this ILogger logger, string invocationId) - { - _receivedUnexpectedComplete(logger, invocationId, null); - } - public static void ErrorWritingStreamItem(this ILogger logger, string invocationId, Exception exception) { _errorWritingStreamItem(logger, invocationId, exception); @@ -268,5 +290,15 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal { _exceptionThrownFromCallback(logger, callbackName, exception); } + + public static void ReceivedUnexpectedMessageTypeForStreamCompletion(this ILogger logger, string invocationId, string messageType) + { + _receivedUnexpectedMessageTypeForStreamCompletion(logger, invocationId, messageType, null); + } + + public static void ReceivedUnexpectedMessageTypeForInvokeCompletion(this ILogger logger, string invocationId, string messageType) + { + _receivedUnexpectedMessageTypeForStreamCompletion(logger, invocationId, messageType, null); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs index dfe02a5c51..97dddead81 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs @@ -2,10 +2,12 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Diagnostics; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Client.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.SignalR.Client @@ -41,7 +43,6 @@ namespace Microsoft.AspNetCore.SignalR.Client return req; } - public static InvocationRequest Stream(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory, HubConnection hubConnection, out ReadableChannel result) { @@ -51,7 +52,7 @@ namespace Microsoft.AspNetCore.SignalR.Client } public abstract void Fail(Exception exception); - public abstract void Complete(object result); + public abstract void Complete(HubMessage message); public abstract ValueTask StreamItem(object item); protected abstract void Cancel(); @@ -77,18 +78,27 @@ namespace Microsoft.AspNetCore.SignalR.Client public ReadableChannel Result => _channel.In; - public override void Complete(object result) + public override void Complete(HubMessage message) { + Debug.Assert(message != null, "message is null"); + + if (!(message is StreamCompletionMessage streamCompletionMessage)) + { + Logger.ReceivedUnexpectedMessageTypeForStreamCompletion(InvocationId, message.GetType().Name); + // This is not 100% accurate but it is the only case that can be encountered today when running end-to-end + // and this is the most useful message to show to the user. + Fail(new InvalidOperationException($"Streaming hub methods must be invoked with the '{nameof(HubConnection)}.{nameof(HubConnection.StreamAsync)}' method.")); + return; + } + + if (!string.IsNullOrEmpty(streamCompletionMessage.Error)) + { + Fail(new HubException(streamCompletionMessage.Error)); + return; + } + Logger.InvocationCompleted(InvocationId); - if (result != null) - { - Logger.ReceivedUnexpectedComplete(InvocationId); - _channel.Out.TryComplete(new InvalidOperationException("Server provided a result in a completion response to a streamed invocation.")); - } - else - { - _channel.Out.TryComplete(); - } + _channel.Out.TryComplete(); } public override void Fail(Exception exception) @@ -133,10 +143,28 @@ namespace Microsoft.AspNetCore.SignalR.Client public Task Result => _completionSource.Task; - public override void Complete(object result) + public override void Complete(HubMessage message) { + Debug.Assert(message != null, "message is null"); + + if (!(message is CompletionMessage completionMessage)) + { + Logger.ReceivedUnexpectedMessageTypeForStreamCompletion(InvocationId, message.GetType().Name); + // This is not 100% accurate but it is the only case that can be encountered today when running end-to-end + // and this is the most useful message to show to the user. + Fail(new InvalidOperationException( + $"Non-streaming hub methods must be invoked with the '{nameof(HubConnection)}.{nameof(HubConnection.InvokeAsync)}' method.")); + return; + } + + if (!string.IsNullOrEmpty(completionMessage.Error)) + { + Fail(new HubException(completionMessage.Error)); + return; + } + Logger.InvocationCompleted(InvocationId); - _completionSource.TrySetResult(result); + _completionSource.TrySetResult(completionMessage.Result); } public override void Fail(Exception exception) @@ -148,7 +176,7 @@ namespace Microsoft.AspNetCore.SignalR.Client public override ValueTask StreamItem(object item) { Logger.StreamItemOnNonStreamInvocation(InvocationId); - _completionSource.TrySetException(new InvalidOperationException("Streaming methods must be invoked using HubConnection.Stream")); + _completionSource.TrySetException(new InvalidOperationException($"Streaming hub methods must be invoked with the '{nameof(HubConnection)}.{nameof(HubConnection.StreamAsync)}' method.")); // We "delivered" the stream item successfully as far as the caller cares return new ValueTask(true); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CompletionMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CompletionMessage.cs index 78c39c9a9f..911dc7b6f8 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CompletionMessage.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/CompletionMessage.cs @@ -11,12 +11,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol public object Result { get; } public bool HasResult { get; } - public CompletionMessage(string invocationId, string error, object result, bool hasResult) : base(invocationId) + public CompletionMessage(string invocationId, string error, object result, bool hasResult) + : base(invocationId) { if (error != null && result != null) { throw new ArgumentException($"Expected either '{nameof(error)}' or '{nameof(result)}' to be provided, but not both"); } + Error = error; Result = result; HasResult = hasResult; @@ -31,10 +33,13 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol // Static factory methods. Don't want to use constructor overloading because it will break down // if you need to send a payload statically-typed as a string. And because a static factory is clearer here - public static CompletionMessage WithError(string invocationId, string error) => new CompletionMessage(invocationId, error, result: null, hasResult: false); + public static CompletionMessage WithError(string invocationId, string error) + => new CompletionMessage(invocationId, error, result: null, hasResult: false); - public static CompletionMessage WithResult(string invocationId, object payload) => new CompletionMessage(invocationId, error: null, result: payload, hasResult: true); + public static CompletionMessage WithResult(string invocationId, object payload) + => new CompletionMessage(invocationId, error: null, result: payload, hasResult: true); - public static CompletionMessage Empty(string invocationId) => new CompletionMessage(invocationId, error: null, result: null, hasResult: false); + public static CompletionMessage Empty(string invocationId) + => new CompletionMessage(invocationId, error: null, result: null, hasResult: false); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs index e45d5ff471..4643e75f58 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Newtonsoft.Json; @@ -25,6 +26,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private const int InvocationMessageType = 1; private const int ResultMessageType = 2; private const int CompletionMessageType = 3; + private const int StreamCompletionMessageType = 4; private const int CancelInvocationMessageType = 5; // ONLY to be used for application payloads (args, return values, etc.) @@ -112,6 +114,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol return BindResultMessage(json, binder); case CompletionMessageType: return BindCompletionMessage(json, binder); + case StreamCompletionMessageType: + return BindStreamCompletionMessage(json); case CancelInvocationMessageType: return BindCancelInvocationMessage(json); default: @@ -140,6 +144,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol case CompletionMessage m: WriteCompletionMessage(m, writer); break; + case StreamCompletionMessage m: + WriteStreamCompletionMessage(m, writer); + break; case CancelInvocationMessage m: WriteCancelInvocationMessage(m, writer); break; @@ -166,6 +173,18 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol writer.WriteEndObject(); } + private void WriteStreamCompletionMessage(StreamCompletionMessage message, JsonTextWriter writer) + { + writer.WriteStartObject(); + WriteHubMessageCommon(message, writer, StreamCompletionMessageType); + if (!string.IsNullOrEmpty(message.Error)) + { + writer.WritePropertyName(ErrorPropertyName); + writer.WriteValue(message.Error); + } + writer.WriteEndObject(); + } + private void WriteCancelInvocationMessage(CancelInvocationMessage message, JsonTextWriter writer) { writer.WriteStartObject(); @@ -265,12 +284,17 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { return new CompletionMessage(invocationId, error, result: null, hasResult: false); } - else - { - var returnType = binder.GetReturnType(invocationId); - var payload = resultProp.Value?.ToObject(returnType, _payloadSerializer); - return new CompletionMessage(invocationId, error, result: payload, hasResult: true); - } + + var returnType = binder.GetReturnType(invocationId); + var payload = resultProp.Value?.ToObject(returnType, _payloadSerializer); + return new CompletionMessage(invocationId, error, result: payload, hasResult: true); + } + + private StreamCompletionMessage BindStreamCompletionMessage(JObject json) + { + var invocationId = JsonUtils.GetRequiredProperty(json, InvocationIdPropertyName, JTokenType.String); + var error = JsonUtils.GetOptionalProperty(json, ErrorPropertyName, JTokenType.String); + return new StreamCompletionMessage(invocationId, error); } private CancelInvocationMessage BindCancelInvocationMessage(JObject json) diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs index d06ef8346b..be3ed3a8f7 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using Microsoft.AspNetCore.SignalR.Internal.Formatters; using MsgPack; @@ -15,6 +16,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private const int InvocationMessageType = 1; private const int StreamItemMessageType = 2; private const int CompletionMessageType = 3; + private const int StreamCompletionMessageType = 4; private const int CancelInvocationMessageType = 5; private const int ErrorResult = 1; @@ -54,7 +56,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private static HubMessage ParseMessage(Stream input, IInvocationBinder binder) { var unpacker = Unpacker.Create(input); - _ = ReadArrayLength(unpacker, "elementCount"); + var arraySize = ReadArrayLength(unpacker, "elementCount"); var messageType = ReadInt32(unpacker, "messageType"); switch (messageType) @@ -65,6 +67,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol return CreateStreamItemMessage(unpacker, binder); case CompletionMessageType: return CreateCompletionMessage(unpacker, binder); + case StreamCompletionMessageType: + return CreateStreamCompletionMessage(unpacker, arraySize, binder); case CancelInvocationMessageType: return CreateCancelInvocationMessage(unpacker); default: @@ -132,6 +136,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol return new CompletionMessage(invocationId, error, result, hasResult); } + private static StreamCompletionMessage CreateStreamCompletionMessage(Unpacker unpacker, long arraySize, IInvocationBinder binder) + { + Debug.Assert(arraySize == 2 || arraySize == 3, "Unexpected item count"); + + var invocationId = ReadInvocationId(unpacker); + // Error is optional so StreamCompletion without error has 2 items, StreamCompletion with error has 3 items + var error = arraySize == 3 ? ReadString(unpacker, "error") : null; + return new StreamCompletionMessage(invocationId, error); + } + private static CancelInvocationMessage CreateCancelInvocationMessage(Unpacker unpacker) { var invocationId = ReadInvocationId(unpacker); @@ -163,6 +177,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol case CompletionMessage completionMessage: WriteCompletionMessage(completionMessage, packer); break; + case StreamCompletionMessage streamCompletionMessage: + WriteStreamCompletionMessage(streamCompletionMessage, packer); + break; case CancelInvocationMessage cancelInvocationMessage: WriteCancelInvocationMessage(cancelInvocationMessage, packer); break; @@ -197,6 +214,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol VoidResult; packer.PackArrayHeader(3 + (resultKind != VoidResult ? 1 : 0)); + packer.Pack(CompletionMessageType); packer.PackString(completionMessage.InvocationId); packer.Pack(resultKind); @@ -211,6 +229,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol } } + private void WriteStreamCompletionMessage(StreamCompletionMessage streamCompletionMessage, Packer packer) + { + var hasError = !string.IsNullOrEmpty(streamCompletionMessage.Error); + packer.PackArrayHeader(2 + (hasError ? 1 : 0)); + + packer.Pack(StreamCompletionMessageType); + packer.PackString(streamCompletionMessage.InvocationId); + if (hasError) + { + packer.PackString(streamCompletionMessage.Error); + } + } + private void WriteCancelInvocationMessage(CancelInvocationMessage cancelInvocationMessage, Packer packer) { packer.PackArrayHeader(2); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/StreamCompletionMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/StreamCompletionMessage.cs new file mode 100644 index 0000000000..1e8a9ded59 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/StreamCompletionMessage.cs @@ -0,0 +1,22 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public class StreamCompletionMessage : HubMessage + { + public string Error { get; } + public StreamCompletionMessage(string invocationId, string error) + : base(invocationId) + { + Error = error; + } + + public override string ToString() + { + var errorStr = Error == null ? "<>" : $"\"{Error}\""; + return $"StreamCompletion {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(Error)}: {errorStr} }}"; + } + + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs index c6e3aa1d95..e55d31dc57 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs @@ -411,18 +411,12 @@ namespace Microsoft.AspNetCore.SignalR catch (TargetInvocationException ex) { _logger.FailedInvokingHubMethod(invocationMessage.Target, ex); - if (!invocationMessage.NonBlocking) - { - await SendMessageAsync(connection, CompletionMessage.WithError(invocationMessage.InvocationId, ex.InnerException.Message)); - } + await SendInvocationError(invocationMessage, connection, methodExecutor.MethodReturnType, ex.InnerException); } catch (Exception ex) { _logger.FailedInvokingHubMethod(invocationMessage.Target, ex); - if (!invocationMessage.NonBlocking) - { - await SendMessageAsync(connection, CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message)); - } + await SendInvocationError(invocationMessage, connection, methodExecutor.MethodReturnType, ex); } finally { @@ -431,6 +425,21 @@ namespace Microsoft.AspNetCore.SignalR } } + private async Task SendInvocationError(InvocationMessage invocationMessage, HubConnectionContext connection, Type returnType, Exception ex) + { + if (!invocationMessage.NonBlocking) + { + if (IsIObservable(returnType) || IsChannel(returnType, out _)) + { + await SendMessageAsync(connection, new StreamCompletionMessage(invocationMessage.InvocationId, ex.Message)); + } + else + { + await SendMessageAsync(connection, CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message)); + } + } + } + private void InitializeHub(THub hub, HubConnectionContext connection) { hub.Clients = _hubContext.Clients; @@ -463,11 +472,11 @@ namespace Microsoft.AspNetCore.SignalR await SendMessageAsync(connection, new StreamItemMessage(invocationId, enumerator.Current)); } - await SendMessageAsync(connection, CompletionMessage.Empty(invocationId)); + await SendMessageAsync(connection, new StreamCompletionMessage(invocationId, error: null)); } catch (Exception ex) { - await SendMessageAsync(connection, CompletionMessage.WithError(invocationId, ex.Message)); + await SendMessageAsync(connection, new StreamCompletionMessage(invocationId, error: ex.Message)); } finally { diff --git a/test/Common/TestClient.cs b/test/Common/TestClient.cs index b50a1e8fe8..1a4aa11b4a 100644 --- a/test/Common/TestClient.cs +++ b/test/Common/TestClient.cs @@ -17,18 +17,20 @@ using Newtonsoft.Json; namespace Microsoft.AspNetCore.SignalR.Tests { - public class TestClient : IDisposable, IInvocationBinder + public class TestClient : IDisposable { private static int _id; private readonly HubProtocolReaderWriter _protocolReaderWriter; + private readonly IInvocationBinder _invocationBinder; private CancellationTokenSource _cts; private ChannelConnection _transport; + public DefaultConnectionContext Connection { get; } public Channel Application { get; } public Task Connected => ((TaskCompletionSource)Connection.Metadata["ConnectedTask"]).Task; - public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, bool addClaimId = false) + public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, bool addClaimId = false) { var options = new ChannelOptimizations { AllowSynchronousContinuations = synchronousCallbacks }; var transportToApplication = Channel.CreateUnbounded(options); @@ -51,6 +53,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests protocol = protocol ?? new JsonHubProtocol(); _protocolReaderWriter = new HubProtocolReaderWriter(protocol, new PassThroughEncoder()); + _invocationBinder = invocationBinder ?? new DefaultInvocationBinder(); _cts = new CancellationTokenSource(); @@ -86,6 +89,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests messages.Add(message); break; case CompletionMessage _: + case StreamCompletionMessage _: messages.Add(message); return messages; default: @@ -165,7 +169,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public HubMessage TryRead() { if (Application.In.TryRead(out var buffer) && - _protocolReaderWriter.ReadMessages(buffer, this, out var messages)) + _protocolReaderWriter.ReadMessages(buffer, _invocationBinder, out var messages)) { return messages[0]; } @@ -183,15 +187,18 @@ namespace Microsoft.AspNetCore.SignalR.Tests return Guid.NewGuid().ToString("N"); } - Type[] IInvocationBinder.GetParameterTypes(string methodName) + private class DefaultInvocationBinder : IInvocationBinder { - // TODO: Possibly support actual client methods - return new[] { typeof(object) }; - } + public Type[] GetParameterTypes(string methodName) + { + // TODO: Possibly support actual client methods + return new[] { typeof(object) }; + } - Type IInvocationBinder.GetReturnType(string invocationId) - { - return typeof(object); + public Type GetReturnType(string invocationId) + { + return typeof(object); + } } } } \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index c995f36f3c..0c45e36243 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -293,6 +293,34 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + public async Task ExceptionFromStreamingSentToClient(IHubProtocol protocol, TransportType transportType, string path) + { + using (StartLog(out var loggerFactory)) + { + var httpConnection = new HttpConnection(new Uri(_serverFixture.BaseUrl + path), transportType, loggerFactory); + var connection = new HubConnection(httpConnection, protocol, loggerFactory); + try + { + await connection.StartAsync().OrTimeout(); + var channel = await connection.StreamAsync("StreamException").OrTimeout(); + + var ex = await Assert.ThrowsAsync(() => channel.ReadAllAsync().OrTimeout()); + Assert.Equal("Error occurred while streaming.", ex.Message); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] public async Task ServerClosesConnectionIfHubMethodCannotBeResolved(IHubProtocol hubProtocol, TransportType transportType, string hubPath) diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs index 3246d9c825..0c95dd44d0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs @@ -2,22 +2,22 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Linq; using System.Reactive.Linq; using System.Threading.Tasks; +using System.Threading.Tasks.Channels; namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { public class TestHub : Hub { - public string HelloWorld() - { - return "Hello World!"; - } + public string HelloWorld() => TestHubMethodsImpl.HelloWorld(); - public string Echo(string message) - { - return message; - } + public string Echo(string message) => TestHubMethodsImpl.Echo(message); + + public IObservable Stream(int count) => TestHubMethodsImpl.Stream(count); + + public ReadableChannel StreamException() => TestHubMethodsImpl.StreamException(); public async Task CallEcho(string message) { @@ -28,83 +28,72 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { await Clients.Client(Context.ConnectionId).InvokeAsync("NoClientHandler"); } - - public IObservable Stream(int count) - { - return Observable.Interval(TimeSpan.FromMilliseconds(1)) - .Select((_, index) => index) - .Take(count); - } } public class DynamicTestHub : DynamicHub { - public string HelloWorld() - { - return "Hello World!"; - } + public string HelloWorld() => TestHubMethodsImpl.HelloWorld(); - public string Echo(string message) - { - return message; - } + public string Echo(string message) => TestHubMethodsImpl.Echo(message); + + public IObservable Stream(int count) => TestHubMethodsImpl.Stream(count); + + public ReadableChannel StreamException() => TestHubMethodsImpl.StreamException(); public async Task CallEcho(string message) { await Clients.Client(Context.ConnectionId).Echo(message); } - public IObservable Stream(int count) - { - return Observable.Interval(TimeSpan.FromMilliseconds(1)) - .Select((_, index) => index) - .Take(count); - } - - public Task SendMessage(string message) - { - return Clients.All.Send(message); - } - public async Task CallHandlerThatDoesntExist() { await Clients.Client(Context.ConnectionId).NoClientHandler(); } - } public class TestHubT : Hub { - public string HelloWorld() - { - return "Hello World!"; - } + public string HelloWorld() => TestHubMethodsImpl.HelloWorld(); - public string Echo(string message) - { - return message; - } + public string Echo(string message) => TestHubMethodsImpl.Echo(message); + + public IObservable Stream(int count) => TestHubMethodsImpl.Stream(count); + + public ReadableChannel StreamException() => TestHubMethodsImpl.StreamException(); public async Task CallEcho(string message) { await Clients.Client(Context.ConnectionId).Echo(message); } - public IObservable Stream(int count) + public async Task CallHandlerThatDoesntExist() + { + await Clients.Client(Context.ConnectionId).NoClientHandler(); + } + } + + internal static class TestHubMethodsImpl + { + public static string HelloWorld() + { + return "Hello World!"; + } + + public static string Echo(string message) + { + return message; + } + + public static IObservable Stream(int count) { return Observable.Interval(TimeSpan.FromMilliseconds(1)) .Select((_, index) => index) .Take(count); } - public Task SendMessage(string message) + public static ReadableChannel StreamException() { - return Clients.All.Send(message); - } - - public async Task CallHandlerThatDoesntExist() - { - await Clients.Client(Context.ConnectionId).NoClientHandler(); + throw new InvalidOperationException("Error occurred while streaming."); } } @@ -114,5 +103,4 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests Task Send(string message); Task NoClientHandler(); } - } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs index d51e309668..72e462dae9 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs @@ -107,7 +107,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Assert.Equal("{\"invocationId\":\"1\",\"type\":1,\"target\":\"Foo\",\"arguments\":[]}\u001e", invokeMessage); // Complete the channel - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 4 }).OrTimeout(); await channel.Completion; } finally @@ -150,7 +150,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var channel = await hubConnection.StreamAsync("Foo"); - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 4 }).OrTimeout(); Assert.Empty(await channel.ReadAllAsync()); } @@ -184,7 +184,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } [Fact] - public async Task StreamFailsIfCompletionMessageHasPayload() + public async Task StreamFailsIfCompletionMessageIsNotStreamCompletionMessage() { var connection = new TestConnection(); var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); @@ -194,10 +194,33 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var channel = await hubConnection.StreamAsync("Foo"); - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, result = "Oops" }).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); var ex = await Assert.ThrowsAsync(async () => await channel.ReadAllAsync().OrTimeout()); - Assert.Equal("Server provided a result in a completion response to a streamed invocation.", ex.Message); + Assert.Equal("Streaming hub methods must be invoked with the 'HubConnection.StreamAsync' method.", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task StreamFailsIfErrorCompletionMessageIsNotStreamCompletionMessage() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var channel = await hubConnection.StreamAsync("Foo"); + + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "error" }).OrTimeout(); + + var ex = await Assert.ThrowsAsync(async () => await channel.ReadAllAsync().OrTimeout()); + Assert.Equal("Streaming hub methods must be invoked with the 'HubConnection.StreamAsync' method.", ex.Message); } finally { @@ -240,7 +263,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var channel = await hubConnection.StreamAsync("Foo"); - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "An error occurred" }).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 4, error = "An error occurred" }).OrTimeout(); var ex = await Assert.ThrowsAsync(async () => await channel.ReadAllAsync().OrTimeout()); Assert.Equal("An error occurred", ex.Message); @@ -266,7 +289,53 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = 42 }).OrTimeout(); var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout(); - Assert.Equal("Streaming methods must be invoked using HubConnection.Stream", ex.Message); + Assert.Equal("Streaming hub methods must be invoked with the 'HubConnection.StreamAsync' method.", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task InvokeFailsWithErrorWhenStreamCompletionReceived() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var invokeTask = hubConnection.InvokeAsync("Foo"); + + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 4 }).OrTimeout(); + + var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout(); + Assert.Equal("Non-streaming hub methods must be invoked with the 'HubConnection.InvokeAsync' method.", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task InvokeFailsWithErrorWhenErrorStreamCompletionReceived() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var invokeTask = hubConnection.InvokeAsync("Foo"); + + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 4, error = "error" }).OrTimeout(); + + var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout(); + Assert.Equal("Non-streaming hub methods must be invoked with the 'HubConnection.InvokeAsync' method.", ex.Message); } finally { @@ -289,7 +358,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "1" }).OrTimeout(); await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "2" }).OrTimeout(); await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "3" }).OrTimeout(); - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); + await connection.ReceiveJsonMessage(new { invocationId = "1", type = 4 }).OrTimeout(); var notifications = await channel.ReadAllAsync().OrTimeout(); diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/CompositeTestBinder.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/CompositeTestBinder.cs index e786540c13..695e5c1786 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/CompositeTestBinder.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/CompositeTestBinder.cs @@ -1,6 +1,7 @@ -using System; -using System.Collections.Generic; -using System.Text; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs index de6d95de3f..c65ee54f52 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -35,6 +35,9 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol new object[] { new[] { new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true) } }, new object[] { new[] { new CompletionMessage("xyz", error: null, result: new[] { new CustomObject(), new CustomObject() }, hasResult: true) } }, + new object[] { new[] { new StreamCompletionMessage("xyz", error: null) } }, + new object[] { new[] { new StreamCompletionMessage("xyz", error: "Error not found!") } }, + new object[] { new[] { new StreamItemMessage("xyz", null) } }, new object[] { new[] { new StreamItemMessage("xyz", 42) } }, new object[] { new[] { new StreamItemMessage("xyz", 42.0f) } }, @@ -52,7 +55,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string", new CustomObject()), new CompletionMessage("xyz", error: null, result: 42, hasResult: true), new StreamItemMessage("xyz", null), - new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true) + new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true), + new StreamCompletionMessage("xyz", error: null), } } }; diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestHubMessageEqualityComparer.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestHubMessageEqualityComparer.cs index e4dc69e5eb..057d4f1add 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestHubMessageEqualityComparer.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestHubMessageEqualityComparer.cs @@ -4,7 +4,6 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Linq; using Microsoft.AspNetCore.SignalR.Internal.Protocol; namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol @@ -22,7 +21,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol return false; } - return InvocationMessagesEqual(x, y) || StreamItemMessagesEqual(x, y) || CompletionMessagesEqual(x, y) || CancelInvocationMessagesEqual(x, y); + return InvocationMessagesEqual(x, y) || StreamItemMessagesEqual(x, y) || CompletionMessagesEqual(x, y) + || StreamCompletionMessagesEqual(x, y) || CancelInvocationMessagesEqual(x, y); } private bool CompletionMessagesEqual(HubMessage x, HubMessage y) @@ -33,6 +33,12 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol (Equals(left.Result, right.Result) || SequenceEqual(left.Result, right.Result)); } + private bool StreamCompletionMessagesEqual(HubMessage x, HubMessage y) + { + return x is StreamCompletionMessage left && y is StreamCompletionMessage right && + string.Equals(left.Error, right.Error, StringComparison.Ordinal); + } + private bool StreamItemMessagesEqual(HubMessage x, HubMessage y) { return x is StreamItemMessage left && y is StreamItemMessage right && diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index bd8f0f7020..709a7156a8 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -190,7 +190,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests logger.LogInformation("Disposed Connection"); } } - } private bool IsBase64Encoded(TransferMode transferMode, IConnection connection) diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index e5c8e0cb6a..f120f60d1f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -10,8 +10,8 @@ using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; @@ -988,16 +988,23 @@ namespace Microsoft.AspNetCore.SignalR.Tests } [Theory] - [InlineData(nameof(StreamingHub.CounterChannel))] - [InlineData(nameof(StreamingHub.CounterObservable))] - public async Task HubsCanStreamResponses(string method) + [MemberData(nameof(StreamingMethodAndHubProtocols))] + public async Task HubsCanStreamResponses(string method, IHubProtocol protocol) { var serviceProvider = CreateServiceProvider(); var endPoint = serviceProvider.GetService>(); - using (var client = new TestClient()) + var invocationBinder = new Mock(); + invocationBinder.Setup(b => b.GetReturnType(It.IsAny())).Returns(typeof(string)); + + using (var client = new TestClient(synchronousCallbacks: false, protocol: protocol, invocationBinder: invocationBinder.Object)) { + var transportFeature = new Mock(); + transportFeature.SetupGet(f => f.TransportCapabilities) + .Returns(protocol.Type == ProtocolType.Binary ? TransferMode.Binary : TransferMode.Text); + client.Connection.Features.Set(transportFeature.Object); + var endPointLifetime = endPoint.OnConnectedAsync(client.Connection); await client.Connected.OrTimeout(); @@ -1009,7 +1016,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests AssertHubMessage(new StreamItemMessage(string.Empty, "1"), messages[1]); AssertHubMessage(new StreamItemMessage(string.Empty, "2"), messages[2]); AssertHubMessage(new StreamItemMessage(string.Empty, "3"), messages[3]); - AssertHubMessage(new CompletionMessage(string.Empty, error: null, result: null, hasResult: false), messages[4]); + AssertHubMessage(new StreamCompletionMessage(string.Empty, error: null), messages[4]); client.Dispose(); @@ -1017,6 +1024,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + public static IEnumerable StreamingMethodAndHubProtocols + { + get + { + foreach (var method in new[] { nameof(StreamingHub.CounterChannel), nameof(StreamingHub.CounterObservable) }) + { + foreach (var protocol in new IHubProtocol[] { new JsonHubProtocol(), new MessagePackHubProtocol() }) + { + yield return new object[] { method, protocol }; + } + } + } + } + [Fact] public async Task UnauthorizedConnectionCannotInvokeHubMethodWithAuthorization() { @@ -1271,6 +1292,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.Equal(expectedCompletion.HasResult, actualCompletion.HasResult); Assert.Equal(expectedCompletion.Result, actualCompletion.Result); break; + case StreamCompletionMessage expectedStreamCompletion: + var actualStreamCompletion = Assert.IsType(actual); + Assert.Equal(expectedStreamCompletion.Error, actualStreamCompletion.Error); + break; case StreamItemMessage expectedStreamItem: var actualStreamItem = Assert.IsType(actual); Assert.Equal(expectedStreamItem.Item, actualStreamItem.Item);