Fix WebSockets Negotiate Auth in Kestrel (#26480)

* Don't close connections after upgrade requests without a 101 response

* Add test

* Add DefautCredentials_WebSocket_Success
This commit is contained in:
Stephen Halter 2020-10-02 14:47:20 -07:00 committed by GitHub
parent 8eb9603a9c
commit 96c082f285
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 112 additions and 26 deletions

View File

@ -9,6 +9,7 @@
<Reference Include="Microsoft.AspNetCore.Authentication.Negotiate" />
<Reference Include="Microsoft.AspNetCore.Routing" />
<Reference Include="Microsoft.AspNetCore.Server.Kestrel" />
<Reference Include="Microsoft.AspNetCore.WebSockets" />
<Reference Include="Microsoft.Extensions.Hosting" />
<Reference Include="System.Net.Http.WinHttpHandler" />
</ItemGroup>

View File

@ -6,6 +6,9 @@ using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
@ -23,7 +26,7 @@ namespace Microsoft.AspNetCore.Authentication.Negotiate
{
// In theory this would work on Linux and Mac, but the client would require explicit credentials.
[OSSkipCondition(OperatingSystems.Linux | OperatingSystems.MacOSX)]
public class NegotiateHandlerFunctionalTests
public class NegotiateHandlerFunctionalTests : LoggedTest
{
private static readonly Version Http11Version = new Version(1, 1);
private static readonly Version Http2Version = new Version(2, 0);
@ -109,6 +112,34 @@ namespace Microsoft.AspNetCore.Authentication.Negotiate
Assert.Equal(Http11Version, result.Version); // HTTP/2 downgrades.
}
[ConditionalFact]
public async Task DefautCredentials_WebSocket_Success()
{
using var host = await CreateHostAsync();
var address = host.Services.GetRequiredService<IServer>().Features.Get<IServerAddressesFeature>().Addresses.First().Replace("https://", "wss://");
using var webSocket = new ClientWebSocket
{
Options =
{
RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true,
UseDefaultCredentials = true,
}
};
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));
await webSocket.ConnectAsync(new Uri($"{address}/AuthenticateWebSocket"), cts.Token);
var receiveBuffer = new byte[13];
var receiveResult = await webSocket.ReceiveAsync(receiveBuffer, cts.Token);
Assert.True(receiveResult.EndOfMessage);
Assert.Equal(WebSocketMessageType.Text, receiveResult.MessageType);
Assert.Equal("Hello World!", Encoding.UTF8.GetString(receiveBuffer, 0, receiveResult.Count));
}
public static IEnumerable<object[]> HttpOrders =>
new List<object[]>
{
@ -232,9 +263,10 @@ namespace Microsoft.AspNetCore.Authentication.Negotiate
Assert.Equal(Http11Version, result.Version); // HTTP/2 downgrades.
}
private static Task<IHost> CreateHostAsync(Action<NegotiateOptions> configureOptions = null)
private Task<IHost> CreateHostAsync(Action<NegotiateOptions> configureOptions = null)
{
var builder = new HostBuilder()
.ConfigureServices(AddTestLogging)
.ConfigureServices(services => services
.AddRouting()
.AddAuthentication(NegotiateDefaults.AuthenticationScheme)
@ -252,6 +284,7 @@ namespace Microsoft.AspNetCore.Authentication.Negotiate
{
app.UseRouting();
app.UseAuthentication();
app.UseWebSockets();
app.UseEndpoints(ConfigureEndpoints);
});
});
@ -289,6 +322,27 @@ namespace Microsoft.AspNetCore.Authentication.Negotiate
await context.Response.WriteAsync(name);
});
builder.Map("/AuthenticateWebSocket", async context =>
{
if (!context.User.Identity.IsAuthenticated)
{
await context.ChallengeAsync();
return;
}
if (!context.WebSockets.IsWebSocketRequest)
{
context.Response.StatusCode = 400;
return;
}
Assert.False(string.IsNullOrEmpty(context.User.Identity.Name), "name");
WebSocket webSocket = await context.WebSockets.AcceptWebSocketAsync();
await webSocket.SendAsync(Encoding.UTF8.GetBytes("Hello World!"), WebSocketMessageType.Text, endOfMessage: true, context.RequestAborted);
});
builder.Map("/AlreadyAuthenticated", async context =>
{
Assert.Equal("HTTP/1.1", context.Request.Protocol); // Not HTTP/2

View File

@ -30,10 +30,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private readonly Pipe _requestBodyPipe;
private ReadResult _readResult;
public Http1ChunkedEncodingMessageBody(bool keepAlive, Http1Connection context)
: base(context)
public Http1ChunkedEncodingMessageBody(Http1Connection context, bool keepAlive)
: base(context, keepAlive)
{
RequestKeepAlive = keepAlive;
_requestBodyPipe = CreateRequestBodyPipe(context);
}

View File

@ -10,8 +10,6 @@ using Microsoft.AspNetCore.Connections;
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
using BadHttpRequestException = Microsoft.AspNetCore.Http.BadHttpRequestException;
internal sealed class Http1ContentLengthMessageBody : Http1MessageBody
{
private ReadResult _readResult;
@ -23,12 +21,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private bool _finalAdvanceCalled;
private bool _cannotResetInputPipe;
public Http1ContentLengthMessageBody(bool keepAlive, long contentLength, Http1Connection context)
: base(context)
public Http1ContentLengthMessageBody(Http1Connection context, long contentLength, bool keepAlive)
: base(context, keepAlive)
{
RequestKeepAlive = keepAlive;
_contentLength = contentLength;
_unexaminedInputLength = _contentLength;
_unexaminedInputLength = contentLength;
}
public override ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = default)

View File

@ -18,9 +18,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
protected readonly Http1Connection _context;
protected bool _completed;
protected Http1MessageBody(Http1Connection context) : base(context)
protected Http1MessageBody(Http1Connection context, bool keepAlive) : base(context)
{
_context = context;
RequestKeepAlive = keepAlive;
}
[StackTraceHidden]
@ -118,14 +119,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
// see also http://tools.ietf.org/html/rfc2616#section-4.4
var keepAlive = httpVersion != HttpVersion.Http10;
var upgrade = false;
if (headers.HasConnection)
{
var connectionOptions = HttpHeaders.ParseConnection(headers.HeaderConnection);
upgrade = (connectionOptions & ConnectionOptions.Upgrade) != 0;
keepAlive = (connectionOptions & ConnectionOptions.KeepAlive) != 0;
keepAlive = keepAlive || (connectionOptions & ConnectionOptions.KeepAlive) != 0;
keepAlive = keepAlive && (connectionOptions & ConnectionOptions.Close) == 0;
}
if (upgrade)
@ -136,7 +138,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
}
context.OnTrailersComplete(); // No trailers for these.
return new Http1UpgradeMessageBody(context);
return new Http1UpgradeMessageBody(context, keepAlive);
}
if (headers.HasTransferEncoding)
@ -157,7 +159,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
// TODO may push more into the wrapper rather than just calling into the message body
// NBD for now.
return new Http1ChunkedEncodingMessageBody(keepAlive, context);
return new Http1ChunkedEncodingMessageBody(context, keepAlive);
}
if (headers.ContentLength.HasValue)
@ -169,7 +171,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
return keepAlive ? MessageBody.ZeroContentLengthKeepAlive : MessageBody.ZeroContentLengthClose;
}
return new Http1ContentLengthMessageBody(keepAlive, contentLength, context);
return new Http1ContentLengthMessageBody(context, contentLength, keepAlive);
}
// If we got here, request contains no Content-Length or Transfer-Encoding header.

