Better circuit ids (#10688)

* Better circuit ids.
* Generates cryptographically strong circuit ids.
  * 32 bits of entropy from a PRNG.
  * DataProtected for authenticated encryption.
This commit is contained in:
Javier Calvarro Nelson 2019-06-01 00:26:46 +02:00 committed by GitHub
parent bd4b843678
commit 2d2806b083
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 236 additions and 33 deletions

View File

@ -6,6 +6,7 @@
<ItemGroup Condition="'$(TargetFramework)' == 'netcoreapp3.0'">
<Compile Include="Microsoft.AspNetCore.Components.Server.netcoreapp3.0.cs" />
<Reference Include="Microsoft.AspNetCore.Components.Browser" />
<Reference Include="Microsoft.AspNetCore.DataProtection" />
<Reference Include="Microsoft.Extensions.Logging" />
<Reference Include="Microsoft.AspNetCore.SignalR" />
<Reference Include="Microsoft.AspNetCore.StaticFiles" />

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNetCore.DataProtection;
namespace Microsoft.AspNetCore.Components.Server
{

View File

@ -50,6 +50,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public event UnhandledExceptionEventHandler UnhandledException;
public CircuitHost(
string circuitId,
IServiceScope scope,
CircuitClientProxy client,
RendererRegistry rendererRegistry,
@ -60,6 +61,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
CircuitHandler[] circuitHandlers,
ILogger logger)
{
CircuitId = circuitId;
_scope = scope ?? throw new ArgumentNullException(nameof(scope));
Dispatcher = dispatcher;
Client = client;
@ -78,7 +80,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
Renderer.UnhandledSynchronizationException += SynchronizationContext_UnhandledException;
}
public string CircuitId { get; } = Guid.NewGuid().ToString();
public string CircuitId { get; }
public Circuit Circuit { get; }

View File

@ -0,0 +1,57 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Security.Cryptography;
using Microsoft.AspNetCore.DataProtection;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.Components.Server.Circuits
{
// This is a singleton instance
// Generates strong cryptographic ids for circuits that are protected with authenticated encryption.
internal class CircuitIdFactory
{
private const string CircuitIdProtectorPurpose = "Microsoft.AspNetCore.Components.Server";
private readonly RandomNumberGenerator _generator = RandomNumberGenerator.Create();
private readonly IDataProtector _protector;
public CircuitIdFactory(IDataProtectionProvider provider)
{
_protector = provider.CreateProtector(CircuitIdProtectorPurpose);
}
// Generates a circuit id that is produced from a strong cryptographic random number generator
// we don't care about the underlying payload, other than its uniqueness and the fact that we
// authenticate encrypt it using data protection.
// For validation, the fact that we can unprotect the payload is guarantee enough.
public string CreateCircuitId()
{
var buffer = new byte[32];
_generator.GetBytes(buffer);
var payload = _protector.Protect(buffer);
return Base64UrlTextEncoder.Encode(payload);
}
public bool ValidateCircuitId(string circuitId)
{
try
{
var protectedBytes = Base64UrlTextEncoder.Decode(circuitId);
_protector.Unprotect(protectedBytes);
// Its enough that we prove that we can unprotect the payload to validate the circuit id,
// as this demonstrates that it the id wasn't tampered with.
return true;
}
catch (Exception)
{
// The payload format is not correct (either not base64urlencoded or not data protected)
return false;
}
}
}
}

View File

@ -6,7 +6,6 @@ using System.Collections.Concurrent;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Components.Rendering;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Logging;
@ -41,15 +40,17 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
private readonly object CircuitRegistryLock = new object();
private readonly CircuitOptions _options;
private readonly ILogger _logger;
private readonly CircuitIdFactory _circuitIdFactory;
private readonly PostEvictionCallbackRegistration _postEvictionCallback;
public CircuitRegistry(
IOptions<CircuitOptions> options,
ILogger<CircuitRegistry> logger)
ILogger<CircuitRegistry> logger,
CircuitIdFactory circuitIdFactory)
{
_options = options.Value;
_logger = logger;
_circuitIdFactory = circuitIdFactory;
ConnectedCircuits = new ConcurrentDictionary<string, CircuitHost>(StringComparer.Ordinal);
DisconnectedCircuits = new MemoryCache(new MemoryCacheOptions
@ -139,6 +140,11 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public virtual async Task<CircuitHost> ConnectAsync(string circuitId, IClientProxy clientProxy, string connectionId, CancellationToken cancellationToken)
{
if (!_circuitIdFactory.ValidateCircuitId(circuitId))
{
return null;
}
CircuitHost circuitHost;
bool previouslyConnected;

View File

@ -21,13 +21,16 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
{
private readonly IServiceScopeFactory _scopeFactory;
private readonly ILoggerFactory _loggerFactory;
private readonly CircuitIdFactory _circuitIdFactory;
public DefaultCircuitFactory(
IServiceScopeFactory scopeFactory,
ILoggerFactory loggerFactory)
ILoggerFactory loggerFactory,
CircuitIdFactory circuitIdFactory)
{
_scopeFactory = scopeFactory ?? throw new ArgumentNullException(nameof(scopeFactory));
_loggerFactory = loggerFactory;
_circuitIdFactory = circuitIdFactory ?? throw new ArgumentNullException(nameof(circuitIdFactory));
}
public override CircuitHost CreateCircuitHost(
@ -81,6 +84,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
.ToArray();
var circuitHost = new CircuitHost(
_circuitIdFactory.CreateCircuitId(),
scope,
client,
rendererRegistry,

View File

@ -28,6 +28,8 @@ namespace Microsoft.Extensions.DependencyInjection
{
var builder = new DefaultServerSideBlazorBuilder(services);
services.AddDataProtection();
// This call INTENTIONALLY uses the AddHubOptions on the SignalR builder, because it will merge
// the global HubOptions before running the configure callback. We want to ensure that happens
// once. Our AddHubOptions method doesn't do this.
@ -51,6 +53,9 @@ namespace Microsoft.Extensions.DependencyInjection
// Components entrypoints, this lot is the same and repeated registrations are a no-op.
services.TryAddEnumerable(ServiceDescriptor.Singleton<IPostConfigureOptions<StaticFileOptions>, ConfigureStaticFilesOptions>());
services.TryAddSingleton<CircuitFactory, DefaultCircuitFactory>();
services.TryAddSingleton<CircuitIdFactory>();
services.TryAddScoped(s => s.GetRequiredService<ICircuitAccessor>().Circuit);
services.TryAddScoped<ICircuitAccessor, DefaultCircuitAccessor>();

View File

@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netcoreapp3.0</TargetFramework>
@ -13,6 +13,7 @@
<ItemGroup>
<Reference Include="Microsoft.AspNetCore.Components.Browser" />
<Reference Include="Microsoft.AspNetCore.DataProtection" />
<Reference Include="Microsoft.Extensions.Logging" />
<Reference Include="Microsoft.AspNetCore.SignalR" />
<Reference Include="Microsoft.AspNetCore.StaticFiles" />
@ -52,10 +53,7 @@
<ItemGroup>
<EmbeddedResource Include="..\..\Browser.JS\dist\$(Configuration)\blazor.server.js" LogicalName="_framework/%(Filename)%(Extension)" />
<EmbeddedResource
Include="..\..\Browser.JS\dist\$(Configuration)\blazor.server.js.map"
LogicalName="_framework/%(Filename)%(Extension)"
Condition="Exists('..\..\Browser.JS\dist\$(Configuration)\blazor.server.js.map')" />
<EmbeddedResource Include="..\..\Browser.JS\dist\$(Configuration)\blazor.server.js.map" LogicalName="_framework/%(Filename)%(Extension)" Condition="Exists('..\..\Browser.JS\dist\$(Configuration)\blazor.server.js.map')" />
</ItemGroup>
</Project>

View File

@ -28,6 +28,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
var serviceScope = new Mock<IServiceScope>();
var remoteRenderer = GetRemoteRenderer(Renderer.CreateDefaultDispatcher());
var circuitHost = TestCircuitHost.Create(
Guid.NewGuid().ToString(),
serviceScope.Object,
remoteRenderer);
@ -46,6 +47,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
var serviceScope = new Mock<IServiceScope>();
var remoteRenderer = GetRemoteRenderer(Renderer.CreateDefaultDispatcher());
var circuitHost = TestCircuitHost.Create(
Guid.NewGuid().ToString(),
serviceScope.Object,
remoteRenderer);

View File

@ -0,0 +1,89 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Linq;
using Microsoft.AspNetCore.DataProtection;
using Microsoft.AspNetCore.WebUtilities;
using Xunit;
namespace Microsoft.AspNetCore.Components.Server.Circuits
{
public class CircuitIdFactoryTest
{
[Fact]
public void CreateCircuitId_Generates_NewRandomId()
{
var factory = TestCircuitIdFactory.CreateTestFactory();
// Act
var id = factory.CreateCircuitId();
// Assert
Assert.NotNull(id);
// This is the magic data protection header that validates its protected
Assert.StartsWith("CfDJ", id);
}
[Fact]
public void CreateCircuitId_Generates_GeneratesDifferentIds_ForSuccesiveCalls()
{
// Arrange
var factory = TestCircuitIdFactory.CreateTestFactory();
// Act
var ids = Enumerable.Range(0, 100).Select(i => factory.CreateCircuitId()).ToArray();
// Assert
Assert.All(ids, id => Assert.NotNull(id));
Assert.Equal(100, ids.Distinct(StringComparer.Ordinal).Count());
}
[Fact]
public void CircuitIds_Roundtrip()
{
// Arrange
var factory = TestCircuitIdFactory.CreateTestFactory();
var id = factory.CreateCircuitId();
// Act
var isValid = factory.ValidateCircuitId(id);
// Assert
Assert.True(isValid, "Failed to validate id");
}
[Fact]
public void ValidateCircuitId_ReturnsFalseForMalformedPayloads()
{
// Arrange
var factory = TestCircuitIdFactory.CreateTestFactory();
// Act
var isValid = factory.ValidateCircuitId("$%@&==");
// Assert
Assert.False(isValid, "Accepted an invalid payload");
}
[Fact]
public void ValidateCircuitId_ReturnsFalseForPotentiallyTamperedPayloads()
{
// Arrange
var factory = TestCircuitIdFactory.CreateTestFactory();
var id = factory.CreateCircuitId();
var protectedBytes = Base64UrlTextEncoder.Decode(id);
for (int i = protectedBytes.Length - 10; i < protectedBytes.Length; i++)
{
protectedBytes[i] = 0;
}
var tamperedId = Base64UrlTextEncoder.Encode(protectedBytes);
// Act
var isValid = factory.ValidateCircuitId(tamperedId);
// Assert
Assert.False(isValid, "Accepted a tampered payload");
}
}
}

View File

@ -35,7 +35,10 @@ namespace Microsoft.AspNetCore.Components.Server.Tests.Circuits
{
// Arrange
var circuitFactory = new TestCircuitFactory();
var circuitRegistry = new CircuitRegistry(Options.Create(new CircuitOptions()), Mock.Of<ILogger<CircuitRegistry>>());
var circuitRegistry = new CircuitRegistry(
Options.Create(new CircuitOptions()),
Mock.Of<ILogger<CircuitRegistry>>(),
TestCircuitIdFactory.CreateTestFactory());
var circuitPrerenderer = new CircuitPrerenderer(circuitFactory, circuitRegistry);
var httpContext = new DefaultHttpContext();
var httpRequest = httpContext.Request;
@ -76,7 +79,10 @@ namespace Microsoft.AspNetCore.Components.Server.Tests.Circuits
{
// Arrange
var circuitFactory = new TestCircuitFactory();
var circuitRegistry = new CircuitRegistry(Options.Create(new CircuitOptions()), Mock.Of<ILogger<CircuitRegistry>>());
var circuitRegistry = new CircuitRegistry(
Options.Create(new CircuitOptions()),
Mock.Of<ILogger<CircuitRegistry>>(),
TestCircuitIdFactory.CreateTestFactory());
var circuitPrerenderer = new CircuitPrerenderer(circuitFactory, circuitRegistry);
var httpContext = new DefaultHttpContext();
var httpRequest = httpContext.Request;
@ -117,7 +123,7 @@ namespace Microsoft.AspNetCore.Components.Server.Tests.Circuits
return uriHelper;
});
var serviceScope = serviceCollection.BuildServiceProvider().CreateScope();
return TestCircuitHost.Create(serviceScope);
return TestCircuitHost.Create(Guid.NewGuid().ToString(), serviceScope);
}
}

View File

@ -4,6 +4,7 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.DataProtection;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Logging.Abstractions;
@ -34,8 +35,10 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public async Task ConnectAsync_TransfersClientOnActiveCircuit()
{
// Arrange
var registry = CreateRegistry();
var circuitHost = TestCircuitHost.Create();
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var registry = CreateRegistry(circuitIdFactory);
var circuitHost = TestCircuitHost.Create(circuitIdFactory.CreateCircuitId());
registry.Register(circuitHost);
var newClient = Mock.Of<IClientProxy>();
@ -57,8 +60,10 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public async Task ConnectAsync_MakesInactiveCircuitActive()
{
// Arrange
var registry = CreateRegistry();
var circuitHost = TestCircuitHost.Create();
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var registry = CreateRegistry(circuitIdFactory);
var circuitHost = TestCircuitHost.Create(circuitIdFactory.CreateCircuitId());
registry.DisconnectedCircuits.Set(circuitHost.CircuitId, circuitHost, new MemoryCacheEntryOptions { Size = 1 });
var newClient = Mock.Of<IClientProxy>();
@ -81,9 +86,10 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public async Task ConnectAsync_InvokesCircuitHandlers_WhenCircuitWasPreviouslyDisconnected()
{
// Arrange
var registry = CreateRegistry();
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var registry = CreateRegistry(circuitIdFactory);
var handler = new Mock<CircuitHandler> { CallBase = true };
var circuitHost = TestCircuitHost.Create(handlers: new[] { handler.Object });
var circuitHost = TestCircuitHost.Create(circuitIdFactory.CreateCircuitId(), handlers: new[] { handler.Object });
registry.DisconnectedCircuits.Set(circuitHost.CircuitId, circuitHost, new MemoryCacheEntryOptions { Size = 1 });
var newClient = Mock.Of<IClientProxy>();
@ -104,9 +110,10 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public async Task ConnectAsync_InvokesCircuitHandlers_WhenCircuitWasConsideredConnected()
{
// Arrange
var registry = CreateRegistry();
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var registry = CreateRegistry(circuitIdFactory);
var handler = new Mock<CircuitHandler> { CallBase = true };
var circuitHost = TestCircuitHost.Create(handlers: new[] { handler.Object });
var circuitHost = TestCircuitHost.Create(circuitIdFactory.CreateCircuitId(), handlers: new[] { handler.Object });
registry.Register(circuitHost);
var newClient = Mock.Of<IClientProxy>();
@ -199,11 +206,13 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public async Task Connect_WhileDisconnectIsInProgress()
{
// Arrange
var registry = new TestCircuitRegistry();
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var registry = new TestCircuitRegistry(circuitIdFactory);
registry.BeforeDisconnect = new ManualResetEventSlim();
var tcs = new TaskCompletionSource<int>();
var circuitHost = TestCircuitHost.Create();
var circuitHost = TestCircuitHost.Create(circuitIdFactory.CreateCircuitId());
registry.Register(circuitHost);
var client = Mock.Of<IClientProxy>();
var newId = "new-connection";
@ -238,13 +247,15 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public async Task Connect_WhileDisconnectIsInProgress_SeriallyExecutesCircuitHandlers()
{
// Arrange
var registry = new TestCircuitRegistry();
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var registry = new TestCircuitRegistry(circuitIdFactory);
registry.BeforeDisconnect = new ManualResetEventSlim();
// This verifies that connection up \ down events on a circuit handler are always invoked serially.
var circuitHandler = new SerialCircuitHandler();
var tcs = new TaskCompletionSource<int>();
var circuitHost = TestCircuitHost.Create(handlers: new[] { circuitHandler });
var circuitHost = TestCircuitHost.Create(circuitIdFactory.CreateCircuitId(), handlers: new[] { circuitHandler });
registry.Register(circuitHost);
var client = Mock.Of<IClientProxy>();
var newId = "new-connection";
@ -276,9 +287,11 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public async Task DisconnectWhenAConnectIsInProgress()
{
// Arrange
var registry = new TestCircuitRegistry();
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var registry = new TestCircuitRegistry(circuitIdFactory);
registry.BeforeConnect = new ManualResetEventSlim();
var circuitHost = TestCircuitHost.Create();
var circuitHost = TestCircuitHost.Create(circuitIdFactory.CreateCircuitId());
registry.Register(circuitHost);
var client = Mock.Of<IClientProxy>();
var oldId = circuitHost.Client.ConnectionId;
@ -302,8 +315,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
private class TestCircuitRegistry : CircuitRegistry
{
public TestCircuitRegistry()
: base(Options.Create(new CircuitOptions()), NullLogger<CircuitRegistry>.Instance)
public TestCircuitRegistry(CircuitIdFactory factory)
: base(Options.Create(new CircuitOptions()), NullLogger<CircuitRegistry>.Instance, factory)
{
}
@ -331,11 +344,12 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
}
}
private static CircuitRegistry CreateRegistry()
private static CircuitRegistry CreateRegistry(CircuitIdFactory factory = null)
{
return new CircuitRegistry(
Options.Create(new CircuitOptions()),
NullLogger<CircuitRegistry>.Instance);
NullLogger<CircuitRegistry>.Instance,
factory ?? TestCircuitIdFactory.CreateTestFactory());
}
private class SerialCircuitHandler : CircuitHandler

View File

@ -18,8 +18,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
{
internal class TestCircuitHost : CircuitHost
{
private TestCircuitHost(IServiceScope scope, CircuitClientProxy client, RendererRegistry rendererRegistry, RemoteRenderer renderer, IList<ComponentDescriptor> descriptors, IDispatcher dispatcher, RemoteJSRuntime jsRuntime, CircuitHandler[] circuitHandlers, ILogger logger)
: base(scope, client, rendererRegistry, renderer, descriptors, dispatcher, jsRuntime, circuitHandlers, logger)
private TestCircuitHost(string circuitId, IServiceScope scope, CircuitClientProxy client, RendererRegistry rendererRegistry, RemoteRenderer renderer, IList<ComponentDescriptor> descriptors, IDispatcher dispatcher, RemoteJSRuntime jsRuntime, CircuitHandler[] circuitHandlers, ILogger logger)
: base(circuitId, scope, client, rendererRegistry, renderer, descriptors, dispatcher, jsRuntime, circuitHandlers, logger)
{
}
@ -29,6 +29,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
}
public static CircuitHost Create(
string circuitId = null,
IServiceScope serviceScope = null,
RemoteRenderer remoteRenderer = null,
CircuitHandler[] handlers = null,
@ -54,6 +55,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
handlers = handlers ?? Array.Empty<CircuitHandler>();
return new TestCircuitHost(
circuitId ?? Guid.NewGuid().ToString(),
serviceScope,
clientProxy,
renderRegistry,

View File

@ -0,0 +1,15 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.DataProtection;
namespace Microsoft.AspNetCore.Components.Server.Circuits
{
internal class TestCircuitIdFactory
{
public static CircuitIdFactory CreateTestFactory()
{
return new CircuitIdFactory(new EphemeralDataProtectionProvider());
}
}
}

View File

@ -432,6 +432,7 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures
{
var services = new ServiceCollection();
services.AddLogging();
services.AddDataProtection();
services.AddSingleton(HtmlEncoder.Default);
configureServices = configureServices ?? (s => s.AddServerSideBlazor());
configureServices?.Invoke(services);