rename getid to negotiate (#124)

* rename getid to negotiate
* also change SSE and Long Polling to require a pre-established connection
* disallow changing transports mid-connection; return a 400 response if the user attempts to do so
This commit is contained in:
Andrew Stanton-Nurse 2017-01-17 15:45:29 -08:00 committed by GitHub
parent 3d5fc9493a
commit 464077866c
10 changed files with 187 additions and 103 deletions

View File

@ -65,7 +65,7 @@
xhr.send(data);
};
xhr('GET', `${sock.url}/getid`).then(connectionId => {
xhr('GET', `${sock.url}/negotiate`).then(connectionId => {
sock.connectionId = connectionId;
sock.onopen();

View File

@ -53,7 +53,7 @@
};
// Start the connection
xhr('GET', `${sock.url}/getid`).then(connectionId => {
xhr('GET', `${sock.url}/negotiate`).then(connectionId => {
sock.connectionId = connectionId;
let source = new EventSource(`${sock.url}/sse?id=${connectionId}`);

View File

@ -79,19 +79,19 @@ namespace Microsoft.AspNetCore.Sockets.Client
}
var logger = loggerFactory.CreateLogger<Connection>();
var getIdUrl = Utils.AppendPath(url, "getid");
var negotiateUrl = Utils.AppendPath(url, "negotiate");
string connectionId;
try
{
// Get a connection ID from the server
logger.LogDebug("Reserving Connection Id from: {0}", getIdUrl);
connectionId = await httpClient.GetStringAsync(getIdUrl);
logger.LogDebug("Reserved Connection Id: {0}", connectionId);
logger.LogDebug("Establishing Connection at: {0}", negotiateUrl);
connectionId = await httpClient.GetStringAsync(negotiateUrl);
logger.LogDebug("Connection Id: {0}", connectionId);
}
catch (Exception ex)
{
logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", getIdUrl, ex);
logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", negotiateUrl, ex);
throw;
}

View File

@ -3,7 +3,6 @@
using System;
using System.Security.Claims;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets
{

View File

@ -33,9 +33,9 @@ namespace Microsoft.AspNetCore.Sockets
// Get the end point mapped to this http connection
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
if (context.Request.Path.StartsWithSegments(path + "/getid"))
if (context.Request.Path.StartsWithSegments(path + "/negotiate"))
{
await ProcessGetId(context);
await ProcessNegotiate(context);
}
else if (context.Request.Path.StartsWithSegments(path + "/send"))
{
@ -49,23 +49,25 @@ namespace Microsoft.AspNetCore.Sockets
private async Task ExecuteEndpointAsync(string path, HttpContext context, EndPoint endpoint)
{
var format =
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
? Format.Binary
: Format.Text;
var state = GetOrCreateConnection(context);
// Adapt the connection to a message-based transport if necessary, since all the HTTP transports are message-based.
var application = state.Application;
// Server sent events transport
if (context.Request.Path.StartsWithSegments(path + "/sse"))
{
InitializePersistentConnection(state, "sse", context, endpoint, format);
// Connection must already exist
var state = await GetConnectionAsync(context);
if (state == null)
{
// No such connection, GetConnection already set the response status code
return;
}
if (!await EnsureConnectionStateAsync(state, context, ServerSentEventsTransport.Name))
{
// Bad connection state. It's already set the response status code.
return;
}
// We only need to provide the Input channel since writing to the application is handled through /send.
var sse = new ServerSentEventsTransport(application.Input, _loggerFactory);
var sse = new ServerSentEventsTransport(state.Application.Input, _loggerFactory);
await DoPersistentConnection(endpoint, sse, context, state);
@ -73,9 +75,21 @@ namespace Microsoft.AspNetCore.Sockets
}
else if (context.Request.Path.StartsWithSegments(path + "/ws"))
{
InitializePersistentConnection(state, "websockets", context, endpoint, format);
// Connection can be established lazily
var state = await GetOrCreateConnectionAsync(context);
if (state == null)
{
// No such connection, GetOrCreateConnection already set the response status code
return;
}
var ws = new WebSocketsTransport(application, _loggerFactory);
if (!await EnsureConnectionStateAsync(state, context, WebSocketsTransport.Name))
{
// Bad connection state. It's already set the response status code.
return;
}
var ws = new WebSocketsTransport(state.Application, _loggerFactory);
await DoPersistentConnection(endpoint, ws, context, state);
@ -83,14 +97,24 @@ namespace Microsoft.AspNetCore.Sockets
}
else if (context.Request.Path.StartsWithSegments(path + "/poll"))
{
// TODO: this is wrong. + how does the user add their own metadata based on HttpContext
var formatType = (string)context.Request.Query["formatType"];
state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType;
// Connection must already exist
var state = await GetConnectionAsync(context);
if (state == null)
{
// No such connection, GetConnection already set the response status code
return;
}
if (!await EnsureConnectionStateAsync(state, context, LongPollingTransport.Name))
{
// Bad connection state. It's already set the response status code.
return;
}
// Mark the connection as active
state.Active = true;
var longPolling = new LongPollingTransport(application.Input, _loggerFactory);
var longPolling = new LongPollingTransport(state.Application.Input, _loggerFactory);
// Start the transport
var transportTask = longPolling.ProcessRequestAsync(context);
@ -102,7 +126,7 @@ namespace Microsoft.AspNetCore.Sockets
_logger.LogDebug("Establishing new Long Polling connection: {0}", state.Connection.ConnectionId);
// This will re-initialize formatType metadata, but meh...
InitializePersistentConnection(state, "poll", context, endpoint, format);
state.Connection.Metadata["transport"] = LongPollingTransport.Name;
// REVIEW: This is super gross, this all needs to be cleaned up...
state.Close = async () =>
@ -149,10 +173,15 @@ namespace Microsoft.AspNetCore.Sockets
}
}
private ConnectionState InitializePersistentConnection(ConnectionState state, string transport, HttpContext context, EndPoint endpoint, Format format)
private ConnectionState CreateConnection(HttpContext context)
{
var format =
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
? Format.Binary
: Format.Text;
var state = _manager.CreateConnection();
state.Connection.User = context.User;
state.Connection.Metadata["transport"] = transport;
// TODO: this is wrong. + how does the user add their own metadata based on HttpContext
var formatType = (string)context.Request.Query["formatType"];
@ -181,10 +210,10 @@ namespace Microsoft.AspNetCore.Sockets
await Task.WhenAll(endpointTask, transportTask);
}
private Task ProcessGetId(HttpContext context)
private Task ProcessNegotiate(HttpContext context)
{
// Establish the connection
var state = _manager.CreateConnection();
var state = CreateConnection(context);
// Get the bytes for the connection id
var connectionIdBuffer = Encoding.UTF8.GetBytes(state.Connection.ConnectionId);
@ -196,50 +225,85 @@ namespace Microsoft.AspNetCore.Sockets
private async Task ProcessSend(HttpContext context)
{
var connectionId = context.Request.Query["id"];
if (StringValues.IsNullOrEmpty(connectionId))
var state = await GetConnectionAsync(context);
if (state == null)
{
throw new InvalidOperationException("Missing connection id");
// No such connection, GetConnection already set the response status code
return;
}
ConnectionState state;
if (_manager.TryGetConnection(connectionId, out state))
// Collect the message and write it to the channel
// TODO: Need to use some kind of pooled memory here.
byte[] buffer;
using (var stream = new MemoryStream())
{
// Collect the message and write it to the channel
// TODO: Need to use some kind of pooled memory here.
byte[] buffer;
using (var stream = new MemoryStream())
{
await context.Request.Body.CopyToAsync(stream);
buffer = stream.ToArray();
}
var format =
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
? Format.Binary
: Format.Text;
var message = new Message(
ReadableBuffer.Create(buffer).Preserve(),
format,
endOfMessage: true);
// REVIEW: Do we want to return a specific status code here if the connection has ended?
while (await state.Application.Output.WaitToWriteAsync())
{
if (state.Application.Output.TryWrite(message))
{
break;
}
}
await context.Request.Body.CopyToAsync(stream);
buffer = stream.ToArray();
}
else
var format =
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
? Format.Binary
: Format.Text;
var message = new Message(
ReadableBuffer.Create(buffer).Preserve(),
format,
endOfMessage: true);
// REVIEW: Do we want to return a specific status code here if the connection has ended?
while (await state.Application.Output.WaitToWriteAsync())
{
throw new InvalidOperationException("Unknown connection id");
if (state.Application.Output.TryWrite(message))
{
break;
}
}
}
private ConnectionState GetOrCreateConnection(HttpContext context)
private async Task<bool> EnsureConnectionStateAsync(ConnectionState connectionState, HttpContext context, string transportName)
{
connectionState.Connection.User = context.User;
var transport = connectionState.Connection.Metadata.Get<string>("transport");
if (string.IsNullOrEmpty(transport))
{
connectionState.Connection.Metadata["transport"] = transportName;
}
else if (!string.Equals(transport, transportName, StringComparison.Ordinal))
{
context.Response.StatusCode = 400;
await context.Response.WriteAsync("Cannot change transports mid-connection");
return false;
}
return true;
}
private async Task<ConnectionState> GetConnectionAsync(HttpContext context)
{
var connectionId = context.Request.Query["id"];
ConnectionState connectionState;
if (StringValues.IsNullOrEmpty(connectionId))
{
// There's no connection ID: bad request
context.Response.StatusCode = StatusCodes.Status400BadRequest;
await context.Response.WriteAsync("Connection ID required");
return null;
}
if (!_manager.TryGetConnection(connectionId, out connectionState))
{
// No connection with that ID: Not Found
context.Response.StatusCode = StatusCodes.Status404NotFound;
await context.Response.WriteAsync("No Connection with that ID");
return null;
}
return connectionState;
}
private async Task<ConnectionState> GetOrCreateConnectionAsync(HttpContext context)
{
var connectionId = context.Request.Query["id"];
ConnectionState connectionState;
@ -247,11 +311,14 @@ namespace Microsoft.AspNetCore.Sockets
// There's no connection id so this is a brand new connection
if (StringValues.IsNullOrEmpty(connectionId))
{
connectionState = _manager.CreateConnection();
connectionState = CreateConnection(context);
}
else if (!_manager.TryGetConnection(connectionId, out connectionState))
{
throw new InvalidOperationException("Unknown connection id");
// No connection with that ID: Not Found
context.Response.StatusCode = StatusCodes.Status404NotFound;
await context.Response.WriteAsync("No Connection with that ID");
return null;
}
return connectionState;

View File

@ -13,6 +13,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
{
public class LongPollingTransport : IHttpTransport
{
public static readonly string Name = "longPolling";
private readonly IReadableChannel<Message> _application;
private readonly ILogger _logger;

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http;
@ -12,6 +11,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
{
public class ServerSentEventsTransport : IHttpTransport
{
public static readonly string Name = "serverSentEvents";
private readonly IReadableChannel<Message> _application;
private readonly ILogger _logger;

View File

@ -4,7 +4,6 @@
using System;
using System.Diagnostics;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.WebSockets.Internal;
using Microsoft.Extensions.Internal;
@ -15,8 +14,10 @@ namespace Microsoft.AspNetCore.Sockets.Transports
{
public class WebSocketsTransport : IHttpTransport
{
public static readonly string Name = "webSockets";
private static readonly TimeSpan _closeTimeout = TimeSpan.FromSeconds(5);
private static readonly WebSocketAcceptContext EmptyContext = new WebSocketAcceptContext();
private static readonly WebSocketAcceptContext _emptyContext = new WebSocketAcceptContext();
private WebSocketOpcode _lastOpcode = WebSocketOpcode.Continuation;
private bool _lastFrameIncomplete = false;
@ -48,7 +49,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
return;
}
using (var ws = await feature.AcceptWebSocketConnectionAsync(EmptyContext))
using (var ws = await feature.AcceptWebSocketConnectionAsync(_emptyContext))
{
_logger.LogInformation("Socket opened.");

View File

@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
@ -20,7 +19,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
public class HttpConnectionDispatcherTests
{
[Fact]
public async Task GetIdReservesConnectionIdAndReturnsIt()
public async Task NegotiateReservesConnectionIdAndReturnsIt()
{
var manager = new ConnectionManager();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
@ -29,7 +28,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
var ms = new MemoryStream();
context.Request.Path = "/getid";
context.Request.Path = "/negotiate";
context.Response.Body = ms;
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
@ -40,41 +39,61 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal(id, state.Connection.ConnectionId);
}
[Fact]
public async Task SendingToUnknownConnectionIdThrows()
[Theory]
[InlineData("/send")]
[InlineData("/sse")]
[InlineData("/poll")]
[InlineData("/ws")]
public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvided(string path)
{
var manager = new ConnectionManager();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/send";
var values = new Dictionary<string, StringValues>();
values["id"] = "unknown";
var qs = new QueryCollection(values);
context.Request.Query = qs;
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
using (var strm = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Response.Body = strm;
var services = new ServiceCollection();
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = path;
var values = new Dictionary<string, StringValues>();
values["id"] = "unknown";
var qs = new QueryCollection(values);
context.Request.Query = qs;
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
});
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
await strm.FlushAsync();
Assert.Equal("No Connection with that ID", Encoding.UTF8.GetString(strm.ToArray()));
}
}
[Fact]
public async Task SendingWithoutConnectionIdThrows()
[Theory]
[InlineData("/send")]
[InlineData("/sse")]
[InlineData("/poll")]
public async Task EndpointsThatRequireConnectionId400WhenNoConnectionIdProvided(string path)
{
var manager = new ConnectionManager();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/send";
await Assert.ThrowsAsync<InvalidOperationException>(async () =>
using (var strm = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Response.Body = strm;
var services = new ServiceCollection();
services.AddSingleton<TestEndPoint>();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = path;
await dispatcher.ExecuteAsync<TestEndPoint>("", context);
});
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await strm.FlushAsync();
Assert.Equal("Connection ID required", Encoding.UTF8.GetString(strm.ToArray()));
}
}
}

View File

@ -1,11 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;