Call OnStarting and OnCompleted callbacks in LIFO order (#1042).

This commit is contained in:
Cesar Blum Silveira 2016-08-10 11:09:24 -07:00
parent 5181e4196c
commit 08a91f17bf
2 changed files with 103 additions and 12 deletions

View File

@ -47,9 +47,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
protected bool _requestRejected;
private Streams _frameStreams;
protected List<KeyValuePair<Func<object, Task>, object>> _onStarting;
protected List<KeyValuePair<Func<object, Task>, object>> _onCompleted;
protected Stack<KeyValuePair<Func<object, Task>, object>> _onStarting;
protected Stack<KeyValuePair<Func<object, Task>, object>> _onCompleted;
private Task _requestProcessingTask;
protected volatile bool _requestProcessingStopping; // volatile, see: https://msdn.microsoft.com/en-us/library/x13ttww7.aspx
@ -395,9 +394,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
if (_onStarting == null)
{
_onStarting = new List<KeyValuePair<Func<object, Task>, object>>();
_onStarting = new Stack<KeyValuePair<Func<object, Task>, object>>();
}
_onStarting.Add(new KeyValuePair<Func<object, Task>, object>(callback, state));
_onStarting.Push(new KeyValuePair<Func<object, Task>, object>(callback, state));
}
}
@ -407,15 +406,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
{
if (_onCompleted == null)
{
_onCompleted = new List<KeyValuePair<Func<object, Task>, object>>();
_onCompleted = new Stack<KeyValuePair<Func<object, Task>, object>>();
}
_onCompleted.Add(new KeyValuePair<Func<object, Task>, object>(callback, state));
_onCompleted.Push(new KeyValuePair<Func<object, Task>, object>(callback, state));
}
}
protected async Task FireOnStarting()
{
List<KeyValuePair<Func<object, Task>, object>> onStarting = null;
Stack<KeyValuePair<Func<object, Task>, object>> onStarting = null;
lock (_onStartingSync)
{
onStarting = _onStarting;
@ -439,7 +438,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
protected async Task FireOnCompleted()
{
List<KeyValuePair<Func<object, Task>, object>> onCompleted = null;
Stack<KeyValuePair<Func<object, Task>, object>> onCompleted = null;
lock (_onCompletedSync)
{
onCompleted = _onCompleted;

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Sockets;
@ -833,9 +834,12 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
"",
"");
Assert.Equal(2, onStartingCallCount1);
// The second OnStarting callback should not be called since the first failed.
Assert.Equal(0, onStartingCallCount2);
Assert.Equal(2, onStartingCallCount2);
// The first registered OnStarting callback should not be called,
// since they are called LIFO and the other one failed.
Assert.Equal(0, onStartingCallCount1);
Assert.Equal(2, testLogger.ApplicationErrorsLogged);
}
}
@ -1199,5 +1203,93 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
}
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task OnStartingCallbacksAreCalledInLastInFirstOutOrder(TestServiceContext testContext)
{
const string response = "hello, world";
var callOrder = new Stack<int>();
using (var server = new TestServer(async context =>
{
context.Response.OnStarting(_ =>
{
callOrder.Push(1);
return TaskUtilities.CompletedTask;
}, null);
context.Response.OnStarting(_ =>
{
callOrder.Push(2);
return TaskUtilities.CompletedTask;
}, null);
context.Response.ContentLength = response.Length;
await context.Response.WriteAsync(response);
}, testContext))
{
using (var connection = server.CreateConnection())
{
await connection.SendEnd(
"GET / HTTP/1.1",
"",
"");
await connection.ReceiveEnd(
"HTTP/1.1 200 OK",
$"Date: {testContext.DateHeaderValue}",
$"Content-Length: {response.Length}",
"",
"hello, world");
Assert.Equal(1, callOrder.Pop());
Assert.Equal(2, callOrder.Pop());
}
}
}
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task OnCompletedCallbacksAreCalledInLastInFirstOutOrder(TestServiceContext testContext)
{
const string response = "hello, world";
var callOrder = new Stack<int>();
using (var server = new TestServer(async context =>
{
context.Response.OnCompleted(_ =>
{
callOrder.Push(1);
return TaskUtilities.CompletedTask;
}, null);
context.Response.OnCompleted(_ =>
{
callOrder.Push(2);
return TaskUtilities.CompletedTask;
}, null);
context.Response.ContentLength = response.Length;
await context.Response.WriteAsync(response);
}, testContext))
{
using (var connection = server.CreateConnection())
{
await connection.SendEnd(
"GET / HTTP/1.1",
"",
"");
await connection.ReceiveEnd(
"HTTP/1.1 200 OK",
$"Date: {testContext.DateHeaderValue}",
$"Content-Length: {response.Length}",
"",
"hello, world");
Assert.Equal(1, callOrder.Pop());
Assert.Equal(2, callOrder.Pop());
}
}
}
}
}