Add Origin validation to WebSockets middleware (#252)
This commit is contained in:
parent
5e074fec4e
commit
768d2a023e
|
|
@ -17,6 +17,7 @@
|
||||||
<MicrosoftExtensionsConfigurationCommandLinePackageVersion>2.2.0-preview1-34823</MicrosoftExtensionsConfigurationCommandLinePackageVersion>
|
<MicrosoftExtensionsConfigurationCommandLinePackageVersion>2.2.0-preview1-34823</MicrosoftExtensionsConfigurationCommandLinePackageVersion>
|
||||||
<MicrosoftExtensionsLoggingConsolePackageVersion>2.2.0-preview1-34823</MicrosoftExtensionsLoggingConsolePackageVersion>
|
<MicrosoftExtensionsLoggingConsolePackageVersion>2.2.0-preview1-34823</MicrosoftExtensionsLoggingConsolePackageVersion>
|
||||||
<MicrosoftExtensionsLoggingPackageVersion>2.2.0-preview1-34823</MicrosoftExtensionsLoggingPackageVersion>
|
<MicrosoftExtensionsLoggingPackageVersion>2.2.0-preview1-34823</MicrosoftExtensionsLoggingPackageVersion>
|
||||||
|
<MicrosoftExtensionsLoggingAbstractionsPackageVersion>2.2.0-preview1-34823</MicrosoftExtensionsLoggingAbstractionsPackageVersion>
|
||||||
<MicrosoftExtensionsLoggingTestingPackageVersion>2.2.0-preview1-34823</MicrosoftExtensionsLoggingTestingPackageVersion>
|
<MicrosoftExtensionsLoggingTestingPackageVersion>2.2.0-preview1-34823</MicrosoftExtensionsLoggingTestingPackageVersion>
|
||||||
<MicrosoftExtensionsOptionsPackageVersion>2.2.0-preview1-34823</MicrosoftExtensionsOptionsPackageVersion>
|
<MicrosoftExtensionsOptionsPackageVersion>2.2.0-preview1-34823</MicrosoftExtensionsOptionsPackageVersion>
|
||||||
<MicrosoftNETCoreApp20PackageVersion>2.0.9</MicrosoftNETCoreApp20PackageVersion>
|
<MicrosoftNETCoreApp20PackageVersion>2.0.9</MicrosoftNETCoreApp20PackageVersion>
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@
|
||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<PackageReference Include="Microsoft.AspNetCore.Http.Extensions" Version="$(MicrosoftAspNetCoreHttpExtensionsPackageVersion)" />
|
<PackageReference Include="Microsoft.AspNetCore.Http.Extensions" Version="$(MicrosoftAspNetCoreHttpExtensionsPackageVersion)" />
|
||||||
|
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="$(MicrosoftExtensionsLoggingAbstractionsPackageVersion)" />
|
||||||
<PackageReference Include="Microsoft.Extensions.Options" Version="$(MicrosoftExtensionsOptionsPackageVersion)" />
|
<PackageReference Include="Microsoft.Extensions.Options" Version="$(MicrosoftExtensionsOptionsPackageVersion)" />
|
||||||
<PackageReference Include="System.Net.WebSockets.WebSocketProtocol" Version="$(SystemNetWebSocketsWebSocketProtocolPackageVersion)" />
|
<PackageReference Include="System.Net.WebSockets.WebSocketProtocol" Version="$(SystemNetWebSocketsWebSocketProtocolPackageVersion)" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,18 @@
|
||||||
using System;
|
using System;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
using System.IO;
|
using System.IO;
|
||||||
|
using System.Linq;
|
||||||
using System.Net.WebSockets;
|
using System.Net.WebSockets;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
using Microsoft.AspNetCore.Builder;
|
using Microsoft.AspNetCore.Builder;
|
||||||
using Microsoft.AspNetCore.Http;
|
using Microsoft.AspNetCore.Http;
|
||||||
using Microsoft.AspNetCore.Http.Features;
|
using Microsoft.AspNetCore.Http.Features;
|
||||||
using Microsoft.AspNetCore.WebSockets.Internal;
|
using Microsoft.AspNetCore.WebSockets.Internal;
|
||||||
|
using Microsoft.Extensions.Logging;
|
||||||
|
using Microsoft.Extensions.Logging.Abstractions;
|
||||||
using Microsoft.Extensions.Options;
|
using Microsoft.Extensions.Options;
|
||||||
|
using Microsoft.Extensions.Primitives;
|
||||||
|
using Microsoft.Net.Http.Headers;
|
||||||
|
|
||||||
namespace Microsoft.AspNetCore.WebSockets
|
namespace Microsoft.AspNetCore.WebSockets
|
||||||
{
|
{
|
||||||
|
|
@ -18,8 +23,11 @@ namespace Microsoft.AspNetCore.WebSockets
|
||||||
{
|
{
|
||||||
private readonly RequestDelegate _next;
|
private readonly RequestDelegate _next;
|
||||||
private readonly WebSocketOptions _options;
|
private readonly WebSocketOptions _options;
|
||||||
|
private readonly ILogger _logger;
|
||||||
|
private readonly bool _anyOriginAllowed;
|
||||||
|
private readonly List<string> _allowedOrigins;
|
||||||
|
|
||||||
public WebSocketMiddleware(RequestDelegate next, IOptions<WebSocketOptions> options)
|
public WebSocketMiddleware(RequestDelegate next, IOptions<WebSocketOptions> options, ILoggerFactory loggerFactory)
|
||||||
{
|
{
|
||||||
if (next == null)
|
if (next == null)
|
||||||
{
|
{
|
||||||
|
|
@ -32,17 +40,45 @@ namespace Microsoft.AspNetCore.WebSockets
|
||||||
|
|
||||||
_next = next;
|
_next = next;
|
||||||
_options = options.Value;
|
_options = options.Value;
|
||||||
|
_allowedOrigins = _options.AllowedOrigins.Select(o => o.ToLowerInvariant()).ToList();
|
||||||
|
_anyOriginAllowed = _options.AllowedOrigins.Count == 0 || _options.AllowedOrigins.Contains("*", StringComparer.Ordinal);
|
||||||
|
|
||||||
|
_logger = loggerFactory.CreateLogger<WebSocketMiddleware>();
|
||||||
|
|
||||||
// TODO: validate options.
|
// TODO: validate options.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[Obsolete("This constructor has been replaced with an equivalent constructor which requires an ILoggerFactory.")]
|
||||||
|
public WebSocketMiddleware(RequestDelegate next, IOptions<WebSocketOptions> options)
|
||||||
|
: this(next, options, NullLoggerFactory.Instance)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
public Task Invoke(HttpContext context)
|
public Task Invoke(HttpContext context)
|
||||||
{
|
{
|
||||||
// Detect if an opaque upgrade is available. If so, add a websocket upgrade.
|
// Detect if an opaque upgrade is available. If so, add a websocket upgrade.
|
||||||
var upgradeFeature = context.Features.Get<IHttpUpgradeFeature>();
|
var upgradeFeature = context.Features.Get<IHttpUpgradeFeature>();
|
||||||
if (upgradeFeature != null && context.Features.Get<IHttpWebSocketFeature>() == null)
|
if (upgradeFeature != null && context.Features.Get<IHttpWebSocketFeature>() == null)
|
||||||
{
|
{
|
||||||
context.Features.Set<IHttpWebSocketFeature>(new UpgradeHandshake(context, upgradeFeature, _options));
|
var webSocketFeature = new UpgradeHandshake(context, upgradeFeature, _options);
|
||||||
|
context.Features.Set<IHttpWebSocketFeature>(webSocketFeature);
|
||||||
|
|
||||||
|
if (!_anyOriginAllowed)
|
||||||
|
{
|
||||||
|
// Check for Origin header
|
||||||
|
var originHeader = context.Request.Headers[HeaderNames.Origin];
|
||||||
|
|
||||||
|
if (!StringValues.IsNullOrEmpty(originHeader) && webSocketFeature.IsWebSocketRequest)
|
||||||
|
{
|
||||||
|
// Check allowed origins to see if request is allowed
|
||||||
|
if (!_allowedOrigins.Contains(originHeader.ToString(), StringComparer.Ordinal))
|
||||||
|
{
|
||||||
|
_logger.LogDebug("Request origin {Origin} is not in the list of allowed origins.", originHeader);
|
||||||
|
context.Response.StatusCode = StatusCodes.Status403Forbidden;
|
||||||
|
return Task.CompletedTask;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return _next(context);
|
return _next(context);
|
||||||
|
|
@ -53,6 +89,7 @@ namespace Microsoft.AspNetCore.WebSockets
|
||||||
private readonly HttpContext _context;
|
private readonly HttpContext _context;
|
||||||
private readonly IHttpUpgradeFeature _upgradeFeature;
|
private readonly IHttpUpgradeFeature _upgradeFeature;
|
||||||
private readonly WebSocketOptions _options;
|
private readonly WebSocketOptions _options;
|
||||||
|
private bool? _isWebSocketRequest;
|
||||||
|
|
||||||
public UpgradeHandshake(HttpContext context, IHttpUpgradeFeature upgradeFeature, WebSocketOptions options)
|
public UpgradeHandshake(HttpContext context, IHttpUpgradeFeature upgradeFeature, WebSocketOptions options)
|
||||||
{
|
{
|
||||||
|
|
@ -65,19 +102,26 @@ namespace Microsoft.AspNetCore.WebSockets
|
||||||
{
|
{
|
||||||
get
|
get
|
||||||
{
|
{
|
||||||
if (!_upgradeFeature.IsUpgradableRequest)
|
if (_isWebSocketRequest == null)
|
||||||
{
|
{
|
||||||
return false;
|
if (!_upgradeFeature.IsUpgradableRequest)
|
||||||
}
|
|
||||||
var headers = new List<KeyValuePair<string, string>>();
|
|
||||||
foreach (string headerName in HandshakeHelpers.NeededHeaders)
|
|
||||||
{
|
|
||||||
foreach (var value in _context.Request.Headers.GetCommaSeparatedValues(headerName))
|
|
||||||
{
|
{
|
||||||
headers.Add(new KeyValuePair<string, string>(headerName, value));
|
_isWebSocketRequest = false;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
var headers = new List<KeyValuePair<string, string>>();
|
||||||
|
foreach (string headerName in HandshakeHelpers.NeededHeaders)
|
||||||
|
{
|
||||||
|
foreach (var value in _context.Request.Headers.GetCommaSeparatedValues(headerName))
|
||||||
|
{
|
||||||
|
headers.Add(new KeyValuePair<string, string>(headerName, value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_isWebSocketRequest = HandshakeHelpers.CheckSupportedWebSocketRequest(_context.Request.Method, headers);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return HandshakeHelpers.CheckSupportedWebSocketRequest(_context.Request.Method, headers);
|
return _isWebSocketRequest.Value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
|
||||||
namespace Microsoft.AspNetCore.Builder
|
namespace Microsoft.AspNetCore.Builder
|
||||||
{
|
{
|
||||||
|
|
@ -14,6 +15,7 @@ namespace Microsoft.AspNetCore.Builder
|
||||||
{
|
{
|
||||||
KeepAliveInterval = TimeSpan.FromMinutes(2);
|
KeepAliveInterval = TimeSpan.FromMinutes(2);
|
||||||
ReceiveBufferSize = 4 * 1024;
|
ReceiveBufferSize = 4 * 1024;
|
||||||
|
AllowedOrigins = new List<string>();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
|
@ -27,5 +29,11 @@ namespace Microsoft.AspNetCore.Builder
|
||||||
/// The default is 4kb.
|
/// The default is 4kb.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public int ReceiveBufferSize { get; set; }
|
public int ReceiveBufferSize { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Set the Origin header values allowed for WebSocket requests to prevent Cross-Site WebSocket Hijacking.
|
||||||
|
/// By default all Origins are allowed.
|
||||||
|
/// </summary>
|
||||||
|
public IList<string> AllowedOrigins { get; }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
// Copyright (c) .NET Foundation. All rights reserved.
|
||||||
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using Microsoft.AspNetCore.Builder;
|
||||||
|
using Microsoft.Extensions.DependencyInjection;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.WebSockets
|
||||||
|
{
|
||||||
|
public static class WebSocketsDependencyInjectionExtensions
|
||||||
|
{
|
||||||
|
public static IServiceCollection AddWebSockets(this IServiceCollection services, Action<WebSocketOptions> configure)
|
||||||
|
{
|
||||||
|
return services.Configure(configure);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
// Copyright (c) .NET Foundation. All rights reserved.
|
||||||
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using Microsoft.AspNetCore.Builder;
|
||||||
|
using Microsoft.Extensions.DependencyInjection;
|
||||||
|
using Microsoft.Extensions.Options;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.WebSockets.Test
|
||||||
|
{
|
||||||
|
public class AddWebSocketsTests
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public void AddWebSocketsConfiguresOptions()
|
||||||
|
{
|
||||||
|
var serviceCollection = new ServiceCollection();
|
||||||
|
|
||||||
|
serviceCollection.AddWebSockets(o =>
|
||||||
|
{
|
||||||
|
o.KeepAliveInterval = TimeSpan.FromSeconds(1000);
|
||||||
|
o.AllowedOrigins.Add("someString");
|
||||||
|
});
|
||||||
|
|
||||||
|
var services = serviceCollection.BuildServiceProvider();
|
||||||
|
var socketOptions = services.GetRequiredService<IOptions<WebSocketOptions>>().Value;
|
||||||
|
|
||||||
|
Assert.Equal(TimeSpan.FromSeconds(1000), socketOptions.KeepAliveInterval);
|
||||||
|
Assert.Single(socketOptions.AllowedOrigins);
|
||||||
|
Assert.Equal("someString", socketOptions.AllowedOrigins[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -14,8 +14,9 @@ namespace Microsoft.AspNetCore.WebSockets.Test
|
||||||
{
|
{
|
||||||
public class KestrelWebSocketHelpers
|
public class KestrelWebSocketHelpers
|
||||||
{
|
{
|
||||||
public static IDisposable CreateServer(ILoggerFactory loggerFactory, Func<HttpContext, Task> app)
|
public static IDisposable CreateServer(ILoggerFactory loggerFactory, Func<HttpContext, Task> app, Action<WebSocketOptions> configure = null)
|
||||||
{
|
{
|
||||||
|
configure = configure ?? (o => { });
|
||||||
Action<IApplicationBuilder> startup = builder =>
|
Action<IApplicationBuilder> startup = builder =>
|
||||||
{
|
{
|
||||||
builder.Use(async (ct, next) =>
|
builder.Use(async (ct, next) =>
|
||||||
|
|
@ -48,7 +49,11 @@ namespace Microsoft.AspNetCore.WebSockets.Test
|
||||||
config["server.urls"] = "http://localhost:54321";
|
config["server.urls"] = "http://localhost:54321";
|
||||||
|
|
||||||
var host = new WebHostBuilder()
|
var host = new WebHostBuilder()
|
||||||
.ConfigureServices(s => s.AddSingleton(loggerFactory))
|
.ConfigureServices(s =>
|
||||||
|
{
|
||||||
|
s.AddWebSockets(configure);
|
||||||
|
s.AddSingleton(loggerFactory);
|
||||||
|
})
|
||||||
.UseConfiguration(config)
|
.UseConfiguration(config)
|
||||||
.UseKestrel()
|
.UseKestrel()
|
||||||
.Configure(startup)
|
.Configure(startup)
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,15 @@
|
||||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
|
using System.Net;
|
||||||
|
using System.Net.Http;
|
||||||
using System.Net.WebSockets;
|
using System.Net.WebSockets;
|
||||||
using System.Text;
|
using System.Text;
|
||||||
using System.Threading;
|
using System.Threading;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
|
using Microsoft.AspNetCore.Builder;
|
||||||
using Microsoft.AspNetCore.Testing.xunit;
|
using Microsoft.AspNetCore.Testing.xunit;
|
||||||
|
using Microsoft.AspNetCore.WebSockets.Internal;
|
||||||
using Microsoft.Extensions.Logging.Testing;
|
using Microsoft.Extensions.Logging.Testing;
|
||||||
using Xunit;
|
using Xunit;
|
||||||
using Xunit.Abstractions;
|
using Xunit.Abstractions;
|
||||||
|
|
@ -558,5 +562,85 @@ namespace Microsoft.AspNetCore.WebSockets.Test
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[Theory]
|
||||||
|
[InlineData(HttpStatusCode.OK, null)]
|
||||||
|
[InlineData(HttpStatusCode.Forbidden, "")]
|
||||||
|
[InlineData(HttpStatusCode.Forbidden, "http://e.com")]
|
||||||
|
[InlineData(HttpStatusCode.OK, "http://e.com", "http://example.com")]
|
||||||
|
[InlineData(HttpStatusCode.OK, "*")]
|
||||||
|
[InlineData(HttpStatusCode.OK, "http://e.com", "*")]
|
||||||
|
[InlineData(HttpStatusCode.OK, "http://ExAmPLE.cOm")]
|
||||||
|
public async Task OriginIsValidatedForWebSocketRequests(HttpStatusCode expectedCode, params string[] origins)
|
||||||
|
{
|
||||||
|
using (StartLog(out var loggerFactory))
|
||||||
|
{
|
||||||
|
using (var server = KestrelWebSocketHelpers.CreateServer(loggerFactory, context =>
|
||||||
|
{
|
||||||
|
Assert.True(context.WebSockets.IsWebSocketRequest);
|
||||||
|
return Task.CompletedTask;
|
||||||
|
}, o =>
|
||||||
|
{
|
||||||
|
if (origins != null)
|
||||||
|
{
|
||||||
|
foreach (var origin in origins)
|
||||||
|
{
|
||||||
|
o.AllowedOrigins.Add(origin);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
{
|
||||||
|
using (var client = new HttpClient())
|
||||||
|
{
|
||||||
|
var uri = new UriBuilder(ClientAddress);
|
||||||
|
uri.Scheme = "http";
|
||||||
|
|
||||||
|
// Craft a valid WebSocket Upgrade request
|
||||||
|
using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString()))
|
||||||
|
{
|
||||||
|
request.Headers.Connection.Clear();
|
||||||
|
request.Headers.Connection.Add("Upgrade");
|
||||||
|
request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("websocket"));
|
||||||
|
request.Headers.Add(Constants.Headers.SecWebSocketVersion, Constants.Headers.SupportedVersion);
|
||||||
|
// SecWebSocketKey required to be 16 bytes
|
||||||
|
request.Headers.Add(Constants.Headers.SecWebSocketKey, Convert.ToBase64String(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }, Base64FormattingOptions.None));
|
||||||
|
|
||||||
|
request.Headers.Add("Origin", "http://example.com");
|
||||||
|
|
||||||
|
var response = await client.SendAsync(request);
|
||||||
|
Assert.Equal(expectedCode, response.StatusCode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task OriginIsNotValidatedForNonWebSocketRequests()
|
||||||
|
{
|
||||||
|
using (StartLog(out var loggerFactory))
|
||||||
|
{
|
||||||
|
using (var server = KestrelWebSocketHelpers.CreateServer(loggerFactory, context =>
|
||||||
|
{
|
||||||
|
Assert.False(context.WebSockets.IsWebSocketRequest);
|
||||||
|
return Task.CompletedTask;
|
||||||
|
}, o => o.AllowedOrigins.Add("http://example.com")))
|
||||||
|
{
|
||||||
|
using (var client = new HttpClient())
|
||||||
|
{
|
||||||
|
var uri = new UriBuilder(ClientAddress);
|
||||||
|
uri.Scheme = "http";
|
||||||
|
|
||||||
|
using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString()))
|
||||||
|
{
|
||||||
|
request.Headers.Add("Origin", "http://notexample.com");
|
||||||
|
|
||||||
|
var response = await client.SendAsync(request);
|
||||||
|
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue