From 6053a34cf3c01b200b0c7253a99ed8bfa0614a7a Mon Sep 17 00:00:00 2001 From: David Fowler Date: Wed, 21 Mar 2018 10:07:41 -0700 Subject: [PATCH] Don't expose HubConnectionContext on the Hub (#1674) - Made HubCallerContext an abstract class - Made DefaultHubCallerContext that gets data from the HubConnectionContext. - Removed IP address - Removed Connection property --- benchmarks/BenchmarkServer/Hubs/EchoHub.cs | 2 +- client-ts/FunctionalTests/TestHub.cs | 4 +-- .../ChatSample/PresenceHubLifetimeManager.cs | 4 +-- .../DefaultHubCallerContext.cs | 36 +++++++++++++++++++ .../HubCallerContext.cs | 24 ++++++++----- .../HubConnectionContext.cs | 12 ++----- .../HubEndPoint.cs | 2 +- .../Internal/DefaultHubDispatcher.cs | 4 +-- ...sions.cs => HubCallerContextExtensions.cs} | 4 +-- .../Hubs.cs | 17 +++++---- .../HubEndpointTestUtils/Hubs.cs | 16 ++++----- 11 files changed, 82 insertions(+), 43 deletions(-) create mode 100644 src/Microsoft.AspNetCore.SignalR.Core/DefaultHubCallerContext.cs rename src/Microsoft.AspNetCore.SignalR/{HttpConnectionContextExtensions.cs => HubCallerContextExtensions.cs} (73%) diff --git a/benchmarks/BenchmarkServer/Hubs/EchoHub.cs b/benchmarks/BenchmarkServer/Hubs/EchoHub.cs index 739be3fc61..2e4144dc6d 100644 --- a/benchmarks/BenchmarkServer/Hubs/EchoHub.cs +++ b/benchmarks/BenchmarkServer/Hubs/EchoHub.cs @@ -16,7 +16,7 @@ namespace BenchmarkServer.Hubs { var t = new CancellationTokenSource(); t.CancelAfter(TimeSpan.FromSeconds(duration)); - while (!t.IsCancellationRequested && !Context.Connection.ConnectionAbortedToken.IsCancellationRequested) + while (!t.IsCancellationRequested && !Context.ConnectionAborted.IsCancellationRequested) { await Clients.All.SendAsync("echo", DateTime.UtcNow); } diff --git a/client-ts/FunctionalTests/TestHub.cs b/client-ts/FunctionalTests/TestHub.cs index e8a9d12b0c..a963f66c07 100644 --- a/client-ts/FunctionalTests/TestHub.cs +++ b/client-ts/FunctionalTests/TestHub.cs @@ -30,7 +30,7 @@ namespace FunctionalTests public Task InvokeWithString(string message) { - return Clients.Client(Context.Connection.ConnectionId).SendAsync("Message", message); + return Clients.Client(Context.ConnectionId).SendAsync("Message", message); } public Task SendCustomObject(CustomObject customObject) @@ -55,7 +55,7 @@ namespace FunctionalTests public string GetActiveTransportName() { - return Context.Connection.Items[ConnectionMetadataNames.Transport].ToString(); + return Context.Items[ConnectionMetadataNames.Transport].ToString(); } public ComplexObject EchoComplexObject(ComplexObject complexObject) diff --git a/samples/ChatSample/PresenceHubLifetimeManager.cs b/samples/ChatSample/PresenceHubLifetimeManager.cs index 485bac061c..a318fee4a0 100644 --- a/samples/ChatSample/PresenceHubLifetimeManager.cs +++ b/samples/ChatSample/PresenceHubLifetimeManager.cs @@ -85,7 +85,7 @@ namespace ChatSample else { return hub.OnUsersJoined( - users.Where(u => u.ConnectionId != hub.Context.Connection.ConnectionId).ToArray()); + users.Where(u => u.ConnectionId != hub.Context.ConnectionId).ToArray()); } return Task.CompletedTask; }); @@ -112,7 +112,7 @@ namespace ChatSample } hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId); - hub.Context = new HubCallerContext(connection); + hub.Context = new DefaultHubCallerContext(connection); hub.Groups = _hubContext.Groups; try diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubCallerContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubCallerContext.cs new file mode 100644 index 0000000000..9a22ef1bd6 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubCallerContext.cs @@ -0,0 +1,36 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Security.Claims; +using System.Threading; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.SignalR +{ + public class DefaultHubCallerContext : HubCallerContext + { + private readonly HubConnectionContext _connection; + + public DefaultHubCallerContext(HubConnectionContext connection) + { + _connection = connection; + } + + public override string ConnectionId => _connection.ConnectionId; + + public override string UserIdentifier => _connection.UserIdentifier; + + public override ClaimsPrincipal User => _connection.User; + + public override IDictionary Items => _connection.Items; + + public override IFeatureCollection Features => _connection.Features; + + public override CancellationToken ConnectionAborted => _connection.ConnectionAborted; + + public override void Abort() => _connection.Abort(); + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubCallerContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubCallerContext.cs index 3f75195538..4f3e0a1572 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubCallerContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubCallerContext.cs @@ -1,21 +1,29 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; using System.Security.Claims; +using System.Threading; +using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.SignalR { - public class HubCallerContext + public abstract class HubCallerContext { - public HubCallerContext(HubConnectionContext connection) - { - Connection = connection; - } + public abstract string ConnectionId { get; } - public HubConnectionContext Connection { get; } + public abstract string UserIdentifier { get; } - public ClaimsPrincipal User => Connection.User; + public abstract ClaimsPrincipal User { get; } - public string ConnectionId => Connection.ConnectionId; + public abstract IDictionary Items { get; } + + public abstract IFeatureCollection Features { get; } + + public abstract CancellationToken ConnectionAborted { get; } + + public abstract void Abort(); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 8026a1ad00..e723d6adce 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -41,11 +41,11 @@ namespace Microsoft.AspNetCore.SignalR { _connectionContext = connectionContext; _logger = loggerFactory.CreateLogger(); - ConnectionAbortedToken = _connectionAbortedTokenSource.Token; + ConnectionAborted = _connectionAbortedTokenSource.Token; _keepAliveDuration = (int)keepAliveInterval.TotalMilliseconds * (Stopwatch.Frequency / 1000); } - public virtual CancellationToken ConnectionAbortedToken { get; } + public virtual CancellationToken ConnectionAborted { get; } public virtual string ConnectionId => _connectionContext.ConnectionId; @@ -66,14 +66,6 @@ namespace Microsoft.AspNetCore.SignalR // Currently used only for streaming methods internal ConcurrentDictionary ActiveRequestCancellationSources { get; } = new ConcurrentDictionary(); - public IPAddress RemoteIpAddress => Features.Get()?.RemoteIpAddress; - - public IPAddress LocalIpAddress => Features.Get()?.LocalIpAddress; - - public int? RemotePort => Features.Get()?.RemotePort; - - public int? LocalPort => Features.Get()?.LocalPort; - public virtual ValueTask WriteAsync(HubMessage message) { // We were unable to get the lock so take the slow async path of waiting for the semaphore diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs index 85a707aba6..2ff9c0de1b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs @@ -162,7 +162,7 @@ namespace Microsoft.AspNetCore.SignalR { while (true) { - var result = await connection.Input.ReadAsync(connection.ConnectionAbortedToken); + var result = await connection.Input.ReadAsync(connection.ConnectionAborted); var buffer = result.Buffer; var consumed = buffer.End; var examined = buffer.End; diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs index b1092c0987..b4b29aebd6 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs @@ -300,7 +300,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal private void InitializeHub(THub hub, HubConnectionContext connection) { hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId); - hub.Context = new HubCallerContext(connection); + hub.Context = new DefaultHubCallerContext(connection); hub.Groups = _hubContext.Groups; } @@ -376,7 +376,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal { var streamCts = new CancellationTokenSource(); connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts); - return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAbortedToken, streamCts.Token).Token; + return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, streamCts.Token).Token; } } diff --git a/src/Microsoft.AspNetCore.SignalR/HttpConnectionContextExtensions.cs b/src/Microsoft.AspNetCore.SignalR/HubCallerContextExtensions.cs similarity index 73% rename from src/Microsoft.AspNetCore.SignalR/HttpConnectionContextExtensions.cs rename to src/Microsoft.AspNetCore.SignalR/HubCallerContextExtensions.cs index 3842012f04..45ef53def2 100644 --- a/src/Microsoft.AspNetCore.SignalR/HttpConnectionContextExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubCallerContextExtensions.cs @@ -6,9 +6,9 @@ using Microsoft.AspNetCore.Sockets.Http.Features; namespace Microsoft.AspNetCore.SignalR { - public static class DefaultConnectionContextExtensions + public static class HubCallerContextExtensions { - public static HttpContext GetHttpContext(this HubConnectionContext connection) + public static HttpContext GetHttpContext(this HubCallerContext connection) { return connection.Features.Get()?.HttpContext; } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs index da95142802..58f7358234 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs @@ -9,6 +9,7 @@ using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Authentication.JwtBearer; using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Sockets; namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests @@ -37,7 +38,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests public IEnumerable GetHeaderValues(string[] headerNames) { - var context = Context.Connection.GetHttpContext(); + var context = Context.GetHttpContext(); if (context == null) { @@ -56,17 +57,19 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests public string GetCookieValue(string cookieName) { - return Context.Connection.GetHttpContext().Request.Cookies[cookieName]; + return Context.GetHttpContext().Request.Cookies[cookieName]; } public object[] GetIHttpConnectionFeatureProperties() { + var feature = Context.Features.Get(); + object[] result = { - Context.Connection.LocalPort, - Context.Connection.RemotePort, - Context.Connection.LocalIpAddress.ToString(), - Context.Connection.RemoteIpAddress.ToString() + feature.LocalPort, + feature.RemotePort, + feature.LocalIpAddress.ToString(), + feature.RemoteIpAddress.ToString() }; return result; @@ -74,7 +77,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests public string GetActiveTransportName() { - return Context.Connection.Items[ConnectionMetadataNames.Transport].ToString(); + return Context.Items[ConnectionMetadataNames.Transport].ToString(); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs index 8b9a5b8d4d..b78324bdc6 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs @@ -144,7 +144,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils public bool HasHttpContext() { - return Context.Connection.GetHttpContext() != null; + return Context.GetHttpContext() != null; } public Task SendToOthers(string message) @@ -162,7 +162,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils { public override Task OnConnectedAsync() { - var tcs = (TaskCompletionSource)Context.Connection.Items["ConnectedTask"]; + var tcs = (TaskCompletionSource)Context.Items["ConnectedTask"]; tcs?.TrySetResult(true); return base.OnConnectedAsync(); } @@ -172,7 +172,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils { public override Task OnConnectedAsync() { - var tcs = (TaskCompletionSource)Context.Connection.Items["ConnectedTask"]; + var tcs = (TaskCompletionSource)Context.Items["ConnectedTask"]; tcs?.TrySetResult(true); return base.OnConnectedAsync(); } @@ -252,7 +252,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils { public override Task OnConnectedAsync() { - var tcs = (TaskCompletionSource)Context.Connection.Items["ConnectedTask"]; + var tcs = (TaskCompletionSource)Context.Items["ConnectedTask"]; tcs?.TrySetResult(true); return base.OnConnectedAsync(); } @@ -437,7 +437,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils { public void Kill() { - Context.Connection.Abort(); + Context.Abort(); } } @@ -623,9 +623,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils public override Task OnConnectedAsync() { - _state.TokenStateInConnected = Context.Connection.ConnectionAbortedToken.IsCancellationRequested; + _state.TokenStateInConnected = Context.ConnectionAborted.IsCancellationRequested; - Context.Connection.ConnectionAbortedToken.Register(() => + Context.ConnectionAborted.Register(() => { _state.TokenCallbackTriggered = true; }); @@ -635,7 +635,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils public override Task OnDisconnectedAsync(Exception exception) { - _state.TokenStateInDisconnected = Context.Connection.ConnectionAbortedToken.IsCancellationRequested; + _state.TokenStateInDisconnected = Context.ConnectionAborted.IsCancellationRequested; return base.OnDisconnectedAsync(exception); }