// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.IO; using System.IO.Pipelines; using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Sockets.Transports; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Sockets { public class HttpConnectionDispatcher { private readonly ConnectionManager _manager; private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; public HttpConnectionDispatcher(ConnectionManager manager, ILoggerFactory loggerFactory) { _manager = manager; _loggerFactory = loggerFactory; _logger = _loggerFactory.CreateLogger(); } public async Task ExecuteAsync(HttpContext context, HttpSocketOptions options, SocketDelegate socketDelegate) { if (!await AuthorizeHelper.AuthorizeAsync(context, options.AuthorizationPolicyNames)) { return; } if (HttpMethods.IsOptions(context.Request.Method)) { // OPTIONS /{path} await ProcessNegotiate(context, options); } else if (HttpMethods.IsPost(context.Request.Method)) { // POST /{path} await ProcessSend(context); } else if (HttpMethods.IsGet(context.Request.Method)) { // GET /{path} await ExecuteEndpointAsync(context, socketDelegate, options); } else { context.Response.StatusCode = StatusCodes.Status405MethodNotAllowed; } } private async Task ExecuteEndpointAsync(HttpContext context, SocketDelegate socketDelegate, HttpSocketOptions options) { var supportedTransports = options.Transports; // Server sent events transport // GET /{path} // Accept: text/event-stream var headers = context.Request.GetTypedHeaders(); if (headers.Accept?.Contains(new Net.Http.Headers.MediaTypeHeaderValue("text/event-stream")) == true) { // Connection must already exist var connection = await GetConnectionAsync(context); if (connection == null) { // No such connection, GetConnection already set the response status code return; } if (!await EnsureConnectionStateAsync(connection, context, TransportType.ServerSentEvents, supportedTransports)) { // Bad connection state. It's already set the response status code. return; } // We only need to provide the Input channel since writing to the application is handled through /send. var sse = new ServerSentEventsTransport(connection.Application.Input, _loggerFactory); await DoPersistentConnection(socketDelegate, sse, context, connection); } else if (context.WebSockets.IsWebSocketRequest) { // Connection can be established lazily var connection = await GetOrCreateConnectionAsync(context); if (connection == null) { // No such connection, GetOrCreateConnection already set the response status code return; } if (!await EnsureConnectionStateAsync(connection, context, TransportType.WebSockets, supportedTransports)) { // Bad connection state. It's already set the response status code. return; } var ws = new WebSocketsTransport(options.WebSockets, connection.Application, _loggerFactory); await DoPersistentConnection(socketDelegate, ws, context, connection); } else { // GET /{path} maps to long polling // Connection must already exist var connection = await GetConnectionAsync(context); if (connection == null) { // No such connection, GetConnection already set the response status code return; } if (!await EnsureConnectionStateAsync(connection, context, TransportType.LongPolling, supportedTransports)) { // Bad connection state. It's already set the response status code. return; } try { await connection.Lock.WaitAsync(); if (connection.Status == DefaultConnectionContext.ConnectionStatus.Disposed) { _logger.LogDebug("Connection {connectionId} was disposed,", connection.ConnectionId); // The connection was disposed context.Response.StatusCode = StatusCodes.Status404NotFound; return; } if (connection.Status == DefaultConnectionContext.ConnectionStatus.Active) { _logger.LogDebug("Connection {connectionId} is already active via {requestId}. Cancelling previous request.", connection.ConnectionId, connection.GetHttpContext().TraceIdentifier); using (connection.Cancellation) { // Cancel the previous request connection.Cancellation.Cancel(); try { // Wait for the previous request to drain await connection.TransportTask; } catch (OperationCanceledException) { // Should be a cancelled task } _logger.LogDebug("Previous poll cancelled for {connectionId} on {requestId}.", connection.ConnectionId, connection.GetHttpContext().TraceIdentifier); } } // Mark the connection as active connection.Status = DefaultConnectionContext.ConnectionStatus.Active; // Raise OnConnected for new connections only since polls happen all the time if (connection.ApplicationTask == null) { _logger.LogDebug("Establishing new connection: {connectionId} on {requestId}", connection.ConnectionId, connection.GetHttpContext().TraceIdentifier); connection.Metadata[ConnectionMetadataNames.Transport] = TransportType.LongPolling; connection.ApplicationTask = ExecuteApplication(socketDelegate, connection); } else { _logger.LogDebug("Resuming existing connection: {connectionId} on {requestId}", connection.ConnectionId, connection.GetHttpContext().TraceIdentifier); } var longPolling = new LongPollingTransport(connection.Application.Input, _loggerFactory); connection.Cancellation = new CancellationTokenSource(); // REVIEW: Performance of this isn't great as this does a bunch of per request allocations var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(connection.Cancellation.Token, context.RequestAborted); // Start the transport connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token); } finally { connection.Lock.Release(); } var resultTask = await Task.WhenAny(connection.ApplicationTask, connection.TransportTask); var pollAgain = true; // If the application ended before the transport task then we need to potentially need to end the // connection if (resultTask == connection.ApplicationTask) { // Complete the transport (notifying it of the application error if there is one) connection.Transport.Output.TryComplete(connection.ApplicationTask.Exception); // Wait for the transport to run await connection.TransportTask; // If the status code is a 204 it means we didn't write anything if (context.Response.StatusCode == StatusCodes.Status204NoContent) { // We should be able to safely dispose because there's no more data being written await _manager.DisposeAndRemoveAsync(connection); // Don't poll again if we've removed the connection completely pollAgain = false; } } else if (resultTask.IsCanceled) { // Don't poll if the transport task was cancelled pollAgain = false; } if (pollAgain) { // Otherwise, we update the state to inactive again and wait for the next poll try { await connection.Lock.WaitAsync(); if (connection.Status == DefaultConnectionContext.ConnectionStatus.Active) { // Mark the connection as inactive connection.LastSeenUtc = DateTime.UtcNow; connection.Status = DefaultConnectionContext.ConnectionStatus.Inactive; connection.Metadata[ConnectionMetadataNames.HttpContext] = null; // Dispose the cancellation token connection.Cancellation.Dispose(); connection.Cancellation = null; } } finally { connection.Lock.Release(); } } } } private async Task DoPersistentConnection(SocketDelegate socketDelegate, IHttpTransport transport, HttpContext context, DefaultConnectionContext connection) { try { await connection.Lock.WaitAsync(); if (connection.Status == DefaultConnectionContext.ConnectionStatus.Disposed) { _logger.LogDebug("Connection {connectionId} was disposed,", connection.ConnectionId); // Connection was disposed context.Response.StatusCode = StatusCodes.Status404NotFound; return; } // There's already an active request if (connection.Status == DefaultConnectionContext.ConnectionStatus.Active) { _logger.LogDebug("Connection {connectionId} is already active via {requestId}.", connection.ConnectionId, connection.GetHttpContext().TraceIdentifier); // Reject the request with a 409 conflict context.Response.StatusCode = StatusCodes.Status409Conflict; return; } // Mark the connection as active connection.Status = DefaultConnectionContext.ConnectionStatus.Active; // Call into the end point passing the connection connection.ApplicationTask = ExecuteApplication(socketDelegate, connection); // Start the transport connection.TransportTask = transport.ProcessRequestAsync(context, context.RequestAborted); } finally { connection.Lock.Release(); } // Wait for any of them to end await Task.WhenAny(connection.ApplicationTask, connection.TransportTask); await _manager.DisposeAndRemoveAsync(connection); } private async Task ExecuteApplication(SocketDelegate socketDelegate, ConnectionContext connection) { // Jump onto the thread pool thread so blocking user code doesn't block the setup of the // connection and transport await AwaitableThreadPool.Yield(); // Running this in an async method turns sync exceptions into async ones await socketDelegate(connection); } private Task ProcessNegotiate(HttpContext context, HttpSocketOptions options) { // Set the allowed headers for this resource context.Response.Headers.AppendCommaSeparatedValues("Allow", "GET", "POST", "OPTIONS"); context.Response.ContentType = "text/plain"; // Establish the connection var connection = _manager.CreateConnection(); // Get the bytes for the connection id var connectionIdBuffer = Encoding.UTF8.GetBytes(connection.ConnectionId); // Write it out to the response with the right content length context.Response.ContentLength = connectionIdBuffer.Length; return context.Response.Body.WriteAsync(connectionIdBuffer, 0, connectionIdBuffer.Length); } private async Task ProcessSend(HttpContext context) { var connection = await GetConnectionAsync(context); if (connection == null) { // No such connection, GetConnection already set the response status code return; } // TODO: Use a pool here byte[] buffer; using (var stream = new MemoryStream()) { await context.Request.Body.CopyToAsync(stream); await stream.FlushAsync(); buffer = stream.ToArray(); } while (!connection.Application.Output.TryWrite(buffer)) { if (!await connection.Application.Output.WaitToWriteAsync()) { return; } } } private async Task EnsureConnectionStateAsync(DefaultConnectionContext connection, HttpContext context, TransportType transportType, TransportType supportedTransports) { if ((supportedTransports & transportType) == 0) { context.Response.StatusCode = StatusCodes.Status404NotFound; await context.Response.WriteAsync($"{transportType} transport not supported by this end point type"); return false; } var transport = connection.Metadata.Get(ConnectionMetadataNames.Transport); if (transport == null) { connection.Metadata[ConnectionMetadataNames.Transport] = transportType; } else if (transport != transportType) { context.Response.StatusCode = StatusCodes.Status400BadRequest; await context.Response.WriteAsync("Cannot change transports mid-connection"); return false; } // Setup the connection state from the http context connection.User = context.User; connection.Metadata[ConnectionMetadataNames.HttpContext] = context; return true; } private async Task GetConnectionAsync(HttpContext context) { var connectionId = context.Request.Query["id"]; if (StringValues.IsNullOrEmpty(connectionId)) { // There's no connection ID: bad request context.Response.StatusCode = StatusCodes.Status400BadRequest; await context.Response.WriteAsync("Connection ID required"); return null; } if (!_manager.TryGetConnection(connectionId, out var connection)) { // No connection with that ID: Not Found context.Response.StatusCode = StatusCodes.Status404NotFound; await context.Response.WriteAsync("No Connection with that ID"); return null; } return connection; } private async Task GetOrCreateConnectionAsync(HttpContext context) { var connectionId = context.Request.Query["id"]; DefaultConnectionContext connection; // There's no connection id so this is a brand new connection if (StringValues.IsNullOrEmpty(connectionId)) { connection = _manager.CreateConnection(); } else if (!_manager.TryGetConnection(connectionId, out connection)) { // No connection with that ID: Not Found context.Response.StatusCode = StatusCodes.Status404NotFound; await context.Response.WriteAsync("No Connection with that ID"); return null; } return connection; } } }