diff --git a/src/Microsoft.AspNetCore.HttpsPolicy/HstsMiddleware.cs b/src/Microsoft.AspNetCore.HttpsPolicy/HstsMiddleware.cs index 252ae44c1e..da5aa3af4b 100644 --- a/src/Microsoft.AspNetCore.HttpsPolicy/HstsMiddleware.cs +++ b/src/Microsoft.AspNetCore.HttpsPolicy/HstsMiddleware.cs @@ -6,6 +6,8 @@ using System.Collections.Generic; using System.Globalization; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.HttpsPolicy.Internal; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; using Microsoft.Net.Http.Headers; @@ -24,13 +26,15 @@ namespace Microsoft.AspNetCore.HttpsPolicy private readonly RequestDelegate _next; private readonly StringValues _strictTransportSecurityValue; private readonly IList _excludedHosts; + private readonly ILogger _logger; /// /// Initialize the HSTS middleware. /// /// /// - public HstsMiddleware(RequestDelegate next, IOptions options) + /// + public HstsMiddleware(RequestDelegate next, IOptions options, ILoggerFactory loggerFactory) { if (options == null) { @@ -46,6 +50,7 @@ namespace Microsoft.AspNetCore.HttpsPolicy var preload = hstsOptions.Preload ? Preload : StringSegment.Empty; _strictTransportSecurityValue = new StringValues($"max-age={maxAge}{includeSubdomains}{preload}"); _excludedHosts = hstsOptions.ExcludedHosts; + _logger = loggerFactory.CreateLogger(); } /// @@ -55,11 +60,21 @@ namespace Microsoft.AspNetCore.HttpsPolicy /// public Task Invoke(HttpContext context) { - if (context.Request.IsHttps && !IsHostExcluded(context.Request.Host.Host)) + if (!context.Request.IsHttps) { - context.Response.Headers[HeaderNames.StrictTransportSecurity] = _strictTransportSecurityValue; + _logger.SkippingInsecure(); + return _next(context); } + if (IsHostExcluded(context.Request.Host.Host)) + { + _logger.SkippingExcludedHost(context.Request.Host.Host); + return _next(context); + } + + context.Response.Headers[HeaderNames.StrictTransportSecurity] = _strictTransportSecurityValue; + _logger.AddingHstsHeader(); + return _next(context); } diff --git a/src/Microsoft.AspNetCore.HttpsPolicy/internal/HstsLoggingExtensions.cs b/src/Microsoft.AspNetCore.HttpsPolicy/internal/HstsLoggingExtensions.cs new file mode 100644 index 0000000000..5162ccb9f5 --- /dev/null +++ b/src/Microsoft.AspNetCore.HttpsPolicy/internal/HstsLoggingExtensions.cs @@ -0,0 +1,37 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.HttpsPolicy.Internal +{ + internal static class HstsLoggingExtensions + { + private static readonly Action _notSecure; + private static readonly Action _excludedHost; + private static readonly Action _addingHstsHeader; + + static HstsLoggingExtensions() + { + _notSecure = LoggerMessage.Define(LogLevel.Debug, 1, "The request is insecure. Skipping HSTS header."); + _excludedHost = LoggerMessage.Define(LogLevel.Debug, 2, "The host '{host}' is excluded. Skipping HSTS header."); + _addingHstsHeader = LoggerMessage.Define(LogLevel.Trace, 3, "Adding HSTS header to response."); + } + + public static void SkippingInsecure(this ILogger logger) + { + _notSecure(logger, null); + } + + public static void SkippingExcludedHost(this ILogger logger, string host) + { + _excludedHost(logger, host, null); + } + + public static void AddingHstsHeader(this ILogger logger) + { + _addingHstsHeader(logger, null); + } + } +} diff --git a/test/Microsoft.AspNetCore.HttpsPolicy.Tests/HstsMiddlewareTests.cs b/test/Microsoft.AspNetCore.HttpsPolicy.Tests/HstsMiddlewareTests.cs index 08df78f7c2..0cb5f5755c 100644 --- a/test/Microsoft.AspNetCore.HttpsPolicy.Tests/HstsMiddlewareTests.cs +++ b/test/Microsoft.AspNetCore.HttpsPolicy.Tests/HstsMiddlewareTests.cs @@ -12,6 +12,8 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; using Microsoft.Net.Http.Headers; using Xunit; @@ -131,7 +133,16 @@ namespace Microsoft.AspNetCore.HttpsPolicy.Tests [InlineData("[::1]")] public async Task DefaultExcludesCommonLocalhostDomains_DoesNotSetHstsHeader(string host) { + var sink = new TestSink( + TestSink.EnableWithTypeName, + TestSink.EnableWithTypeName); + var loggerFactory = new TestLoggerFactory(sink, enabled: true); + var builder = new WebHostBuilder() + .ConfigureServices(services => + { + services.AddSingleton(loggerFactory); + }) .Configure(app => { app.UseHsts(); @@ -149,6 +160,13 @@ namespace Microsoft.AspNetCore.HttpsPolicy.Tests Assert.Equal(HttpStatusCode.OK, response.StatusCode); Assert.Empty(response.Headers); + + var logMessages = sink.Writes.ToList(); + + Assert.Single(logMessages); + var message = logMessages.Single(); + Assert.Equal(LogLevel.Debug, message.LogLevel); + Assert.Equal($"The host '{host}' is excluded. Skipping HSTS header.", message.State.ToString(), ignoreCase: true); } [Theory] @@ -157,9 +175,16 @@ namespace Microsoft.AspNetCore.HttpsPolicy.Tests [InlineData("[::1]")] public async Task AllowLocalhostDomainsIfListIsReset_SetHstsHeader(string host) { + var sink = new TestSink( + TestSink.EnableWithTypeName, + TestSink.EnableWithTypeName); + var loggerFactory = new TestLoggerFactory(sink, enabled: true); + var builder = new WebHostBuilder() .ConfigureServices(services => { + services.AddSingleton(loggerFactory); + services.AddHsts(options => { options.ExcludedHosts.Clear(); @@ -182,6 +207,13 @@ namespace Microsoft.AspNetCore.HttpsPolicy.Tests Assert.Equal(HttpStatusCode.OK, response.StatusCode); Assert.Single(response.Headers); + + var logMessages = sink.Writes.ToList(); + + Assert.Single(logMessages); + var message = logMessages.Single(); + Assert.Equal(LogLevel.Trace, message.LogLevel); + Assert.Equal("Adding HSTS header to response.", message.State.ToString()); } [Theory] @@ -190,9 +222,16 @@ namespace Microsoft.AspNetCore.HttpsPolicy.Tests [InlineData("EXAMPLE.COM")] public async Task AddExcludedDomains_DoesNotAddHstsHeader(string host) { + var sink = new TestSink( + TestSink.EnableWithTypeName, + TestSink.EnableWithTypeName); + var loggerFactory = new TestLoggerFactory(sink, enabled: true); + var builder = new WebHostBuilder() .ConfigureServices(services => { + services.AddSingleton(loggerFactory); + services.AddHsts(options => { options.ExcludedHosts.Add(host); }); @@ -214,6 +253,91 @@ namespace Microsoft.AspNetCore.HttpsPolicy.Tests Assert.Equal(HttpStatusCode.OK, response.StatusCode); Assert.Empty(response.Headers); + + var logMessages = sink.Writes.ToList(); + + Assert.Single(logMessages); + var message = logMessages.Single(); + Assert.Equal(LogLevel.Debug, message.LogLevel); + Assert.Equal($"The host '{host}' is excluded. Skipping HSTS header.", message.State.ToString(), ignoreCase: true); + } + + [Fact] + public async Task WhenRequestIsInsecure_DoesNotAddHstsHeader() + { + var sink = new TestSink( + TestSink.EnableWithTypeName, + TestSink.EnableWithTypeName); + var loggerFactory = new TestLoggerFactory(sink, enabled: true); + + var builder = new WebHostBuilder() + .ConfigureServices(services => + { + services.AddSingleton(loggerFactory); + }) + .Configure(app => + { + app.UseHsts(); + app.Run(context => + { + return context.Response.WriteAsync("Hello world"); + }); + }); + var server = new TestServer(builder); + var client = server.CreateClient(); + client.BaseAddress = new Uri("http://example.com:5050"); + var request = new HttpRequestMessage(HttpMethod.Get, ""); + + var response = await client.SendAsync(request); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Empty(response.Headers); + + var logMessages = sink.Writes.ToList(); + + Assert.Single(logMessages); + var message = logMessages.Single(); + Assert.Equal(LogLevel.Debug, message.LogLevel); + Assert.Equal("The request is insecure. Skipping HSTS header.", message.State.ToString()); + } + + [Fact] + public async Task WhenRequestIsSecure_AddsHstsHeader() + { + var sink = new TestSink( + TestSink.EnableWithTypeName, + TestSink.EnableWithTypeName); + var loggerFactory = new TestLoggerFactory(sink, enabled: true); + + var builder = new WebHostBuilder() + .ConfigureServices(services => + { + services.AddSingleton(loggerFactory); + }) + .Configure(app => + { + app.UseHsts(); + app.Run(context => + { + return context.Response.WriteAsync("Hello world"); + }); + }); + var server = new TestServer(builder); + var client = server.CreateClient(); + client.BaseAddress = new Uri("https://example.com:5050"); + var request = new HttpRequestMessage(HttpMethod.Get, ""); + + var response = await client.SendAsync(request); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Contains(response.Headers, x => x.Key == HeaderNames.StrictTransportSecurity); + + var logMessages = sink.Writes.ToList(); + + Assert.Single(logMessages); + var message = logMessages.Single(); + Assert.Equal(LogLevel.Trace, message.LogLevel); + Assert.Equal("Adding HSTS header to response.", message.State.ToString()); } } }