Ignore unexpected stream items and completions from client (#7511)

This commit is contained in:
BrennanConroy 2019-02-19 15:25:50 -08:00 committed by GitHub
parent 14b7184c09
commit f37d30833d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 87 additions and 81 deletions

View File

@ -348,12 +348,12 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
case HubProtocolConstants.StreamItemMessageType: case HubProtocolConstants.StreamItemMessageType:
if (itemToken != null) if (itemToken != null)
{ {
var returnType = binder.GetReturnType(invocationId);
try try
{ {
item = itemToken.ToObject(returnType, PayloadSerializer); var itemType = binder.GetStreamItemType(invocationId);
item = itemToken.ToObject(itemType, PayloadSerializer);
} }
catch (JsonSerializationException ex) catch (Exception ex)
{ {
message = new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex)); message = new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex));
break; break;

View File

@ -69,6 +69,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal
private static readonly Action<ILogger, string, string, Exception> _closingStreamWithBindingError = private static readonly Action<ILogger, string, string, Exception> _closingStreamWithBindingError =
LoggerMessage.Define<string, string>(LogLevel.Warning, new EventId(19, "ClosingStreamWithBindingError"), "Stream '{StreamId}' closed with error '{Error}'."); LoggerMessage.Define<string, string>(LogLevel.Warning, new EventId(19, "ClosingStreamWithBindingError"), "Stream '{StreamId}' closed with error '{Error}'.");
private static readonly Action<ILogger, Exception> _unexpectedStreamCompletion =
LoggerMessage.Define(LogLevel.Debug, new EventId(20, "UnexpectedStreamCompletion"), "StreamCompletionMessage received unexpectedly.");
private static readonly Action<ILogger, Exception> _unexpectedStreamItem =
LoggerMessage.Define(LogLevel.Debug, new EventId(21, "UnexpectedStreamItem"), "StreamItemMessage received unexpectedly.");
public static void ReceivedHubInvocation(ILogger logger, InvocationMessage invocationMessage) public static void ReceivedHubInvocation(ILogger logger, InvocationMessage invocationMessage)
{ {
_receivedHubInvocation(logger, invocationMessage, null); _receivedHubInvocation(logger, invocationMessage, null);
@ -168,6 +174,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{ {
_closingStreamWithBindingError(logger, message.InvocationId, message.Error, null); _closingStreamWithBindingError(logger, message.InvocationId, message.Error, null);
} }
public static void UnexpectedStreamCompletion(ILogger logger)
{
_unexpectedStreamCompletion(logger, null);
}
public static void UnexpectedStreamItem(ILogger logger)
{
_unexpectedStreamItem(logger, null);
}
} }
} }
} }

View File

