Features everywhere (#639)

* Features everywhere
- The goal here is to move things closer to the final design where
ConnectionContext represents a very low level primitive that represents
any connection like transport. As part of that change, we remove unnecessary
properties like User and move those into features. They temporarily live in the same
assembly but they are not required by ConnectionContext.
- Used features for Hubs instead of Metadata
- Metadata is no longer thread safe
This commit is contained in:
David Fowler 2017-07-06 11:27:16 -07:00 committed by GitHub
parent 3c0fb48fab
commit 595f783857
22 changed files with 335 additions and 95 deletions

View File

@ -21,6 +21,7 @@ namespace SocialWeather
public void OnConnectedAsync(ConnectionContext connection)
{
connection.Metadata["groups"] = new HashSet<string>();
connection.Metadata["format"] = "json";
_connectionList.Add(connection);
}
@ -35,7 +36,7 @@ namespace SocialWeather
foreach (var connection in _connectionList)
{
var context = connection.GetHttpContext();
var formatter = _formatterResolver.GetFormatter<T>(connection.Metadata.Get<string>("format"));
var formatter = _formatterResolver.GetFormatter<T>((string)connection.Metadata["format"]);
var ms = new MemoryStream();
await formatter.WriteAsync(data, ms);
@ -60,7 +61,7 @@ namespace SocialWeather
public void AddGroupAsync(ConnectionContext connection, string groupName)
{
var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>());
var groups = (HashSet<string>)connection.Metadata["groups"];
lock (groups)
{
groups.Add(groupName);
@ -69,7 +70,7 @@ namespace SocialWeather
public void RemoveGroupAsync(ConnectionContext connection, string groupName)
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
var groups = (HashSet<string>)connection.Metadata["groups"];
if (groups != null)
{
lock (groups)

View File

@ -32,7 +32,7 @@ namespace SocialWeather
public async Task ProcessRequests(ConnectionContext connection)
{
var formatter = _formatterResolver.GetFormatter<WeatherReport>(
connection.Metadata.Get<string>("formatType"));
(string)connection.Metadata["formatType"]);
while (await connection.Transport.In.WaitToReadAsync())
{

View File

@ -19,8 +19,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable
{
private const string RedisSubscriptionsMetadataName = "redis_subscriptions";
private readonly HubConnectionList _connections = new HubConnectionList();
// TODO: Investigate "memory leak" entries never get removed
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>();
@ -129,7 +127,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis
public override Task OnConnectedAsync(HubConnectionContext connection)
{
var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet<string>());
var feature = new RedisFeature();
connection.Features.Set<IRedisFeature>(feature);
var redisSubscriptions = feature.Subscriptions;
var connectionTask = Task.CompletedTask;
var userTask = Task.CompletedTask;
@ -178,7 +179,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var tasks = new List<Task>();
var redisSubscriptions = connection.Metadata.Get<HashSet<string>>(RedisSubscriptionsMetadataName);
var feature = connection.Features.Get<IRedisFeature>();
var redisSubscriptions = feature.Subscriptions;
if (redisSubscriptions != null)
{
foreach (var subscription in redisSubscriptions)
@ -188,7 +191,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
var groupNames = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
var groupNames = feature.Groups;
if (groupNames != null)
{
@ -211,7 +214,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
return;
}
var groupNames = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet<string>());
var feature = connection.Features.Get<IRedisFeature>();
var groupNames = feature.Groups;
lock (groupNames)
{
@ -274,7 +279,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return;
}
var groupNames = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
var feature = connection.Features.Get<IRedisFeature>();
var groupNames = feature.Groups;
if (groupNames != null)
{
lock (groupNames)
@ -363,5 +369,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis
public SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
public HubConnectionList Connections = new HubConnectionList();
}
private interface IRedisFeature
{
HashSet<string> Subscriptions { get; }
HashSet<string> Groups { get; }
}
private class RedisFeature : IRedisFeature
{
public HashSet<string> Subscriptions { get; } = new HashSet<string>();
public HashSet<string> Groups { get; } = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
}
}
}

View File

@ -5,6 +5,7 @@ using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Features;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
namespace Microsoft.AspNetCore.SignalR
@ -22,7 +23,9 @@ namespace Microsoft.AspNetCore.SignalR
return Task.CompletedTask;
}
var groups = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet<string>());
var feature = connection.Features.Get<IHubGroupsFeature>();
var groups = feature.Groups;
lock (groups)
{
groups.Add(groupName);
@ -39,12 +42,8 @@ namespace Microsoft.AspNetCore.SignalR
return Task.CompletedTask;
}
var groups = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
if (groups == null)
{
return Task.CompletedTask;
}
var feature = connection.Features.Get<IHubGroupsFeature>();
var groups = feature.Groups;
lock (groups)
{
@ -91,8 +90,14 @@ namespace Microsoft.AspNetCore.SignalR
{
return InvokeAllWhere(methodName, args, connection =>
{
var groups = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
return groups?.Contains(groupName) == true;
var feature = connection.Features.Get<IHubGroupsFeature>();
var groups = feature.Groups;
// PERF: ...
lock (groups)
{
return groups.Contains(groupName) == true;
}
});
}
@ -106,6 +111,9 @@ namespace Microsoft.AspNetCore.SignalR
public override Task OnConnectedAsync(HubConnectionContext connection)
{
// Set the hub groups feature
connection.Features.Set<IHubGroupsFeature>(new HubGroupsFeature());
_connections.Add(connection);
return Task.CompletedTask;
}
@ -134,5 +142,15 @@ namespace Microsoft.AspNetCore.SignalR
var invocationId = Interlocked.Increment(ref _nextInvocationId);
return invocationId.ToString();
}
private interface IHubGroupsFeature
{
HashSet<string> Groups { get; }
}
private class HubGroupsFeature : IHubGroupsFeature
{
public HashSet<string> Groups { get; } = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
}
}
}

