diff --git a/src/Microsoft.AspNetCore.Http.Abstractions/Routing/EndpointHttpContextExtensions.cs b/src/Microsoft.AspNetCore.Http.Abstractions/Routing/EndpointHttpContextExtensions.cs new file mode 100644 index 0000000000..bf4a5ed046 --- /dev/null +++ b/src/Microsoft.AspNetCore.Http.Abstractions/Routing/EndpointHttpContextExtensions.cs @@ -0,0 +1,70 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.Http +{ + /// + /// Extension methods to expose Endpoint on HttpContext. + /// + public static class EndpointHttpContextExtensions + { + /// + /// Extension method for getting the for the current request. + /// + /// The context. + /// The . + public static Endpoint GetEndpoint(this HttpContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + return context.Features.Get()?.Endpoint; + } + + /// + /// Extension method for setting the for the current request. + /// + /// The context. + /// The . + public static void SetEndpoint(this HttpContext context, Endpoint endpoint) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + var feature = context.Features.Get(); + + if (endpoint != null) + { + if (feature == null) + { + feature = new EndpointFeature(); + context.Features.Set(feature); + } + + feature.Endpoint = endpoint; + } + else + { + if (feature == null) + { + // No endpoint to set and no feature on context. Do nothing + return; + } + + feature.Endpoint = null; + } + } + + private class EndpointFeature : IEndpointFeature + { + public Endpoint Endpoint { get; set; } + } + } +} diff --git a/test/Microsoft.AspNetCore.Http.Abstractions.Tests/EndpointHttpContextExtensionsTests.cs b/test/Microsoft.AspNetCore.Http.Abstractions.Tests/EndpointHttpContextExtensionsTests.cs new file mode 100644 index 0000000000..c34f06f380 --- /dev/null +++ b/test/Microsoft.AspNetCore.Http.Abstractions.Tests/EndpointHttpContextExtensionsTests.cs @@ -0,0 +1,155 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; +using Xunit; + +namespace Microsoft.AspNetCore.Http.Abstractions.Tests +{ + public class EndpointHttpContextExtensionsTests + { + [Fact] + public void GetEndpoint_ContextWithoutFeature_ReturnsNull() + { + // Arrange + var context = new DefaultHttpContext(); + + // Act + var endpoint = context.GetEndpoint(); + + // Assert + Assert.Null(endpoint); + } + + [Fact] + public void GetEndpoint_ContextWithFeatureAndNullEndpoint_ReturnsNull() + { + // Arrange + var context = new DefaultHttpContext(); + context.Features.Set(new EndpointFeature + { + Endpoint = null + }); + + // Act + var endpoint = context.GetEndpoint(); + + // Assert + Assert.Null(endpoint); + } + + [Fact] + public void GetEndpoint_ContextWithFeatureAndEndpoint_ReturnsNull() + { + // Arrange + var context = new DefaultHttpContext(); + var initial = new Endpoint(c => Task.CompletedTask, EndpointMetadataCollection.Empty, "Test endpoint"); + context.Features.Set(new EndpointFeature + { + Endpoint = initial + }); + + // Act + var endpoint = context.GetEndpoint(); + + // Assert + Assert.Equal(initial, endpoint); + } + + [Fact] + public void SetEndpoint_NullOnContextWithoutFeature_NoFeatureSet() + { + // Arrange + var context = new DefaultHttpContext(); + + // Act + context.SetEndpoint(null); + + // Assert + Assert.Null(context.Features.Get()); + } + + [Fact] + public void SetEndpoint_EndpointOnContextWithoutFeature_FeatureWithEndpointSet() + { + // Arrange + var context = new DefaultHttpContext(); + + // Act + var endpoint = new Endpoint(c => Task.CompletedTask, EndpointMetadataCollection.Empty, "Test endpoint"); + context.SetEndpoint(endpoint); + + // Assert + var feature = context.Features.Get(); + Assert.NotNull(feature); + Assert.Equal(endpoint, feature.Endpoint); + } + + [Fact] + public void SetEndpoint_EndpointOnContextWithFeature_EndpointSetOnExistingFeature() + { + // Arrange + var context = new DefaultHttpContext(); + var initialEndpoint = new Endpoint(c => Task.CompletedTask, EndpointMetadataCollection.Empty, "Test endpoint"); + var initialFeature = new EndpointFeature + { + Endpoint = initialEndpoint + }; + context.Features.Set(initialFeature); + + // Act + var endpoint = new Endpoint(c => Task.CompletedTask, EndpointMetadataCollection.Empty, "Test endpoint"); + context.SetEndpoint(endpoint); + + // Assert + var feature = context.Features.Get(); + Assert.Equal(initialFeature, feature); + Assert.Equal(endpoint, feature.Endpoint); + } + + [Fact] + public void SetEndpoint_NullOnContextWithFeature_NullSetOnExistingFeature() + { + // Arrange + var context = new DefaultHttpContext(); + var initialEndpoint = new Endpoint(c => Task.CompletedTask, EndpointMetadataCollection.Empty, "Test endpoint"); + var initialFeature = new EndpointFeature + { + Endpoint = initialEndpoint + }; + context.Features.Set(initialFeature); + + // Act + context.SetEndpoint(null); + + // Assert + var feature = context.Features.Get(); + Assert.Equal(initialFeature, feature); + Assert.Null(feature.Endpoint); + } + + [Fact] + public void SetAndGetEndpoint_Roundtrip_EndpointIsRoundtrip() + { + // Arrange + var context = new DefaultHttpContext(); + var initialEndpoint = new Endpoint(c => Task.CompletedTask, EndpointMetadataCollection.Empty, "Test endpoint"); + + // Act + context.SetEndpoint(initialEndpoint); + var endpoint = context.GetEndpoint(); + + // Assert + Assert.Equal(initialEndpoint, endpoint); + } + + private class EndpointFeature : IEndpointFeature + { + public Endpoint Endpoint { get; set; } + } + } +}