diff --git a/src/Microsoft.AspNetCore.Sockets/Connection.cs b/src/Microsoft.AspNetCore.Sockets/Connection.cs index 10746ae64d..77a16f4220 100644 --- a/src/Microsoft.AspNetCore.Sockets/Connection.cs +++ b/src/Microsoft.AspNetCore.Sockets/Connection.cs @@ -12,6 +12,6 @@ namespace Microsoft.AspNetCore.Sockets public string ConnectionId { get; set; } public ClaimsPrincipal User { get; set; } public IChannel Channel { get; set; } - public IDictionary Metadata { get; } = new Dictionary(); + public ConnectionMetadata Metadata { get; } = new ConnectionMetadata(); } } diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs new file mode 100644 index 0000000000..18b4e8eccd --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs @@ -0,0 +1,24 @@ + +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Sockets +{ + public class ConnectionMetadata + { + private IDictionary _metadata = new Dictionary(); + + public Format Format { get; set; } = Format.Text; + + public object this[string key] + { + get + { + return _metadata[key]; + } + set + { + _metadata[key] = value; + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/Format.cs b/src/Microsoft.AspNetCore.Sockets/Format.cs new file mode 100644 index 0000000000..3e86022ec5 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/Format.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Sockets +{ + public enum Format + { + Text, + Binary + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index 08febfc486..e60f45ff5f 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -46,6 +46,10 @@ namespace Microsoft.AspNetCore.Sockets // Get the end point mapped to this http connection var endpoint = (EndPoint)context.RequestServices.GetRequiredService(); + var format = + string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase) + ? Format.Binary + : Format.Text; // Server sent events transport if (context.Request.Path.StartsWithSegments(path + "/sse")) @@ -54,6 +58,7 @@ namespace Microsoft.AspNetCore.Sockets var connectionState = GetOrCreateConnection(context); connectionState.Connection.User = context.User; connectionState.Connection.Metadata["transport"] = "sse"; + connectionState.Connection.Metadata.Format = format; var sse = new ServerSentEvents(connectionState.Connection); // Register this transport for disconnect @@ -82,6 +87,7 @@ namespace Microsoft.AspNetCore.Sockets var connectionState = GetOrCreateConnection(context); connectionState.Connection.User = context.User; connectionState.Connection.Metadata["transport"] = "websockets"; + connectionState.Connection.Metadata.Format = format; var ws = new WebSockets(connectionState.Connection); // Register this transport for disconnect @@ -136,6 +142,7 @@ namespace Microsoft.AspNetCore.Sockets if (isNewConnection) { connectionState.Connection.Metadata["transport"] = "poll"; + connectionState.Connection.Metadata.Format = format; connectionState.Connection.User = context.User; endpointTask = endpoint.OnConnected(connectionState.Connection); connectionState.Connection.Metadata["endpoint"] = endpointTask;