diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs index 4a98254367..3b3d332e66 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs @@ -125,7 +125,7 @@ namespace Microsoft.AspNetCore.Sockets else if (context.WebSockets.IsWebSocketRequest) { // Connection can be established lazily - var connection = await GetOrCreateConnectionAsync(context); + var connection = await GetOrCreateConnectionAsync(context, options); if (connection == null) { // No such connection, GetOrCreateConnection already set the response status code @@ -364,7 +364,7 @@ namespace Microsoft.AspNetCore.Sockets context.Response.ContentType = "application/json"; // Establish the connection - var connection = _manager.CreateConnection(); + var connection = CreateConnectionInternal(options); // Set the Connection ID on the logging scope so that logs from now on will have the // Connection ID metadata set. @@ -515,7 +515,14 @@ namespace Microsoft.AspNetCore.Sockets return connection; } - private async Task GetOrCreateConnectionAsync(HttpContext context) + private DefaultConnectionContext CreateConnectionInternal(HttpSocketOptions options) + { + var transportPipeOptions = new PipeOptions(pauseWriterThreshold: options.TransportMaxBufferSize, resumeWriterThreshold: options.TransportMaxBufferSize / 2); + var appPipeOptions = new PipeOptions(pauseWriterThreshold: options.ApplicationMaxBufferSize, resumeWriterThreshold: options.ApplicationMaxBufferSize / 2); + return _manager.CreateConnection(transportPipeOptions, appPipeOptions); + } + + private async Task GetOrCreateConnectionAsync(HttpContext context, HttpSocketOptions options) { var connectionId = GetConnectionId(context); DefaultConnectionContext connection; @@ -523,7 +530,7 @@ namespace Microsoft.AspNetCore.Sockets // There's no connection id so this is a brand new connection if (StringValues.IsNullOrEmpty(connectionId)) { - connection = _manager.CreateConnection(); + connection = CreateConnectionInternal(options); } else if (!_manager.TryGetConnection(connectionId, out connection)) { diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpSocketOptions.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpSocketOptions.cs index 51dc9c4d5f..ffc61d563f 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpSocketOptions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpSocketOptions.cs @@ -1,7 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; using System.Collections.Generic; using Microsoft.AspNetCore.Authorization; @@ -16,5 +15,9 @@ namespace Microsoft.AspNetCore.Sockets public WebSocketOptions WebSockets { get; } = new WebSocketOptions(); public LongPollingOptions LongPolling { get; } = new LongPollingOptions(); + + public long TransportMaxBufferSize { get; set; } = 0; + + public long ApplicationMaxBufferSize { get; set; } = 0; } } diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index 939982c336..c9b956fa3a 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -56,15 +56,14 @@ namespace Microsoft.AspNetCore.Sockets return _connections.TryGetValue(id, out connection); } - public DefaultConnectionContext CreateConnection() + public DefaultConnectionContext CreateConnection(PipeOptions transportPipeOptions, PipeOptions appPipeOptions) { var id = MakeNewConnectionId(); _logger.CreatedNewConnection(id); var connectionTimer = SocketEventSource.Log.ConnectionStart(id); + var pair = DuplexPipe.CreateConnectionPair(transportPipeOptions, appPipeOptions); - var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); - var connection = new DefaultConnectionContext(id, pair.Application, pair.Transport); connection.ConnectionTimer = connectionTimer; @@ -72,6 +71,11 @@ namespace Microsoft.AspNetCore.Sockets return connection; } + public DefaultConnectionContext CreateConnection() + { + return CreateConnection(PipeOptions.Default, PipeOptions.Default); + } + public void RemoveConnection(string id) { if (_connections.TryRemove(id, out var connection)) diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs index 5098d4b5df..87575a3252 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs @@ -51,6 +51,21 @@ namespace System.IO.Pipelines } } + public static async Task ConsumeAsync(this PipeReader pipeReader, int numBytes) + { + while (true) + { + var result = await pipeReader.ReadAsync(); + if (result.Buffer.Length < numBytes) + { + pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.End); + continue; + } + pipeReader.AdvanceTo(result.Buffer.GetPosition(result.Buffer.Start, numBytes)); + break; + } + } + public static async Task ReadAllAsync(this PipeReader pipeReader) { while (true) diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 11998ab298..e95d3147f0 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -57,6 +57,90 @@ namespace Microsoft.AspNetCore.Sockets.Tests } } + [Fact] + public async Task CheckThatThresholdValuesAreEnforced() + { + using (StartLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddEndPoint(); + services.AddOptions(); + var ms = new MemoryStream(); + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + context.Response.Body = ms; + var httpSocketOptions = new HttpSocketOptions { TransportMaxBufferSize = 4, ApplicationMaxBufferSize = 4 }; + await dispatcher.ExecuteNegotiateAsync(context, httpSocketOptions); + var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); + var connectionId = negotiateResponse.Value("connectionId"); + Assert.True(manager.TryGetConnection(connectionId, out var connection)); + + // This write should complete immediately but it exceeds the writer threshold + var writeTask = connection.Application.Output.WriteAsync(new byte[] { (byte)'b', (byte)'y', (byte)'t', (byte)'e', (byte)'s' }); + + Assert.False(writeTask.IsCompleted); + + // Reading here puts us below the threshold + await connection.Transport.Input.ConsumeAsync(5); + + await writeTask.OrTimeout(); + } + } + + [Theory] + [InlineData(TransportType.LongPolling)] + [InlineData(TransportType.ServerSentEvents)] + public async Task CheckThatThresholdValuesAreEnforcedWithSends(TransportType transportType) + { + using (StartLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + var pipeOptions = new PipeOptions(pauseWriterThreshold: 8, resumeWriterThreshold: 4); + var connection = manager.CreateConnection(pipeOptions, pipeOptions); + connection.Metadata[ConnectionMetadataNames.Transport] = transportType; + + using (var requestBody = new MemoryStream()) + using (var responseBody = new MemoryStream()) + { + var bytes = Encoding.UTF8.GetBytes("EXTRADATA Hi"); + requestBody.Write(bytes, 0, bytes.Length); + requestBody.Seek(0, SeekOrigin.Begin); + + var context = new DefaultHttpContext(); + context.Request.Body = requestBody; + context.Response.Body = responseBody; + + var services = new ServiceCollection(); + services.AddEndPoint(); + services.AddOptions(); + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + var values = new Dictionary(); + values["id"] = connection.ConnectionId; + var qs = new QueryCollection(values); + context.Request.Query = qs; + + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + + // This task should complete immediately but it exceeds the writer threshold + var executeTask = dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); + Assert.False(executeTask.IsCompleted); + await connection.Transport.Input.ConsumeAsync(10); + await executeTask.OrTimeout(); + + Assert.True(connection.Transport.Input.TryRead(out var result)); + Assert.Equal("Hi", Encoding.UTF8.GetString(result.Buffer.ToArray())); + connection.Transport.Input.AdvanceTo(result.Buffer.End); + } + } + } + [Theory] [InlineData(TransportType.All)] [InlineData((TransportType)0)]