From 1ea10f68fc6d6fa5713002ad38bf2b2ee1908a49 Mon Sep 17 00:00:00 2001 From: Brennan Date: Wed, 26 Feb 2020 09:54:24 -0800 Subject: [PATCH] Create WindowsPrincipal when cloning WindowsIdentity for SignalR (#19337) --- .../src/Internal/HttpConnectionDispatcher.cs | 23 +++++++- .../test/HttpConnectionDispatcherTests.cs | 57 ++++++++++++++++++- 2 files changed, 77 insertions(+), 3 deletions(-) diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index 73b45bb9b6..969d4a5935 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -572,12 +572,31 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private static void CloneUser(HttpContext newContext, HttpContext oldContext) { - if (oldContext.User.Identity is WindowsIdentity) + // If the identity is a WindowsIdentity we need to clone the User. + // This is because the WindowsIdentity uses SafeHandle's which are disposed at the end of the request + // and accessing the identity can happen outside of the request scope. + if (oldContext.User.Identity is WindowsIdentity windowsIdentity) { - newContext.User = new ClaimsPrincipal(); + var skipFirstIdentity = false; + if (oldContext.User is WindowsPrincipal) + { + // We want to explicitly create a WindowsPrincipal instead of a ClaimsPrincipal + // so methods that WindowsPrincipal overrides like 'IsInRole', work as expected. + newContext.User = new WindowsPrincipal((WindowsIdentity)(windowsIdentity.Clone())); + skipFirstIdentity = true; + } + else + { + newContext.User = new ClaimsPrincipal(); + } foreach (var identity in oldContext.User.Identities) { + if (skipFirstIdentity) + { + skipFirstIdentity = false; + continue; + } newContext.User.AddIdentity(identity.Clone()); } } diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index 2a083c5982..664a794324 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -1639,7 +1639,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests [ConditionalFact] [OSSkipCondition(OperatingSystems.Linux | OperatingSystems.MacOSX)] - public async Task LongPollingKeepsWindowsIdentityBetweenRequests() + public async Task LongPollingKeepsWindowsPrincipalAndIdentityBetweenRequests() { using (StartVerifiableLog()) { @@ -1668,6 +1668,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var windowsIdentity = WindowsIdentity.GetAnonymous(); context.User = new WindowsPrincipal(windowsIdentity); + context.User.AddIdentity(new ClaimsIdentity()); // would get stuck if EndPoint was running await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); @@ -1681,6 +1682,60 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests // This is the important check Assert.Same(currentUser, connection.User); + Assert.IsType(currentUser); + Assert.Equal(2, connection.User.Identities.Count()); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + } + } + + [ConditionalFact] + [OSSkipCondition(OperatingSystems.Linux | OperatingSystems.MacOSX)] + public async Task LongPollingKeepsWindowsIdentityWithoutWindowsPrincipalBetweenRequests() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddOptions(); + services.AddSingleton(); + services.AddLogging(); + var sp = services.BuildServiceProvider(); + context.Request.Path = "/foo"; + context.Request.Method = "GET"; + context.RequestServices = sp; + var values = new Dictionary(); + values["id"] = connection.ConnectionToken; + values["negotiateVersion"] = "1"; + var qs = new QueryCollection(values); + context.Request.Query = qs; + var builder = new ConnectionBuilder(sp); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + + var windowsIdentity = WindowsIdentity.GetAnonymous(); + context.User = new ClaimsPrincipal(windowsIdentity); + context.User.AddIdentity(new ClaimsIdentity()); + + // would get stuck if EndPoint was running + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); + + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + var currentUser = connection.User; + + var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Unblock")).AsTask().OrTimeout(); + await connectionHandlerTask.OrTimeout(); + + // This is the important check + Assert.Same(currentUser, connection.User); + Assert.IsNotType(currentUser); + Assert.Equal(2, connection.User.Identities.Count()); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); }