View File

@ -0,0 +1,20 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
namespace Microsoft.AspNetCore.SignalR.Features
{
public interface IHubFeature
{
IHubProtocol Protocol { get; set; }
}
public class HubFeature : IHubFeature
{
public IHubProtocol Protocol { get; set; }
}
}

View File

@ -2,10 +2,14 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Security.Claims;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.SignalR.Features;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Features;
namespace Microsoft.AspNetCore.SignalR
{
@ -20,16 +24,24 @@ namespace Microsoft.AspNetCore.SignalR
_connectionContext = connectionContext;
}
private IHubFeature HubFeature => Features.Get<IHubFeature>();
// Used by the HubEndPoint only
internal ReadableChannel<byte[]> Input => _connectionContext.Transport;
public virtual string ConnectionId => _connectionContext.ConnectionId;
public virtual ClaimsPrincipal User => _connectionContext.User;
public virtual ClaimsPrincipal User => Features.Get<IConnectionUserFeature>()?.User;
public virtual ConnectionMetadata Metadata => _connectionContext.Metadata;
public virtual IFeatureCollection Features => _connectionContext.Features;
public virtual IHubProtocol Protocol => _connectionContext.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
public virtual IDictionary<object, object> Metadata => _connectionContext.Metadata;
public virtual IHubProtocol Protocol
{
get => HubFeature.Protocol;
set => HubFeature.Protocol = value;
}
public virtual WritableChannel<byte[]> Output => _output;
}

View File

@ -1,11 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.SignalR
{
public static class HubConnectionMetadataNames
{
public static readonly string HubProtocol = nameof(HubProtocol);
public static readonly string Groups = nameof(Groups);
}
}

View File

