115 lines
4.7 KiB
C#
115 lines
4.7 KiB
C#
// Copyright (c) .NET Foundation. All rights reserved.
|
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
|
|
|
using System.Collections.Generic;
|
|
using System.Net;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
using Microsoft.AspNetCore.Connections;
|
|
using Microsoft.AspNetCore.Connections.Features;
|
|
using Microsoft.AspNetCore.Hosting.Server;
|
|
using Microsoft.AspNetCore.Http.Features;
|
|
using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal;
|
|
using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal;
|
|
|
|
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal
|
|
{
|
|
public class HttpConnectionMiddleware<TContext>
|
|
{
|
|
private static long _lastHttpConnectionId = long.MinValue;
|
|
|
|
private readonly IList<IConnectionAdapter> _connectionAdapters;
|
|
private readonly ServiceContext _serviceContext;
|
|
private readonly IHttpApplication<TContext> _application;
|
|
private readonly HttpProtocols _protocols;
|
|
|
|
public HttpConnectionMiddleware(IList<IConnectionAdapter> adapters, ServiceContext serviceContext, IHttpApplication<TContext> application, HttpProtocols protocols)
|
|
{
|
|
_serviceContext = serviceContext;
|
|
_application = application;
|
|
_protocols = protocols;
|
|
|
|
// Keeping these around for now so progress can be made without updating tests
|
|
_connectionAdapters = adapters;
|
|
}
|
|
|
|
public async Task OnConnectionAsync(ConnectionContext connectionContext)
|
|
{
|
|
// We need the transport feature so that we can cancel the output reader that the transport is using
|
|
// This is a bit of a hack but it preserves the existing semantics
|
|
var memoryPoolFeature = connectionContext.Features.Get<IMemoryPoolFeature>();
|
|
|
|
var httpConnectionId = Interlocked.Increment(ref _lastHttpConnectionId);
|
|
|
|
var httpConnectionContext = new HttpConnectionContext
|
|
{
|
|
ConnectionId = connectionContext.ConnectionId,
|
|
ConnectionContext = connectionContext,
|
|
HttpConnectionId = httpConnectionId,
|
|
Protocols = _protocols,
|
|
ServiceContext = _serviceContext,
|
|
ConnectionFeatures = connectionContext.Features,
|
|
MemoryPool = memoryPoolFeature.MemoryPool,
|
|
ConnectionAdapters = _connectionAdapters,
|
|
Transport = connectionContext.Transport
|
|
};
|
|
|
|
var connectionFeature = connectionContext.Features.Get<IHttpConnectionFeature>();
|
|
var lifetimeFeature = connectionContext.Features.Get<IConnectionLifetimeFeature>();
|
|
|
|
if (connectionFeature != null)
|
|
{
|
|
if (connectionFeature.LocalIpAddress != null)
|
|
{
|
|
httpConnectionContext.LocalEndPoint = new IPEndPoint(connectionFeature.LocalIpAddress, connectionFeature.LocalPort);
|
|
}
|
|
|
|
if (connectionFeature.RemoteIpAddress != null)
|
|
{
|
|
httpConnectionContext.RemoteEndPoint = new IPEndPoint(connectionFeature.RemoteIpAddress, connectionFeature.RemotePort);
|
|
}
|
|
}
|
|
|
|
var connection = new HttpConnection(httpConnectionContext);
|
|
_serviceContext.ConnectionManager.AddConnection(httpConnectionId, connection);
|
|
|
|
try
|
|
{
|
|
var processingTask = connection.ProcessRequestsAsync(_application);
|
|
|
|
connectionContext.Transport.Input.OnWriterCompleted(
|
|
(_, state) => ((HttpConnection)state).OnInputOrOutputCompleted(),
|
|
connection);
|
|
|
|
connectionContext.Transport.Output.OnReaderCompleted(
|
|
(_, state) => ((HttpConnection)state).OnInputOrOutputCompleted(),
|
|
connection);
|
|
|
|
await CancellationTokenAsTask(lifetimeFeature.ConnectionClosed);
|
|
|
|
connection.OnConnectionClosed();
|
|
|
|
await processingTask;
|
|
}
|
|
finally
|
|
{
|
|
_serviceContext.ConnectionManager.RemoveConnection(httpConnectionId);
|
|
}
|
|
}
|
|
|
|
private static Task CancellationTokenAsTask(CancellationToken token)
|
|
{
|
|
if (token.IsCancellationRequested)
|
|
{
|
|
return Task.CompletedTask;
|
|
}
|
|
|
|
// Transports already dispatch prior to tripping ConnectionClosed
|
|
// since application code can register to this token.
|
|
var tcs = new TaskCompletionSource<object>();
|
|
token.Register(() => tcs.SetResult(null));
|
|
return tcs.Task;
|
|
}
|
|
}
|
|
}
|