252 lines
9.6 KiB
C#
252 lines
9.6 KiB
C#
// Copyright (c) .NET Foundation. All rights reserved.
|
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
|
|
|
using System;
|
|
using System.Collections.Generic;
|
|
using System.Diagnostics.CodeAnalysis;
|
|
using System.Diagnostics.Contracts;
|
|
using System.IO;
|
|
using System.Linq;
|
|
using System.Net;
|
|
using System.Net.Http;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
using Microsoft.AspNet.Http;
|
|
using Microsoft.AspNet.Http.Features;
|
|
using Microsoft.AspNet.Http.Internal;
|
|
|
|
namespace Microsoft.AspNet.TestHost
|
|
{
|
|
/// <summary>
|
|
/// This adapts HttpRequestMessages to ASP.NET requests, dispatches them through the pipeline, and returns the
|
|
/// associated HttpResponseMessage.
|
|
/// </summary>
|
|
public class ClientHandler : HttpMessageHandler
|
|
{
|
|
private readonly RequestDelegate _next;
|
|
private readonly PathString _pathBase;
|
|
private readonly IHttpContextFactory _factory;
|
|
|
|
/// <summary>
|
|
/// Create a new handler.
|
|
/// </summary>
|
|
/// <param name="next">The pipeline entry point.</param>
|
|
public ClientHandler(RequestDelegate next, PathString pathBase, IHttpContextFactory httpContextFactory)
|
|
{
|
|
if (next == null)
|
|
{
|
|
throw new ArgumentNullException(nameof(next));
|
|
}
|
|
if (httpContextFactory == null)
|
|
{
|
|
throw new ArgumentNullException(nameof(httpContextFactory));
|
|
}
|
|
|
|
_next = next;
|
|
_factory = httpContextFactory;
|
|
|
|
// PathString.StartsWithSegments that we use below requires the base path to not end in a slash.
|
|
if (pathBase.HasValue && pathBase.Value.EndsWith("/"))
|
|
{
|
|
pathBase = new PathString(pathBase.Value.Substring(0, pathBase.Value.Length - 1));
|
|
}
|
|
_pathBase = pathBase;
|
|
}
|
|
|
|
/// <summary>
|
|
/// This adapts HttpRequestMessages to ASP.NET requests, dispatches them through the pipeline, and returns the
|
|
/// associated HttpResponseMessage.
|
|
/// </summary>
|
|
/// <param name="request"></param>
|
|
/// <param name="cancellationToken"></param>
|
|
/// <returns></returns>
|
|
protected override async Task<HttpResponseMessage> SendAsync(
|
|
HttpRequestMessage request,
|
|
CancellationToken cancellationToken)
|
|
{
|
|
if (request == null)
|
|
{
|
|
throw new ArgumentNullException(nameof(request));
|
|
}
|
|
|
|
var state = new RequestState(request, _pathBase, _factory);
|
|
var requestContent = request.Content ?? new StreamContent(Stream.Null);
|
|
var body = await requestContent.ReadAsStreamAsync();
|
|
if (body.CanSeek)
|
|
{
|
|
// This body may have been consumed before, rewind it.
|
|
body.Seek(0, SeekOrigin.Begin);
|
|
}
|
|
state.HttpContext.Request.Body = body;
|
|
var registration = cancellationToken.Register(state.AbortRequest);
|
|
|
|
// Async offload, don't let the test code block the caller.
|
|
var offload = Task.Factory.StartNew(async () =>
|
|
{
|
|
try
|
|
{
|
|
await _next(state.HttpContext);
|
|
state.CompleteResponse();
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
state.Abort(ex);
|
|
}
|
|
finally
|
|
{
|
|
state.ServerCleanup();
|
|
registration.Dispose();
|
|
}
|
|
});
|
|
|
|
return await state.ResponseTask;
|
|
}
|
|
|
|
private class RequestState
|
|
{
|
|
private readonly HttpRequestMessage _request;
|
|
private TaskCompletionSource<HttpResponseMessage> _responseTcs;
|
|
private ResponseStream _responseStream;
|
|
private ResponseFeature _responseFeature;
|
|
private CancellationTokenSource _requestAbortedSource;
|
|
private IHttpContextFactory _factory;
|
|
private bool _pipelineFinished;
|
|
|
|
internal RequestState(HttpRequestMessage request, PathString pathBase, IHttpContextFactory factory)
|
|
{
|
|
_request = request;
|
|
_responseTcs = new TaskCompletionSource<HttpResponseMessage>();
|
|
_requestAbortedSource = new CancellationTokenSource();
|
|
_pipelineFinished = false;
|
|
_factory = factory;
|
|
|
|
if (request.RequestUri.IsDefaultPort)
|
|
{
|
|
request.Headers.Host = request.RequestUri.Host;
|
|
}
|
|
else
|
|
{
|
|
request.Headers.Host = request.RequestUri.GetComponents(UriComponents.HostAndPort, UriFormat.UriEscaped);
|
|
}
|
|
|
|
HttpContext = _factory.Create(new FeatureCollection());
|
|
|
|
HttpContext.Features.Set<IHttpRequestFeature>(new RequestFeature());
|
|
_responseFeature = new ResponseFeature();
|
|
HttpContext.Features.Set<IHttpResponseFeature>(_responseFeature);
|
|
var serverRequest = HttpContext.Request;
|
|
serverRequest.Protocol = "HTTP/" + request.Version.ToString(2);
|
|
serverRequest.Scheme = request.RequestUri.Scheme;
|
|
serverRequest.Method = request.Method.ToString();
|
|
|
|
var fullPath = PathString.FromUriComponent(request.RequestUri);
|
|
PathString remainder;
|
|
if (fullPath.StartsWithSegments(pathBase, out remainder))
|
|
{
|
|
serverRequest.PathBase = pathBase;
|
|
serverRequest.Path = remainder;
|
|
}
|
|
else
|
|
{
|
|
serverRequest.PathBase = PathString.Empty;
|
|
serverRequest.Path = fullPath;
|
|
}
|
|
|
|
serverRequest.QueryString = QueryString.FromUriComponent(request.RequestUri);
|
|
|
|
foreach (var header in request.Headers)
|
|
{
|
|
serverRequest.Headers.Append(header.Key, header.Value.ToArray());
|
|
}
|
|
var requestContent = request.Content;
|
|
if (requestContent != null)
|
|
{
|
|
foreach (var header in request.Content.Headers)
|
|
{
|
|
serverRequest.Headers.Append(header.Key, header.Value.ToArray());
|
|
}
|
|
}
|
|
|
|
_responseStream = new ResponseStream(ReturnResponseMessage, AbortRequest);
|
|
HttpContext.Response.Body = _responseStream;
|
|
HttpContext.Response.StatusCode = 200;
|
|
HttpContext.RequestAborted = _requestAbortedSource.Token;
|
|
}
|
|
|
|
public HttpContext HttpContext { get; private set; }
|
|
|
|
public Task<HttpResponseMessage> ResponseTask
|
|
{
|
|
get { return _responseTcs.Task; }
|
|
}
|
|
|
|
internal void AbortRequest()
|
|
{
|
|
if (!_pipelineFinished)
|
|
{
|
|
_requestAbortedSource.Cancel();
|
|
}
|
|
_responseStream.Complete();
|
|
}
|
|
|
|
internal void CompleteResponse()
|
|
{
|
|
_pipelineFinished = true;
|
|
ReturnResponseMessage();
|
|
_responseStream.Complete();
|
|
}
|
|
|
|
internal void ReturnResponseMessage()
|
|
{
|
|
if (!_responseTcs.Task.IsCompleted)
|
|
{
|
|
var response = GenerateResponse();
|
|
_responseFeature.FireOnResponseCompleted();
|
|
// Dispatch, as TrySetResult will synchronously execute the waiters callback and block our Write.
|
|
Task.Factory.StartNew(() => _responseTcs.TrySetResult(response));
|
|
}
|
|
}
|
|
|
|
[SuppressMessage("Microsoft.Reliability", "CA2000:DisposeObjectsBeforeLosingScope",
|
|
Justification = "HttpResposneMessage must be returned to the caller.")]
|
|
private HttpResponseMessage GenerateResponse()
|
|
{
|
|
_responseFeature.FireOnSendingHeaders();
|
|
|
|
var response = new HttpResponseMessage();
|
|
response.StatusCode = (HttpStatusCode)HttpContext.Response.StatusCode;
|
|
response.ReasonPhrase = HttpContext.Features.Get<IHttpResponseFeature>().ReasonPhrase;
|
|
response.RequestMessage = _request;
|
|
// response.Version = owinResponse.Protocol;
|
|
|
|
response.Content = new StreamContent(_responseStream);
|
|
|
|
foreach (var header in HttpContext.Response.Headers)
|
|
{
|
|
if (!response.Headers.TryAddWithoutValidation(header.Key, (IEnumerable<string>)header.Value))
|
|
{
|
|
bool success = response.Content.Headers.TryAddWithoutValidation(header.Key, (IEnumerable<string>)header.Value);
|
|
Contract.Assert(success, "Bad header");
|
|
}
|
|
}
|
|
return response;
|
|
}
|
|
|
|
internal void Abort(Exception exception)
|
|
{
|
|
_pipelineFinished = true;
|
|
_responseStream.Abort(exception);
|
|
_responseTcs.TrySetException(exception);
|
|
}
|
|
|
|
internal void ServerCleanup()
|
|
{
|
|
if (HttpContext != null)
|
|
{
|
|
_factory.Dispose(HttpContext);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|