@ -11,6 +11,7 @@ using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.SignalR.Features;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
@ -48,6 +49,11 @@ namespace Microsoft.AspNetCore.SignalR
public async Task OnConnectedAsync(ConnectionContext connection)
{
var output = Channel.CreateUnbounded<byte[]>();
// Set the hub feature before doing anything else. This stores
// all the relevant state for a SignalR Hub connection
connection.Features.Set<IHubFeature>(new HubFeature());
var connectionContext = new HubConnectionContext(output, connection);
await ProcessNegotiate(connectionContext);
@ -101,8 +107,7 @@ namespace Microsoft.AspNetCore.SignalR
// Resolve the Hub Protocol for the connection and store it in metadata
// Other components, outside the Hub, may need to know what protocol is in use
// for a particular connection, so we store it here.
connection.Metadata[HubConnectionMetadataNames.HubProtocol] =
_protocolResolver.GetProtocol(negotiationMessage.Protocol, connection);
connection.Protocol = _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection);
return;
}
@ -188,7 +193,7 @@ namespace Microsoft.AspNetCore.SignalR
// is used to get the exception so we can bubble it up the stack
var cts = new CancellationTokenSource();
var completion = new TaskCompletionSource<object>();
var protocol = connection.Metadata.Get<IHubProtocol>(HubConnectionMetadataNames.HubProtocol);
var protocol = connection.Protocol;
try
{

View File

@ -1,8 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Security.Claims;
using System.Collections.Generic;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http.Features;
@ -10,16 +9,12 @@ namespace Microsoft.AspNetCore.Sockets
{
public abstract class ConnectionContext
{
public abstract string ConnectionId { get; }
public abstract string ConnectionId { get; set; }
public abstract IFeatureCollection Features { get; }
public abstract ClaimsPrincipal User { get; set; }
public abstract IDictionary<object, object> Metadata { get; set; }
// REVIEW: Should this be changed to items
public abstract ConnectionMetadata Metadata { get; }
// TEMPORARY
public abstract Channel<byte[]> Transport { get; set; }
}
}

View File

@ -1,37 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Concurrent;
namespace Microsoft.AspNetCore.Sockets
{
public class ConnectionMetadata
{
private ConcurrentDictionary<object, object> _metadata = new ConcurrentDictionary<object, object>();
public object this[object key]
{
get
{
object value;
_metadata.TryGetValue(key, out value);
return value;
}
set
{
_metadata[key] = value;
}
}
public T GetOrAdd<T>(object key, Func<object, T> factory)
{
return (T)_metadata.GetOrAdd(key, k => factory(k));
}
public T Get<T>(object key)
{
return (T)this[key];
}
}
}

View File

@ -0,0 +1,14 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
namespace Microsoft.AspNetCore.Sockets.Features
{
public interface IConnectionIdFeature
{
string ConnectionId { get; set; }
}
}

View File

@ -0,0 +1,13 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
namespace Microsoft.AspNetCore.Sockets.Features
{
public interface IConnectionMetadataFeature
{
IDictionary<object, object> Metadata { get; set; }
}
}

View File

@ -0,0 +1,15 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks.Channels;
namespace Microsoft.AspNetCore.Sockets.Features
{
public interface IConnectionTransportFeature
{
Channel<byte[]> Transport { get; set; }
}
}

View File

@ -0,0 +1,15 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Security.Claims;
using System.Text;
namespace Microsoft.AspNetCore.Sockets.Features
{
public interface IConnectionUserFeature
{
ClaimsPrincipal User { get; set; }
}
}

View File

@ -6,6 +6,5 @@ namespace Microsoft.AspNetCore.Sockets
public static class ConnectionMetadataNames
{
public static readonly string Transport = nameof(Transport);
public static readonly string HttpContext = nameof(HttpContext);
}
}

View File

@ -0,0 +1,20 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.AspNetCore.Http;
namespace Microsoft.AspNetCore.Sockets.Http.Features
{
public interface IHttpContextFeature
{
HttpContext HttpContext { get; set; }
}
public class HttpContextFeature : IHttpContextFeature
{
public HttpContext HttpContext { get; set; }
}
}

View File

@ -5,6 +5,7 @@ using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Sockets.Http.Features;
namespace Microsoft.AspNetCore.Sockets
{
@ -12,7 +13,18 @@ namespace Microsoft.AspNetCore.Sockets
{
public static HttpContext GetHttpContext(this ConnectionContext connection)
{
return connection.Metadata.Get<HttpContext>(ConnectionMetadataNames.HttpContext);
return connection.Features.Get<IHttpContextFeature>().HttpContext;
}
public static void SetHttpContext(this ConnectionContext connection, HttpContext httpContext)
{
var feature = connection.Features.Get<IHttpContextFeature>();
if (feature == null)
{
feature = new HttpContextFeature();
connection.Features.Set(feature);
}
feature.HttpContext = httpContext;
}
}
}

View File

@ -241,7 +241,7 @@ namespace Microsoft.AspNetCore.Sockets
connection.Status = DefaultConnectionContext.ConnectionStatus.Inactive;
connection.Metadata[ConnectionMetadataNames.HttpContext] = null;
connection.SetHttpContext(null);
// Dispose the cancellation token
connection.Cancellation.Dispose();
@ -403,7 +403,7 @@ namespace Microsoft.AspNetCore.Sockets
return false;
}
var transport = connection.Metadata.Get<TransportType?>(ConnectionMetadataNames.Transport);
var transport = (TransportType?)connection.Metadata[ConnectionMetadataNames.Transport];
if (transport == null)
{
@ -419,7 +419,7 @@ namespace Microsoft.AspNetCore.Sockets
// Setup the connection state from the http context
connection.User = context.User;
connection.Metadata[ConnectionMetadataNames.HttpContext] = context;
connection.SetHttpContext(context);
return true;
}

View File

