rename getid to negotiate (#124)
* rename getid to negotiate * also change SSE and Long Polling to require a pre-established connection * disallow changing transports mid-connection; return a 400 response if the user attempts to do so
This commit is contained in:
parent
3d5fc9493a
commit
464077866c
|
|
@ -65,7 +65,7 @@
|
|||
xhr.send(data);
|
||||
};
|
||||
|
||||
xhr('GET', `${sock.url}/getid`).then(connectionId => {
|
||||
xhr('GET', `${sock.url}/negotiate`).then(connectionId => {
|
||||
sock.connectionId = connectionId;
|
||||
|
||||
sock.onopen();
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@
|
|||
};
|
||||
|
||||
// Start the connection
|
||||
xhr('GET', `${sock.url}/getid`).then(connectionId => {
|
||||
xhr('GET', `${sock.url}/negotiate`).then(connectionId => {
|
||||
sock.connectionId = connectionId;
|
||||
|
||||
let source = new EventSource(`${sock.url}/sse?id=${connectionId}`);
|
||||
|
|
|
|||
|
|
@ -79,19 +79,19 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
}
|
||||
|
||||
var logger = loggerFactory.CreateLogger<Connection>();
|
||||
var getIdUrl = Utils.AppendPath(url, "getid");
|
||||
var negotiateUrl = Utils.AppendPath(url, "negotiate");
|
||||
|
||||
string connectionId;
|
||||
try
|
||||
{
|
||||
// Get a connection ID from the server
|
||||
logger.LogDebug("Reserving Connection Id from: {0}", getIdUrl);
|
||||
connectionId = await httpClient.GetStringAsync(getIdUrl);
|
||||
logger.LogDebug("Reserved Connection Id: {0}", connectionId);
|
||||
logger.LogDebug("Establishing Connection at: {0}", negotiateUrl);
|
||||
connectionId = await httpClient.GetStringAsync(negotiateUrl);
|
||||
logger.LogDebug("Connection Id: {0}", connectionId);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", getIdUrl, ex);
|
||||
logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", negotiateUrl, ex);
|
||||
throw;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
using System;
|
||||
using System.Security.Claims;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
|
|
|
|||
|
|
@ -33,9 +33,9 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
// Get the end point mapped to this http connection
|
||||
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
|
||||
|
||||
if (context.Request.Path.StartsWithSegments(path + "/getid"))
|
||||
if (context.Request.Path.StartsWithSegments(path + "/negotiate"))
|
||||
{
|
||||
await ProcessGetId(context);
|
||||
await ProcessNegotiate(context);
|
||||
}
|
||||
else if (context.Request.Path.StartsWithSegments(path + "/send"))
|
||||
{
|
||||
|
|
@ -49,23 +49,25 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
private async Task ExecuteEndpointAsync(string path, HttpContext context, EndPoint endpoint)
|
||||
{
|
||||
var format =
|
||||
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
|
||||
? Format.Binary
|
||||
: Format.Text;
|
||||
|
||||
var state = GetOrCreateConnection(context);
|
||||
|
||||
// Adapt the connection to a message-based transport if necessary, since all the HTTP transports are message-based.
|
||||
var application = state.Application;
|
||||
|
||||
// Server sent events transport
|
||||
if (context.Request.Path.StartsWithSegments(path + "/sse"))
|
||||
{
|
||||
InitializePersistentConnection(state, "sse", context, endpoint, format);
|
||||
// Connection must already exist
|
||||
var state = await GetConnectionAsync(context);
|
||||
if (state == null)
|
||||
{
|
||||
// No such connection, GetConnection already set the response status code
|
||||
return;
|
||||
}
|
||||
|
||||
if (!await EnsureConnectionStateAsync(state, context, ServerSentEventsTransport.Name))
|
||||
{
|
||||
// Bad connection state. It's already set the response status code.
|
||||
return;
|
||||
}
|
||||
|
||||
// We only need to provide the Input channel since writing to the application is handled through /send.
|
||||
var sse = new ServerSentEventsTransport(application.Input, _loggerFactory);
|
||||
var sse = new ServerSentEventsTransport(state.Application.Input, _loggerFactory);
|
||||
|
||||
await DoPersistentConnection(endpoint, sse, context, state);
|
||||
|
||||
|
|
@ -73,9 +75,21 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
}
|
||||
else if (context.Request.Path.StartsWithSegments(path + "/ws"))
|
||||
{
|
||||
InitializePersistentConnection(state, "websockets", context, endpoint, format);
|
||||
// Connection can be established lazily
|
||||
var state = await GetOrCreateConnectionAsync(context);
|
||||
if (state == null)
|
||||
{
|
||||
// No such connection, GetOrCreateConnection already set the response status code
|
||||
return;
|
||||
}
|
||||
|
||||
var ws = new WebSocketsTransport(application, _loggerFactory);
|
||||
if (!await EnsureConnectionStateAsync(state, context, WebSocketsTransport.Name))
|
||||
{
|
||||
// Bad connection state. It's already set the response status code.
|
||||
return;
|
||||
}
|
||||
|
||||
var ws = new WebSocketsTransport(state.Application, _loggerFactory);
|
||||
|
||||
await DoPersistentConnection(endpoint, ws, context, state);
|
||||
|
||||
|
|
@ -83,14 +97,24 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
}
|
||||
else if (context.Request.Path.StartsWithSegments(path + "/poll"))
|
||||
{
|
||||
// TODO: this is wrong. + how does the user add their own metadata based on HttpContext
|
||||
var formatType = (string)context.Request.Query["formatType"];
|
||||
state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType;
|
||||
// Connection must already exist
|
||||
var state = await GetConnectionAsync(context);
|
||||
if (state == null)
|
||||
{
|
||||
// No such connection, GetConnection already set the response status code
|
||||
return;
|
||||
}
|
||||
|
||||
if (!await EnsureConnectionStateAsync(state, context, LongPollingTransport.Name))
|
||||
{
|
||||
// Bad connection state. It's already set the response status code.
|
||||
return;
|
||||
}
|
||||
|
||||
// Mark the connection as active
|
||||
state.Active = true;
|
||||
|
||||
var longPolling = new LongPollingTransport(application.Input, _loggerFactory);
|
||||
var longPolling = new LongPollingTransport(state.Application.Input, _loggerFactory);
|
||||
|
||||
// Start the transport
|
||||
var transportTask = longPolling.ProcessRequestAsync(context);
|
||||
|
|
@ -102,7 +126,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
_logger.LogDebug("Establishing new Long Polling connection: {0}", state.Connection.ConnectionId);
|
||||
|
||||
// This will re-initialize formatType metadata, but meh...
|
||||
InitializePersistentConnection(state, "poll", context, endpoint, format);
|
||||
state.Connection.Metadata["transport"] = LongPollingTransport.Name;
|
||||
|
||||
// REVIEW: This is super gross, this all needs to be cleaned up...
|
||||
state.Close = async () =>
|
||||
|
|
@ -149,10 +173,15 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
}
|
||||
}
|
||||
|
||||
private ConnectionState InitializePersistentConnection(ConnectionState state, string transport, HttpContext context, EndPoint endpoint, Format format)
|
||||
private ConnectionState CreateConnection(HttpContext context)
|
||||
{
|
||||
var format =
|
||||
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
|
||||
? Format.Binary
|
||||
: Format.Text;
|
||||
|
||||
var state = _manager.CreateConnection();
|
||||
state.Connection.User = context.User;
|
||||
state.Connection.Metadata["transport"] = transport;
|
||||
|
||||
// TODO: this is wrong. + how does the user add their own metadata based on HttpContext
|
||||
var formatType = (string)context.Request.Query["formatType"];
|
||||
|
|
@ -181,10 +210,10 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
await Task.WhenAll(endpointTask, transportTask);
|
||||
}
|
||||
|
||||
private Task ProcessGetId(HttpContext context)
|
||||
private Task ProcessNegotiate(HttpContext context)
|
||||
{
|
||||
// Establish the connection
|
||||
var state = _manager.CreateConnection();
|
||||
var state = CreateConnection(context);
|
||||
|
||||
// Get the bytes for the connection id
|
||||
var connectionIdBuffer = Encoding.UTF8.GetBytes(state.Connection.ConnectionId);
|
||||
|
|
@ -196,50 +225,85 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
private async Task ProcessSend(HttpContext context)
|
||||
{
|
||||
var connectionId = context.Request.Query["id"];
|
||||
if (StringValues.IsNullOrEmpty(connectionId))
|
||||
var state = await GetConnectionAsync(context);
|
||||
if (state == null)
|
||||
{
|
||||
throw new InvalidOperationException("Missing connection id");
|
||||
// No such connection, GetConnection already set the response status code
|
||||
return;
|
||||
}
|
||||
|
||||
ConnectionState state;
|
||||
if (_manager.TryGetConnection(connectionId, out state))
|
||||
// Collect the message and write it to the channel
|
||||
// TODO: Need to use some kind of pooled memory here.
|
||||
byte[] buffer;
|
||||
using (var stream = new MemoryStream())
|
||||
{
|
||||
// Collect the message and write it to the channel
|
||||
// TODO: Need to use some kind of pooled memory here.
|
||||
byte[] buffer;
|
||||
using (var stream = new MemoryStream())
|
||||
{
|
||||
await context.Request.Body.CopyToAsync(stream);
|
||||
buffer = stream.ToArray();
|
||||
}
|
||||
|
||||
var format =
|
||||
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
|
||||
? Format.Binary
|
||||
: Format.Text;
|
||||
|
||||
var message = new Message(
|
||||
ReadableBuffer.Create(buffer).Preserve(),
|
||||
format,
|
||||
endOfMessage: true);
|
||||
|
||||
// REVIEW: Do we want to return a specific status code here if the connection has ended?
|
||||
while (await state.Application.Output.WaitToWriteAsync())
|
||||
{
|
||||
if (state.Application.Output.TryWrite(message))
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
await context.Request.Body.CopyToAsync(stream);
|
||||
buffer = stream.ToArray();
|
||||
}
|
||||
else
|
||||
|
||||
var format =
|
||||
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
|
||||
? Format.Binary
|
||||
: Format.Text;
|
||||
|
||||
var message = new Message(
|
||||
ReadableBuffer.Create(buffer).Preserve(),
|
||||
format,
|
||||
endOfMessage: true);
|
||||
|
||||
// REVIEW: Do we want to return a specific status code here if the connection has ended?
|
||||
while (await state.Application.Output.WaitToWriteAsync())
|
||||
{
|
||||
throw new InvalidOperationException("Unknown connection id");
|
||||
if (state.Application.Output.TryWrite(message))
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private ConnectionState GetOrCreateConnection(HttpContext context)
|
||||
private async Task<bool> EnsureConnectionStateAsync(ConnectionState connectionState, HttpContext context, string transportName)
|
||||
{
|
||||
connectionState.Connection.User = context.User;
|
||||
|
||||
var transport = connectionState.Connection.Metadata.Get<string>("transport");
|
||||
if (string.IsNullOrEmpty(transport))
|
||||
{
|
||||
connectionState.Connection.Metadata["transport"] = transportName;
|
||||
}
|
||||
else if (!string.Equals(transport, transportName, StringComparison.Ordinal))
|
||||
{
|
||||
context.Response.StatusCode = 400;
|
||||
await context.Response.WriteAsync("Cannot change transports mid-connection");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private async Task<ConnectionState> GetConnectionAsync(HttpContext context)
|
||||
{
|
||||
var connectionId = context.Request.Query["id"];
|
||||
ConnectionState connectionState;
|
||||
|
||||
if (StringValues.IsNullOrEmpty(connectionId))
|
||||
{
|
||||
// There's no connection ID: bad request
|
||||
context.Response.StatusCode = StatusCodes.Status400BadRequest;
|
||||
await context.Response.WriteAsync("Connection ID required");
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!_manager.TryGetConnection(connectionId, out connectionState))
|
||||
{
|
||||
// No connection with that ID: Not Found
|
||||
context.Response.StatusCode = StatusCodes.Status404NotFound;
|
||||
await context.Response.WriteAsync("No Connection with that ID");
|
||||
return null;
|
||||
}
|
||||
|
||||
return connectionState;
|
||||
}
|
||||
|
||||
private async Task<ConnectionState> GetOrCreateConnectionAsync(HttpContext context)
|
||||
{
|
||||
var connectionId = context.Request.Query["id"];
|
||||
ConnectionState connectionState;
|
||||
|
|
@ -247,11 +311,14 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
// There's no connection id so this is a brand new connection
|
||||
if (StringValues.IsNullOrEmpty(connectionId))
|
||||
{
|
||||
connectionState = _manager.CreateConnection();
|
||||
connectionState = CreateConnection(context);
|
||||
}
|
||||
else if (!_manager.TryGetConnection(connectionId, out connectionState))
|
||||
{
|
||||
throw new InvalidOperationException("Unknown connection id");
|
||||
// No connection with that ID: Not Found
|
||||
context.Response.StatusCode = StatusCodes.Status404NotFound;
|
||||
await context.Response.WriteAsync("No Connection with that ID");
|
||||
return null;
|
||||
}
|
||||
|
||||
return connectionState;
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
|
|||
{
|
||||
public class LongPollingTransport : IHttpTransport
|
||||
{
|
||||
public static readonly string Name = "longPolling";
|
||||
private readonly IReadableChannel<Message> _application;
|
||||
private readonly ILogger _logger;
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
|
|
@ -12,6 +11,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
|
|||
{
|
||||
public class ServerSentEventsTransport : IHttpTransport
|
||||
{
|
||||
public static readonly string Name = "serverSentEvents";
|
||||
private readonly IReadableChannel<Message> _application;
|
||||
private readonly ILogger _logger;
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
using System;
|
||||
using System.Diagnostics;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.WebSockets.Internal;
|
||||
using Microsoft.Extensions.Internal;
|
||||
|
|
@ -15,8 +14,10 @@ namespace Microsoft.AspNetCore.Sockets.Transports
|
|||
{
|
||||
public class WebSocketsTransport : IHttpTransport
|
||||
{
|
||||
public static readonly string Name = "webSockets";
|
||||
|
||||
private static readonly TimeSpan _closeTimeout = TimeSpan.FromSeconds(5);
|
||||
private static readonly WebSocketAcceptContext EmptyContext = new WebSocketAcceptContext();
|
||||
private static readonly WebSocketAcceptContext _emptyContext = new WebSocketAcceptContext();
|
||||
|
||||
private WebSocketOpcode _lastOpcode = WebSocketOpcode.Continuation;
|
||||
private bool _lastFrameIncomplete = false;
|
||||
|
|
@ -48,7 +49,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
|
|||
return;
|
||||
}
|
||||
|
||||
using (var ws = await feature.AcceptWebSocketConnectionAsync(EmptyContext))
|
||||
using (var ws = await feature.AcceptWebSocketConnectionAsync(_emptyContext))
|
||||
{
|
||||
_logger.LogInformation("Socket opened.");
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.IO.Pipelines;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
|
|
@ -20,7 +19,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
public class HttpConnectionDispatcherTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task GetIdReservesConnectionIdAndReturnsIt()
|
||||
public async Task NegotiateReservesConnectionIdAndReturnsIt()
|
||||
{
|
||||
var manager = new ConnectionManager();
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
|
|
@ -29,7 +28,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
services.AddSingleton<TestEndPoint>();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
var ms = new MemoryStream();
|
||||
context.Request.Path = "/getid";
|
||||
context.Request.Path = "/negotiate";
|
||||
context.Response.Body = ms;
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
|
||||
|
||||
|
|
@ -40,41 +39,61 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
Assert.Equal(id, state.Connection.ConnectionId);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SendingToUnknownConnectionIdThrows()
|
||||
[Theory]
|
||||
[InlineData("/send")]
|
||||
[InlineData("/sse")]
|
||||
[InlineData("/poll")]
|
||||
[InlineData("/ws")]
|
||||
public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvided(string path)
|
||||
{
|
||||
var manager = new ConnectionManager();
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddSingleton<TestEndPoint>();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = "/send";
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = "unknown";
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
|
||||
|
||||
using (var strm = new MemoryStream())
|
||||
{
|
||||
var context = new DefaultHttpContext();
|
||||
context.Response.Body = strm;
|
||||
|
||||
var services = new ServiceCollection();
|
||||
services.AddSingleton<TestEndPoint>();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = path;
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = "unknown";
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
|
||||
});
|
||||
|
||||
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
|
||||
await strm.FlushAsync();
|
||||
Assert.Equal("No Connection with that ID", Encoding.UTF8.GetString(strm.ToArray()));
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SendingWithoutConnectionIdThrows()
|
||||
[Theory]
|
||||
[InlineData("/send")]
|
||||
[InlineData("/sse")]
|
||||
[InlineData("/poll")]
|
||||
public async Task EndpointsThatRequireConnectionId400WhenNoConnectionIdProvided(string path)
|
||||
{
|
||||
|
||||
var manager = new ConnectionManager();
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddSingleton<TestEndPoint>();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = "/send";
|
||||
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
|
||||
using (var strm = new MemoryStream())
|
||||
{
|
||||
var context = new DefaultHttpContext();
|
||||
context.Response.Body = strm;
|
||||
var services = new ServiceCollection();
|
||||
services.AddSingleton<TestEndPoint>();
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = path;
|
||||
|
||||
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
|
||||
});
|
||||
|
||||
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
|
||||
await strm.FlushAsync();
|
||||
Assert.Equal("Connection ID required", Encoding.UTF8.GetString(strm.ToArray()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,8 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.IO.Pipelines;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
|
|
|
|||
Loading…
Reference in New Issue