diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs index 61db155e30..fd2255c308 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs @@ -400,6 +400,11 @@ namespace Microsoft.AspNetCore.SignalR await SendMessageAsync(connection, new StreamItemMessage(invocationId, enumerator.Current)); } } + catch (ChannelClosedException ex) + { + // If the channel closes from an exception in the streaming method, grab the innerException for the error from the streaming method + error = ex.InnerException == null ? ex.Message : ex.InnerException.Message; + } catch (Exception ex) { // If the streaming method was canceled we don't want to send a HubException message - this is not an error case diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs index c7c88a05d9..4dd0378bc3 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Reactive.Linq; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; @@ -533,6 +534,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils return Channel.CreateUnbounded().Reader; } + public IObservable ThrowStream() + { + return Observable.Throw(new Exception("Exception from observable")); + } + private class CountingObservable : IObservable { private int _count; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 6d308bd92c..7114dbcc6c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -1386,6 +1386,31 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task ReceiveCorrectErrorFromStreamThrowing() + { + var serviceProvider = HubEndPointTestUtils.CreateServiceProvider(); + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var endPointLifetime = endPoint.OnConnectedAsync(client.Connection); + + await client.Connected.OrTimeout(); + + var messages = await client.StreamAsync(nameof(StreamingHub.ThrowStream)); + + Assert.Equal(1, messages.Count); + var completion = messages[0] as CompletionMessage; + Assert.NotNull(completion); + Assert.Equal("Exception from observable", completion.Error); + + client.Dispose(); + + await endPointLifetime.OrTimeout(); + } + } + public static IEnumerable StreamingMethodAndHubProtocols { get diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/Microsoft.AspNetCore.SignalR.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.Tests/Microsoft.AspNetCore.SignalR.Tests.csproj index 49e135404e..d6da59a690 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/Microsoft.AspNetCore.SignalR.Tests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Tests/Microsoft.AspNetCore.SignalR.Tests.csproj @@ -36,6 +36,7 @@ +