@ -0,0 +1,118 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections;
using System.Collections.Generic;
namespace Microsoft.AspNetCore.Sockets
{
internal class ConnectionMetadata : IDictionary<object, object>
{
public ConnectionMetadata()
: this(new Dictionary<object, object>())
{
}
public ConnectionMetadata(IDictionary<object, object> items)
{
Items = items;
}
public IDictionary<object, object> Items { get; }
// Replace the indexer with one that returns null for missing values
object IDictionary<object, object>.this[object key]
{
get
{
if (Items.TryGetValue(key, out var value))
{
return value;
}
return null;
}
set { Items[key] = value; }
}
void IDictionary<object, object>.Add(object key, object value)
{
Items.Add(key, value);
}
bool IDictionary<object, object>.ContainsKey(object key)
{
return Items.ContainsKey(key);
}
ICollection<object> IDictionary<object, object>.Keys
{
get { return Items.Keys; }
}
bool IDictionary<object, object>.Remove(object key)
{
return Items.Remove(key);
}
bool IDictionary<object, object>.TryGetValue(object key, out object value)
{
return Items.TryGetValue(key, out value);
}
ICollection<object> IDictionary<object, object>.Values
{
get { return Items.Values; }
}
void ICollection<KeyValuePair<object, object>>.Add(KeyValuePair<object, object> item)
{
Items.Add(item);
}
void ICollection<KeyValuePair<object, object>>.Clear()
{
Items.Clear();
}
bool ICollection<KeyValuePair<object, object>>.Contains(KeyValuePair<object, object> item)
{
return Items.Contains(item);
}
void ICollection<KeyValuePair<object, object>>.CopyTo(KeyValuePair<object, object>[] array, int arrayIndex)
{
Items.CopyTo(array, arrayIndex);
}
int ICollection<KeyValuePair<object, object>>.Count
{
get { return Items.Count; }
}
bool ICollection<KeyValuePair<object, object>>.IsReadOnly
{
get { return Items.IsReadOnly; }
}
bool ICollection<KeyValuePair<object, object>>.Remove(KeyValuePair<object, object> item)
{
object value;
if (Items.TryGetValue(item.Key, out value) && Equals(item.Value, value))
{
return Items.Remove(item.Key);
}
return false;
}
IEnumerator<KeyValuePair<object, object>> IEnumerable<KeyValuePair<object, object>>.GetEnumerator()
{
return Items.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return Items.GetEnumerator();
}
}
}

View File

@ -2,15 +2,21 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Sockets.Features;
namespace Microsoft.AspNetCore.Sockets
{
public class DefaultConnectionContext : ConnectionContext
public class DefaultConnectionContext : ConnectionContext,
IConnectionIdFeature,
IConnectionMetadataFeature,
IConnectionTransportFeature,
IConnectionUserFeature
{
// This tcs exists so that multiple calls to DisposeAsync all wait asynchronously
// on the same task
@ -22,6 +28,13 @@ namespace Microsoft.AspNetCore.Sockets
Application = application;
ConnectionId = id;
LastSeenUtc = DateTime.UtcNow;
// PERF: This type could just implement IFeatureCollection
Features = new FeatureCollection();
Features.Set<IConnectionUserFeature>(this);
Features.Set<IConnectionMetadataFeature>(this);
Features.Set<IConnectionIdFeature>(this);
Features.Set<IConnectionTransportFeature>(this);
}
public CancellationTokenSource Cancellation { get; set; }
@ -36,13 +49,13 @@ namespace Microsoft.AspNetCore.Sockets
public ConnectionStatus Status { get; set; } = ConnectionStatus.Inactive;
public override string ConnectionId { get; }
public override string ConnectionId { get; set; }
public override IFeatureCollection Features { get; } = new FeatureCollection();
public override IFeatureCollection Features { get; }
public override ClaimsPrincipal User { get; set; }
public ClaimsPrincipal User { get; set; }
public override ConnectionMetadata Metadata { get; } = new ConnectionMetadata();
public override IDictionary<object, object> Metadata { get; set; } = new ConnectionMetadata();
public Channel<byte[]> Application { get; }
@ -121,7 +134,6 @@ namespace Microsoft.AspNetCore.Sockets
}
}
public enum ConnectionStatus
{
Inactive,

View File

@ -954,7 +954,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
public override Task OnConnectedAsync()
{
Context.Connection.Metadata.Get<TaskCompletionSource<bool>>("ConnectedTask")?.TrySetResult(true);
var tcs = (TaskCompletionSource<bool>)Context.Connection.Metadata["ConnectedTask"];
tcs?.TrySetResult(true);
return base.OnConnectedAsync();
}
}

View File

@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public DefaultConnectionContext Connection { get; }
public Channel<byte[]> Application { get; }
public Task Connected => Connection.Metadata.Get<TaskCompletionSource<bool>>("ConnectedTask").Task;
public Task Connected => ((TaskCompletionSource<bool>)Connection.Metadata["ConnectedTask"]).Task;
public TestClient(bool synchronousCallbacks = false)
{