Add logging to HstsMiddleware (#327)

This commit is contained in:
Nate Barbettini 2018-05-07 08:18:04 -07:00 committed by Chris Ross
parent 597ed938ea
commit 77c9bc38f9
3 changed files with 179 additions and 3 deletions

View File

@ -6,6 +6,8 @@ using System.Collections.Generic;
using System.Globalization;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.HttpsPolicy.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Primitives;
using Microsoft.Net.Http.Headers;
@ -24,13 +26,15 @@ namespace Microsoft.AspNetCore.HttpsPolicy
private readonly RequestDelegate _next;
private readonly StringValues _strictTransportSecurityValue;
private readonly IList<string> _excludedHosts;
private readonly ILogger _logger;
/// <summary>
/// Initialize the HSTS middleware.
/// </summary>
/// <param name="next"></param>
/// <param name="options"></param>
public HstsMiddleware(RequestDelegate next, IOptions<HstsOptions> options)
/// <param name="loggerFactory"></param>
public HstsMiddleware(RequestDelegate next, IOptions<HstsOptions> options, ILoggerFactory loggerFactory)
{
if (options == null)
{
@ -46,6 +50,7 @@ namespace Microsoft.AspNetCore.HttpsPolicy
var preload = hstsOptions.Preload ? Preload : StringSegment.Empty;
_strictTransportSecurityValue = new StringValues($"max-age={maxAge}{includeSubdomains}{preload}");
_excludedHosts = hstsOptions.ExcludedHosts;
_logger = loggerFactory.CreateLogger<HstsMiddleware>();
}
/// <summary>
@ -55,11 +60,21 @@ namespace Microsoft.AspNetCore.HttpsPolicy
/// <returns></returns>
public Task Invoke(HttpContext context)
{
if (context.Request.IsHttps && !IsHostExcluded(context.Request.Host.Host))
if (!context.Request.IsHttps)
{
context.Response.Headers[HeaderNames.StrictTransportSecurity] = _strictTransportSecurityValue;
_logger.SkippingInsecure();
return _next(context);
}
if (IsHostExcluded(context.Request.Host.Host))
{
_logger.SkippingExcludedHost(context.Request.Host.Host);
return _next(context);
}
context.Response.Headers[HeaderNames.StrictTransportSecurity] = _strictTransportSecurityValue;
_logger.AddingHstsHeader();
return _next(context);
}

View File

@ -0,0 +1,37 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.HttpsPolicy.Internal
{
internal static class HstsLoggingExtensions
{
private static readonly Action<ILogger, Exception> _notSecure;
private static readonly Action<ILogger, string, Exception> _excludedHost;
private static readonly Action<ILogger, Exception> _addingHstsHeader;
static HstsLoggingExtensions()
{
_notSecure = LoggerMessage.Define(LogLevel.Debug, 1, "The request is insecure. Skipping HSTS header.");
_excludedHost = LoggerMessage.Define<string>(LogLevel.Debug, 2, "The host '{host}' is excluded. Skipping HSTS header.");
_addingHstsHeader = LoggerMessage.Define(LogLevel.Trace, 3, "Adding HSTS header to response.");
}
public static void SkippingInsecure(this ILogger logger)
{
_notSecure(logger, null);
}
public static void SkippingExcludedHost(this ILogger logger, string host)
{
_excludedHost(logger, host, null);
}
public static void AddingHstsHeader(this ILogger logger)
{
_addingHstsHeader(logger, null);
}
}
}

View File

@ -12,6 +12,8 @@ using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Microsoft.Net.Http.Headers;
using Xunit;
@ -131,7 +133,16 @@ namespace Microsoft.AspNetCore.HttpsPolicy.Tests
[InlineData("[::1]")]
public async Task DefaultExcludesCommonLocalhostDomains_DoesNotSetHstsHeader(string host)
{
var sink = new TestSink(
TestSink.EnableWithTypeName<HstsMiddleware>,
TestSink.EnableWithTypeName<HstsMiddleware>);
var loggerFactory = new TestLoggerFactory(sink, enabled: true);
var builder = new WebHostBuilder()
.ConfigureServices(services =>
{
services.AddSingleton<ILoggerFactory>(loggerFactory);
})
.Configure(app =>
{
app.UseHsts();
@ -149,6 +160,13 @@ namespace Microsoft.AspNetCore.HttpsPolicy.Tests
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Empty(response.Headers);
var logMessages = sink.Writes.ToList();
Assert.Single(logMessages);
var message = logMessages.Single();
Assert.Equal(LogLevel.Debug, message.LogLevel);
Assert.Equal($"The host '{host}' is excluded. Skipping HSTS header.", message.State.ToString(), ignoreCase: true);
}
[Theory]
@ -157,9 +175,16 @@ namespace Microsoft.AspNetCore.HttpsPolicy.Tests
[InlineData("[::1]")]
public async Task AllowLocalhostDomainsIfListIsReset_SetHstsHeader(string host)
{
var sink = new TestSink(
TestSink.EnableWithTypeName<HstsMiddleware>,
TestSink.EnableWithTypeName<HstsMiddleware>);
var loggerFactory = new TestLoggerFactory(sink, enabled: true);
var builder = new WebHostBuilder()
.ConfigureServices(services =>
{
services.AddSingleton<ILoggerFactory>(loggerFactory);
services.AddHsts(options =>
{
options.ExcludedHosts.Clear();
@ -182,6 +207,13 @@ namespace Microsoft.AspNetCore.HttpsPolicy.Tests
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Single(response.Headers);
var logMessages = sink.Writes.ToList();
Assert.Single(logMessages);
var message = logMessages.Single();
Assert.Equal(LogLevel.Trace, message.LogLevel);
Assert.Equal("Adding HSTS header to response.", message.State.ToString());
}
[Theory]
@ -190,9 +222,16 @@ namespace Microsoft.AspNetCore.HttpsPolicy.Tests
[InlineData("EXAMPLE.COM")]
public async Task AddExcludedDomains_DoesNotAddHstsHeader(string host)
{
var sink = new TestSink(
TestSink.EnableWithTypeName<HstsMiddleware>,
TestSink.EnableWithTypeName<HstsMiddleware>);
var loggerFactory = new TestLoggerFactory(sink, enabled: true);
var builder = new WebHostBuilder()
.ConfigureServices(services =>
{
services.AddSingleton<ILoggerFactory>(loggerFactory);
services.AddHsts(options => {
options.ExcludedHosts.Add(host);
});
@ -214,6 +253,91 @@ namespace Microsoft.AspNetCore.HttpsPolicy.Tests
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Empty(response.Headers);
var logMessages = sink.Writes.ToList();
Assert.Single(logMessages);
var message = logMessages.Single();
Assert.Equal(LogLevel.Debug, message.LogLevel);
Assert.Equal($"The host '{host}' is excluded. Skipping HSTS header.", message.State.ToString(), ignoreCase: true);
}
[Fact]
public async Task WhenRequestIsInsecure_DoesNotAddHstsHeader()
{
var sink = new TestSink(
TestSink.EnableWithTypeName<HstsMiddleware>,
TestSink.EnableWithTypeName<HstsMiddleware>);
var loggerFactory = new TestLoggerFactory(sink, enabled: true);
var builder = new WebHostBuilder()
.ConfigureServices(services =>
{
services.AddSingleton<ILoggerFactory>(loggerFactory);
})
.Configure(app =>
{
app.UseHsts();
app.Run(context =>
{
return context.Response.WriteAsync("Hello world");
});
});
var server = new TestServer(builder);
var client = server.CreateClient();
client.BaseAddress = new Uri("http://example.com:5050");
var request = new HttpRequestMessage(HttpMethod.Get, "");
var response = await client.SendAsync(request);
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Empty(response.Headers);
var logMessages = sink.Writes.ToList();
Assert.Single(logMessages);
var message = logMessages.Single();
Assert.Equal(LogLevel.Debug, message.LogLevel);
Assert.Equal("The request is insecure. Skipping HSTS header.", message.State.ToString());
}
[Fact]
public async Task WhenRequestIsSecure_AddsHstsHeader()
{
var sink = new TestSink(
TestSink.EnableWithTypeName<HstsMiddleware>,
TestSink.EnableWithTypeName<HstsMiddleware>);
var loggerFactory = new TestLoggerFactory(sink, enabled: true);
var builder = new WebHostBuilder()
.ConfigureServices(services =>
{
services.AddSingleton<ILoggerFactory>(loggerFactory);
})
.Configure(app =>
{
app.UseHsts();
app.Run(context =>
{
return context.Response.WriteAsync("Hello world");
});
});
var server = new TestServer(builder);
var client = server.CreateClient();
client.BaseAddress = new Uri("https://example.com:5050");
var request = new HttpRequestMessage(HttpMethod.Get, "");
var response = await client.SendAsync(request);
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Contains(response.Headers, x => x.Key == HeaderNames.StrictTransportSecurity);
var logMessages = sink.Writes.ToList();
Assert.Single(logMessages);
var message = logMessages.Single();
Assert.Equal(LogLevel.Trace, message.LogLevel);
Assert.Equal("Adding HSTS header to response.", message.State.ToString());
}
}
}