diff --git a/src/Microsoft.AspNetCore.Server.WebListener/FeatureContext.cs b/src/Microsoft.AspNetCore.Server.WebListener/FeatureContext.cs index b443d1496a..223cf3a9ac 100644 --- a/src/Microsoft.AspNetCore.Server.WebListener/FeatureContext.cs +++ b/src/Microsoft.AspNetCore.Server.WebListener/FeatureContext.cs @@ -19,7 +19,6 @@ using System; using System.Collections.Generic; using System.Globalization; using System.IO; -using System.Linq; using System.Net; using System.Net.WebSockets; using System.Security.Claims; @@ -63,17 +62,19 @@ namespace Microsoft.AspNetCore.Server.WebListener private string _rawTarget; private IPAddress _remoteIpAddress; private IPAddress _localIpAddress; - private int? _remotePort; - private int? _localPort; + private int _remotePort; + private int _localPort; private string _connectionId; - private string _requestId; + private string _traceIdentitfier; private X509Certificate2 _clientCert; private ClaimsPrincipal _user; private IAuthenticationHandler _authHandler; - private CancellationToken? _disconnectToken; + private CancellationToken _disconnectToken; private Stream _responseStream; private IHeaderDictionary _responseHeaders; + private Fields _initializedFields; + private List, object>> _onStartingActions = new List, object>>(); private List, object>> _onCompletedActions = new List, object>>(); private bool _responseStarted; @@ -85,91 +86,95 @@ namespace Microsoft.AspNetCore.Server.WebListener _features = new FeatureCollection(new StandardFeatureCollection(this)); _authHandler = new AuthenticationHandler(requestContext); _enableResponseCaching = enableResponseCaching; + + // Pre-initialize any fields that are not lazy at the lower level. + _requestHeaders = new HeaderDictionary(Request.Headers); + _httpMethod = Request.Method; + _path = Request.Path; + _pathBase = Request.PathBase; + _query = Request.QueryString; + _rawTarget = Request.RawUrl; + _scheme = Request.Scheme; + _user = _requestContext.User; + _responseStream = new ResponseStream(requestContext.Response.Body, OnStart); + _responseHeaders = new HeaderDictionary(Response.Headers); } - internal IFeatureCollection Features + internal IFeatureCollection Features => _features; + + internal object RequestContext => _requestContext; + + private Request Request => _requestContext.Request; + + private Response Response => _requestContext.Response; + + [Flags] + // Fields that may be lazy-initialized + private enum Fields { - get { return _features; } + None = 0x0, + Protocol = 0x1, + RequestBody = 0x2, + RequestAborted = 0x4, + LocalIpAddress = 0x8, + RemoteIpAddress = 0x10, + LocalPort = 0x20, + RemotePort = 0x40, + ConnectionId = 0x80, + ClientCertificate = 0x100, + TraceIdentifier = 0x200, } - internal object RequestContext + private bool IsNotInitialized(Fields field) { - get { return _requestContext; } + return (_initializedFields & field) != field; } - private Request Request + private void SetInitialized(Fields field) { - get { return _requestContext.Request; } - } - - private Response Response - { - get { return _requestContext.Response; } + _initializedFields |= field; } Stream IHttpRequestFeature.Body { get { - if (_requestBody == null) + if (IsNotInitialized(Fields.RequestBody)) { _requestBody = Request.Body; + SetInitialized(Fields.RequestBody); } return _requestBody; } - set { _requestBody = value; } + set + { + _requestBody = value; + SetInitialized(Fields.RequestBody); + } } IHeaderDictionary IHttpRequestFeature.Headers { - get - { - if (_requestHeaders == null) - { - _requestHeaders = new HeaderDictionary(Request.Headers); - } - return _requestHeaders; - } + get { return _requestHeaders; } set { _requestHeaders = value; } } string IHttpRequestFeature.Method { - get - { - if (_httpMethod == null) - { - _httpMethod = Request.Method; - } - return _httpMethod; - } + get { return _httpMethod; } set { _httpMethod = value; } } string IHttpRequestFeature.Path { - get - { - if (_path == null) - { - _path = Request.Path; - } - return _path; - } + get { return _path; } set { _path = value; } } string IHttpRequestFeature.PathBase { - get - { - if (_pathBase == null) - { - _pathBase = Request.PathBase; - } - return _pathBase; - } + get { return _pathBase; } set { _pathBase = value; } } @@ -177,63 +182,47 @@ namespace Microsoft.AspNetCore.Server.WebListener { get { - if (_httpProtocolVersion == null) + if (IsNotInitialized(Fields.Protocol)) { - if (Request.ProtocolVersion.Major == 1) + var protocol = Request.ProtocolVersion; + if (protocol.Major == 1 && protocol.Minor == 1) { - if (Request.ProtocolVersion.Minor == 1) - { - _httpProtocolVersion = "HTTP/1.1"; - } - else if (Request.ProtocolVersion.Minor == 0) - { - _httpProtocolVersion = "HTTP/1.0"; - } + _httpProtocolVersion = "HTTP/1.1"; } - - _httpProtocolVersion = "HTTP/" + Request.ProtocolVersion.ToString(2); + else if (protocol.Major == 1 && protocol.Minor == 0) + { + _httpProtocolVersion = "HTTP/1.0"; + } + else + { + _httpProtocolVersion = "HTTP/" + protocol.ToString(2); + } + SetInitialized(Fields.Protocol); } return _httpProtocolVersion; } - set { _httpProtocolVersion = value; } + set + { + _httpProtocolVersion = value; + SetInitialized(Fields.Protocol); + } } string IHttpRequestFeature.QueryString { - get - { - if (_query == null) - { - _query = Request.QueryString; - } - return _query; - } + get { return _query; } set { _query = value; } } string IHttpRequestFeature.RawTarget { - get - { - if (_rawTarget == null) - { - _rawTarget = Request.RawUrl; - } - return _rawTarget; - } + get { return _rawTarget; } set { _rawTarget = value; } } string IHttpRequestFeature.Scheme { - get - { - if (_scheme == null) - { - _scheme = Request.Scheme; - } - return _scheme; - } + get { return _scheme; } set { _scheme = value; } } @@ -241,85 +230,116 @@ namespace Microsoft.AspNetCore.Server.WebListener { get { - if (_localIpAddress == null) + if (IsNotInitialized(Fields.LocalIpAddress)) { _localIpAddress = Request.LocalIpAddress; + SetInitialized(Fields.LocalIpAddress); } return _localIpAddress; } - set { _localIpAddress = value; } + set + { + _localIpAddress = value; + SetInitialized(Fields.LocalIpAddress); + } } IPAddress IHttpConnectionFeature.RemoteIpAddress { get { - if (_remoteIpAddress == null) + if (IsNotInitialized(Fields.RemoteIpAddress)) { _remoteIpAddress = Request.RemoteIpAddress; + SetInitialized(Fields.RemoteIpAddress); } return _remoteIpAddress; } - set { _remoteIpAddress = value; } + set + { + _remoteIpAddress = value; + SetInitialized(Fields.RemoteIpAddress); + } } int IHttpConnectionFeature.LocalPort { get { - if (_localPort == null) + if (IsNotInitialized(Fields.LocalPort)) { _localPort = Request.LocalPort; + SetInitialized(Fields.LocalPort); } - return _localPort.Value; + return _localPort; + } + set + { + _localPort = value; + SetInitialized(Fields.LocalPort); } - set { _localPort = value; } } int IHttpConnectionFeature.RemotePort { get { - if (_remotePort == null) + if (IsNotInitialized(Fields.RemotePort)) { _remotePort = Request.RemotePort; + SetInitialized(Fields.RemotePort); } - return _remotePort.Value; + return _remotePort; + } + set + { + _remotePort = value; + SetInitialized(Fields.RemotePort); } - set { _remotePort = value; } } string IHttpConnectionFeature.ConnectionId { get { - if (_connectionId == null) + if (IsNotInitialized(Fields.ConnectionId)) { _connectionId = Request.ConnectionId.ToString(CultureInfo.InvariantCulture); + SetInitialized(Fields.ConnectionId); } return _connectionId; } - set { _connectionId = value; } + set + { + _connectionId = value; + SetInitialized(Fields.ConnectionId); + } } X509Certificate2 ITlsConnectionFeature.ClientCertificate { get { - if (_clientCert == null) + if (IsNotInitialized(Fields.ClientCertificate)) { _clientCert = Request.GetClientCertificateAsync().Result; // TODO: Sync; + SetInitialized(Fields.ClientCertificate); } return _clientCert; } - set { _clientCert = value; } + set + { + _clientCert = value; + SetInitialized(Fields.ClientCertificate); + } } async Task ITlsConnectionFeature.GetClientCertificateAsync(CancellationToken cancellationToken) { - if (_clientCert == null) + if (IsNotInitialized(Fields.ClientCertificate)) { _clientCert = await Request.GetClientCertificateAsync(cancellationToken); + SetInitialized(Fields.ClientCertificate); } return _clientCert; } @@ -329,15 +349,9 @@ namespace Microsoft.AspNetCore.Server.WebListener return Request.IsHttps ? this : null; } /* TODO: https://github.com/aspnet/WebListener/issues/231 - byte[] ITlsTokenBindingFeature.GetProvidedTokenBindingId() - { - return Request.GetProvidedTokenBindingId(); - } + byte[] ITlsTokenBindingFeature.GetProvidedTokenBindingId() => Request.GetProvidedTokenBindingId(); - byte[] ITlsTokenBindingFeature.GetReferredTokenBindingId() - { - return Request.GetReferredTokenBindingId(); - } + byte[] ITlsTokenBindingFeature.GetReferredTokenBindingId() => Request.GetReferredTokenBindingId(); internal ITlsTokenBindingFeature GetTlsTokenBindingFeature() { @@ -356,34 +370,17 @@ namespace Microsoft.AspNetCore.Server.WebListener Stream IHttpResponseFeature.Body { - get - { - if (_responseStream == null) - { - _responseStream = Response.Body; - } - return _responseStream; - } + get { return _responseStream; } set { _responseStream = value; } } IHeaderDictionary IHttpResponseFeature.Headers { - get - { - if (_responseHeaders == null) - { - _responseHeaders = new HeaderDictionary(Response.Headers); - } - return _responseHeaders; - } + get { return _responseHeaders; } set { _responseHeaders = value; } } - bool IHttpResponseFeature.HasStarted - { - get { return Response.HasStarted; } - } + bool IHttpResponseFeature.HasStarted => Response.HasStarted; void IHttpResponseFeature.OnStarting(Func callback, object state) { @@ -435,24 +432,23 @@ namespace Microsoft.AspNetCore.Server.WebListener { get { - if (!_disconnectToken.HasValue) + if (IsNotInitialized(Fields.RequestAborted)) { _disconnectToken = _requestContext.DisconnectToken; + SetInitialized(Fields.RequestAborted); } - return _disconnectToken.Value; + return _disconnectToken; + } + set + { + _disconnectToken = value; + SetInitialized(Fields.RequestAborted); } - set { _disconnectToken = value; } } - void IHttpRequestLifetimeFeature.Abort() - { - _requestContext.Abort(); - } + void IHttpRequestLifetimeFeature.Abort() => _requestContext.Abort(); - bool IHttpUpgradeFeature.IsUpgradableRequest - { - get { return _requestContext.IsUpgradableRequest; } - } + bool IHttpUpgradeFeature.IsUpgradableRequest => _requestContext.IsUpgradableRequest; async Task IHttpUpgradeFeature.UpgradeAsync() { @@ -477,14 +473,7 @@ namespace Microsoft.AspNetCore.Server.WebListener ClaimsPrincipal IHttpAuthenticationFeature.User { - get - { - if (_user == null) - { - _user = _requestContext.User; - } - return _user; - } + get { return _user; } set { _user = value; } } @@ -498,13 +487,18 @@ namespace Microsoft.AspNetCore.Server.WebListener { get { - if (_requestId == null) + if (IsNotInitialized(Fields.TraceIdentifier)) { - _requestId = _requestContext.TraceIdentifier.ToString(); + _traceIdentitfier = _requestContext.TraceIdentifier.ToString(); + SetInitialized(Fields.TraceIdentifier); } - return _requestId; + return _traceIdentitfier; + } + set + { + _traceIdentitfier = value; + SetInitialized(Fields.TraceIdentifier); } - set { _requestId = value; } } internal async Task OnStart() diff --git a/test/Microsoft.AspNetCore.Server.WebListener.FunctionalTests/RequestTests.cs b/test/Microsoft.AspNetCore.Server.WebListener.FunctionalTests/RequestTests.cs index 62b1f18d4f..ce7db01208 100644 --- a/test/Microsoft.AspNetCore.Server.WebListener.FunctionalTests/RequestTests.cs +++ b/test/Microsoft.AspNetCore.Server.WebListener.FunctionalTests/RequestTests.cs @@ -17,6 +17,7 @@ using System; using System.IO; +using System.Net; using System.Net.Http; using System.Net.Sockets; using System.Text; @@ -34,7 +35,7 @@ namespace Microsoft.AspNetCore.Server.WebListener public class RequestTests { [Fact] - public async Task Request_SimpleGet_Success() + public async Task Request_SimpleGet_ExpectedFieldsSet() { string root; using (Utilities.CreateHttpServerReturnRoot("/basepath", out root, httpContext => @@ -54,14 +55,12 @@ namespace Microsoft.AspNetCore.Server.WebListener Assert.Equal("/basepath/SomePath?SomeQuery", requestInfo.RawTarget); Assert.Equal("HTTP/1.1", requestInfo.Protocol); - // Server Keys - // TODO: Assert.NotNull(httpContext.Get>("server.Capabilities")); - var connectionInfo = httpContext.Features.Get(); Assert.Equal("::1", connectionInfo.RemoteIpAddress.ToString()); Assert.NotEqual(0, connectionInfo.RemotePort); Assert.Equal("::1", connectionInfo.LocalIpAddress.ToString()); Assert.NotEqual(0, connectionInfo.LocalPort); + Assert.NotNull(connectionInfo.ConnectionId); // Trace identifier var requestIdentifierFeature = httpContext.Features.Get(); @@ -83,6 +82,133 @@ namespace Microsoft.AspNetCore.Server.WebListener } } + [Fact] + public async Task Request_FieldsCanBeSet_Set() + { + string root; + using (Utilities.CreateHttpServerReturnRoot("/basepath", out root, httpContext => + { + try + { + var requestInfo = httpContext.Features.Get(); + + // Request Keys + requestInfo.Method = "TEST"; + Assert.Equal("TEST", requestInfo.Method); + requestInfo.Body = new MemoryStream(); + Assert.IsType(requestInfo.Body); + var customHeaders = new HeaderDictionary(new HeaderCollection()); + requestInfo.Headers = customHeaders; + Assert.Same(customHeaders, requestInfo.Headers); + requestInfo.Scheme = "abcd"; + Assert.Equal("abcd", requestInfo.Scheme); + requestInfo.PathBase = "/customized/Base"; + Assert.Equal("/customized/Base", requestInfo.PathBase); + requestInfo.Path = "/customized/Path"; + Assert.Equal("/customized/Path", requestInfo.Path); + requestInfo.QueryString = "?customizedQuery"; + Assert.Equal("?customizedQuery", requestInfo.QueryString); + requestInfo.RawTarget = "/customized/raw?Target"; + Assert.Equal("/customized/raw?Target", requestInfo.RawTarget); + requestInfo.Protocol = "Custom/2.0"; + Assert.Equal("Custom/2.0", requestInfo.Protocol); + + var connectionInfo = httpContext.Features.Get(); + connectionInfo.RemoteIpAddress = IPAddress.Broadcast; + Assert.Equal(IPAddress.Broadcast, connectionInfo.RemoteIpAddress); + connectionInfo.RemotePort = 12345; + Assert.Equal(12345, connectionInfo.RemotePort); + connectionInfo.LocalIpAddress = IPAddress.Any; + Assert.Equal(IPAddress.Any, connectionInfo.LocalIpAddress); + connectionInfo.LocalPort = 54321; + Assert.Equal(54321, connectionInfo.LocalPort); + connectionInfo.ConnectionId = "CustomId"; + Assert.Equal("CustomId", connectionInfo.ConnectionId); + + // Trace identifier + var requestIdentifierFeature = httpContext.Features.Get(); + Assert.NotNull(requestIdentifierFeature); + requestIdentifierFeature.TraceIdentifier = "customTrace"; + Assert.Equal("customTrace", requestIdentifierFeature.TraceIdentifier); + + // Note: Response keys are validated in the ResponseTests + } + catch (Exception ex) + { + byte[] body = Encoding.ASCII.GetBytes(ex.ToString()); + httpContext.Response.Body.Write(body, 0, body.Length); + } + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(root + "/basepath/SomePath?SomeQuery"); + Assert.Equal(string.Empty, response); + } + } + + [Fact] + public async Task Request_FieldsCanBeSetToNull_Set() + { + string root; + using (Utilities.CreateHttpServerReturnRoot("/basepath", out root, httpContext => + { + try + { + var requestInfo = httpContext.Features.Get(); + + // Request Keys + requestInfo.Method = null; + Assert.Null(requestInfo.Method); + requestInfo.Body = null; + Assert.Null(requestInfo.Body); + requestInfo.Headers = null; + Assert.Null(requestInfo.Headers); + requestInfo.Scheme = null; + Assert.Null(requestInfo.Scheme); + requestInfo.PathBase = null; + Assert.Null(requestInfo.PathBase); + requestInfo.Path = null; + Assert.Null(requestInfo.Path); + requestInfo.QueryString = null; + Assert.Null(requestInfo.QueryString); + requestInfo.RawTarget = null; + Assert.Null(requestInfo.RawTarget); + requestInfo.Protocol = null; + Assert.Null(requestInfo.Protocol); + + var connectionInfo = httpContext.Features.Get(); + connectionInfo.RemoteIpAddress = null; + Assert.Null(connectionInfo.RemoteIpAddress); + connectionInfo.RemotePort = -1; + Assert.Equal(-1, connectionInfo.RemotePort); + connectionInfo.LocalIpAddress = null; + Assert.Null(connectionInfo.LocalIpAddress); + connectionInfo.LocalPort = -1; + Assert.Equal(-1, connectionInfo.LocalPort); + connectionInfo.ConnectionId = null; + Assert.Null(connectionInfo.ConnectionId); + + // Trace identifier + var requestIdentifierFeature = httpContext.Features.Get(); + Assert.NotNull(requestIdentifierFeature); + requestIdentifierFeature.TraceIdentifier = null; + Assert.Null(requestIdentifierFeature.TraceIdentifier); + + // Note: Response keys are validated in the ResponseTests + } + catch (Exception ex) + { + byte[] body = Encoding.ASCII.GetBytes(ex.ToString()); + httpContext.Response.Body.Write(body, 0, body.Length); + } + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(root + "/basepath/SomePath?SomeQuery"); + Assert.Equal(string.Empty, response); + } + } + [Theory] [InlineData("/", "/", "", "/")] [InlineData("/basepath/", "/basepath", "/basepath", "")]