diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs index 0266660792..62b995e855 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs @@ -8,7 +8,6 @@ using System.IO; using System.IO.Pipelines; using System.Linq; using System.Net; -using System.Net.Http.Headers; using System.Runtime.CompilerServices; using System.Text; using System.Threading; @@ -609,9 +608,21 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http private async Task ProcessRequests(IHttpApplication application) { - var cleanContext = ExecutionContext.Capture(); while (_keepAlive) { + if (_context.InitialExecutionContext is null) + { + // If this is a first request on a non-Http2Connection, capture a clean ExecutionContext. + _context.InitialExecutionContext = ExecutionContext.Capture(); + } + else + { + // Clear any AsyncLocals set during the request; back to a clean state ready for next request + // And/or reset to Http2Connection's ExecutionContext giving access to the connection logging scope + // and any other AsyncLocals set by connection middleware. + ExecutionContext.Restore(_context.InitialExecutionContext); + } + BeginRequestProcessing(); var result = default(ReadResult); @@ -737,9 +748,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { await messageBody.StopAsync(); } - - // Clear any AsyncLocals set during the request; back to a clean state ready for next request - ExecutionContext.Restore(cleanContext); } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs index b63f4a4e7c..86b13a09c4 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Connection.cs @@ -36,7 +36,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 private readonly HttpConnectionContext _context; private readonly Http2FrameWriter _frameWriter; private readonly Pipe _input; - private Task _inputTask; + private readonly Task _inputTask; private readonly int _minAllocBufferSize; private readonly HPackDecoder _hpackDecoder; private readonly InputFlowControl _inputFlowControl; @@ -85,6 +85,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 _context = context; + // Capture the ExecutionContext before dispatching HTTP/2 middleware. Will be restored by streams when processing request + _context.InitialExecutionContext = ExecutionContext.Capture(); + _frameWriter = new Http2FrameWriter( context.Transport.Output, context.ConnectionContext, @@ -647,6 +650,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 ConnectionInputFlowControl = _inputFlowControl, ConnectionOutputFlowControl = _outputFlowControl, TimeoutControl = TimeoutControl, + InitialExecutionContext = _context.InitialExecutionContext, }; } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2StreamOfT.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2StreamOfT.cs index 544013c7cd..8437722cc7 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2StreamOfT.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2StreamOfT.cs @@ -1,7 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - using Microsoft.AspNetCore.Hosting.Server; using Microsoft.AspNetCore.Hosting.Server.Abstractions; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; diff --git a/src/Servers/Kestrel/Core/src/Internal/HttpConnectionContext.cs b/src/Servers/Kestrel/Core/src/Internal/HttpConnectionContext.cs index 69c46ec79c..a1bc6f3f21 100644 --- a/src/Servers/Kestrel/Core/src/Internal/HttpConnectionContext.cs +++ b/src/Servers/Kestrel/Core/src/Internal/HttpConnectionContext.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.IO.Pipelines; using System.Net; +using System.Threading; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; @@ -22,5 +23,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal public IPEndPoint RemoteEndPoint { get; set; } public ITimeoutControl TimeoutControl { get; set; } public IDuplexPipe Transport { get; set; } + public ExecutionContext InitialExecutionContext { get; set; } } } diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2EndToEndTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2EndToEndTests.cs new file mode 100644 index 0000000000..87d4b3dd61 --- /dev/null +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2EndToEndTests.cs @@ -0,0 +1,143 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Net.Http; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; +using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport; +using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.Http2 +{ + public class Http2EndToEndTests : TestApplicationErrorLoggerLoggedTest + { + [Fact] + public async Task MiddlewareIsRunWithConnectionLoggingScopeForHttp2Requests() + { + var expectedLogMessage = "Log from connection scope!"; + string connectionIdFromFeature = null; + + var mockScopeLoggerProvider = new MockScopeLoggerProvider(expectedLogMessage); + LoggerFactory.AddProvider(mockScopeLoggerProvider); + + await using var server = new TestServer(async context => + { + connectionIdFromFeature = context.Features.Get().ConnectionId; + + var logger = context.RequestServices.GetRequiredService>(); + logger.LogInformation(expectedLogMessage); + + await context.Response.WriteAsync("hello, world"); + }, + new TestServiceContext(LoggerFactory), + listenOptions => + { + listenOptions.Protocols = HttpProtocols.Http2; + }); + + var connectionCount = 0; + using var connection = server.CreateConnection(); + + using var socketsHandler = new SocketsHttpHandler() + { + ConnectCallback = (_, _) => + { + if (connectionCount != 0) + { + throw new InvalidOperationException(); + } + + connectionCount++; + return new ValueTask(connection.Stream); + }, + }; + + using var httpClient = new HttpClient(socketsHandler); + + using var httpRequsetMessage = new HttpRequestMessage() + { + RequestUri = new Uri("http://localhost/"), + Version = new Version(2, 0), + VersionPolicy = HttpVersionPolicy.RequestVersionExact, + }; + + using var responseMessage = await httpClient.SendAsync(httpRequsetMessage); + + Assert.Equal("hello, world", await responseMessage.Content.ReadAsStringAsync()); + + Assert.NotNull(connectionIdFromFeature); + Assert.NotNull(mockScopeLoggerProvider.ConnectionLogScope); + Assert.Equal(connectionIdFromFeature, mockScopeLoggerProvider.ConnectionLogScope[0].Value); + } + + private class MockScopeLoggerProvider : ILoggerProvider, ISupportExternalScope + { + private readonly string _expectedLogMessage; + private IExternalScopeProvider _scopeProvider; + + public MockScopeLoggerProvider(string expectedLogMessage) + { + _expectedLogMessage = expectedLogMessage; + } + + public ConnectionLogScope ConnectionLogScope { get; private set; } + + public ILogger CreateLogger(string categoryName) + { + return new MockScopeLogger(this); + } + + public void SetScopeProvider(IExternalScopeProvider scopeProvider) + { + _scopeProvider = scopeProvider; + } + + public void Dispose() + { + } + + private class MockScopeLogger : ILogger + { + private readonly MockScopeLoggerProvider _loggerProvider; + + public MockScopeLogger(MockScopeLoggerProvider parent) + { + _loggerProvider = parent; + } + + public IDisposable BeginScope(TState state) + { + return _loggerProvider._scopeProvider?.Push(state); + } + + public bool IsEnabled(LogLevel logLevel) + { + return true; + } + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) + { + if (formatter(state, exception) != _loggerProvider._expectedLogMessage) + { + return; + } + + _loggerProvider._scopeProvider?.ForEachScope( + (scopeObject, loggerPovider) => + { + loggerPovider.ConnectionLogScope ??= scopeObject as ConnectionLogScope; + }, + _loggerProvider); + } + } + } + } +}