@ -130,14 +130,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal
break; break;
case StreamItemMessage streamItem: case StreamItemMessage streamItem:
Log.ReceivedStreamItem(_logger, streamItem);
return ProcessStreamItem(connection, streamItem); return ProcessStreamItem(connection, streamItem);
case CompletionMessage streamCompleteMessage: case CompletionMessage streamCompleteMessage:
// closes channels, removes from Lookup dict // closes channels, removes from Lookup dict
// user's method can see the channel is complete and begin wrapping up // user's method can see the channel is complete and begin wrapping up
Log.CompletingStream(_logger, streamCompleteMessage); if (connection.StreamTracker.TryComplete(streamCompleteMessage))
connection.StreamTracker.Complete(streamCompleteMessage); {
Log.CompletingStream(_logger, streamCompleteMessage);
}
else
{
Log.UnexpectedStreamCompletion(_logger);
}
break; break;
// Other kind of message we weren't expecting // Other kind of message we weren't expecting
@ -153,7 +158,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{ {
Log.FailedInvokingHubMethod(_logger, bindingFailureMessage.Target, bindingFailureMessage.BindingFailure.SourceException); Log.FailedInvokingHubMethod(_logger, bindingFailureMessage.Target, bindingFailureMessage.BindingFailure.SourceException);
var errorMessage = ErrorMessageHelper.BuildErrorMessage($"Failed to invoke '{bindingFailureMessage.Target}' due to an error on the server.", var errorMessage = ErrorMessageHelper.BuildErrorMessage($"Failed to invoke '{bindingFailureMessage.Target}' due to an error on the server.",
bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors); bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors);
return SendInvocationError(bindingFailureMessage.InvocationId, connection, errorMessage); return SendInvocationError(bindingFailureMessage.InvocationId, connection, errorMessage);
@ -167,15 +171,25 @@ namespace Microsoft.AspNetCore.SignalR.Internal
var message = CompletionMessage.WithError(bindingFailureMessage.Id, errorString); var message = CompletionMessage.WithError(bindingFailureMessage.Id, errorString);
Log.ClosingStreamWithBindingError(_logger, message); Log.ClosingStreamWithBindingError(_logger, message);
connection.StreamTracker.Complete(message);
// ignore failure, it means the client already completed the stream or the stream never existed on the server
connection.StreamTracker.TryComplete(message);
// TODO: Send stream completion message to client when we add it
return Task.CompletedTask; return Task.CompletedTask;
} }
private Task ProcessStreamItem(HubConnectionContext connection, StreamItemMessage message) private Task ProcessStreamItem(HubConnectionContext connection, StreamItemMessage message)
{ {
if (!connection.StreamTracker.TryProcessItem(message, out var processTask))
{
Log.UnexpectedStreamItem(_logger);
return Task.CompletedTask;
}
Log.ReceivedStreamItem(_logger, message); Log.ReceivedStreamItem(_logger, message);
return connection.StreamTracker.ProcessItem(message); return processTask;
} }
private Task ProcessInvocation(HubConnectionContext connection, private Task ProcessInvocation(HubConnectionContext connection,
@ -370,12 +384,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{ {
foreach (var stream in hubMessage.StreamIds) foreach (var stream in hubMessage.StreamIds)
{ {
try connection.StreamTracker.TryComplete(CompletionMessage.Empty(stream));
{
connection.StreamTracker.Complete(CompletionMessage.Empty(stream));
}
// ignore failures, it means the client already completed the streams
catch (KeyNotFoundException) { }
} }
} }

View File

@ -28,36 +28,47 @@ namespace Microsoft.AspNetCore.SignalR
return newConverter.GetReaderAsObject(); return newConverter.GetReaderAsObject();
} }
private IStreamConverter TryGetConverter(string streamId) private bool TryGetConverter(string streamId, out IStreamConverter converter)
{ {
if (_lookup.TryGetValue(streamId, out var converter)) if (_lookup.TryGetValue(streamId, out converter))
{ {
return converter; return true;
}
else
{
throw new KeyNotFoundException($"No stream with id '{streamId}' could be found.");
} }
return false;
} }
public Task ProcessItem(StreamItemMessage message) public bool TryProcessItem(StreamItemMessage message, out Task task)
{ {
return TryGetConverter(message.InvocationId).WriteToStream(message.Item); if (TryGetConverter(message.InvocationId, out var converter))
{
task = converter.WriteToStream(message.Item);
return true;
}
task = default;
return false;
} }
public Type GetStreamItemType(string streamId) public Type GetStreamItemType(string streamId)
{ {
return TryGetConverter(streamId).GetItemType(); if (TryGetConverter(streamId, out var converter))
{
return converter.GetItemType();
}
throw new KeyNotFoundException($"No stream with id '{streamId}' could be found.");
} }
public void Complete(CompletionMessage message) public bool TryComplete(CompletionMessage message)
{ {
_lookup.TryRemove(message.InvocationId, out var converter); _lookup.TryRemove(message.InvocationId, out var converter);
if (converter == null) if (converter == null)
{ {
throw new KeyNotFoundException($"No stream with id '{message.InvocationId}' could be found."); return false;
} }
converter.TryComplete(message.HasResult || message.Error == null ? null : new Exception(message.Error)); converter.TryComplete(message.HasResult || message.Error == null ? null : new Exception(message.Error));
return true;
} }
private static IStreamConverter BuildStream<T>() private static IStreamConverter BuildStream<T>()

View File

