diff --git a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs index 75bed615b3..87a6fbb9c3 100644 --- a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs +++ b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs @@ -21,6 +21,7 @@ namespace SocialWeather public void OnConnectedAsync(ConnectionContext connection) { + connection.Metadata["groups"] = new HashSet(); connection.Metadata["format"] = "json"; _connectionList.Add(connection); } @@ -35,7 +36,7 @@ namespace SocialWeather foreach (var connection in _connectionList) { var context = connection.GetHttpContext(); - var formatter = _formatterResolver.GetFormatter(connection.Metadata.Get("format")); + var formatter = _formatterResolver.GetFormatter((string)connection.Metadata["format"]); var ms = new MemoryStream(); await formatter.WriteAsync(data, ms); @@ -60,7 +61,7 @@ namespace SocialWeather public void AddGroupAsync(ConnectionContext connection, string groupName) { - var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet()); + var groups = (HashSet)connection.Metadata["groups"]; lock (groups) { groups.Add(groupName); @@ -69,7 +70,7 @@ namespace SocialWeather public void RemoveGroupAsync(ConnectionContext connection, string groupName) { - var groups = connection.Metadata.Get>("groups"); + var groups = (HashSet)connection.Metadata["groups"]; if (groups != null) { lock (groups) diff --git a/samples/SocialWeather/SocialWeatherEndPoint.cs b/samples/SocialWeather/SocialWeatherEndPoint.cs index 9aa3fb8e58..e412cfafeb 100644 --- a/samples/SocialWeather/SocialWeatherEndPoint.cs +++ b/samples/SocialWeather/SocialWeatherEndPoint.cs @@ -32,7 +32,7 @@ namespace SocialWeather public async Task ProcessRequests(ConnectionContext connection) { var formatter = _formatterResolver.GetFormatter( - connection.Metadata.Get("formatType")); + (string)connection.Metadata["formatType"]); while (await connection.Transport.In.WaitToReadAsync()) { diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index 085648ba76..25bdd602de 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -19,8 +19,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis { public class RedisHubLifetimeManager : HubLifetimeManager, IDisposable { - private const string RedisSubscriptionsMetadataName = "redis_subscriptions"; - private readonly HubConnectionList _connections = new HubConnectionList(); // TODO: Investigate "memory leak" entries never get removed private readonly ConcurrentDictionary _groups = new ConcurrentDictionary(); @@ -129,7 +127,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis public override Task OnConnectedAsync(HubConnectionContext connection) { - var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet()); + var feature = new RedisFeature(); + connection.Features.Set(feature); + + var redisSubscriptions = feature.Subscriptions; var connectionTask = Task.CompletedTask; var userTask = Task.CompletedTask; @@ -178,7 +179,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis var tasks = new List(); - var redisSubscriptions = connection.Metadata.Get>(RedisSubscriptionsMetadataName); + var feature = connection.Features.Get(); + + var redisSubscriptions = feature.Subscriptions; if (redisSubscriptions != null) { foreach (var subscription in redisSubscriptions) @@ -188,7 +191,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis } } - var groupNames = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); + var groupNames = feature.Groups; if (groupNames != null) { @@ -211,7 +214,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis { return; } - var groupNames = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet()); + + var feature = connection.Features.Get(); + var groupNames = feature.Groups; lock (groupNames) { @@ -274,7 +279,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis return; } - var groupNames = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); + var feature = connection.Features.Get(); + var groupNames = feature.Groups; if (groupNames != null) { lock (groupNames) @@ -363,5 +369,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis public SemaphoreSlim Lock = new SemaphoreSlim(1, 1); public HubConnectionList Connections = new HubConnectionList(); } + + private interface IRedisFeature + { + HashSet Subscriptions { get; } + HashSet Groups { get; } + } + + private class RedisFeature : IRedisFeature + { + public HashSet Subscriptions { get; } = new HashSet(); + public HashSet Groups { get; } = new HashSet(StringComparer.OrdinalIgnoreCase); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs index 12ec7f3248..eca117dc30 100644 --- a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Features; using Microsoft.AspNetCore.SignalR.Internal.Protocol; namespace Microsoft.AspNetCore.SignalR @@ -22,7 +23,9 @@ namespace Microsoft.AspNetCore.SignalR return Task.CompletedTask; } - var groups = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet()); + var feature = connection.Features.Get(); + var groups = feature.Groups; + lock (groups) { groups.Add(groupName); @@ -39,12 +42,8 @@ namespace Microsoft.AspNetCore.SignalR return Task.CompletedTask; } - var groups = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); - - if (groups == null) - { - return Task.CompletedTask; - } + var feature = connection.Features.Get(); + var groups = feature.Groups; lock (groups) { @@ -91,8 +90,14 @@ namespace Microsoft.AspNetCore.SignalR { return InvokeAllWhere(methodName, args, connection => { - var groups = connection.Metadata.Get>(HubConnectionMetadataNames.Groups); - return groups?.Contains(groupName) == true; + var feature = connection.Features.Get(); + var groups = feature.Groups; + + // PERF: ... + lock (groups) + { + return groups.Contains(groupName) == true; + } }); } @@ -106,6 +111,9 @@ namespace Microsoft.AspNetCore.SignalR public override Task OnConnectedAsync(HubConnectionContext connection) { + // Set the hub groups feature + connection.Features.Set(new HubGroupsFeature()); + _connections.Add(connection); return Task.CompletedTask; } @@ -134,5 +142,15 @@ namespace Microsoft.AspNetCore.SignalR var invocationId = Interlocked.Increment(ref _nextInvocationId); return invocationId.ToString(); } + + private interface IHubGroupsFeature + { + HashSet Groups { get; } + } + + private class HubGroupsFeature : IHubGroupsFeature + { + public HashSet Groups { get; } = new HashSet(StringComparer.OrdinalIgnoreCase); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR/Features/IHubFeature.cs b/src/Microsoft.AspNetCore.SignalR/Features/IHubFeature.cs new file mode 100644 index 0000000000..0269c4458a --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR/Features/IHubFeature.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Features +{ + public interface IHubFeature + { + IHubProtocol Protocol { get; set; } + } + + public class HubFeature : IHubFeature + { + public IHubProtocol Protocol { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs index 6bd4c5e101..9159d25c49 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs @@ -2,10 +2,14 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Security.Claims; using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.SignalR.Features; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; +using Microsoft.AspNetCore.Sockets.Features; namespace Microsoft.AspNetCore.SignalR { @@ -20,16 +24,24 @@ namespace Microsoft.AspNetCore.SignalR _connectionContext = connectionContext; } + private IHubFeature HubFeature => Features.Get(); + // Used by the HubEndPoint only internal ReadableChannel Input => _connectionContext.Transport; public virtual string ConnectionId => _connectionContext.ConnectionId; - public virtual ClaimsPrincipal User => _connectionContext.User; + public virtual ClaimsPrincipal User => Features.Get()?.User; - public virtual ConnectionMetadata Metadata => _connectionContext.Metadata; + public virtual IFeatureCollection Features => _connectionContext.Features; - public virtual IHubProtocol Protocol => _connectionContext.Metadata.Get(HubConnectionMetadataNames.HubProtocol); + public virtual IDictionary Metadata => _connectionContext.Metadata; + + public virtual IHubProtocol Protocol + { + get => HubFeature.Protocol; + set => HubFeature.Protocol = value; + } public virtual WritableChannel Output => _output; } diff --git a/src/Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs b/src/Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs deleted file mode 100644 index 702a762c01..0000000000 --- a/src/Microsoft.AspNetCore.SignalR/HubConnectionMetadataNames.cs +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -namespace Microsoft.AspNetCore.SignalR -{ - public static class HubConnectionMetadataNames - { - public static readonly string HubProtocol = nameof(HubProtocol); - public static readonly string Groups = nameof(Groups); - } -} diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index 4ade79b872..2cb809a3f7 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -11,6 +11,7 @@ using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.SignalR.Features; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; @@ -48,6 +49,11 @@ namespace Microsoft.AspNetCore.SignalR public async Task OnConnectedAsync(ConnectionContext connection) { var output = Channel.CreateUnbounded(); + + // Set the hub feature before doing anything else. This stores + // all the relevant state for a SignalR Hub connection + connection.Features.Set(new HubFeature()); + var connectionContext = new HubConnectionContext(output, connection); await ProcessNegotiate(connectionContext); @@ -101,8 +107,7 @@ namespace Microsoft.AspNetCore.SignalR // Resolve the Hub Protocol for the connection and store it in metadata // Other components, outside the Hub, may need to know what protocol is in use // for a particular connection, so we store it here. - connection.Metadata[HubConnectionMetadataNames.HubProtocol] = - _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection); + connection.Protocol = _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection); return; } @@ -188,7 +193,7 @@ namespace Microsoft.AspNetCore.SignalR // is used to get the exception so we can bubble it up the stack var cts = new CancellationTokenSource(); var completion = new TaskCompletionSource(); - var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol); + var protocol = connection.Protocol; try { diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs index 56e233b8c7..7fa7278130 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionContext.cs @@ -1,8 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Security.Claims; +using System.Collections.Generic; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Http.Features; @@ -10,16 +9,12 @@ namespace Microsoft.AspNetCore.Sockets { public abstract class ConnectionContext { - public abstract string ConnectionId { get; } + public abstract string ConnectionId { get; set; } public abstract IFeatureCollection Features { get; } - public abstract ClaimsPrincipal User { get; set; } + public abstract IDictionary Metadata { get; set; } - // REVIEW: Should this be changed to items - public abstract ConnectionMetadata Metadata { get; } - - // TEMPORARY public abstract Channel Transport { get; set; } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionMetadata.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionMetadata.cs deleted file mode 100644 index 6ca39c25d7..0000000000 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/ConnectionMetadata.cs +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Concurrent; - -namespace Microsoft.AspNetCore.Sockets -{ - public class ConnectionMetadata - { - private ConcurrentDictionary _metadata = new ConcurrentDictionary(); - - public object this[object key] - { - get - { - object value; - _metadata.TryGetValue(key, out value); - return value; - } - set - { - _metadata[key] = value; - } - } - - public T GetOrAdd(object key, Func factory) - { - return (T)_metadata.GetOrAdd(key, k => factory(k)); - } - - public T Get(object key) - { - return (T)this[key]; - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionIdFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionIdFeature.cs new file mode 100644 index 0000000000..5dccb32388 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionIdFeature.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.AspNetCore.Sockets.Features +{ + public interface IConnectionIdFeature + { + string ConnectionId { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionMetadataFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionMetadataFeature.cs new file mode 100644 index 0000000000..78b59cd668 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionMetadataFeature.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Sockets.Features +{ + public interface IConnectionMetadataFeature + { + IDictionary Metadata { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs new file mode 100644 index 0000000000..8eeb8879e0 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks.Channels; + +namespace Microsoft.AspNetCore.Sockets.Features +{ + public interface IConnectionTransportFeature + { + Channel Transport { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionUserFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionUserFeature.cs new file mode 100644 index 0000000000..7637468399 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionUserFeature.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Security.Claims; +using System.Text; + +namespace Microsoft.AspNetCore.Sockets.Features +{ + public interface IConnectionUserFeature + { + ClaimsPrincipal User { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Http/ConnectionMetadataNames.cs b/src/Microsoft.AspNetCore.Sockets.Http/ConnectionMetadataNames.cs index 33ea27cd9c..5074b739c9 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/ConnectionMetadataNames.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/ConnectionMetadataNames.cs @@ -6,6 +6,5 @@ namespace Microsoft.AspNetCore.Sockets public static class ConnectionMetadataNames { public static readonly string Transport = nameof(Transport); - public static readonly string HttpContext = nameof(HttpContext); } } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Features/IHttpContextFeature.cs b/src/Microsoft.AspNetCore.Sockets.Http/Features/IHttpContextFeature.cs new file mode 100644 index 0000000000..d18ebb62d6 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Http/Features/IHttpContextFeature.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Sockets.Http.Features +{ + public interface IHttpContextFeature + { + HttpContext HttpContext { get; set; } + } + + public class HttpContextFeature : IHttpContextFeature + { + public HttpContext HttpContext { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContextExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContextExtensions.cs index 8a5b946d8c..74ce92ccd1 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContextExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContextExtensions.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Text; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Sockets.Http.Features; namespace Microsoft.AspNetCore.Sockets { @@ -12,7 +13,18 @@ namespace Microsoft.AspNetCore.Sockets { public static HttpContext GetHttpContext(this ConnectionContext connection) { - return connection.Metadata.Get(ConnectionMetadataNames.HttpContext); + return connection.Features.Get().HttpContext; + } + + public static void SetHttpContext(this ConnectionContext connection, HttpContext httpContext) + { + var feature = connection.Features.Get(); + if (feature == null) + { + feature = new HttpContextFeature(); + connection.Features.Set(feature); + } + feature.HttpContext = httpContext; } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs index e7db22909c..1a02aac7dd 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs @@ -241,7 +241,7 @@ namespace Microsoft.AspNetCore.Sockets connection.Status = DefaultConnectionContext.ConnectionStatus.Inactive; - connection.Metadata[ConnectionMetadataNames.HttpContext] = null; + connection.SetHttpContext(null); // Dispose the cancellation token connection.Cancellation.Dispose(); @@ -403,7 +403,7 @@ namespace Microsoft.AspNetCore.Sockets return false; } - var transport = connection.Metadata.Get(ConnectionMetadataNames.Transport); + var transport = (TransportType?)connection.Metadata[ConnectionMetadataNames.Transport]; if (transport == null) { @@ -419,7 +419,7 @@ namespace Microsoft.AspNetCore.Sockets // Setup the connection state from the http context connection.User = context.User; - connection.Metadata[ConnectionMetadataNames.HttpContext] = context; + connection.SetHttpContext(context); return true; } diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs new file mode 100644 index 0000000000..51a97e2b05 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs @@ -0,0 +1,118 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Sockets +{ + internal class ConnectionMetadata : IDictionary + { + public ConnectionMetadata() + : this(new Dictionary()) + { + } + + public ConnectionMetadata(IDictionary items) + { + Items = items; + } + + public IDictionary Items { get; } + + // Replace the indexer with one that returns null for missing values + object IDictionary.this[object key] + { + get + { + if (Items.TryGetValue(key, out var value)) + { + return value; + } + return null; + } + set { Items[key] = value; } + } + + void IDictionary.Add(object key, object value) + { + Items.Add(key, value); + } + + bool IDictionary.ContainsKey(object key) + { + return Items.ContainsKey(key); + } + + ICollection IDictionary.Keys + { + get { return Items.Keys; } + } + + bool IDictionary.Remove(object key) + { + return Items.Remove(key); + } + + bool IDictionary.TryGetValue(object key, out object value) + { + return Items.TryGetValue(key, out value); + } + + ICollection IDictionary.Values + { + get { return Items.Values; } + } + + void ICollection>.Add(KeyValuePair item) + { + Items.Add(item); + } + + void ICollection>.Clear() + { + Items.Clear(); + } + + bool ICollection>.Contains(KeyValuePair item) + { + return Items.Contains(item); + } + + void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) + { + Items.CopyTo(array, arrayIndex); + } + + int ICollection>.Count + { + get { return Items.Count; } + } + + bool ICollection>.IsReadOnly + { + get { return Items.IsReadOnly; } + } + + bool ICollection>.Remove(KeyValuePair item) + { + object value; + if (Items.TryGetValue(item.Key, out value) && Equals(item.Value, value)) + { + return Items.Remove(item.Key); + } + return false; + } + + IEnumerator> IEnumerable>.GetEnumerator() + { + return Items.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return Items.GetEnumerator(); + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs index a23b1224ac..4072a9e9ec 100644 --- a/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets/DefaultConnectionContext.cs @@ -2,15 +2,21 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Sockets.Features; namespace Microsoft.AspNetCore.Sockets { - public class DefaultConnectionContext : ConnectionContext + public class DefaultConnectionContext : ConnectionContext, + IConnectionIdFeature, + IConnectionMetadataFeature, + IConnectionTransportFeature, + IConnectionUserFeature { // This tcs exists so that multiple calls to DisposeAsync all wait asynchronously // on the same task @@ -22,6 +28,13 @@ namespace Microsoft.AspNetCore.Sockets Application = application; ConnectionId = id; LastSeenUtc = DateTime.UtcNow; + + // PERF: This type could just implement IFeatureCollection + Features = new FeatureCollection(); + Features.Set(this); + Features.Set(this); + Features.Set(this); + Features.Set(this); } public CancellationTokenSource Cancellation { get; set; } @@ -36,13 +49,13 @@ namespace Microsoft.AspNetCore.Sockets public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive; - public override string ConnectionId { get; } + public override string ConnectionId { get; set; } - public override IFeatureCollection Features { get; } = new FeatureCollection(); + public override IFeatureCollection Features { get; } - public override ClaimsPrincipal User { get; set; } + public ClaimsPrincipal User { get; set; } - public override ConnectionMetadata Metadata { get; } = new ConnectionMetadata(); + public override IDictionary Metadata { get; set; } = new ConnectionMetadata(); public Channel Application { get; } @@ -121,7 +134,6 @@ namespace Microsoft.AspNetCore.Sockets } } - public enum ConnectionStatus { Inactive, diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index a47c7ad13c..f88907b019 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -954,7 +954,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests { public override Task OnConnectedAsync() { - Context.Connection.Metadata.Get>("ConnectedTask")?.TrySetResult(true); + var tcs = (TaskCompletionSource)Context.Connection.Metadata["ConnectedTask"]; + tcs?.TrySetResult(true); return base.OnConnectedAsync(); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs index 774884cbcb..a69df361ae 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs @@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public DefaultConnectionContext Connection { get; } public Channel Application { get; } - public Task Connected => Connection.Metadata.Get>("ConnectedTask").Task; + public Task Connected => ((TaskCompletionSource)Connection.Metadata["ConnectedTask"]).Task; public TestClient(bool synchronousCallbacks = false) {