diff --git a/samples/SocketsSample/Startup.cs b/samples/SocketsSample/Startup.cs index c2cd314ed0..05afcc881d 100644 --- a/samples/SocketsSample/Startup.cs +++ b/samples/SocketsSample/Startup.cs @@ -16,6 +16,7 @@ namespace SocketsSample services.AddRouting(); services.AddSignalR(); + // .AddRedis(); services.AddSingleton(); services.AddSingleton(); diff --git a/samples/SocketsSample/project.json b/samples/SocketsSample/project.json index 8d4206376c..99ce0811a1 100644 --- a/samples/SocketsSample/project.json +++ b/samples/SocketsSample/project.json @@ -3,6 +3,9 @@ "Microsoft.AspNetCore.SignalR": { "target": "project" }, + "Microsoft.AspNetCore.SignalR.Redis": { + "target": "project" + }, "Microsoft.AspNetCore.Sockets": { "target": "project" }, diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisDependencyInjectionExtensions.cs new file mode 100644 index 0000000000..decaa12942 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisDependencyInjectionExtensions.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.SignalR.Redis; + +namespace Microsoft.Extensions.DependencyInjection +{ + public static class RedisDependencyInjectionExtensions + { + public static ISignalRBuilder AddRedis(this ISignalRBuilder builder) + { + return AddRedis(builder, o => { }); + } + + public static ISignalRBuilder AddRedis(this ISignalRBuilder builder, Action configure) + { + builder.Services.Configure(configure); + builder.Services.AddSingleton(typeof(HubLifetimeManager<>), typeof(RedisHubLifetimeManager<>)); + return builder; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index bc1ccf3216..da7733de7f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -1,13 +1,11 @@ using System; -using System.Collections.Concurrent; -using System.Collections.Generic; using System.IO; using System.Text; -using System.Threading; using System.Threading.Tasks; using Channels; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; using StackExchange.Redis; namespace Microsoft.AspNetCore.SignalR.Redis @@ -15,17 +13,22 @@ namespace Microsoft.AspNetCore.SignalR.Redis public class RedisHubLifetimeManager : HubLifetimeManager, IDisposable { private readonly InvocationAdapterRegistry _registry; - private readonly ConnectionMultiplexer _redis; + private readonly ConnectionMultiplexer _redisServerConnection; private readonly ISubscriber _bus; private readonly ILoggerFactory _loggerFactory; + private readonly RedisOptions _options; - public RedisHubLifetimeManager(InvocationAdapterRegistry registry, ILoggerFactory loggerFactory) + public RedisHubLifetimeManager(InvocationAdapterRegistry registry, + ILoggerFactory loggerFactory, + IOptions options) { - var writer = new LoggerTextWriter(loggerFactory.CreateLogger>()); _loggerFactory = loggerFactory; - _redis = ConnectionMultiplexer.Connect("localhost", writer); - _bus = _redis.GetSubscriber(); _registry = registry; + _options = options.Value; + + var writer = new LoggerTextWriter(loggerFactory.CreateLogger>()); + _redisServerConnection = _options.Connect(writer); + _bus = _redisServerConnection.GetSubscriber(); } public override Task InvokeAll(string methodName, params object[] args) @@ -86,92 +89,74 @@ namespace Microsoft.AspNetCore.SignalR.Redis } } - public override Task OnConnectedAsync(Connection connection) + public override async Task OnConnectedAsync(Connection connection) { - var subs = connection.Metadata.GetOrAdd("subscriptions", k => new List()); - - subs.Add(Subscribe(typeof(THub).Name, connection)); - subs.Add(Subscribe(typeof(THub).Name + "." + connection.ConnectionId, connection)); - subs.Add(Subscribe(typeof(THub).Name + "." + connection.User.Identity.Name, connection)); - - return Task.CompletedTask; + await SubscribeAsync(typeof(THub).Name, connection); + await SubscribeAsync(typeof(THub).Name + "." + connection.ConnectionId, connection); + await SubscribeAsync(typeof(THub).Name + "." + connection.User.Identity.Name, connection); } public override Task OnDisconnectedAsync(Connection connection) { - var subs = connection.Metadata.Get>("subscriptions"); + var redisConnection = connection.Metadata.Get("redis"); - if (subs != null) + if (redisConnection == null) { - foreach (var sub in subs) - { - sub.Dispose(); - } + return Task.CompletedTask; } - connection.Metadata.Get("redis")?.Dispose(); + redisConnection.GetSubscriber().UnsubscribeAll(); + redisConnection.Close(allowCommandsToComplete: true); return Task.CompletedTask; } - public override void AddGroup(Connection connection, string groupName) + public override Task AddGroup(Connection connection, string groupName) { - var groups = connection.Metadata.GetOrAdd("groups", k => new ConcurrentDictionary()); var key = typeof(THub).Name + "." + groupName; - groups.TryAdd(key, Subscribe(key, connection)); + return SubscribeAsync(key, connection); } - public override void RemoveGroup(Connection connection, string groupName) + public override Task RemoveGroup(Connection connection, string groupName) { var key = typeof(THub) + "." + groupName; - var groups = connection.Metadata.Get>("groups"); - - IDisposable subscription; - if (groups != null && groups.TryRemove(key, out subscription)) - { - subscription.Dispose(); - } + return UnsubscribeAsync(key, connection); } - private IDisposable Subscribe(string channel, Connection connection) + private Task SubscribeAsync(string channel, Connection connection) { - var muxer = connection.Metadata.GetOrAdd("redis", k => + var redisConnection = connection.Metadata.GetOrAdd("redis", k => { var logger = _loggerFactory.CreateLogger("REDIS_" + connection.ConnectionId); - return ConnectionMultiplexer.Connect("localhost", new LoggerTextWriter(logger)); + // TODO: Async + return _options.Connect(new LoggerTextWriter(logger)); }); - var subscriber = muxer.GetSubscriber(); + var subscriber = redisConnection.GetSubscriber(); - subscriber.SubscribeAsync(channel, (c, data) => + return subscriber.SubscribeAsync(channel, (c, data) => { connection.Channel.Output.WriteAsync((byte[])data); }); + } - return new DisposableAction(() => + private Task UnsubscribeAsync(string channel, Connection connection) + { + var redisConnection = connection.Metadata.Get("redis"); + + if (redisConnection == null) { - subscriber.Unsubscribe(channel); - }); + return Task.CompletedTask; + } + + var subscriber = redisConnection.GetSubscriber(); + + return subscriber.UnsubscribeAsync(channel); } public void Dispose() { - _redis.Dispose(); - } - - private class DisposableAction : IDisposable - { - private Action _action; - - public DisposableAction(Action action) - { - _action = action; - } - - public void Dispose() - { - Interlocked.Exchange(ref _action, () => { }).Invoke(); - } + _redisServerConnection.Dispose(); } private class LoggerTextWriter : TextWriter diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisOptions.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisOptions.cs new file mode 100644 index 0000000000..5120e6fae6 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisOptions.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading.Tasks; +using StackExchange.Redis; + +namespace Microsoft.AspNetCore.SignalR.Redis +{ + public class RedisOptions + { + public ConfigurationOptions Options { get; set; } = new ConfigurationOptions(); + + public Func Factory { get; set; } + + // TODO: Async + internal ConnectionMultiplexer Connect(TextWriter log) + { + if (Factory == null) + { + // REVIEW: Should we do this? + if (Options.EndPoints.Count == 0) + { + Options.EndPoints.Add("localhost"); + } + return ConnectionMultiplexer.Connect(Options, log); + } + + return Factory(); + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs index a16becd557..1d5a8185e9 100644 --- a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs @@ -17,7 +17,7 @@ namespace Microsoft.AspNetCore.SignalR _registry = registry; } - public override void AddGroup(Connection connection, string groupName) + public override Task AddGroup(Connection connection, string groupName) { var groups = connection.Metadata.GetOrAdd("groups", k => new HashSet()); @@ -25,9 +25,11 @@ namespace Microsoft.AspNetCore.SignalR { groups.Add(groupName); } + + return Task.CompletedTask; } - public override void RemoveGroup(Connection connection, string groupName) + public override Task RemoveGroup(Connection connection, string groupName) { var groups = connection.Metadata.Get>("groups"); @@ -35,6 +37,8 @@ namespace Microsoft.AspNetCore.SignalR { groups.Remove(groupName); } + + return Task.CompletedTask; } public override Task InvokeAll(string methodName, params object[] args) diff --git a/src/Microsoft.AspNetCore.SignalR/DependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR/DependencyInjectionExtensions.cs index c12234a1d1..7bc3ec6a30 100644 --- a/src/Microsoft.AspNetCore.SignalR/DependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR/DependencyInjectionExtensions.cs @@ -2,21 +2,34 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; -using Microsoft.Extensions.DependencyInjection; +using Microsoft.AspNetCore.SignalR; -namespace Microsoft.AspNetCore.SignalR +namespace Microsoft.Extensions.DependencyInjection { public static class DependencyInjectionExtensions { - // TODO: We might need a builder here for things like scaleout - public static IServiceCollection AddSignalR(this IServiceCollection services) + public static ISignalRBuilder AddSignalR(this IServiceCollection services) { services.AddSingleton(typeof(HubLifetimeManager<>), typeof(DefaultHubLifetimeManager<>)); - // services.AddSingleton(typeof(HubLifetimeManager<>), typeof(RedisHubLifetimeManager<>)); services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>)); services.AddSingleton(typeof(RpcEndpoint<>), typeof(RpcEndpoint<>)); - return services; + return new SignalRBuilder(services); } } + + public interface ISignalRBuilder + { + IServiceCollection Services { get; } + } + + public class SignalRBuilder : ISignalRBuilder + { + public SignalRBuilder(IServiceCollection services) + { + Services = services; + } + + public IServiceCollection Services { get; } + } } diff --git a/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs index 87255ef4ca..2ee18cf6d5 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs @@ -17,9 +17,9 @@ namespace Microsoft.AspNetCore.SignalR public abstract Task InvokeUser(string userId, string methodName, params object[] args); - public abstract void AddGroup(Connection connection, string groupName); + public abstract Task AddGroup(Connection connection, string groupName); - public abstract void RemoveGroup(Connection connection, string groupName); + public abstract Task RemoveGroup(Connection connection, string groupName); } } diff --git a/src/Microsoft.AspNetCore.SignalR/IGroupManager.cs b/src/Microsoft.AspNetCore.SignalR/IGroupManager.cs index 98e10eb5dc..b3f783f06e 100644 --- a/src/Microsoft.AspNetCore.SignalR/IGroupManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/IGroupManager.cs @@ -1,8 +1,10 @@ -namespace Microsoft.AspNetCore.SignalR +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR { public interface IGroupManager { - void Add(string groupName); - void Remove(string groupName); + Task Add(string groupName); + Task Remove(string groupName); } } diff --git a/src/Microsoft.AspNetCore.SignalR/Proxies.cs b/src/Microsoft.AspNetCore.SignalR/Proxies.cs index b945444957..c8c9c3180f 100644 --- a/src/Microsoft.AspNetCore.SignalR/Proxies.cs +++ b/src/Microsoft.AspNetCore.SignalR/Proxies.cs @@ -83,14 +83,14 @@ namespace Microsoft.AspNetCore.SignalR _lifetimeManager = lifetimeManager; } - public void Add(string groupName) + public Task Add(string groupName) { - _lifetimeManager.AddGroup(_connection, groupName); + return _lifetimeManager.AddGroup(_connection, groupName); } - public void Remove(string groupName) + public Task Remove(string groupName) { - _lifetimeManager.RemoveGroup(_connection, groupName); + return _lifetimeManager.RemoveGroup(_connection, groupName); } } }