Fix null reference exception for Streaming null object (#14004)

This commit is contained in:
Kahbazi 2019-09-27 01:39:42 +03:30 committed by Brennan
parent 2359634909
commit b44e9c6a24
4 changed files with 92 additions and 3 deletions

View File

@ -265,7 +265,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++)
{
if (hubMethodInvocationMessage.Arguments.Length > hubInvocationArgumentPointer &&
descriptor.OriginalParameterTypes[parameterPointer].IsAssignableFrom(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType()))
(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer] == null ||
descriptor.OriginalParameterTypes[parameterPointer].IsAssignableFrom(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType())))
{
// The types match so it isn't a synthetic argument, just copy it into the arguments array
arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer];

View File

@ -933,6 +933,36 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return channel.Reader;
}
public ChannelReader<int> CancelableStreamNullableParameter(int x, string y, CancellationToken token)
{
var channel = Channel.CreateBounded<int>(10);
Task.Run(async () =>
{
_tcsService.StartedMethod.SetResult(x);
await token.WaitForCancellationAsync();
channel.Writer.TryComplete();
_tcsService.EndMethod.SetResult(y);
});
return channel.Reader;
}
public ChannelReader<int> StreamNullableParameter(int x, int? input)
{
var channel = Channel.CreateBounded<int>(10);
Task.Run(() =>
{
_tcsService.StartedMethod.SetResult(x);
channel.Writer.TryComplete();
_tcsService.EndMethod.SetResult(input);
return Task.CompletedTask;
});
return channel.Reader;
}
public ChannelReader<int> CancelableStreamMiddleParameter(int ignore, CancellationToken token, int ignore2)
{
var channel = Channel.CreateBounded<int>(10);

View File

@ -3591,6 +3591,64 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task StreamHubMethodCanAcceptNullableParameter()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(tcsService);
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.StreamNullableParameter), 5, null).OrTimeout();
// Wait for the stream method to start
var firstArgument = await tcsService.StartedMethod.Task.OrTimeout();
Assert.Equal(5, firstArgument);
var secondArgument = await tcsService.EndMethod.Task.OrTimeout();
Assert.Null(secondArgument);
}
}
}
[Fact]
public async Task StreamHubMethodCanAcceptNullableParameterWithCancellationToken()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(tcsService);
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStreamNullableParameter), 5, null).OrTimeout();
// Wait for the stream method to start
var firstArgument = await tcsService.StartedMethod.Task.OrTimeout();
Assert.Equal(5, firstArgument);
// Cancel the stream which should trigger the CancellationToken in the hub method
await client.SendHubMessageAsync(new CancelInvocationMessage(streamInvocationId)).OrTimeout();
var secondArgument = await tcsService.EndMethod.Task.OrTimeout();
Assert.Null(secondArgument);
}
}
}
[Fact]
public async Task InvokeHubMethodCannotAcceptCancellationTokenAsArgument()
{

View File

@ -92,7 +92,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests.Internal
send1 =>
{
Assert.Equal("Method", send1.Method);
Assert.Equal(1, send1.Arguments.Length);
Assert.Single(send1.Arguments);
Assert.Collection(send1.Arguments,
arg1 => Assert.Equal("foo", arg1));
Assert.Equal(cts1.Token, send1.CancellationToken);
@ -101,7 +101,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests.Internal
send2 =>
{
Assert.Equal("NoArgumentMethod", send2.Method);
Assert.Equal(0, send2.Arguments.Length);
Assert.Empty(send2.Arguments);
Assert.Equal(cts2.Token, send2.CancellationToken);
send2.Complete();
});