diff --git a/build/dependencies.props b/build/dependencies.props index 36b1977681..c1feca987e 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -17,6 +17,7 @@ 2.2.0-preview1-34823 2.2.0-preview1-34823 2.2.0-preview1-34823 + 2.2.0-preview1-34823 2.2.0-preview1-34823 2.2.0-preview1-34823 2.0.9 diff --git a/src/Microsoft.AspNetCore.WebSockets/Microsoft.AspNetCore.WebSockets.csproj b/src/Microsoft.AspNetCore.WebSockets/Microsoft.AspNetCore.WebSockets.csproj index c5c5c8fbfb..277c0cc9fe 100644 --- a/src/Microsoft.AspNetCore.WebSockets/Microsoft.AspNetCore.WebSockets.csproj +++ b/src/Microsoft.AspNetCore.WebSockets/Microsoft.AspNetCore.WebSockets.csproj @@ -11,6 +11,7 @@ + diff --git a/src/Microsoft.AspNetCore.WebSockets/WebSocketMiddleware.cs b/src/Microsoft.AspNetCore.WebSockets/WebSocketMiddleware.cs index eba9597cd5..74f1e4f7d8 100644 --- a/src/Microsoft.AspNetCore.WebSockets/WebSocketMiddleware.cs +++ b/src/Microsoft.AspNetCore.WebSockets/WebSocketMiddleware.cs @@ -4,13 +4,18 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Net.WebSockets; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.WebSockets.Internal; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; +using Microsoft.Net.Http.Headers; namespace Microsoft.AspNetCore.WebSockets { @@ -18,8 +23,11 @@ namespace Microsoft.AspNetCore.WebSockets { private readonly RequestDelegate _next; private readonly WebSocketOptions _options; + private readonly ILogger _logger; + private readonly bool _anyOriginAllowed; + private readonly List _allowedOrigins; - public WebSocketMiddleware(RequestDelegate next, IOptions options) + public WebSocketMiddleware(RequestDelegate next, IOptions options, ILoggerFactory loggerFactory) { if (next == null) { @@ -32,17 +40,45 @@ namespace Microsoft.AspNetCore.WebSockets _next = next; _options = options.Value; + _allowedOrigins = _options.AllowedOrigins.Select(o => o.ToLowerInvariant()).ToList(); + _anyOriginAllowed = _options.AllowedOrigins.Count == 0 || _options.AllowedOrigins.Contains("*", StringComparer.Ordinal); + + _logger = loggerFactory.CreateLogger(); // TODO: validate options. } + [Obsolete("This constructor has been replaced with an equivalent constructor which requires an ILoggerFactory.")] + public WebSocketMiddleware(RequestDelegate next, IOptions options) + : this(next, options, NullLoggerFactory.Instance) + { + } + public Task Invoke(HttpContext context) { // Detect if an opaque upgrade is available. If so, add a websocket upgrade. var upgradeFeature = context.Features.Get(); if (upgradeFeature != null && context.Features.Get() == null) { - context.Features.Set(new UpgradeHandshake(context, upgradeFeature, _options)); + var webSocketFeature = new UpgradeHandshake(context, upgradeFeature, _options); + context.Features.Set(webSocketFeature); + + if (!_anyOriginAllowed) + { + // Check for Origin header + var originHeader = context.Request.Headers[HeaderNames.Origin]; + + if (!StringValues.IsNullOrEmpty(originHeader) && webSocketFeature.IsWebSocketRequest) + { + // Check allowed origins to see if request is allowed + if (!_allowedOrigins.Contains(originHeader.ToString(), StringComparer.Ordinal)) + { + _logger.LogDebug("Request origin {Origin} is not in the list of allowed origins.", originHeader); + context.Response.StatusCode = StatusCodes.Status403Forbidden; + return Task.CompletedTask; + } + } + } } return _next(context); @@ -53,6 +89,7 @@ namespace Microsoft.AspNetCore.WebSockets private readonly HttpContext _context; private readonly IHttpUpgradeFeature _upgradeFeature; private readonly WebSocketOptions _options; + private bool? _isWebSocketRequest; public UpgradeHandshake(HttpContext context, IHttpUpgradeFeature upgradeFeature, WebSocketOptions options) { @@ -65,19 +102,26 @@ namespace Microsoft.AspNetCore.WebSockets { get { - if (!_upgradeFeature.IsUpgradableRequest) + if (_isWebSocketRequest == null) { - return false; - } - var headers = new List>(); - foreach (string headerName in HandshakeHelpers.NeededHeaders) - { - foreach (var value in _context.Request.Headers.GetCommaSeparatedValues(headerName)) + if (!_upgradeFeature.IsUpgradableRequest) { - headers.Add(new KeyValuePair(headerName, value)); + _isWebSocketRequest = false; + } + else + { + var headers = new List>(); + foreach (string headerName in HandshakeHelpers.NeededHeaders) + { + foreach (var value in _context.Request.Headers.GetCommaSeparatedValues(headerName)) + { + headers.Add(new KeyValuePair(headerName, value)); + } + } + _isWebSocketRequest = HandshakeHelpers.CheckSupportedWebSocketRequest(_context.Request.Method, headers); } } - return HandshakeHelpers.CheckSupportedWebSocketRequest(_context.Request.Method, headers); + return _isWebSocketRequest.Value; } } diff --git a/src/Microsoft.AspNetCore.WebSockets/WebSocketOptions.cs b/src/Microsoft.AspNetCore.WebSockets/WebSocketOptions.cs index 808cc86251..da5f630d62 100644 --- a/src/Microsoft.AspNetCore.WebSockets/WebSocketOptions.cs +++ b/src/Microsoft.AspNetCore.WebSockets/WebSocketOptions.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; namespace Microsoft.AspNetCore.Builder { @@ -14,6 +15,7 @@ namespace Microsoft.AspNetCore.Builder { KeepAliveInterval = TimeSpan.FromMinutes(2); ReceiveBufferSize = 4 * 1024; + AllowedOrigins = new List(); } /// @@ -27,5 +29,11 @@ namespace Microsoft.AspNetCore.Builder /// The default is 4kb. /// public int ReceiveBufferSize { get; set; } + + /// + /// Set the Origin header values allowed for WebSocket requests to prevent Cross-Site WebSocket Hijacking. + /// By default all Origins are allowed. + /// + public IList AllowedOrigins { get; } } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.WebSockets/WebSocketsDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.WebSockets/WebSocketsDependencyInjectionExtensions.cs new file mode 100644 index 0000000000..69b2da7eb0 --- /dev/null +++ b/src/Microsoft.AspNetCore.WebSockets/WebSocketsDependencyInjectionExtensions.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.WebSockets +{ + public static class WebSocketsDependencyInjectionExtensions + { + public static IServiceCollection AddWebSockets(this IServiceCollection services, Action configure) + { + return services.Configure(configure); + } + } +} diff --git a/test/Microsoft.AspNetCore.WebSockets.Test/AddWebSocketsTests.cs b/test/Microsoft.AspNetCore.WebSockets.Test/AddWebSocketsTests.cs new file mode 100644 index 0000000000..255d17f24f --- /dev/null +++ b/test/Microsoft.AspNetCore.WebSockets.Test/AddWebSocketsTests.cs @@ -0,0 +1,33 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Xunit; + +namespace Microsoft.AspNetCore.WebSockets.Test +{ + public class AddWebSocketsTests + { + [Fact] + public void AddWebSocketsConfiguresOptions() + { + var serviceCollection = new ServiceCollection(); + + serviceCollection.AddWebSockets(o => + { + o.KeepAliveInterval = TimeSpan.FromSeconds(1000); + o.AllowedOrigins.Add("someString"); + }); + + var services = serviceCollection.BuildServiceProvider(); + var socketOptions = services.GetRequiredService>().Value; + + Assert.Equal(TimeSpan.FromSeconds(1000), socketOptions.KeepAliveInterval); + Assert.Single(socketOptions.AllowedOrigins); + Assert.Equal("someString", socketOptions.AllowedOrigins[0]); + } + } +} diff --git a/test/Microsoft.AspNetCore.WebSockets.Test/KestrelWebSocketHelpers.cs b/test/Microsoft.AspNetCore.WebSockets.Test/KestrelWebSocketHelpers.cs index 77ec387c10..3e0b184f13 100644 --- a/test/Microsoft.AspNetCore.WebSockets.Test/KestrelWebSocketHelpers.cs +++ b/test/Microsoft.AspNetCore.WebSockets.Test/KestrelWebSocketHelpers.cs @@ -14,8 +14,9 @@ namespace Microsoft.AspNetCore.WebSockets.Test { public class KestrelWebSocketHelpers { - public static IDisposable CreateServer(ILoggerFactory loggerFactory, Func app) + public static IDisposable CreateServer(ILoggerFactory loggerFactory, Func app, Action configure = null) { + configure = configure ?? (o => { }); Action startup = builder => { builder.Use(async (ct, next) => @@ -48,7 +49,11 @@ namespace Microsoft.AspNetCore.WebSockets.Test config["server.urls"] = "http://localhost:54321"; var host = new WebHostBuilder() - .ConfigureServices(s => s.AddSingleton(loggerFactory)) + .ConfigureServices(s => + { + s.AddWebSockets(configure); + s.AddSingleton(loggerFactory); + }) .UseConfiguration(config) .UseKestrel() .Configure(startup) diff --git a/test/Microsoft.AspNetCore.WebSockets.Test/WebSocketMiddlewareTests.cs b/test/Microsoft.AspNetCore.WebSockets.Test/WebSocketMiddlewareTests.cs index fd2fbc8d51..dc7ca9a79b 100644 --- a/test/Microsoft.AspNetCore.WebSockets.Test/WebSocketMiddlewareTests.cs +++ b/test/Microsoft.AspNetCore.WebSockets.Test/WebSocketMiddlewareTests.cs @@ -2,11 +2,15 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Net; +using System.Net.Http; using System.Net.WebSockets; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Testing.xunit; +using Microsoft.AspNetCore.WebSockets.Internal; using Microsoft.Extensions.Logging.Testing; using Xunit; using Xunit.Abstractions; @@ -558,5 +562,85 @@ namespace Microsoft.AspNetCore.WebSockets.Test } } } + + [Theory] + [InlineData(HttpStatusCode.OK, null)] + [InlineData(HttpStatusCode.Forbidden, "")] + [InlineData(HttpStatusCode.Forbidden, "http://e.com")] + [InlineData(HttpStatusCode.OK, "http://e.com", "http://example.com")] + [InlineData(HttpStatusCode.OK, "*")] + [InlineData(HttpStatusCode.OK, "http://e.com", "*")] + [InlineData(HttpStatusCode.OK, "http://ExAmPLE.cOm")] + public async Task OriginIsValidatedForWebSocketRequests(HttpStatusCode expectedCode, params string[] origins) + { + using (StartLog(out var loggerFactory)) + { + using (var server = KestrelWebSocketHelpers.CreateServer(loggerFactory, context => + { + Assert.True(context.WebSockets.IsWebSocketRequest); + return Task.CompletedTask; + }, o => + { + if (origins != null) + { + foreach (var origin in origins) + { + o.AllowedOrigins.Add(origin); + } + } + })) + { + using (var client = new HttpClient()) + { + var uri = new UriBuilder(ClientAddress); + uri.Scheme = "http"; + + // Craft a valid WebSocket Upgrade request + using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString())) + { + request.Headers.Connection.Clear(); + request.Headers.Connection.Add("Upgrade"); + request.Headers.Upgrade.Add(new System.Net.Http.Headers.ProductHeaderValue("websocket")); + request.Headers.Add(Constants.Headers.SecWebSocketVersion, Constants.Headers.SupportedVersion); + // SecWebSocketKey required to be 16 bytes + request.Headers.Add(Constants.Headers.SecWebSocketKey, Convert.ToBase64String(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }, Base64FormattingOptions.None)); + + request.Headers.Add("Origin", "http://example.com"); + + var response = await client.SendAsync(request); + Assert.Equal(expectedCode, response.StatusCode); + } + } + } + } + } + + [Fact] + public async Task OriginIsNotValidatedForNonWebSocketRequests() + { + using (StartLog(out var loggerFactory)) + { + using (var server = KestrelWebSocketHelpers.CreateServer(loggerFactory, context => + { + Assert.False(context.WebSockets.IsWebSocketRequest); + return Task.CompletedTask; + }, o => o.AllowedOrigins.Add("http://example.com"))) + { + using (var client = new HttpClient()) + { + var uri = new UriBuilder(ClientAddress); + uri.Scheme = "http"; + + using (var request = new HttpRequestMessage(HttpMethod.Get, uri.ToString())) + { + request.Headers.Add("Origin", "http://notexample.com"); + + var response = await client.SendAsync(request); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + } + } + } + } } }