diff --git a/src/Microsoft.AspNetCore.Http/HttpContextAccessor.cs b/src/Microsoft.AspNetCore.Http/HttpContextAccessor.cs index 5a4676234c..897c27f734 100644 --- a/src/Microsoft.AspNetCore.Http/HttpContextAccessor.cs +++ b/src/Microsoft.AspNetCore.Http/HttpContextAccessor.cs @@ -7,17 +7,20 @@ namespace Microsoft.AspNetCore.Http { public class HttpContextAccessor : IHttpContextAccessor { - private static AsyncLocal _httpContextCurrent = new AsyncLocal(); + private static AsyncLocal<(string traceIdentifier, HttpContext context)> _httpContextCurrent = new AsyncLocal<(string traceIdentifier, HttpContext context)>(); public HttpContext HttpContext { get { - return _httpContextCurrent.Value; + var value = _httpContextCurrent.Value; + // Only return the context if the stored request id matches the stored trace identifier + // context.TraceIdentifier is cleared by HttpContextFactory.Dispose. + return value.traceIdentifier == value.context?.TraceIdentifier ? value.context : null; } set { - _httpContextCurrent.Value = value; + _httpContextCurrent.Value = (value?.TraceIdentifier, value); } } } diff --git a/src/Microsoft.AspNetCore.Http/HttpContextFactory.cs b/src/Microsoft.AspNetCore.Http/HttpContextFactory.cs index 8236a388a5..c793ba402e 100644 --- a/src/Microsoft.AspNetCore.Http/HttpContextFactory.cs +++ b/src/Microsoft.AspNetCore.Http/HttpContextFactory.cs @@ -53,6 +53,10 @@ namespace Microsoft.AspNetCore.Http { _httpContextAccessor.HttpContext = null; } + + // Null out the TraceIdentifier here as a sign that this request is done, + // the HttpContextAcessor implementation relies on this to detect that the request is over + httpContext.TraceIdentifier = null; } } } \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Http.Tests/HttpContextAccessorTests.cs b/test/Microsoft.AspNetCore.Http.Tests/HttpContextAccessorTests.cs new file mode 100644 index 0000000000..c1521b1bc3 --- /dev/null +++ b/test/Microsoft.AspNetCore.Http.Tests/HttpContextAccessorTests.cs @@ -0,0 +1,197 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.WebSockets; +using System.Reflection; +using System.Security.Claims; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Xunit; + +namespace Microsoft.AspNetCore.Http +{ + public class HttpContextAccessorTests + { + [Fact] + public async Task HttpContextAccessor_GettingHttpContextReturnsHttpContext() + { + var accessor = new HttpContextAccessor(); + + var context = new DefaultHttpContext(); + context.TraceIdentifier = "1"; + accessor.HttpContext = context; + + await Task.Delay(100); + + Assert.Same(context, accessor.HttpContext); + } + + [Fact] + public void HttpContextAccessor_GettingHttpContextWithOutSettingReturnsNull() + { + var accessor = new HttpContextAccessor(); + + Assert.Null(accessor.HttpContext); + } + + [Fact] + public async Task HttpContextAccessor_GettingHttpContextReturnsNullHttpContextIfSetToNull() + { + var accessor = new HttpContextAccessor(); + + var context = new DefaultHttpContext(); + context.TraceIdentifier = "1"; + accessor.HttpContext = context; + + var checkAsyncFlowTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var waitForNullTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var afterNullCheckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + ThreadPool.QueueUserWorkItem(async _ => + { + // The HttpContext flows with the execution context + Assert.Same(context, accessor.HttpContext); + + checkAsyncFlowTcs.SetResult(null); + + await waitForNullTcs.Task; + + try + { + Assert.Null(accessor.HttpContext); + + afterNullCheckTcs.SetResult(null); + } + catch (Exception ex) + { + afterNullCheckTcs.SetException(ex); + } + }); + + await checkAsyncFlowTcs.Task; + + // Null out the accessor + accessor.HttpContext = null; + context.TraceIdentifier = null; + + waitForNullTcs.SetResult(null); + + Assert.Null(accessor.HttpContext); + + await afterNullCheckTcs.Task; + } + + [Fact] + public async Task HttpContextAccessor_GettingHttpContextReturnsNullHttpContextIfDifferentTraceIdentifier() + { + var accessor = new HttpContextAccessor(); + + var context = new DefaultHttpContext(); + context.TraceIdentifier = "1"; + accessor.HttpContext = context; + + var checkAsyncFlowTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var waitForNullTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var afterNullCheckTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + ThreadPool.QueueUserWorkItem(async _ => + { + // The HttpContext flows with the execution context + Assert.Same(context, accessor.HttpContext); + + checkAsyncFlowTcs.SetResult(null); + + await waitForNullTcs.Task; + + try + { + Assert.Null(accessor.HttpContext); + + afterNullCheckTcs.SetResult(null); + } + catch (Exception ex) + { + afterNullCheckTcs.SetException(ex); + } + }); + + await checkAsyncFlowTcs.Task; + + // Reset the trace identifier on the first request + context.TraceIdentifier = null; + + // Set a new http context + var context2 = new DefaultHttpContext(); + context2.TraceIdentifier = "2"; + accessor.HttpContext = context2; + + waitForNullTcs.SetResult(null); + + Assert.Same(context2, accessor.HttpContext); + + await afterNullCheckTcs.Task; + } + + [Fact] + public async Task HttpContextAccessor_GettingHttpContextDoesNotFlowIfAccessorSetToNull() + { + var accessor = new HttpContextAccessor(); + + var context = new DefaultHttpContext(); + context.TraceIdentifier = "1"; + accessor.HttpContext = context; + + var checkAsyncFlowTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + accessor.HttpContext = null; + + ThreadPool.QueueUserWorkItem(_ => + { + try + { + // The HttpContext flows with the execution context + Assert.Null(accessor.HttpContext); + checkAsyncFlowTcs.SetResult(null); + } + catch (Exception ex) + { + checkAsyncFlowTcs.SetException(ex); + } + }); + + await checkAsyncFlowTcs.Task; + } + + [Fact] + public async Task HttpContextAccessor_GettingHttpContextDoesNotFlowIfExecutionContextDoesNotFlow() + { + var accessor = new HttpContextAccessor(); + + var context = new DefaultHttpContext(); + context.TraceIdentifier = "1"; + accessor.HttpContext = context; + + var checkAsyncFlowTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + ThreadPool.UnsafeQueueUserWorkItem(_ => + { + try + { + // The HttpContext flows with the execution context + Assert.Null(accessor.HttpContext); + checkAsyncFlowTcs.SetResult(null); + } + catch (Exception ex) + { + checkAsyncFlowTcs.SetException(ex); + } + }, null); + + await checkAsyncFlowTcs.Task; + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Http.Tests/HttpContextFactoryTests.cs b/test/Microsoft.AspNetCore.Http.Tests/HttpContextFactoryTests.cs index ba983198e7..80e421273a 100644 --- a/test/Microsoft.AspNetCore.Http.Tests/HttpContextFactoryTests.cs +++ b/test/Microsoft.AspNetCore.Http.Tests/HttpContextFactoryTests.cs @@ -22,7 +22,27 @@ namespace Microsoft.AspNetCore.Http var context = contextFactory.Create(new FeatureCollection()); // Assert - Assert.True(ReferenceEquals(context, accessor.HttpContext)); + Assert.Same(context, accessor.HttpContext); + } + + [Fact] + public void DisposeHttpContextSetsHttpContextAccessorToNull() + { + // Arrange + var accessor = new HttpContextAccessor(); + var contextFactory = new HttpContextFactory(Options.Create(new FormOptions()), accessor); + + // Act + var context = contextFactory.Create(new FeatureCollection()); + var traceIdentifier = context.TraceIdentifier; + + // Assert + Assert.Same(context, accessor.HttpContext); + + contextFactory.Dispose(context); + + Assert.Null(accessor.HttpContext); + Assert.NotEqual(traceIdentifier, context.TraceIdentifier); } [Fact]