#216 Lazy feature initialization

This commit is contained in:
Chris R 2016-08-19 13:57:42 -07:00
parent 9f1476aea8
commit efef52a0ad
2 changed files with 277 additions and 157 deletions

View File

@ -19,7 +19,6 @@ using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.WebSockets;
using System.Security.Claims;
@ -63,17 +62,19 @@ namespace Microsoft.AspNetCore.Server.WebListener
private string _rawTarget;
private IPAddress _remoteIpAddress;
private IPAddress _localIpAddress;
private int? _remotePort;
private int? _localPort;
private int _remotePort;
private int _localPort;
private string _connectionId;
private string _requestId;
private string _traceIdentitfier;
private X509Certificate2 _clientCert;
private ClaimsPrincipal _user;
private IAuthenticationHandler _authHandler;
private CancellationToken? _disconnectToken;
private CancellationToken _disconnectToken;
private Stream _responseStream;
private IHeaderDictionary _responseHeaders;
private Fields _initializedFields;
private List<Tuple<Func<object, Task>, object>> _onStartingActions = new List<Tuple<Func<object, Task>, object>>();
private List<Tuple<Func<object, Task>, object>> _onCompletedActions = new List<Tuple<Func<object, Task>, object>>();
private bool _responseStarted;
@ -85,91 +86,95 @@ namespace Microsoft.AspNetCore.Server.WebListener
_features = new FeatureCollection(new StandardFeatureCollection(this));
_authHandler = new AuthenticationHandler(requestContext);
_enableResponseCaching = enableResponseCaching;
// Pre-initialize any fields that are not lazy at the lower level.
_requestHeaders = new HeaderDictionary(Request.Headers);
_httpMethod = Request.Method;
_path = Request.Path;
_pathBase = Request.PathBase;
_query = Request.QueryString;
_rawTarget = Request.RawUrl;
_scheme = Request.Scheme;
_user = _requestContext.User;
_responseStream = new ResponseStream(requestContext.Response.Body, OnStart);
_responseHeaders = new HeaderDictionary(Response.Headers);
}
internal IFeatureCollection Features
internal IFeatureCollection Features => _features;
internal object RequestContext => _requestContext;
private Request Request => _requestContext.Request;
private Response Response => _requestContext.Response;
[Flags]
// Fields that may be lazy-initialized
private enum Fields
{
get { return _features; }
None = 0x0,
Protocol = 0x1,
RequestBody = 0x2,
RequestAborted = 0x4,
LocalIpAddress = 0x8,
RemoteIpAddress = 0x10,
LocalPort = 0x20,
RemotePort = 0x40,
ConnectionId = 0x80,
ClientCertificate = 0x100,
TraceIdentifier = 0x200,
}
internal object RequestContext
private bool IsNotInitialized(Fields field)
{
get { return _requestContext; }
return (_initializedFields & field) != field;
}
private Request Request
private void SetInitialized(Fields field)
{
get { return _requestContext.Request; }
}
private Response Response
{
get { return _requestContext.Response; }
_initializedFields |= field;
}
Stream IHttpRequestFeature.Body
{
get
{
if (_requestBody == null)
if (IsNotInitialized(Fields.RequestBody))
{
_requestBody = Request.Body;
SetInitialized(Fields.RequestBody);
}
return _requestBody;
}
set { _requestBody = value; }
set
{
_requestBody = value;
SetInitialized(Fields.RequestBody);
}
}
IHeaderDictionary IHttpRequestFeature.Headers
{
get
{
if (_requestHeaders == null)
{
_requestHeaders = new HeaderDictionary(Request.Headers);
}
return _requestHeaders;
}
get { return _requestHeaders; }
set { _requestHeaders = value; }
}
string IHttpRequestFeature.Method
{
get
{
if (_httpMethod == null)
{
_httpMethod = Request.Method;
}
return _httpMethod;
}
get { return _httpMethod; }
set { _httpMethod = value; }
}
string IHttpRequestFeature.Path
{
get
{
if (_path == null)
{
_path = Request.Path;
}
return _path;
}
get { return _path; }
set { _path = value; }
}
string IHttpRequestFeature.PathBase
{
get
{
if (_pathBase == null)
{
_pathBase = Request.PathBase;
}
return _pathBase;
}
get { return _pathBase; }
set { _pathBase = value; }
}
@ -177,63 +182,47 @@ namespace Microsoft.AspNetCore.Server.WebListener
{
get
{
if (_httpProtocolVersion == null)
if (IsNotInitialized(Fields.Protocol))
{
if (Request.ProtocolVersion.Major == 1)
var protocol = Request.ProtocolVersion;
if (protocol.Major == 1 && protocol.Minor == 1)
{
if (Request.ProtocolVersion.Minor == 1)
{
_httpProtocolVersion = "HTTP/1.1";
}
else if (Request.ProtocolVersion.Minor == 0)
{
_httpProtocolVersion = "HTTP/1.0";
}
_httpProtocolVersion = "HTTP/1.1";
}
_httpProtocolVersion = "HTTP/" + Request.ProtocolVersion.ToString(2);
else if (protocol.Major == 1 && protocol.Minor == 0)
{
_httpProtocolVersion = "HTTP/1.0";
}
else
{
_httpProtocolVersion = "HTTP/" + protocol.ToString(2);
}
SetInitialized(Fields.Protocol);
}
return _httpProtocolVersion;
}
set { _httpProtocolVersion = value; }
set
{
_httpProtocolVersion = value;
SetInitialized(Fields.Protocol);
}
}
string IHttpRequestFeature.QueryString
{
get
{
if (_query == null)
{
_query = Request.QueryString;
}
return _query;
}
get { return _query; }
set { _query = value; }
}
string IHttpRequestFeature.RawTarget
{
get
{
if (_rawTarget == null)
{
_rawTarget = Request.RawUrl;
}
return _rawTarget;
}
get { return _rawTarget; }
set { _rawTarget = value; }
}
string IHttpRequestFeature.Scheme
{
get
{
if (_scheme == null)
{
_scheme = Request.Scheme;
}
return _scheme;
}
get { return _scheme; }
set { _scheme = value; }
}
@ -241,85 +230,116 @@ namespace Microsoft.AspNetCore.Server.WebListener
{
get
{
if (_localIpAddress == null)
if (IsNotInitialized(Fields.LocalIpAddress))
{
_localIpAddress = Request.LocalIpAddress;
SetInitialized(Fields.LocalIpAddress);
}
return _localIpAddress;
}
set { _localIpAddress = value; }
set
{
_localIpAddress = value;
SetInitialized(Fields.LocalIpAddress);
}
}
IPAddress IHttpConnectionFeature.RemoteIpAddress
{
get
{
if (_remoteIpAddress == null)
if (IsNotInitialized(Fields.RemoteIpAddress))
{
_remoteIpAddress = Request.RemoteIpAddress;
SetInitialized(Fields.RemoteIpAddress);
}
return _remoteIpAddress;
}
set { _remoteIpAddress = value; }
set
{
_remoteIpAddress = value;
SetInitialized(Fields.RemoteIpAddress);
}
}
int IHttpConnectionFeature.LocalPort
{
get
{
if (_localPort == null)
if (IsNotInitialized(Fields.LocalPort))
{
_localPort = Request.LocalPort;
SetInitialized(Fields.LocalPort);
}
return _localPort.Value;
return _localPort;
}
set
{
_localPort = value;
SetInitialized(Fields.LocalPort);
}
set { _localPort = value; }
}
int IHttpConnectionFeature.RemotePort
{
get
{
if (_remotePort == null)
if (IsNotInitialized(Fields.RemotePort))
{
_remotePort = Request.RemotePort;
SetInitialized(Fields.RemotePort);
}
return _remotePort.Value;
return _remotePort;
}
set
{
_remotePort = value;
SetInitialized(Fields.RemotePort);
}
set { _remotePort = value; }
}
string IHttpConnectionFeature.ConnectionId
{
get
{
if (_connectionId == null)
if (IsNotInitialized(Fields.ConnectionId))
{
_connectionId = Request.ConnectionId.ToString(CultureInfo.InvariantCulture);
SetInitialized(Fields.ConnectionId);
}
return _connectionId;
}
set { _connectionId = value; }
set
{
_connectionId = value;
SetInitialized(Fields.ConnectionId);
}
}
X509Certificate2 ITlsConnectionFeature.ClientCertificate
{
get
{
if (_clientCert == null)
if (IsNotInitialized(Fields.ClientCertificate))
{
_clientCert = Request.GetClientCertificateAsync().Result; // TODO: Sync;
SetInitialized(Fields.ClientCertificate);
}
return _clientCert;
}
set { _clientCert = value; }
set
{
_clientCert = value;
SetInitialized(Fields.ClientCertificate);
}
}
async Task<X509Certificate2> ITlsConnectionFeature.GetClientCertificateAsync(CancellationToken cancellationToken)
{
if (_clientCert == null)
if (IsNotInitialized(Fields.ClientCertificate))
{
_clientCert = await Request.GetClientCertificateAsync(cancellationToken);
SetInitialized(Fields.ClientCertificate);
}
return _clientCert;
}
@ -329,15 +349,9 @@ namespace Microsoft.AspNetCore.Server.WebListener
return Request.IsHttps ? this : null;
}
/* TODO: https://github.com/aspnet/WebListener/issues/231
byte[] ITlsTokenBindingFeature.GetProvidedTokenBindingId()
{
return Request.GetProvidedTokenBindingId();
}
byte[] ITlsTokenBindingFeature.GetProvidedTokenBindingId() => Request.GetProvidedTokenBindingId();
byte[] ITlsTokenBindingFeature.GetReferredTokenBindingId()
{
return Request.GetReferredTokenBindingId();
}
byte[] ITlsTokenBindingFeature.GetReferredTokenBindingId() => Request.GetReferredTokenBindingId();
internal ITlsTokenBindingFeature GetTlsTokenBindingFeature()
{
@ -356,34 +370,17 @@ namespace Microsoft.AspNetCore.Server.WebListener
Stream IHttpResponseFeature.Body
{
get
{
if (_responseStream == null)
{
_responseStream = Response.Body;
}
return _responseStream;
}
get { return _responseStream; }
set { _responseStream = value; }
}
IHeaderDictionary IHttpResponseFeature.Headers
{
get
{
if (_responseHeaders == null)
{
_responseHeaders = new HeaderDictionary(Response.Headers);
}
return _responseHeaders;
}
get { return _responseHeaders; }
set { _responseHeaders = value; }
}
bool IHttpResponseFeature.HasStarted
{
get { return Response.HasStarted; }
}
bool IHttpResponseFeature.HasStarted => Response.HasStarted;
void IHttpResponseFeature.OnStarting(Func<object, Task> callback, object state)
{
@ -435,24 +432,23 @@ namespace Microsoft.AspNetCore.Server.WebListener
{
get
{
if (!_disconnectToken.HasValue)
if (IsNotInitialized(Fields.RequestAborted))
{
_disconnectToken = _requestContext.DisconnectToken;
SetInitialized(Fields.RequestAborted);
}
return _disconnectToken.Value;
return _disconnectToken;
}
set
{
_disconnectToken = value;
SetInitialized(Fields.RequestAborted);
}
set { _disconnectToken = value; }
}
void IHttpRequestLifetimeFeature.Abort()
{
_requestContext.Abort();
}
void IHttpRequestLifetimeFeature.Abort() => _requestContext.Abort();
bool IHttpUpgradeFeature.IsUpgradableRequest
{
get { return _requestContext.IsUpgradableRequest; }
}
bool IHttpUpgradeFeature.IsUpgradableRequest => _requestContext.IsUpgradableRequest;
async Task<Stream> IHttpUpgradeFeature.UpgradeAsync()
{
@ -477,14 +473,7 @@ namespace Microsoft.AspNetCore.Server.WebListener
ClaimsPrincipal IHttpAuthenticationFeature.User
{
get
{
if (_user == null)
{
_user = _requestContext.User;
}
return _user;
}
get { return _user; }
set { _user = value; }
}
@ -498,13 +487,18 @@ namespace Microsoft.AspNetCore.Server.WebListener
{
get
{
if (_requestId == null)
if (IsNotInitialized(Fields.TraceIdentifier))
{
_requestId = _requestContext.TraceIdentifier.ToString();
_traceIdentitfier = _requestContext.TraceIdentifier.ToString();
SetInitialized(Fields.TraceIdentifier);
}
return _requestId;
return _traceIdentitfier;
}
set
{
_traceIdentitfier = value;
SetInitialized(Fields.TraceIdentifier);
}
set { _requestId = value; }
}
internal async Task OnStart()

View File

@ -17,6 +17,7 @@
using System;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Net.Sockets;
using System.Text;
@ -34,7 +35,7 @@ namespace Microsoft.AspNetCore.Server.WebListener
public class RequestTests
{
[Fact]
public async Task Request_SimpleGet_Success()
public async Task Request_SimpleGet_ExpectedFieldsSet()
{
string root;
using (Utilities.CreateHttpServerReturnRoot("/basepath", out root, httpContext =>
@ -54,14 +55,12 @@ namespace Microsoft.AspNetCore.Server.WebListener
Assert.Equal("/basepath/SomePath?SomeQuery", requestInfo.RawTarget);
Assert.Equal("HTTP/1.1", requestInfo.Protocol);
// Server Keys
// TODO: Assert.NotNull(httpContext.Get<IDictionary<string, object>>("server.Capabilities"));
var connectionInfo = httpContext.Features.Get<IHttpConnectionFeature>();
Assert.Equal("::1", connectionInfo.RemoteIpAddress.ToString());
Assert.NotEqual(0, connectionInfo.RemotePort);
Assert.Equal("::1", connectionInfo.LocalIpAddress.ToString());
Assert.NotEqual(0, connectionInfo.LocalPort);
Assert.NotNull(connectionInfo.ConnectionId);
// Trace identifier
var requestIdentifierFeature = httpContext.Features.Get<IHttpRequestIdentifierFeature>();
@ -83,6 +82,133 @@ namespace Microsoft.AspNetCore.Server.WebListener
}
}
[Fact]
public async Task Request_FieldsCanBeSet_Set()
{
string root;
using (Utilities.CreateHttpServerReturnRoot("/basepath", out root, httpContext =>
{
try
{
var requestInfo = httpContext.Features.Get<IHttpRequestFeature>();
// Request Keys
requestInfo.Method = "TEST";
Assert.Equal("TEST", requestInfo.Method);
requestInfo.Body = new MemoryStream();
Assert.IsType<MemoryStream>(requestInfo.Body);
var customHeaders = new HeaderDictionary(new HeaderCollection());
requestInfo.Headers = customHeaders;
Assert.Same(customHeaders, requestInfo.Headers);
requestInfo.Scheme = "abcd";
Assert.Equal("abcd", requestInfo.Scheme);
requestInfo.PathBase = "/customized/Base";
Assert.Equal("/customized/Base", requestInfo.PathBase);
requestInfo.Path = "/customized/Path";
Assert.Equal("/customized/Path", requestInfo.Path);
requestInfo.QueryString = "?customizedQuery";
Assert.Equal("?customizedQuery", requestInfo.QueryString);
requestInfo.RawTarget = "/customized/raw?Target";
Assert.Equal("/customized/raw?Target", requestInfo.RawTarget);
requestInfo.Protocol = "Custom/2.0";
Assert.Equal("Custom/2.0", requestInfo.Protocol);
var connectionInfo = httpContext.Features.Get<IHttpConnectionFeature>();
connectionInfo.RemoteIpAddress = IPAddress.Broadcast;
Assert.Equal(IPAddress.Broadcast, connectionInfo.RemoteIpAddress);
connectionInfo.RemotePort = 12345;
Assert.Equal(12345, connectionInfo.RemotePort);
connectionInfo.LocalIpAddress = IPAddress.Any;
Assert.Equal(IPAddress.Any, connectionInfo.LocalIpAddress);
connectionInfo.LocalPort = 54321;
Assert.Equal(54321, connectionInfo.LocalPort);
connectionInfo.ConnectionId = "CustomId";
Assert.Equal("CustomId", connectionInfo.ConnectionId);
// Trace identifier
var requestIdentifierFeature = httpContext.Features.Get<IHttpRequestIdentifierFeature>();
Assert.NotNull(requestIdentifierFeature);
requestIdentifierFeature.TraceIdentifier = "customTrace";
Assert.Equal("customTrace", requestIdentifierFeature.TraceIdentifier);
// Note: Response keys are validated in the ResponseTests
}
catch (Exception ex)
{
byte[] body = Encoding.ASCII.GetBytes(ex.ToString());
httpContext.Response.Body.Write(body, 0, body.Length);
}
return Task.FromResult(0);
}))
{
string response = await SendRequestAsync(root + "/basepath/SomePath?SomeQuery");
Assert.Equal(string.Empty, response);
}
}
[Fact]
public async Task Request_FieldsCanBeSetToNull_Set()
{
string root;
using (Utilities.CreateHttpServerReturnRoot("/basepath", out root, httpContext =>
{
try
{
var requestInfo = httpContext.Features.Get<IHttpRequestFeature>();
// Request Keys
requestInfo.Method = null;
Assert.Null(requestInfo.Method);
requestInfo.Body = null;
Assert.Null(requestInfo.Body);
requestInfo.Headers = null;
Assert.Null(requestInfo.Headers);
requestInfo.Scheme = null;
Assert.Null(requestInfo.Scheme);
requestInfo.PathBase = null;
Assert.Null(requestInfo.PathBase);
requestInfo.Path = null;
Assert.Null(requestInfo.Path);
requestInfo.QueryString = null;
Assert.Null(requestInfo.QueryString);
requestInfo.RawTarget = null;
Assert.Null(requestInfo.RawTarget);
requestInfo.Protocol = null;
Assert.Null(requestInfo.Protocol);
var connectionInfo = httpContext.Features.Get<IHttpConnectionFeature>();
connectionInfo.RemoteIpAddress = null;
Assert.Null(connectionInfo.RemoteIpAddress);
connectionInfo.RemotePort = -1;
Assert.Equal(-1, connectionInfo.RemotePort);
connectionInfo.LocalIpAddress = null;
Assert.Null(connectionInfo.LocalIpAddress);
connectionInfo.LocalPort = -1;
Assert.Equal(-1, connectionInfo.LocalPort);
connectionInfo.ConnectionId = null;
Assert.Null(connectionInfo.ConnectionId);
// Trace identifier
var requestIdentifierFeature = httpContext.Features.Get<IHttpRequestIdentifierFeature>();
Assert.NotNull(requestIdentifierFeature);
requestIdentifierFeature.TraceIdentifier = null;
Assert.Null(requestIdentifierFeature.TraceIdentifier);
// Note: Response keys are validated in the ResponseTests
}
catch (Exception ex)
{
byte[] body = Encoding.ASCII.GetBytes(ex.ToString());
httpContext.Response.Body.Write(body, 0, body.Length);
}
return Task.FromResult(0);
}))
{
string response = await SendRequestAsync(root + "/basepath/SomePath?SomeQuery");
Assert.Equal(string.Empty, response);
}
}
[Theory]
[InlineData("/", "/", "", "/")]
[InlineData("/basepath/", "/basepath", "/basepath", "")]