diff --git a/build/dependencies.props b/build/dependencies.props
index 36b1977681..c1feca987e 100644
--- a/build/dependencies.props
+++ b/build/dependencies.props
@@ -17,6 +17,7 @@
2.2.0-preview1-34823
2.2.0-preview1-34823
2.2.0-preview1-34823
+ 2.2.0-preview1-34823
2.2.0-preview1-34823
2.2.0-preview1-34823
2.0.9
diff --git a/src/Microsoft.AspNetCore.WebSockets/Microsoft.AspNetCore.WebSockets.csproj b/src/Microsoft.AspNetCore.WebSockets/Microsoft.AspNetCore.WebSockets.csproj
index c5c5c8fbfb..277c0cc9fe 100644
--- a/src/Microsoft.AspNetCore.WebSockets/Microsoft.AspNetCore.WebSockets.csproj
+++ b/src/Microsoft.AspNetCore.WebSockets/Microsoft.AspNetCore.WebSockets.csproj
@@ -11,6 +11,7 @@
+
diff --git a/src/Microsoft.AspNetCore.WebSockets/WebSocketMiddleware.cs b/src/Microsoft.AspNetCore.WebSockets/WebSocketMiddleware.cs
index eba9597cd5..74f1e4f7d8 100644
--- a/src/Microsoft.AspNetCore.WebSockets/WebSocketMiddleware.cs
+++ b/src/Microsoft.AspNetCore.WebSockets/WebSocketMiddleware.cs
@@ -4,13 +4,18 @@
using System;
using System.Collections.Generic;
using System.IO;
+using System.Linq;
using System.Net.WebSockets;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.WebSockets.Internal;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
+using Microsoft.Extensions.Primitives;
+using Microsoft.Net.Http.Headers;
namespace Microsoft.AspNetCore.WebSockets
{
@@ -18,8 +23,11 @@ namespace Microsoft.AspNetCore.WebSockets
{
private readonly RequestDelegate _next;
private readonly WebSocketOptions _options;
+ private readonly ILogger _logger;
+ private readonly bool _anyOriginAllowed;
+ private readonly List _allowedOrigins;
- public WebSocketMiddleware(RequestDelegate next, IOptions options)
+ public WebSocketMiddleware(RequestDelegate next, IOptions options, ILoggerFactory loggerFactory)
{
if (next == null)
{
@@ -32,17 +40,45 @@ namespace Microsoft.AspNetCore.WebSockets
_next = next;
_options = options.Value;
+ _allowedOrigins = _options.AllowedOrigins.Select(o => o.ToLowerInvariant()).ToList();
+ _anyOriginAllowed = _options.AllowedOrigins.Count == 0 || _options.AllowedOrigins.Contains("*", StringComparer.Ordinal);
+
+ _logger = loggerFactory.CreateLogger();
// TODO: validate options.
}
+ [Obsolete("This constructor has been replaced with an equivalent constructor which requires an ILoggerFactory.")]
+ public WebSocketMiddleware(RequestDelegate next, IOptions options)
+ : this(next, options, NullLoggerFactory.Instance)
+ {
+ }
+
public Task Invoke(HttpContext context)
{
// Detect if an opaque upgrade is available. If so, add a websocket upgrade.
var upgradeFeature = context.Features.Get();
if (upgradeFeature != null && context.Features.Get() == null)
{
- context.Features.Set(new UpgradeHandshake(context, upgradeFeature, _options));
+ var webSocketFeature = new UpgradeHandshake(context, upgradeFeature, _options);
+ context.Features.Set(webSocketFeature);
+
+ if (!_anyOriginAllowed)
+ {
+ // Check for Origin header
+ var originHeader = context.Request.Headers[HeaderNames.Origin];
+
+ if (!StringValues.IsNullOrEmpty(originHeader) && webSocketFeature.IsWebSocketRequest)
+ {
+ // Check allowed origins to see if request is allowed
+ if (!_allowedOrigins.Contains(originHeader.ToString(), StringComparer.Ordinal))
+ {
+ _logger.LogDebug("Request origin {Origin} is not in the list of allowed origins.", originHeader);
+ context.Response.StatusCode = StatusCodes.Status403Forbidden;
+ return Task.CompletedTask;
+ }
+ }
+ }
}
return _next(context);
@@ -53,6 +89,7 @@ namespace Microsoft.AspNetCore.WebSockets
private readonly HttpContext _context;
private readonly IHttpUpgradeFeature _upgradeFeature;
private readonly WebSocketOptions _options;
+ private bool? _isWebSocketRequest;
public UpgradeHandshake(HttpContext context, IHttpUpgradeFeature upgradeFeature, WebSocketOptions options)
{
@@ -65,19 +102,26 @@ namespace Microsoft.AspNetCore.WebSockets
{
get
{
- if (!_upgradeFeature.IsUpgradableRequest)
+ if (_isWebSocketRequest == null)
{
- return false;
- }
- var headers = new List>();
- foreach (string headerName in HandshakeHelpers.NeededHeaders)
- {
- foreach (var value in _context.Request.Headers.GetCommaSeparatedValues(headerName))
+ if (!_upgradeFeature.IsUpgradableRequest)
{
- headers.Add(new KeyValuePair(headerName, value));
+ _isWebSocketRequest = false;
+ }
+ else
+ {
+ var headers = new List>();
+ foreach (string headerName in HandshakeHelpers.NeededHeaders)
+ {
+ foreach (var value in _context.Request.Headers.GetCommaSeparatedValues(headerName))
+ {
+ headers.Add(new KeyValuePair(headerName, value));
+ }
+ }
+ _isWebSocketRequest = HandshakeHelpers.CheckSupportedWebSocketRequest(_context.Request.Method, headers);
}
}
- return HandshakeHelpers.CheckSupportedWebSocketRequest(_context.Request.Method, headers);
+ return _isWebSocketRequest.Value;
}
}
diff --git a/src/Microsoft.AspNetCore.WebSockets/WebSocketOptions.cs b/src/Microsoft.AspNetCore.WebSockets/WebSocketOptions.cs
index 808cc86251..da5f630d62 100644
--- a/src/Microsoft.AspNetCore.WebSockets/WebSocketOptions.cs
+++ b/src/Microsoft.AspNetCore.WebSockets/WebSocketOptions.cs
@@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
+using System.Collections.Generic;
namespace Microsoft.AspNetCore.Builder
{
@@ -14,6 +15,7 @@ namespace Microsoft.AspNetCore.Builder
{
KeepAliveInterval = TimeSpan.FromMinutes(2);
ReceiveBufferSize = 4 * 1024;
+ AllowedOrigins = new List();
}
///
@@ -27,5 +29,11 @@ namespace Microsoft.AspNetCore.Builder
/// The default is 4kb.
///
public int ReceiveBufferSize { get; set; }
+
+ ///
+ /// Set the Origin header values allowed for WebSocket requests to prevent Cross-Site WebSocket Hijacking.
+ /// By default all Origins are allowed.
+ ///
+ public IList AllowedOrigins { get; }
}
}
\ No newline at end of file
diff --git a/src/Microsoft.AspNetCore.WebSockets/WebSocketsDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.WebSockets/WebSocketsDependencyInjectionExtensions.cs
new file mode 100644
index 0000000000..69b2da7eb0
--- /dev/null
+++ b/src/Microsoft.AspNetCore.WebSockets/WebSocketsDependencyInjectionExtensions.cs
@@ -0,0 +1,17 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using Microsoft.AspNetCore.Builder;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace Microsoft.AspNetCore.WebSockets
+{
+ public static class WebSocketsDependencyInjectionExtensions
+ {
+ public static IServiceCollection AddWebSockets(this IServiceCollection services, Action configure)
+ {
+ return services.Configure(configure);
+ }
+ }
+}
diff --git a/test/Microsoft.AspNetCore.WebSockets.Test/AddWebSocketsTests.cs b/test/Microsoft.AspNetCore.WebSockets.Test/AddWebSocketsTests.cs
new file mode 100644
index 0000000000..255d17f24f
--- /dev/null
+++ b/test/Microsoft.AspNetCore.WebSockets.Test/AddWebSocketsTests.cs
@@ -0,0 +1,33 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using Microsoft.AspNetCore.Builder;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Options;
+using Xunit;
+
+namespace Microsoft.AspNetCore.WebSockets.Test
+{
+ public class AddWebSocketsTests
+ {
+ [Fact]
+ public void AddWebSocketsConfiguresOptions()
+ {
+ var serviceCollection = new ServiceCollection();
+
+ serviceCollection.AddWebSockets(o =>
+ {
+ o.KeepAliveInterval = TimeSpan.FromSeconds(1000);
+ o.AllowedOrigins.Add("someString");
+ });
+
+ var services = serviceCollection.BuildServiceProvider();
+ var socketOptions = services.GetRequiredService>().Value;
+
+ Assert.Equal(TimeSpan.FromSeconds(1000), socketOptions.KeepAliveInterval);
+ Assert.Single(socketOptions.AllowedOrigins);
+ Assert.Equal("someString", socketOptions.AllowedOrigins[0]);
+ }
+ }
+}
diff --git a/test/Microsoft.AspNetCore.WebSockets.Test/KestrelWebSocketHelpers.cs b/test/Microsoft.AspNetCore.WebSockets.Test/KestrelWebSocketHelpers.cs
index 77ec387c10..3e0b184f13 100644
--- a/test/Microsoft.AspNetCore.WebSockets.Test/KestrelWebSocketHelpers.cs
+++ b/test/Microsoft.AspNetCore.WebSockets.Test/KestrelWebSocketHelpers.cs
@@ -14,8 +14,9 @@ namespace Microsoft.AspNetCore.WebSockets.Test
{
public class KestrelWebSocketHelpers
{
- public static IDisposable CreateServer(ILoggerFactory loggerFactory, Func app)
+ public static IDisposable CreateServer(ILoggerFactory loggerFactory, Func app, Action configure = null)
{
+ configure = configure ?? (o => { });
Action startup = builder =>
{
builder.Use(async (ct, next) =>
@@ -48,7 +49,11 @@ namespace Microsoft.AspNetCore.WebSockets.Test
config["server.urls"] = "http://localhost:54321";
var host = new WebHostBuilder()
- .ConfigureServices(s => s.AddSingleton(loggerFactory))
+ .ConfigureServices(s =>
+ {
+ s.AddWebSockets(configure);
+ s.AddSingleton(loggerFactory);
+ })
.UseConfiguration(config)
.UseKestrel()
.Configure(startup)
diff --git a/test/Microsoft.AspNetCore.WebSockets.Test/WebSocketMiddlewareTests.cs b/test/Microsoft.AspNetCore.WebSockets.Test/WebSocketMiddlewareTests.cs
index fd2fbc8d51..dc7ca9a79b 100644
--- a/test/Microsoft.AspNetCore.WebSockets.Test/WebSocketMiddlewareTests.cs
+++ b/test/Microsoft.AspNetCore.WebSockets.Test/WebSocketMiddlewareTests.cs
@@ -2,11 +2,15 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
+using System.Net;
+using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Testing.xunit;
+using Microsoft.AspNetCore.WebSockets.Internal;
using Microsoft.Extensions.Logging.Testing;
using Xunit;
using Xunit.Abstractions;
@@ -558,5 +562,85 @@ namespace Microsoft.AspNetCore.WebSockets.Test
}
}
}
+
+ [Theory]
+ [InlineData(HttpStatusCode.OK, null)]
+ [InlineData(HttpStatusCode.Forbidden, "")]
+ [InlineData(HttpStatusCode.Forbidden, "http://e.com")]
+ [InlineData(HttpStatusCode.OK, "http://e.com", "http://example.com")]
+ [InlineData(HttpStatusCode.OK, "*")]
+ [InlineData(HttpStatusCode.OK, "http://e.com", "*")]
+ [InlineData(HttpStatusCode.OK, "http://ExAmPLE.cOm")]
+ public async Task OriginIsValidatedForWebSocketRequests(HttpStatusCode expectedCode, params string[] origins)
+ {
+ using (StartLog(out var loggerFactory))
+ {
+ using (var server = KestrelWebSocketHelpers.CreateServer(loggerFactory, context =>
+ {
+ Assert.True(context.WebSockets.IsWebSocketRequest);
+ return Task.CompletedTask;
+ }, o =>
+ {
+ if (origins != null)
+ {
+ foreach (var origin in origins)
+ {
+ o.AllowedOrigins.Add(origin);
+ }
+ }
+ }))
+ {
+ using (var client = new HttpClient())
+ {
+ var uri = new UriBuilder(ClientAddress);
+ uri.Scheme = "http";
+
+ // Craft a valid WebSocket Upgrade request
+ using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString()))
+ {
+ request.Headers.Connection.Clear();
+ request.Headers.Connection.Add("Upgrade");
+ request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("websocket"));
+ request.Headers.Add(Constants.Headers.SecWebSocketVersion, Constants.Headers.SupportedVersion);
+ // SecWebSocketKey required to be 16 bytes
+ request.Headers.Add(Constants.Headers.SecWebSocketKey, Convert.ToBase64String(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }, Base64FormattingOptions.None));
+
+ request.Headers.Add("Origin", "http://example.com");
+
+ var response = await client.SendAsync(request);
+ Assert.Equal(expectedCode, response.StatusCode);
+ }
+ }
+ }
+ }
+ }
+
+ [Fact]
+ public async Task OriginIsNotValidatedForNonWebSocketRequests()
+ {
+ using (StartLog(out var loggerFactory))
+ {
+ using (var server = KestrelWebSocketHelpers.CreateServer(loggerFactory, context =>
+ {
+ Assert.False(context.WebSockets.IsWebSocketRequest);
+ return Task.CompletedTask;
+ }, o => o.AllowedOrigins.Add("http://example.com")))
+ {
+ using (var client = new HttpClient())
+ {
+ var uri = new UriBuilder(ClientAddress);
+ uri.Scheme = "http";
+
+ using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString()))
+ {
+ request.Headers.Add("Origin", "http://notexample.com");
+
+ var response = await client.SendAsync(request);
+ Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+ }
+ }
+ }
+ }
+ }
}
}