From 595f783857598c56056305d627b969b856367454 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Thu, 6 Jul 2017 11:27:16 -0700 Subject: [PATCH] Features everywhere (#639) * Features everywhere - The goal here is to move things closer to the final design where ConnectionContext represents a very low level primitive that represents any connection like transport. As part of that change, we remove unnecessary properties like User and move those into features. They temporarily live in the same assembly but they are not required by ConnectionContext. - Used features for Hubs instead of Metadata - Metadata is no longer thread safe --- .../PersistentConnectionLifeTimeManager.cs | 7 +- .../SocialWeather/SocialWeatherEndPoint.cs | 2 +- .../RedisHubLifetimeManager.cs | 32 +++-- .../DefaultHubLifetimeManager.cs | 36 ++++-- .../Features/IHubFeature.cs | 20 +++ .../HubConnectionContext.cs | 18 ++- .../HubConnectionMetadataNames.cs | 11 -- .../HubEndPoint.cs | 11 +- .../ConnectionContext.cs | 11 +- .../ConnectionMetadata.cs | 37 ------ .../Features/IConnectionIdFeature.cs | 14 +++ .../Features/IConnectionMetadataFeature.cs | 13 ++ .../Features/IConnectionTransportFeature.cs | 15 +++ .../Features/IConnectionUserFeature.cs | 15 +++ .../ConnectionMetadataNames.cs | 1 - .../Features/IHttpContextFeature.cs | 20 +++ .../HttpConnectionContextExtensions.cs | 14 ++- .../HttpConnectionDispatcher.cs | 6 +- .../ConnectionMetadata.cs | 118 ++++++++++++++++++ .../DefaultConnectionContext.cs | 24 +++- .../HubEndpointTests.cs | 3 +- .../TestClient.cs | 2 +- 22 files changed, 335 insertions(+), 95 deletions(-) create mode 100644 src/Microsoft.AspNetCore.SignalR/Features/IHubFeature.cs delete mode 100644 src/Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs delete mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionMetadata.cs create mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionIdFeature.cs create mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionMetadataFeature.cs create mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs create mode 100644 src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionUserFeature.cs create mode 100644 src/Microsoft.AspNetCore.Sockets.Http/Features/IHttpContextFeature.cs create mode 100644 src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs diff --git a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs index 75bed615b3..87a6fbb9c3 100644 --- a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs +++ b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs @@ -21,6 +21,7 @@ namespace SocialWeather public void OnConnectedAsync(ConnectionContext connection) { + connection.Metadata["groups"] = new HashSet(); connection.Metadata["format"] = "json"; _connectionList.Add(connection); } @@ -35,7 +36,7 @@ namespace SocialWeather foreach (var connection in _connectionList) { var context = connection.GetHttpContext(); - var formatter = _formatterResolver.GetFormatter(connection.Metadata.Get("format")); + var formatter = _formatterResolver.GetFormatter((string)connection.Metadata["format"]); var ms = new MemoryStream(); await formatter.WriteAsync(data, ms); @@ -60,7 +61,7 @@ namespace SocialWeather public void AddGroupAsync(ConnectionContext connection, string groupName) { - var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet()); + var groups = (HashSet)connection.Metadata["groups"]; lock (groups) { groups.Add(groupName); @@ -69,7 +70,7 @@ namespace SocialWeather public void RemoveGroupAsync(ConnectionContext connection, string groupName) { - var groups = connection.Metadata.Get>("groups"); + var groups = (HashSet)connection.Metadata["groups"]; if (groups != null) { lock (groups) diff --git a/samples/SocialWeather/SocialWeatherEndPoint.cs b/samples/SocialWeather/SocialWeatherEndPoint.cs index 9aa3fb8e58..e412cfafeb 100644 --- a/samples/SocialWeather/SocialWeatherEndPoint.cs +++ b/samples/SocialWeather/SocialWeatherEndPoint.cs @@ -32,7 +32,7 @@ namespace SocialWeather public async Task ProcessRequests(ConnectionContext connection) { var formatter = _formatterResolver.GetFormatter( - connection.Metadata.Get("formatType")); + (string)connection.Metadata["formatType"]); while (await connection.Transport.In.WaitToReadAsync()) { diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index 085648ba76..25bdd602de 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -19,8 +19,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis { public class RedisHubLifetimeManager : HubLifetimeManager, IDisposable { - private const string RedisSubscriptionsMetadataName = "redis_subscriptions"; - private readonly HubConnectionList _connections = new HubConnectionList(); // TODO: Investigate "memory leak" entries never get removed private readonly ConcurrentDictionary _groups = new ConcurrentDictionary(); @@ -129,7 +127,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis public override Task OnConnectedAsync(HubConnectionContext connection) { - var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet()); + var feature = new RedisFeature(); + connection.Features.Set(feature); + + var redisSubscriptions = feature.Subscriptions; var connectionTask = Task.CompletedTask; var userTask = Task.CompletedTask; @@ -178,7 +179,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis var tasks = new List(); - var redisSubscriptions = connection.Metadata.Get>(RedisSubscriptionsMetadataName); + var feature = connection.Features.Get(); + + var redisSubscriptions = feature.Subscriptions; if (redisSubscriptions != null) { foreach (var subscription in redisSubscriptions) @@ -188,7 +191,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis } } - var groupNames = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); + var groupNames = feature.Groups; if (groupNames != null) { @@ -211,7 +214,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis { return; } - var groupNames = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet()); + + var feature = connection.Features.Get(); + var groupNames = feature.Groups; lock (groupNames) { @@ -274,7 +279,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis return; } - var groupNames = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); + var feature = connection.Features.Get(); + var groupNames = feature.Groups; if (groupNames != null) { lock (groupNames) @@ -363,5 +369,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis public SemaphoreSlim Lock = new SemaphoreSlim(1, 1); public HubConnectionList Connections = new HubConnectionList(); } + + private interface IRedisFeature + { + HashSet Subscriptions { get; } + HashSet Groups { get; } + } + + private class RedisFeature : IRedisFeature + { + public HashSet Subscriptions { get; } = new HashSet(); + public HashSet Groups { get; } = new HashSet(StringComparer.OrdinalIgnoreCase); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs index 12ec7f3248..eca117dc30 100644 --- a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Features; using Microsoft.AspNetCore.SignalR.Internal.Protocol; namespace Microsoft.AspNetCore.SignalR @@ -22,7 +23,9 @@ namespace Microsoft.AspNetCore.SignalR return Task.CompletedTask; } - var groups = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet()); + var feature = connection.Features.Get(); + var groups = feature.Groups; + lock (groups) { groups.Add(groupName); @@ -39,12 +42,8 @@ namespace Microsoft.AspNetCore.SignalR return Task.CompletedTask; } - var groups = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); - - if (groups == null) - { - return Task.CompletedTask; - } + var feature = connection.Features.Get(); + var groups = feature.Groups; lock (groups) { @@ -91,8 +90,14 @@ namespace Microsoft.AspNetCore.SignalR { return InvokeAllWhere(methodName, args, connection => { - var groups = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); - return groups?.Contains(groupName) == true; + var feature = connection.Features.Get(); + var groups = feature.Groups; + + // PERF: ... + lock (groups) + { + return groups.Contains(groupName) == true; + } }); } @@ -106,6 +111,9 @@ namespace Microsoft.AspNetCore.SignalR public override Task OnConnectedAsync(HubConnectionContext connection) { + // Set the hub groups feature + connection.Features.Set(new HubGroupsFeature()); + _connections.Add(connection); return Task.CompletedTask; } @@ -134,5 +142,15 @@ namespace Microsoft.AspNetCore.SignalR var invocationId = Interlocked.Increment(ref _nextInvocationId); return invocationId.ToString(); } + + private interface IHubGroupsFeature + { + HashSet Groups { get; } + } + + private class HubGroupsFeature : IHubGroupsFeature + { + public HashSet Groups { get; } = new HashSet(StringComparer.OrdinalIgnoreCase); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR/Features/IHubFeature.cs b/src/Microsoft.AspNetCore.SignalR/Features/IHubFeature.cs new file mode 100644 index 0000000000..0269c4458a --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR/Features/IHubFeature.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Features +{ + public interface IHubFeature + { + IHubProtocol Protocol { get; set; } + } + + public class HubFeature : IHubFeature + { + public IHubProtocol Protocol { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs index 6bd4c5e101..9159d25c49 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs @@ -2,10 +2,14 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Security.Claims; using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.SignalR.Features; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; +using Microsoft.AspNetCore.Sockets.Features; namespace Microsoft.AspNetCore.SignalR { @@ -20,16 +24,24 @@ namespace Microsoft.AspNetCore.SignalR _connectionContext = connectionContext; } + private IHubFeature HubFeature => Features.Get(); + // Used by the HubEndPoint only internal ReadableChannel Input => _connectionContext.Transport; public virtual string ConnectionId => _connectionContext.ConnectionId; - public virtual ClaimsPrincipal User => _connectionContext.User; + public virtual ClaimsPrincipal User => Features.Get()?.User; - public virtual ConnectionMetadata Metadata => _connectionContext.Metadata; + public virtual IFeatureCollection Features => _connectionContext.Features; - public virtual IHubProtocol Protocol => _connectionContext.Metadata.Get(HubConnectionMetadataNames.HubProtocol); + public virtual IDictionary Metadata => _connectionContext.Metadata; + + public virtual IHubProtocol Protocol + { + get => HubFeature.Protocol; + set => HubFeature.Protocol = value; + } public virtual WritableChannel Output => _output; } diff --git a/src/Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs b/src/Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs deleted file mode 100644 index 702a762c01..0000000000 --- a/src/Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -namespace Microsoft.AspNetCore.SignalR -{ - public static class HubConnectionMetadataNames - { - public static readonly string HubProtocol = nameof(HubProtocol); - public static readonly string Groups = nameof(Groups); - } -} diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index 4ade79b872..2cb809a3f7 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -11,6 +11,7 @@ using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.SignalR.Features; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; @@ -48,6 +49,11 @@ namespace Microsoft.AspNetCore.SignalR public async Task OnConnectedAsync(ConnectionContext connection) { var output = Channel.CreateUnbounded(); + + // Set the hub feature before doing anything else. This stores + // all the relevant state for a SignalR Hub connection + connection.Features.Set(new HubFeature()); + var connectionContext = new HubConnectionContext(output, connection); await ProcessNegotiate(connectionContext); @@ -101,8 +107,7 @@ namespace Microsoft.AspNetCore.SignalR // Resolve the Hub Protocol for the connection and store it in metadata // Other components, outside the Hub, may need to know what protocol is in use // for a particular connection, so we store it here. - connection.Metadata[HubConnectionMetadataNames.HubProtocol] = - _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection); + connection.Protocol = _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection); return; } @@ -188,7 +193,7 @@ namespace Microsoft.AspNetCore.SignalR // is used to get the exception so we can bubble it up the stack var cts = new CancellationTokenSource(); var completion = new TaskCompletionSource(); - var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol); + var protocol = connection.Protocol; try { diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs index 56e233b8c7..7fa7278130 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs @@ -1,8 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Security.Claims; +using System.Collections.Generic; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Http.Features; @@ -10,16 +9,12 @@ namespace Microsoft.AspNetCore.Sockets { public abstract class ConnectionContext { - public abstract string ConnectionId { get; } + public abstract string ConnectionId { get; set; } public abstract IFeatureCollection Features { get; } - public abstract ClaimsPrincipal User { get; set; } + public abstract IDictionary Metadata { get; set; } - // REVIEW: Should this be changed to items - public abstract ConnectionMetadata Metadata { get; } - - // TEMPORARY public abstract Channel Transport { get; set; } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionMetadata.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionMetadata.cs deleted file mode 100644 index 6ca39c25d7..0000000000 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionMetadata.cs +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Concurrent; - -namespace Microsoft.AspNetCore.Sockets -{ - public class ConnectionMetadata - { - private ConcurrentDictionary _metadata = new ConcurrentDictionary(); - - public object this[object key] - { - get - { - object value; - _metadata.TryGetValue(key, out value); - return value; - } - set - { - _metadata[key] = value; - } - } - - public T GetOrAdd(object key, Func factory) - { - return (T)_metadata.GetOrAdd(key, k => factory(k)); - } - - public T Get(object key) - { - return (T)this[key]; - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionIdFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionIdFeature.cs new file mode 100644 index 0000000000..5dccb32388 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionIdFeature.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.AspNetCore.Sockets.Features +{ + public interface IConnectionIdFeature + { + string ConnectionId { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionMetadataFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionMetadataFeature.cs new file mode 100644 index 0000000000..78b59cd668 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionMetadataFeature.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Sockets.Features +{ + public interface IConnectionMetadataFeature + { + IDictionary Metadata { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs new file mode 100644 index 0000000000..8eeb8879e0 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks.Channels; + +namespace Microsoft.AspNetCore.Sockets.Features +{ + public interface IConnectionTransportFeature + { + Channel Transport { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionUserFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionUserFeature.cs new file mode 100644 index 0000000000..7637468399 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionUserFeature.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Security.Claims; +using System.Text; + +namespace Microsoft.AspNetCore.Sockets.Features +{ + public interface IConnectionUserFeature + { + ClaimsPrincipal User { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Http/ConnectionMetadataNames.cs b/src/Microsoft.AspNetCore.Sockets.Http/ConnectionMetadataNames.cs index 33ea27cd9c..5074b739c9 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/ConnectionMetadataNames.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/ConnectionMetadataNames.cs @@ -6,6 +6,5 @@ namespace Microsoft.AspNetCore.Sockets public static class ConnectionMetadataNames { public static readonly string Transport = nameof(Transport); - public static readonly string HttpContext = nameof(HttpContext); } } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Features/IHttpContextFeature.cs b/src/Microsoft.AspNetCore.Sockets.Http/Features/IHttpContextFeature.cs new file mode 100644 index 0000000000..d18ebb62d6 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Http/Features/IHttpContextFeature.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Sockets.Http.Features +{ + public interface IHttpContextFeature + { + HttpContext HttpContext { get; set; } + } + + public class HttpContextFeature : IHttpContextFeature + { + public HttpContext HttpContext { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContextExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContextExtensions.cs index 8a5b946d8c..74ce92ccd1 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContextExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContextExtensions.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Text; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Sockets.Http.Features; namespace Microsoft.AspNetCore.Sockets { @@ -12,7 +13,18 @@ namespace Microsoft.AspNetCore.Sockets { public static HttpContext GetHttpContext(this ConnectionContext connection) { - return connection.Metadata.Get(ConnectionMetadataNames.HttpContext); + return connection.Features.Get().HttpContext; + } + + public static void SetHttpContext(this ConnectionContext connection, HttpContext httpContext) + { + var feature = connection.Features.Get(); + if (feature == null) + { + feature = new HttpContextFeature(); + connection.Features.Set(feature); + } + feature.HttpContext = httpContext; } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs index e7db22909c..1a02aac7dd 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs @@ -241,7 +241,7 @@ namespace Microsoft.AspNetCore.Sockets connection.Status = DefaultConnectionContext.ConnectionStatus.Inactive; - connection.Metadata[ConnectionMetadataNames.HttpContext] = null; + connection.SetHttpContext(null); // Dispose the cancellation token connection.Cancellation.Dispose(); @@ -403,7 +403,7 @@ namespace Microsoft.AspNetCore.Sockets return false; } - var transport = connection.Metadata.Get(ConnectionMetadataNames.Transport); + var transport = (TransportType?)connection.Metadata[ConnectionMetadataNames.Transport]; if (transport == null) { @@ -419,7 +419,7 @@ namespace Microsoft.AspNetCore.Sockets // Setup the connection state from the http context connection.User = context.User; - connection.Metadata[ConnectionMetadataNames.HttpContext] = context; + connection.SetHttpContext(context); return true; } diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs new file mode 100644 index 0000000000..51a97e2b05 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs @@ -0,0 +1,118 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Sockets +{ + internal class ConnectionMetadata : IDictionary + { + public ConnectionMetadata() + : this(new Dictionary()) + { + } + + public ConnectionMetadata(IDictionary items) + { + Items = items; + } + + public IDictionary Items { get; } + + // Replace the indexer with one that returns null for missing values + object IDictionary.this[object key] + { + get + { + if (Items.TryGetValue(key, out var value)) + { + return value; + } + return null; + } + set { Items[key] = value; } + } + + void IDictionary.Add(object key, object value) + { + Items.Add(key, value); + } + + bool IDictionary.ContainsKey(object key) + { + return Items.ContainsKey(key); + } + + ICollection IDictionary.Keys + { + get { return Items.Keys; } + } + + bool IDictionary.Remove(object key) + { + return Items.Remove(key); + } + + bool IDictionary.TryGetValue(object key, out object value) + { + return Items.TryGetValue(key, out value); + } + + ICollection IDictionary.Values + { + get { return Items.Values; } + } + + void ICollection>.Add(KeyValuePair item) + { + Items.Add(item); + } + + void ICollection>.Clear() + { + Items.Clear(); + } + + bool ICollection>.Contains(KeyValuePair item) + { + return Items.Contains(item); + } + + void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) + { + Items.CopyTo(array, arrayIndex); + } + + int ICollection>.Count + { + get { return Items.Count; } + } + + bool ICollection>.IsReadOnly + { + get { return Items.IsReadOnly; } + } + + bool ICollection>.Remove(KeyValuePair item) + { + object value; + if (Items.TryGetValue(item.Key, out value) && Equals(item.Value, value)) + { + return Items.Remove(item.Key); + } + return false; + } + + IEnumerator> IEnumerable>.GetEnumerator() + { + return Items.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return Items.GetEnumerator(); + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs index a23b1224ac..4072a9e9ec 100644 --- a/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs @@ -2,15 +2,21 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Sockets.Features; namespace Microsoft.AspNetCore.Sockets { - public class DefaultConnectionContext : ConnectionContext + public class DefaultConnectionContext : ConnectionContext, + IConnectionIdFeature, + IConnectionMetadataFeature, + IConnectionTransportFeature, + IConnectionUserFeature { // This tcs exists so that multiple calls to DisposeAsync all wait asynchronously // on the same task @@ -22,6 +28,13 @@ namespace Microsoft.AspNetCore.Sockets Application = application; ConnectionId = id; LastSeenUtc = DateTime.UtcNow; + + // PERF: This type could just implement IFeatureCollection + Features = new FeatureCollection(); + Features.Set(this); + Features.Set(this); + Features.Set(this); + Features.Set(this); } public CancellationTokenSource Cancellation { get; set; } @@ -36,13 +49,13 @@ namespace Microsoft.AspNetCore.Sockets public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive; - public override string ConnectionId { get; } + public override string ConnectionId { get; set; } - public override IFeatureCollection Features { get; } = new FeatureCollection(); + public override IFeatureCollection Features { get; } - public override ClaimsPrincipal User { get; set; } + public ClaimsPrincipal User { get; set; } - public override ConnectionMetadata Metadata { get; } = new ConnectionMetadata(); + public override IDictionary Metadata { get; set; } = new ConnectionMetadata(); public Channel Application { get; } @@ -121,7 +134,6 @@ namespace Microsoft.AspNetCore.Sockets } } - public enum ConnectionStatus { Inactive, diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index a47c7ad13c..f88907b019 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -954,7 +954,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests { public override Task OnConnectedAsync() { - Context.Connection.Metadata.Get>("ConnectedTask")?.TrySetResult(true); + var tcs = (TaskCompletionSource)Context.Connection.Metadata["ConnectedTask"]; + tcs?.TrySetResult(true); return base.OnConnectedAsync(); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs index 774884cbcb..a69df361ae 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs @@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public DefaultConnectionContext Connection { get; } public Channel Application { get; } - public Task Connected => Connection.Metadata.Get>("ConnectedTask").Task; + public Task Connected => ((TaskCompletionSource)Connection.Metadata["ConnectedTask"]).Task; public TestClient(bool synchronousCallbacks = false) {