@ -2918,13 +2918,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
[Fact] [Fact]
public async Task UploadStreamItemInvalidId() public async Task UploadStreamItemInvalidId()
{ {
bool ExpectedErrors(WriteContext writeContext) using (StartVerifiableLog())
{
return writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" &&
writeContext.EventId.Name == "ErrorProcessingRequest";
}
using (StartVerifiableLog(ExpectedErrors))
{ {
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{ {
@ -2937,24 +2931,19 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
await client.SendHubMessageAsync(new StreamItemMessage("fake_id", "not a number")).OrTimeout(); await client.SendHubMessageAsync(new StreamItemMessage("fake_id", "not a number")).OrTimeout();
// Client is breaking protocol by sending an invalid id, and should be closed.
var message = client.TryRead(); var message = client.TryRead();
Assert.IsType<CloseMessage>(message); Assert.Null(message);
Assert.Equal("Connection closed with an error. KeyNotFoundException: No stream with id 'fake_id' could be found.", ((CloseMessage)message).Error);
} }
} }
Assert.Single(TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" &&
w.EventId.Name == "ClosingStreamWithBindingError"));
} }
[Fact] [Fact]
public async Task UploadStreamCompleteInvalidId() public async Task UploadStreamCompleteInvalidId()
{ {
bool ExpectedErrors(WriteContext writeContext) using (StartVerifiableLog())
{
return writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" &&
writeContext.EventId.Name == "ErrorProcessingRequest";
}
using (StartVerifiableLog(ExpectedErrors))
{ {
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{ {
@ -2967,12 +2956,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
await client.SendHubMessageAsync(CompletionMessage.Empty("fake_id")).OrTimeout(); await client.SendHubMessageAsync(CompletionMessage.Empty("fake_id")).OrTimeout();
// Client is breaking protocol by sending an invalid id, and should be closed.
var message = client.TryRead(); var message = client.TryRead();
Assert.IsType<CloseMessage>(message); Assert.Null(message);
Assert.Equal("Connection closed with an error. KeyNotFoundException: No stream with id 'fake_id' could be found.", ((CloseMessage)message).Error);
} }
} }
Assert.Single(TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" &&
w.EventId.Name == "UnexpectedStreamCompletion"));
} }
public static string CustomErrorMessage = "custom error for testing ::::)"; public static string CustomErrorMessage = "custom error for testing ::::)";
@ -3088,20 +3078,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
[Fact] [Fact]
public async Task UploadStreamClosesStreamsOnServerWhenMethodCompletes() public async Task UploadStreamClosesStreamsOnServerWhenMethodCompletes()
{ {
bool errorLogged = false; using (StartVerifiableLog())
bool ExpectedErrors(WriteContext writeContext)
{
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" &&
writeContext.EventId.Name == "ErrorProcessingRequest")
{
errorLogged = true;
return true;
}
return false;
}
using (StartVerifiableLog(ExpectedErrors))
{ {
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(loggerFactory: LoggerFactory); var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(loggerFactory: LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>(); var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
@ -3118,9 +3095,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var simpleCompletion = Assert.IsType<CompletionMessage>(result); var simpleCompletion = Assert.IsType<CompletionMessage>(result);
Assert.Null(simpleCompletion.Result); Assert.Null(simpleCompletion.Result);
// This will log an error on the server as the hub method has completed and will complete all associated streams // This will log a warning on the server as the hub method has completed and will complete all associated streams
await client.SendHubMessageAsync(new StreamItemMessage("id", "error!")).OrTimeout(); await client.SendHubMessageAsync(new StreamItemMessage("id", "error!")).OrTimeout();
// Check that the connection hasn't been closed
await client.SendInvocationAsync("VoidMethod").OrTimeout();
// Shut down // Shut down
client.Dispose(); client.Dispose();
@ -3128,27 +3108,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests
} }
} }
// Check that the stream has been completed by noting the existance of an error Assert.Single(TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" &&
Assert.True(errorLogged); w.EventId.Name == "ClosingStreamWithBindingError"));
} }
[Fact] [Fact]
public async Task UploadStreamAndStreamingMethodClosesStreamsOnServerWhenMethodCompletes() public async Task UploadStreamAndStreamingMethodClosesStreamsOnServerWhenMethodCompletes()
{ {
bool errorLogged = false; using (StartVerifiableLog())
bool ExpectedErrors(WriteContext writeContext)
{
if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" &&
writeContext.EventId.Name == "ErrorProcessingRequest")
{
errorLogged = true;
return true;
}
return false;
}
using (StartVerifiableLog(ExpectedErrors))
{ {
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(loggerFactory: LoggerFactory); var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(loggerFactory: LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>(); var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
@ -3165,9 +3132,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var simpleCompletion = Assert.IsType<CompletionMessage>(result); var simpleCompletion = Assert.IsType<CompletionMessage>(result);
Assert.Null(simpleCompletion.Result); Assert.Null(simpleCompletion.Result);
// This will log an error on the server as the hub method has completed and will complete all associated streams // This will log a warning on the server as the hub method has completed and will complete all associated streams
await client.SendHubMessageAsync(new StreamItemMessage("id", "error!")).OrTimeout(); await client.SendHubMessageAsync(new StreamItemMessage("id", "error!")).OrTimeout();
// Check that the connection hasn't been closed
await client.SendInvocationAsync("VoidMethod").OrTimeout();
// Shut down // Shut down
client.Dispose(); client.Dispose();
@ -3175,8 +3145,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
} }
} }
// Check that the stream has been completed by noting the existance of an error Assert.Single(TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" &&
Assert.True(errorLogged); w.EventId.Name == "ClosingStreamWithBindingError"));
} }
[Theory] [Theory]