View File

@ -14,8 +14,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
/// </summary>
internal sealed class Http1UpgradeMessageBody : Http1MessageBody
{
public Http1UpgradeMessageBody(Http1Connection context)
: base(context)
public Http1UpgradeMessageBody(Http1Connection context, bool keepAlive)
: base(context, keepAlive)
{
RequestUpgrade = true;
}

View File

@ -1113,13 +1113,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
RejectNonBodyTransferEncodingResponse(appCompleted);
}
else if (StatusCode == StatusCodes.Status101SwitchingProtocols)
{
_keepAlive = false;
}
else if (!hasTransferEncoding && !responseHeaders.ContentLength.HasValue)
{
if (StatusCode == StatusCodes.Status101SwitchingProtocols)
{
_keepAlive = false;
}
else if ((appCompleted || !_canWriteResponseBody) && !_hasAdvanced) // Avoid setting contentLength of 0 if we wrote data before calling CreateResponseHeaders
if ((appCompleted || !_canWriteResponseBody) && !_hasAdvanced) // Avoid setting contentLength of 0 if we wrote data before calling CreateResponseHeaders
{
// Don't set the Content-Length header automatically for HEAD requests, 204 responses, or 304 responses.
if (CanAutoSetContentLengthZeroResponseHeader())

View File

@ -12,6 +12,7 @@ using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Testing;
@ -112,7 +113,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Performance
});
http1Connection.Reset();
http1Connection.InitializeBodyControl(new Http1ContentLengthMessageBody(keepAlive: true, 100, http1Connection));
http1Connection.InitializeBodyControl(new Http1ContentLengthMessageBody(http1Connection, contentLength: 100, keepAlive: true));
serviceContext.DateHeaderValueManager.OnHeartbeat(DateTimeOffset.UtcNow);
return http1Connection;

View File

@ -10,7 +10,6 @@ using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport;
using Microsoft.AspNetCore.Server.Kestrel.Tests;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Logging.Testing;
using Xunit;
namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests
@ -343,5 +342,38 @@ namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests
await appCompletedTcs.Task.DefaultTimeout();
}
}
[Fact]
public async Task DoesNotCloseConnectionWithout101Response()
{
var requestCount = 0;
await using (var server = new TestServer(async context =>
{
if (requestCount++ > 0)
{
await context.Features.Get<IHttpUpgradeFeature>().UpgradeAsync();
}
}, new TestServiceContext(LoggerFactory)))
{
using (var connection = server.CreateConnection())
{
await connection.SendEmptyGetWithUpgrade();
await connection.Receive(
"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
"Content-Length: 0",
"",
"");
await connection.SendEmptyGetWithUpgrade();
await connection.Receive("HTTP/1.1 101 Switching Protocols",
"Connection: Upgrade",
$"Date: {server.Context.DateHeaderValue}",
"",
"");
}
}
}
}
}