diff --git a/src/Microsoft.AspNet.Http.Extensions/project.json b/src/Microsoft.AspNet.Http.Extensions/project.json index 0254711dd8..cbc45a6ddc 100644 --- a/src/Microsoft.AspNet.Http.Extensions/project.json +++ b/src/Microsoft.AspNet.Http.Extensions/project.json @@ -8,7 +8,7 @@ "frameworks" : { "aspnet50" : { }, - "aspnetcore50" : { + "aspnetcore50" : { "dependencies": { "System.Reflection.TypeExtensions": "4.0.0-beta-*", "System.Runtime": "4.0.20-beta-*" diff --git a/src/Microsoft.AspNet.Http/FragmentString.cs b/src/Microsoft.AspNet.Http/FragmentString.cs index 1d725e742c..f40ba626a1 100644 --- a/src/Microsoft.AspNet.Http/FragmentString.cs +++ b/src/Microsoft.AspNet.Http/FragmentString.cs @@ -90,12 +90,8 @@ namespace Microsoft.AspNet.Http /// /// The Uri object /// The resulting FragmentString - public static FragmentString FromUriComponent(Uri uri) + public static FragmentString FromUriComponent([NotNull] Uri uri) { - if (uri == null) - { - throw new ArgumentNullException("uri"); - } string fragmentValue = uri.GetComponents(UriComponents.Fragment, UriFormat.UriEscaped); if (!string.IsNullOrEmpty(fragmentValue)) { diff --git a/src/Microsoft.AspNet.Http/HostString.cs b/src/Microsoft.AspNet.Http/HostString.cs index ae72f819cc..c91ce65a5b 100644 --- a/src/Microsoft.AspNet.Http/HostString.cs +++ b/src/Microsoft.AspNet.Http/HostString.cs @@ -134,12 +134,8 @@ namespace Microsoft.AspNet.Http /// /// /// - public static HostString FromUriComponent(Uri uri) + public static HostString FromUriComponent([NotNull] Uri uri) { - if (uri == null) - { - throw new ArgumentNullException("uri"); - } return new HostString(uri.GetComponents( UriComponents.NormalizedHost | // Always convert punycode to Unicode. UriComponents.HostAndPort, UriFormat.Unescaped)); diff --git a/src/Microsoft.AspNet.Http/HttpRequest.cs b/src/Microsoft.AspNet.Http/HttpRequest.cs index be5e1e14b1..8048152a08 100644 --- a/src/Microsoft.AspNet.Http/HttpRequest.cs +++ b/src/Microsoft.AspNet.Http/HttpRequest.cs @@ -64,12 +64,6 @@ namespace Microsoft.AspNet.Http /// The query value collection parsed from owin.RequestQueryString. public abstract IReadableStringCollection Query { get; } - /// - /// Gets the form collection. - /// - /// The form collection parsed from the request body. - public abstract Task GetFormAsync(CancellationToken cancellationToken = default(CancellationToken)); - /// /// Gets or set the owin.RequestProtocol. /// @@ -128,5 +122,21 @@ namespace Microsoft.AspNet.Http /// /// The owin.RequestBody Stream. public abstract Stream Body { get; set; } + + /// + /// Checks the content-type header for form types. + /// + public abstract bool HasFormContentType { get; } + + /// + /// Gets or sets the request body as a form. + /// + public abstract IFormCollection Form { get; set; } + + /// + /// Reads the request body if it is a form. + /// + /// + public abstract Task ReadFormAsync(CancellationToken cancellationToken = new CancellationToken()); } } diff --git a/src/Microsoft.AspNet.Http/IFormCollection.cs b/src/Microsoft.AspNet.Http/IFormCollection.cs index a69162fa4d..4ec7437877 100644 --- a/src/Microsoft.AspNet.Http/IFormCollection.cs +++ b/src/Microsoft.AspNet.Http/IFormCollection.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections.Generic; + namespace Microsoft.AspNet.Http { /// @@ -8,5 +10,6 @@ namespace Microsoft.AspNet.Http /// public interface IFormCollection : IReadableStringCollection { + IFormFileCollection Files { get; } } } diff --git a/src/Microsoft.AspNet.Http/IFormFile.cs b/src/Microsoft.AspNet.Http/IFormFile.cs new file mode 100644 index 0000000000..a77a495d5e --- /dev/null +++ b/src/Microsoft.AspNet.Http/IFormFile.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; + +namespace Microsoft.AspNet.Http +{ + public interface IFormFile + { + string ContentType { get; } + + string ContentDisposition { get; } + + IHeaderDictionary Headers { get; } + + long Length { get; } + + Stream OpenReadStream(); + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.Http/IFormFileCollection.cs b/src/Microsoft.AspNet.Http/IFormFileCollection.cs new file mode 100644 index 0000000000..56f1d5879d --- /dev/null +++ b/src/Microsoft.AspNet.Http/IFormFileCollection.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; + +namespace Microsoft.AspNet.Http +{ + public interface IFormFileCollection : IList + { + IFormFile this[string name] { get; } + + IFormFile GetFile(string name); + + IList GetFiles(string name); + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.Http/PathString.cs b/src/Microsoft.AspNet.Http/PathString.cs index 05dfa6d816..2963c64d8f 100644 --- a/src/Microsoft.AspNet.Http/PathString.cs +++ b/src/Microsoft.AspNet.Http/PathString.cs @@ -86,12 +86,8 @@ namespace Microsoft.AspNet.Http /// /// The Uri object /// The resulting PathString - public static PathString FromUriComponent(Uri uri) + public static PathString FromUriComponent([NotNull] Uri uri) { - if (uri == null) - { - throw new ArgumentNullException("uri"); - } // REVIEW: what is the exactly correct thing to do? return new PathString("/" + uri.GetComponents(UriComponents.Path, UriFormat.Unescaped)); } diff --git a/src/Microsoft.AspNet.Http/QueryString.cs b/src/Microsoft.AspNet.Http/QueryString.cs index 5ba16151c7..c32e3bb037 100644 --- a/src/Microsoft.AspNet.Http/QueryString.cs +++ b/src/Microsoft.AspNet.Http/QueryString.cs @@ -102,12 +102,8 @@ namespace Microsoft.AspNet.Http /// /// The Uri object /// The resulting QueryString - public static QueryString FromUriComponent(Uri uri) + public static QueryString FromUriComponent([NotNull] Uri uri) { - if (uri == null) - { - throw new ArgumentNullException("uri"); - } string queryValue = uri.GetComponents(UriComponents.Query, UriFormat.UriEscaped); if (!string.IsNullOrEmpty(queryValue)) { diff --git a/src/Microsoft.AspNet.Http/Security/AuthenticateResult.cs b/src/Microsoft.AspNet.Http/Security/AuthenticateResult.cs index 5a1998f336..38b48dcf0e 100644 --- a/src/Microsoft.AspNet.Http/Security/AuthenticateResult.cs +++ b/src/Microsoft.AspNet.Http/Security/AuthenticateResult.cs @@ -18,16 +18,8 @@ namespace Microsoft.AspNet.Http.Security /// Assigned to Identity. May be null. /// Assigned to Properties. Contains extra information carried along with the identity. /// Assigned to Description. Contains information describing the authentication provider. - public AuthenticationResult(IIdentity identity, AuthenticationProperties properties, AuthenticationDescription description) + public AuthenticationResult(IIdentity identity, [NotNull] AuthenticationProperties properties, [NotNull] AuthenticationDescription description) { - if (properties == null) - { - throw new ArgumentNullException("properties"); - } - if (description == null) - { - throw new ArgumentNullException("description"); - } if (identity != null) { Identity = identity as ClaimsIdentity ?? new ClaimsIdentity(identity); diff --git a/src/Microsoft.AspNet.Http/Security/AuthenticationDescription.cs b/src/Microsoft.AspNet.Http/Security/AuthenticationDescription.cs index d6423f34fd..2d7d4f0c73 100644 --- a/src/Microsoft.AspNet.Http/Security/AuthenticationDescription.cs +++ b/src/Microsoft.AspNet.Http/Security/AuthenticationDescription.cs @@ -27,12 +27,8 @@ namespace Microsoft.AspNet.Http.Security /// Initializes a new instance of the class /// /// - public AuthenticationDescription(IDictionary properties) + public AuthenticationDescription([NotNull] IDictionary properties) { - if (properties == null) - { - throw new ArgumentNullException("properties"); - } Dictionary = properties; } diff --git a/src/Microsoft.AspNet.Owin/NotNullAttribute.cs b/src/Microsoft.AspNet.Owin/NotNullAttribute.cs new file mode 100644 index 0000000000..a42aa58d4a --- /dev/null +++ b/src/Microsoft.AspNet.Owin/NotNullAttribute.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNet.Owin +{ + [AttributeUsage(AttributeTargets.Parameter, AllowMultiple = false)] + internal sealed class NotNullAttribute : Attribute + { + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.Owin/OwinFeatureCollection.cs b/src/Microsoft.AspNet.Owin/OwinFeatureCollection.cs index 23ca6df7bf..434817129d 100644 --- a/src/Microsoft.AspNet.Owin/OwinFeatureCollection.cs +++ b/src/Microsoft.AspNet.Owin/OwinFeatureCollection.cs @@ -399,12 +399,8 @@ namespace Microsoft.AspNet.Owin return TryGetValue(item.Key, out result) && result.Equals(item.Value); } - public void CopyTo(KeyValuePair[] array, int arrayIndex) + public void CopyTo([NotNull] KeyValuePair[] array, int arrayIndex) { - if (array == null) - { - throw new ArgumentNullException("array"); - } if (arrayIndex < 0 || arrayIndex > array.Length) { throw new ArgumentOutOfRangeException("arrayIndex", arrayIndex, string.Empty); diff --git a/src/Microsoft.AspNet.Owin/WebSockets/WebSocketAdapter.cs b/src/Microsoft.AspNet.Owin/WebSockets/WebSocketAdapter.cs index 79f8c0bf2e..9577785639 100644 --- a/src/Microsoft.AspNet.Owin/WebSockets/WebSocketAdapter.cs +++ b/src/Microsoft.AspNet.Owin/WebSockets/WebSocketAdapter.cs @@ -113,7 +113,7 @@ namespace Microsoft.AspNet.Owin } else { - throw new ArgumentOutOfRangeException("buffer"); + throw new ArgumentOutOfRangeException(nameof(buffer)); } } @@ -149,7 +149,7 @@ namespace Microsoft.AspNet.Owin case 0x8: return WebSocketMessageType.Close; default: - throw new ArgumentOutOfRangeException("messageType", messageType, string.Empty); + throw new ArgumentOutOfRangeException(nameof(messageType), messageType, string.Empty); } } @@ -164,7 +164,7 @@ namespace Microsoft.AspNet.Owin case WebSocketMessageType.Close: return 0x8; default: - throw new ArgumentOutOfRangeException("webSocketMessageType", webSocketMessageType, string.Empty); + throw new ArgumentOutOfRangeException(nameof(webSocketMessageType), webSocketMessageType, string.Empty); } } } diff --git a/src/Microsoft.AspNet.PipelineCore/BufferingHelper.cs b/src/Microsoft.AspNet.PipelineCore/BufferingHelper.cs new file mode 100644 index 0000000000..b10c017c36 --- /dev/null +++ b/src/Microsoft.AspNet.PipelineCore/BufferingHelper.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using Microsoft.AspNet.Http; +using Microsoft.AspNet.WebUtilities; + +namespace Microsoft.AspNet.PipelineCore +{ + public static class BufferingHelper + { + internal const int DefaultBufferThreshold = 1024 * 30; + + public static string TempDirectory + { + get + { + var temp = Environment.GetEnvironmentVariable("ASPNET_TEMP"); + if (string.IsNullOrEmpty(temp)) + { + temp = Environment.GetEnvironmentVariable("TEMP"); + } + + if (!Directory.Exists(temp)) + { + // TODO: ??? + throw new DirectoryNotFoundException(temp); + } + + return temp; + } + } + + public static HttpRequest EnableRewind([NotNull] this HttpRequest request, int bufferThreshold = DefaultBufferThreshold) + { + var body = request.Body; + if (!body.CanSeek) + { + // TODO: Register this buffer for disposal at the end of the request to ensure the temp file is deleted. + // Otherwise it won't get deleted until GC closes the stream. + request.Body = new FileBufferingReadStream(body, bufferThreshold, TempDirectory); + } + return request; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebUtilities/Collections/FormCollection.cs b/src/Microsoft.AspNet.PipelineCore/Collections/FormCollection.cs similarity index 53% rename from src/Microsoft.AspNet.WebUtilities/Collections/FormCollection.cs rename to src/Microsoft.AspNet.PipelineCore/Collections/FormCollection.cs index a9d1df8529..a80fea12d1 100644 --- a/src/Microsoft.AspNet.WebUtilities/Collections/FormCollection.cs +++ b/src/Microsoft.AspNet.PipelineCore/Collections/FormCollection.cs @@ -1,23 +1,27 @@ // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using Microsoft.AspNet.Http; using System.Collections.Generic; +using Microsoft.AspNet.Http; -namespace Microsoft.AspNet.WebUtilities.Collections +namespace Microsoft.AspNet.PipelineCore.Collections { /// /// Contains the parsed form values. /// public class FormCollection : ReadableStringCollection, IFormCollection { - /// - /// Initializes a new instance of the class. - /// - /// The store for the form. - public FormCollection(IDictionary store) - : base(store) + public FormCollection([NotNull] IDictionary store) + : this(store, new FormFileCollection()) { } + + public FormCollection([NotNull] IDictionary store, [NotNull] IFormFileCollection files) + : base(store) + { + Files = files; + } + + public IFormFileCollection Files { get; private set; } } } diff --git a/src/Microsoft.AspNet.PipelineCore/Collections/FormFileCollection.cs b/src/Microsoft.AspNet.PipelineCore/Collections/FormFileCollection.cs new file mode 100644 index 0000000000..10aad2bfa2 --- /dev/null +++ b/src/Microsoft.AspNet.PipelineCore/Collections/FormFileCollection.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNet.Http; + +namespace Microsoft.AspNet.PipelineCore.Collections +{ + public class FormFileCollection : List, IFormFileCollection + { + public IFormFile this[string name] + { + get { return GetFile(name); } + } + + public IFormFile GetFile(string name) + { + return Find(file => string.Equals(name, GetName(file.ContentDisposition))); + } + + public IList GetFiles(string name) + { + return FindAll(file => string.Equals(name, GetName(file.ContentDisposition))); + } + + private static string GetName(string contentDisposition) + { + // TODO: Strongly typed headers will take care of this + // Content-Disposition: form-data; name="myfile1"; filename="Misc 002.jpg" + var offset = contentDisposition.IndexOf("name=\"") + "name=\"".Length; + var key = contentDisposition.Substring(offset, contentDisposition.IndexOf("\"", offset) - offset); // Remove quotes + return key; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.PipelineCore/Collections/HeaderDictionary.cs b/src/Microsoft.AspNet.PipelineCore/Collections/HeaderDictionary.cs index 22d62c5989..dfd54393cc 100644 --- a/src/Microsoft.AspNet.PipelineCore/Collections/HeaderDictionary.cs +++ b/src/Microsoft.AspNet.PipelineCore/Collections/HeaderDictionary.cs @@ -20,13 +20,8 @@ namespace Microsoft.AspNet.PipelineCore.Collections /// Initializes a new instance of the class. /// /// The underlying data store. - public HeaderDictionary(IDictionary store) + public HeaderDictionary([NotNull] IDictionary store) { - if (store == null) - { - throw new ArgumentNullException("store"); - } - Store = store; } diff --git a/src/Microsoft.AspNet.PipelineCore/Collections/ItemsDictionary.cs b/src/Microsoft.AspNet.PipelineCore/Collections/ItemsDictionary.cs index e9c21c835d..dc5216d117 100644 --- a/src/Microsoft.AspNet.PipelineCore/Collections/ItemsDictionary.cs +++ b/src/Microsoft.AspNet.PipelineCore/Collections/ItemsDictionary.cs @@ -3,7 +3,6 @@ using System.Collections; using System.Collections.Generic; -using Microsoft.AspNet.Http; namespace Microsoft.AspNet.PipelineCore { diff --git a/src/Microsoft.AspNet.WebUtilities/Collections/ReadableStringCollection.cs b/src/Microsoft.AspNet.PipelineCore/Collections/ReadableStringCollection.cs similarity index 86% rename from src/Microsoft.AspNet.WebUtilities/Collections/ReadableStringCollection.cs rename to src/Microsoft.AspNet.PipelineCore/Collections/ReadableStringCollection.cs index 97c03b0df2..1e12bb6dc1 100644 --- a/src/Microsoft.AspNet.WebUtilities/Collections/ReadableStringCollection.cs +++ b/src/Microsoft.AspNet.PipelineCore/Collections/ReadableStringCollection.cs @@ -6,7 +6,7 @@ using System.Collections; using System.Collections.Generic; using Microsoft.AspNet.Http; -namespace Microsoft.AspNet.WebUtilities.Collections +namespace Microsoft.AspNet.PipelineCore.Collections { /// /// Accessors for query, forms, etc. @@ -17,13 +17,8 @@ namespace Microsoft.AspNet.WebUtilities.Collections /// Create a new wrapper /// /// - public ReadableStringCollection(IDictionary store) + public ReadableStringCollection([NotNull] IDictionary store) { - if (store == null) - { - throw new ArgumentNullException("store"); - } - Store = store; } @@ -75,7 +70,7 @@ namespace Microsoft.AspNet.WebUtilities.Collections /// public string Get(string key) { - return ParsingHelpers.GetJoinedValue(Store, key); + return GetJoinedValue(Store, key); } /// @@ -108,5 +103,15 @@ namespace Microsoft.AspNet.WebUtilities.Collections { return GetEnumerator(); } + + private static string GetJoinedValue(IDictionary store, string key) + { + string[] values; + if (store.TryGetValue(key, out values)) + { + return string.Join(",", values); + } + return null; + } } } diff --git a/src/Microsoft.AspNet.PipelineCore/Collections/ResponseCookies.cs b/src/Microsoft.AspNet.PipelineCore/Collections/ResponseCookies.cs index 27bfbbd38e..2412b54a83 100644 --- a/src/Microsoft.AspNet.PipelineCore/Collections/ResponseCookies.cs +++ b/src/Microsoft.AspNet.PipelineCore/Collections/ResponseCookies.cs @@ -19,13 +19,8 @@ namespace Microsoft.AspNet.PipelineCore.Collections /// Create a new wrapper /// /// - public ResponseCookies(IHeaderDictionary headers) + public ResponseCookies([NotNull] IHeaderDictionary headers) { - if (headers == null) - { - throw new ArgumentNullException("headers"); - } - Headers = headers; } @@ -47,13 +42,8 @@ namespace Microsoft.AspNet.PipelineCore.Collections /// /// /// - public void Append(string key, string value, CookieOptions options) + public void Append(string key, string value, [NotNull] CookieOptions options) { - if (options == null) - { - throw new ArgumentNullException("options"); - } - bool domainHasValue = !string.IsNullOrEmpty(options.Domain); bool pathHasValue = !string.IsNullOrEmpty(options.Path); bool expiresHasValue = options.Expires.HasValue; @@ -98,13 +88,8 @@ namespace Microsoft.AspNet.PipelineCore.Collections /// /// /// - public void Delete(string key, CookieOptions options) + public void Delete(string key, [NotNull] CookieOptions options) { - if (options == null) - { - throw new ArgumentNullException("options"); - } - bool domainHasValue = !string.IsNullOrEmpty(options.Domain); bool pathHasValue = !string.IsNullOrEmpty(options.Path); diff --git a/src/Microsoft.AspNet.PipelineCore/DefaultHttpContext.cs b/src/Microsoft.AspNet.PipelineCore/DefaultHttpContext.cs index 308d896564..3b445fefbe 100644 --- a/src/Microsoft.AspNet.PipelineCore/DefaultHttpContext.cs +++ b/src/Microsoft.AspNet.PipelineCore/DefaultHttpContext.cs @@ -214,12 +214,8 @@ namespace Microsoft.AspNet.PipelineCore return authTypeContext.Results; } - public override IEnumerable Authenticate(IEnumerable authenticationTypes) + public override IEnumerable Authenticate([NotNull] IEnumerable authenticationTypes) { - if (authenticationTypes == null) - { - throw new ArgumentNullException(); - } var handler = HttpAuthenticationFeature.Handler; var authenticateContext = new AuthenticateContext(authenticationTypes); @@ -238,12 +234,8 @@ namespace Microsoft.AspNet.PipelineCore return authenticateContext.Results; } - public override async Task> AuthenticateAsync(IEnumerable authenticationTypes) + public override async Task> AuthenticateAsync([NotNull] IEnumerable authenticationTypes) { - if (authenticationTypes == null) - { - throw new ArgumentNullException(); - } var handler = HttpAuthenticationFeature.Handler; var authenticateContext = new AuthenticateContext(authenticationTypes); diff --git a/src/Microsoft.AspNet.PipelineCore/DefaultHttpRequest.cs b/src/Microsoft.AspNet.PipelineCore/DefaultHttpRequest.cs index eac94536d6..da839c1f16 100644 --- a/src/Microsoft.AspNet.PipelineCore/DefaultHttpRequest.cs +++ b/src/Microsoft.AspNet.PipelineCore/DefaultHttpRequest.cs @@ -2,13 +2,12 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Globalization; using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNet.FeatureModel; using Microsoft.AspNet.Http; using Microsoft.AspNet.Http.Infrastructure; -using Microsoft.AspNet.FeatureModel; using Microsoft.AspNet.HttpFeature; using Microsoft.AspNet.PipelineCore.Collections; using Microsoft.AspNet.PipelineCore.Infrastructure; @@ -55,7 +54,7 @@ namespace Microsoft.AspNet.PipelineCore private IFormFeature FormFeature { - get { return _form.Fetch(_features) ?? _form.Update(_features, new FormFeature(_features)); } + get { return _form.Fetch(_features) ?? _form.Update(_features, new FormFeature(this)); } } private IRequestCookiesFeature RequestCookiesFeature @@ -83,7 +82,7 @@ namespace Microsoft.AspNet.PipelineCore set { HttpRequestFeature.QueryString = value.Value; } } - public override long? ContentLength + public override long? ContentLength { get { @@ -129,11 +128,6 @@ namespace Microsoft.AspNet.PipelineCore get { return QueryFeature.Query; } } - public override Task GetFormAsync(CancellationToken cancellationToken = default(CancellationToken)) - { - return FormFeature.GetFormAsync(cancellationToken); - } - public override string Protocol { get { return HttpRequestFeature.Protocol; } @@ -167,5 +161,21 @@ namespace Microsoft.AspNet.PipelineCore get { return Headers[Constants.Headers.AcceptCharset]; } set { Headers[Constants.Headers.AcceptCharset] = value; } } + + public override bool HasFormContentType + { + get { return FormFeature.HasFormContentType; } + } + + public override IFormCollection Form + { + get { return FormFeature.ReadForm(); } + set { FormFeature.Form = value; } + } + + public override Task ReadFormAsync(CancellationToken cancellationToken) + { + return FormFeature.ReadFormAsync(cancellationToken); + } } } \ No newline at end of file diff --git a/src/Microsoft.AspNet.PipelineCore/DefaultHttpResponse.cs b/src/Microsoft.AspNet.PipelineCore/DefaultHttpResponse.cs index 03841751ec..16d187450d 100644 --- a/src/Microsoft.AspNet.PipelineCore/DefaultHttpResponse.cs +++ b/src/Microsoft.AspNet.PipelineCore/DefaultHttpResponse.cs @@ -129,12 +129,8 @@ namespace Microsoft.AspNet.PipelineCore Headers.Set(Constants.Headers.Location, location); } - public override void Challenge(AuthenticationProperties properties, IEnumerable authenticationTypes) + public override void Challenge(AuthenticationProperties properties, [NotNull] IEnumerable authenticationTypes) { - if (authenticationTypes == null) - { - throw new ArgumentNullException(); - } HttpResponseFeature.StatusCode = 401; var handler = HttpAuthenticationFeature.Handler; @@ -152,13 +148,8 @@ namespace Microsoft.AspNet.PipelineCore } } - public override void SignIn(AuthenticationProperties properties, IEnumerable identities) + public override void SignIn(AuthenticationProperties properties, [NotNull] IEnumerable identities) { - if (identities == null) - { - throw new ArgumentNullException(); - } - var handler = HttpAuthenticationFeature.Handler; var signInContext = new SignInContext(identities, properties == null ? null : properties.Dictionary); @@ -175,12 +166,8 @@ namespace Microsoft.AspNet.PipelineCore } } - public override void SignOut(IEnumerable authenticationTypes) + public override void SignOut([NotNull] IEnumerable authenticationTypes) { - if (authenticationTypes == null) - { - throw new ArgumentNullException(); - } var handler = HttpAuthenticationFeature.Handler; var signOutContext = new SignOutContext(authenticationTypes); diff --git a/src/Microsoft.AspNet.PipelineCore/FormFeature.cs b/src/Microsoft.AspNet.PipelineCore/FormFeature.cs index 271264b870..8d20871688 100644 --- a/src/Microsoft.AspNet.PipelineCore/FormFeature.cs +++ b/src/Microsoft.AspNet.PipelineCore/FormFeature.cs @@ -1,71 +1,192 @@ // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNet.FeatureModel; using Microsoft.AspNet.Http; -using Microsoft.AspNet.HttpFeature; -using Microsoft.AspNet.PipelineCore.Infrastructure; +using Microsoft.AspNet.PipelineCore.Collections; using Microsoft.AspNet.WebUtilities; -using Microsoft.AspNet.WebUtilities.Collections; namespace Microsoft.AspNet.PipelineCore { public class FormFeature : IFormFeature { - private readonly IFeatureCollection _features; - private readonly FeatureReference _request = FeatureReference.Default; - private Stream _bodyStream; - private IReadableStringCollection _form; + private readonly HttpRequest _request; - public FormFeature([NotNull] IDictionary form) - : this (new ReadableStringCollection(form)) + public FormFeature([NotNull] IFormCollection form) { + Form = form; } - public FormFeature([NotNull] IReadableStringCollection form) + public FormFeature([NotNull] HttpRequest request) { - _form = form; + _request = request; } - public FormFeature([NotNull] IFeatureCollection features) + public bool HasFormContentType { - _features = features; - } - - public async Task GetFormAsync(CancellationToken cancellationToken) - { - if (_features == null) + get { - return _form; + // Set directly + if (Form != null) + { + return true; + } + + return HasApplicationFormContentType() || HasMultipartFormContentType(); + } + } + + public IFormCollection Form { get; set; } + + public IFormCollection ReadForm() + { + if (Form != null) + { + return Form; } - var body = _request.Fetch(_features).Body; - - if (_bodyStream == null || _bodyStream != body) + if (!HasFormContentType) { - _bodyStream = body; - if (!_bodyStream.CanSeek) + throw new InvalidOperationException("Incorrect Content-Type: " + _request.ContentType); + } + + // TODO: How do we prevent thread exhaustion? + return ReadFormAsync(CancellationToken.None).GetAwaiter().GetResult(); + } + + public async Task ReadFormAsync(CancellationToken cancellationToken) + { + if (Form != null) + { + return Form; + } + + if (!HasFormContentType) + { + throw new InvalidOperationException("Incorrect Content-Type: " + _request.ContentType); + } + + cancellationToken.ThrowIfCancellationRequested(); + + _request.EnableRewind(); + + IDictionary formFields = null; + var files = new FormFileCollection(); + + // Some of these code paths use StreamReader which does not support cancellation tokens. + using (cancellationToken.Register(_request.HttpContext.Abort)) + { + // Check the content-type + if (HasApplicationFormContentType()) { - var buffer = new MemoryStream(); - await _bodyStream.CopyToAsync(buffer, 4096, cancellationToken); - _bodyStream = buffer; - _request.Fetch(_features).Body = _bodyStream; - _bodyStream.Seek(0, SeekOrigin.Begin); + // TODO: Read the charset from the content-type header after we get strongly typed headers + formFields = await FormReader.ReadFormAsync(_request.Body, cancellationToken); } - using (var streamReader = new StreamReader(_bodyStream, Encoding.UTF8, - detectEncodingFromByteOrderMarks: true, - bufferSize: 1024, leaveOpen: true)) + else if (HasMultipartFormContentType()) { - string form = await streamReader.ReadToEndAsync(); - _form = FormHelpers.ParseForm(form); + var formAccumulator = new KeyValueAccumulator(StringComparer.OrdinalIgnoreCase); + + var boundary = GetBoundary(_request.ContentType); + var multipartReader = new MultipartReader(boundary, _request.Body); + var section = await multipartReader.ReadNextSectionAsync(cancellationToken); + while (section != null) + { + var headers = new HeaderDictionary(section.Headers); + var contentDisposition = headers["Content-Disposition"]; + if (HasFileContentDisposition(contentDisposition)) + { + // Find the end + await section.Body.DrainAsync(cancellationToken); + + var file = new FormFile(_request.Body, section.BaseStreamOffset.Value, section.Body.Length) + { + Headers = headers, + }; + files.Add(file); + } + else if (HasFormDataContentDisposition(contentDisposition)) + { + // Content-Disposition: form-data; name="key" + // + // value + + // TODO: Strongly typed headers will take care of this + var offset = contentDisposition.IndexOf("name=") + "name=".Length; + var key = contentDisposition.Substring(offset + 1, contentDisposition.Length - offset - 2); // Remove quotes + + // TODO: Read the charset from the content-disposition header after we get strongly typed headers + using (var reader = new StreamReader(section.Body, Encoding.UTF8, detectEncodingFromByteOrderMarks: true, bufferSize: 1024, leaveOpen: true)) + { + var value = await reader.ReadToEndAsync(); + formAccumulator.Append(key, value); + } + } + else + { + System.Diagnostics.Debug.Assert(false, "Unrecognized content-disposition for this section: " + contentDisposition); + } + + section = await multipartReader.ReadNextSectionAsync(cancellationToken); + } + + formFields = formAccumulator.GetResults(); } } - return _form; + + Form = new FormCollection(formFields, files); + return Form; + } + + private bool HasApplicationFormContentType() + { + // TODO: Strongly typed headers will take care of this for us + // Content-Type: application/x-www-form-urlencoded; charset=utf-8 + var contentType = _request.ContentType; + return !string.IsNullOrEmpty(contentType) && contentType.IndexOf("application/x-www-form-urlencoded", StringComparison.OrdinalIgnoreCase) >= 0; + } + + private bool HasMultipartFormContentType() + { + // TODO: Strongly typed headers will take care of this for us + // Content-Type: multipart/form-data; boundary=----WebKitFormBoundarymx2fSWqWSd0OxQqq + var contentType = _request.ContentType; + return !string.IsNullOrEmpty(contentType) && contentType.IndexOf("multipart/form-data", StringComparison.OrdinalIgnoreCase) >= 0; + } + + private bool HasFormDataContentDisposition(string contentDisposition) + { + // TODO: Strongly typed headers will take care of this for us + // Content-Disposition: form-data; name="key"; + return !string.IsNullOrEmpty(contentDisposition) && contentDisposition.Contains("form-data") && !contentDisposition.Contains("filename="); + } + + private bool HasFileContentDisposition(string contentDisposition) + { + // TODO: Strongly typed headers will take care of this for us + // Content-Disposition: form-data; name="myfile1"; filename="Misc 002.jpg" + return !string.IsNullOrEmpty(contentDisposition) && contentDisposition.Contains("form-data") && contentDisposition.Contains("filename="); + } + + // Content-Type: multipart/form-data; boundary=----WebKitFormBoundarymx2fSWqWSd0OxQqq + private static string GetBoundary(string contentType) + { + // TODO: Strongly typed headers will take care of this for us + // TODO: Limit the length of boundary we accept. The spec says ~70 chars. + var elements = contentType.Split(' '); + var element = elements.Where(entry => entry.StartsWith("boundary=")).First(); + var boundary = element.Substring("boundary=".Length); + // Remove quotes + if (boundary.Length >= 2 && boundary[0] == '"' && boundary[boundary.Length - 1] == '"') + { + boundary = boundary.Substring(1, boundary.Length - 2); + } + return boundary; } } } diff --git a/src/Microsoft.AspNet.PipelineCore/FormFile.cs b/src/Microsoft.AspNet.PipelineCore/FormFile.cs new file mode 100644 index 0000000000..19ea07466f --- /dev/null +++ b/src/Microsoft.AspNet.PipelineCore/FormFile.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; +using Microsoft.AspNet.Http; + +namespace Microsoft.AspNet.PipelineCore +{ + public class FormFile : IFormFile + { + private Stream _baseStream; + private long _baseStreamOffset; + private long _length; + + public FormFile(Stream baseStream, long baseStreamOffset, long length) + { + _baseStream = baseStream; + _baseStreamOffset = baseStreamOffset; + _length = length; + } + + public string ContentDisposition + { + get { return Headers["Content-Disposition"]; } + set { Headers["Content-Disposition"] = value; } + } + + public string ContentType + { + get { return Headers["Content-Type"]; } + set { Headers["Content-Type"] = value; } + } + + public IHeaderDictionary Headers { get; set; } + + public long Length + { + get { return _length; } + } + + public Stream OpenReadStream() + { + return new ReferenceReadStream(_baseStream, _baseStreamOffset, _length); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.PipelineCore/IFormFeature.cs b/src/Microsoft.AspNet.PipelineCore/IFormFeature.cs index 07debf7870..cc84d32a4b 100644 --- a/src/Microsoft.AspNet.PipelineCore/IFormFeature.cs +++ b/src/Microsoft.AspNet.PipelineCore/IFormFeature.cs @@ -4,11 +4,33 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNet.Http; +using Microsoft.Framework.Runtime; namespace Microsoft.AspNet.PipelineCore { public interface IFormFeature { - Task GetFormAsync(CancellationToken cancellationToken); + /// + /// Indicates if the request has a supported form content-type. + /// + bool HasFormContentType { get; } + + /// + /// The parsed form, if any. + /// + IFormCollection Form { get; set; } + + /// + /// Parses the request body as a form. + /// + /// + IFormCollection ReadForm(); + + /// + /// Parses the request body as a form. + /// + /// + /// + Task ReadFormAsync(CancellationToken cancellationToken); } } diff --git a/src/Microsoft.AspNet.PipelineCore/Infrastructure/ParsingHelpers.cs b/src/Microsoft.AspNet.PipelineCore/Infrastructure/ParsingHelpers.cs index 10c9ef3672..e0ed2c7333 100644 --- a/src/Microsoft.AspNet.PipelineCore/Infrastructure/ParsingHelpers.cs +++ b/src/Microsoft.AspNet.PipelineCore/Infrastructure/ParsingHelpers.cs @@ -445,12 +445,8 @@ namespace Microsoft.AspNet.PipelineCore.Infrastructure #endregion - public bool StartsWith(string text, StringComparison comparisonType) + public bool StartsWith([NotNull] string text, StringComparison comparisonType) { - if (text == null) - { - throw new ArgumentNullException("text"); - } int textLength = text.Length; if (!HasValue || _count < textLength) { @@ -460,12 +456,8 @@ namespace Microsoft.AspNet.PipelineCore.Infrastructure return string.Compare(_buffer, _offset, text, 0, textLength, comparisonType) == 0; } - public bool EndsWith(string text, StringComparison comparisonType) + public bool EndsWith([NotNull] string text, StringComparison comparisonType) { - if (text == null) - { - throw new ArgumentNullException("text"); - } int textLength = text.Length; if (!HasValue || _count < textLength) { @@ -475,12 +467,8 @@ namespace Microsoft.AspNet.PipelineCore.Infrastructure return string.Compare(_buffer, _offset + _count - textLength, text, 0, textLength, comparisonType) == 0; } - public bool Equals(string text, StringComparison comparisonType) + public bool Equals([NotNull] string text, StringComparison comparisonType) { - if (text == null) - { - throw new ArgumentNullException("text"); - } int textLength = text.Length; if (!HasValue || _count != textLength) { @@ -615,25 +603,17 @@ namespace Microsoft.AspNet.PipelineCore.Infrastructure } } - public static string[] GetHeaderUnmodified(IDictionary headers, string key) + public static string[] GetHeaderUnmodified([NotNull] IDictionary headers, string key) { - if (headers == null) - { - throw new ArgumentNullException("headers"); - } string[] values; return headers.TryGetValue(key, out values) ? values : null; } - public static void SetHeader(IDictionary headers, string key, string value) + public static void SetHeader([NotNull] IDictionary headers, [NotNull] string key, string value) { - if (headers == null) - { - throw new ArgumentNullException("headers"); - } if (string.IsNullOrWhiteSpace(key)) { - throw new ArgumentNullException("key"); + throw new ArgumentNullException(nameof(key)); } if (string.IsNullOrWhiteSpace(value)) { @@ -645,15 +625,11 @@ namespace Microsoft.AspNet.PipelineCore.Infrastructure } } - public static void SetHeaderJoined(IDictionary headers, string key, params string[] values) + public static void SetHeaderJoined([NotNull] IDictionary headers, [NotNull] string key, params string[] values) { - if (headers == null) - { - throw new ArgumentNullException("headers"); - } if (string.IsNullOrWhiteSpace(key)) { - throw new ArgumentNullException("key"); + throw new ArgumentNullException(nameof(key)); } if (values == null || values.Length == 0) { @@ -697,15 +673,11 @@ namespace Microsoft.AspNet.PipelineCore.Infrastructure return value; } - public static void SetHeaderUnmodified(IDictionary headers, string key, params string[] values) + public static void SetHeaderUnmodified([NotNull] IDictionary headers, [NotNull] string key, params string[] values) { - if (headers == null) - { - throw new ArgumentNullException("headers"); - } if (string.IsNullOrWhiteSpace(key)) { - throw new ArgumentNullException("key"); + throw new ArgumentNullException(nameof(key)); } if (values == null || values.Length == 0) { @@ -717,16 +689,12 @@ namespace Microsoft.AspNet.PipelineCore.Infrastructure } } - public static void SetHeaderUnmodified(IDictionary headers, string key, IEnumerable values) + public static void SetHeaderUnmodified([NotNull] IDictionary headers, [NotNull] string key, [NotNull] IEnumerable values) { - if (headers == null) - { - throw new ArgumentNullException("headers"); - } headers[key] = values.ToArray(); } - public static void AppendHeader(IDictionary headers, string key, string values) + public static void AppendHeader([NotNull] IDictionary headers, [NotNull] string key, string values) { if (string.IsNullOrWhiteSpace(values)) { @@ -744,7 +712,7 @@ namespace Microsoft.AspNet.PipelineCore.Infrastructure } } - public static void AppendHeaderJoined(IDictionary headers, string key, params string[] values) + public static void AppendHeaderJoined([NotNull] IDictionary headers, [NotNull] string key, params string[] values) { if (values == null || values.Length == 0) { @@ -762,7 +730,7 @@ namespace Microsoft.AspNet.PipelineCore.Infrastructure } } - public static void AppendHeaderUnmodified(IDictionary headers, string key, params string[] values) + public static void AppendHeaderUnmodified([NotNull] IDictionary headers, [NotNull] string key, params string[] values) { if (values == null || values.Length == 0) { @@ -801,12 +769,8 @@ namespace Microsoft.AspNet.PipelineCore.Infrastructure return values == null ? null : string.Join(",", values); } - internal static string[] GetUnmodifiedValues(IDictionary store, string key) + internal static string[] GetUnmodifiedValues([NotNull] IDictionary store, string key) { - if (store == null) - { - throw new ArgumentNullException("store"); - } string[] values; return store.TryGetValue(key, out values) ? values : null; } @@ -826,7 +790,7 @@ namespace Microsoft.AspNet.PipelineCore.Infrastructure // return string.IsNullOrWhiteSpace(localPort) ? localIpAddress : (localIpAddress + ":" + localPort); //} - public static long? GetContentLength(IHeaderDictionary headers) + public static long? GetContentLength([NotNull] IHeaderDictionary headers) { const NumberStyles styles = NumberStyles.AllowLeadingWhite | NumberStyles.AllowTrailingWhite; long value; @@ -840,7 +804,7 @@ namespace Microsoft.AspNet.PipelineCore.Infrastructure return null; } - public static void SetContentLength(IHeaderDictionary headers, long? value) + public static void SetContentLength([NotNull] IHeaderDictionary headers, long? value) { if (value.HasValue) { diff --git a/src/Microsoft.AspNet.PipelineCore/QueryFeature.cs b/src/Microsoft.AspNet.PipelineCore/QueryFeature.cs index 5106808d03..e57b406153 100644 --- a/src/Microsoft.AspNet.PipelineCore/QueryFeature.cs +++ b/src/Microsoft.AspNet.PipelineCore/QueryFeature.cs @@ -5,9 +5,9 @@ using System.Collections.Generic; using Microsoft.AspNet.FeatureModel; using Microsoft.AspNet.Http; using Microsoft.AspNet.HttpFeature; +using Microsoft.AspNet.PipelineCore.Collections; using Microsoft.AspNet.PipelineCore.Infrastructure; using Microsoft.AspNet.WebUtilities; -using Microsoft.AspNet.WebUtilities.Collections; namespace Microsoft.AspNet.PipelineCore { @@ -46,7 +46,7 @@ namespace Microsoft.AspNet.PipelineCore if (_query == null || _queryString != queryString) { _queryString = queryString; - _query = QueryHelpers.ParseQuery(queryString); + _query = new ReadableStringCollection(QueryHelpers.ParseQuery(queryString)); } return _query; } diff --git a/src/Microsoft.AspNet.PipelineCore/ReferenceReadStream.cs b/src/Microsoft.AspNet.PipelineCore/ReferenceReadStream.cs new file mode 100644 index 0000000000..aaad97ae92 --- /dev/null +++ b/src/Microsoft.AspNet.PipelineCore/ReferenceReadStream.cs @@ -0,0 +1,199 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.PipelineCore +{ + /// + /// A Stream that wraps another stream starting at a certain offset and reading for the given length. + /// + internal class ReferenceReadStream : Stream + { + private readonly Stream _inner; + private readonly long _innerOffset; + private readonly long _length; + private long _position; + + private bool _disposed; + + public ReferenceReadStream([NotNull] Stream inner, long offset, long length) + { + _inner = inner; + _innerOffset = offset; + _length = length; + _inner.Position = offset; + } + + public override bool CanRead + { + get { return true; } + } + + public override bool CanSeek + { + get { return _inner.CanSeek; } + } + + public override bool CanWrite + { + get { return false; } + } + + public override long Length + { + get { return _length; } + } + + public override long Position + { + get { return _position; } + set + { + ThrowIfDisposed(); + if (value < 0 || value > Length) + { + throw new ArgumentOutOfRangeException("value", value, "The Position must be within the length of the Stream: " + Length); + } + VerifyPosition(); + _position = value; + _inner.Position = _innerOffset + _position; + } + } + + // Throws if the position in the underlying stream has changed without our knowledge, indicating someone else is trying + // to use the stream at the same time which could lead to data corruption. + private void VerifyPosition() + { + if (_inner.Position != _innerOffset + _position) + { + throw new InvalidOperationException("The inner stream position has changed unexpectedly."); + } + } + + public override long Seek(long offset, SeekOrigin origin) + { + if (origin == SeekOrigin.Begin) + { + Position = offset; + } + else if (origin == SeekOrigin.End) + { + Position = Length + offset; + } + else // if (origin == SeekOrigin.Current) + { + Position = Position + offset; + } + return Position; + } + + public override int Read(byte[] buffer, int offset, int count) + { + ThrowIfDisposed(); + VerifyPosition(); + var toRead = Math.Min(count, _length - _position); + var read = _inner.Read(buffer, offset, (int)toRead); + _position += read; + return read; + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ThrowIfDisposed(); + VerifyPosition(); + var toRead = Math.Min(count, _length - _position); + var read = await _inner.ReadAsync(buffer, offset, (int)toRead, cancellationToken); + _position += read; + return read; + } +#if ASPNET50 + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + ThrowIfDisposed(); + VerifyPosition(); + var tcs = new TaskCompletionSource(state); + BeginRead(buffer, offset, count, callback, tcs); + return tcs.Task; + } + + private async void BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, TaskCompletionSource tcs) + { + try + { + var read = await ReadAsync(buffer, offset, count); + tcs.TrySetResult(read); + } + catch (Exception ex) + { + tcs.TrySetException(ex); + } + + if (callback != null) + { + try + { + callback(tcs.Task); + } + catch (Exception) + { + // Suppress exceptions on background threads. + } + } + } + + public override int EndRead(IAsyncResult asyncResult) + { + var task = (Task)asyncResult; + return task.GetAwaiter().GetResult(); + } +#endif + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } +#if ASPNET50 + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + throw new NotSupportedException(); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + throw new NotSupportedException(); + } +#endif + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override void Flush() + { + throw new NotSupportedException(); + } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + _disposed = true; + } + } + + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(ReferenceReadStream)); + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.PipelineCore/RequestCookiesFeature.cs b/src/Microsoft.AspNet.PipelineCore/RequestCookiesFeature.cs index 8a3ae99e72..952bfed0f6 100644 --- a/src/Microsoft.AspNet.PipelineCore/RequestCookiesFeature.cs +++ b/src/Microsoft.AspNet.PipelineCore/RequestCookiesFeature.cs @@ -9,7 +9,6 @@ using Microsoft.AspNet.Http.Infrastructure; using Microsoft.AspNet.HttpFeature; using Microsoft.AspNet.PipelineCore.Collections; using Microsoft.AspNet.PipelineCore.Infrastructure; -using Microsoft.AspNet.WebUtilities.Collections; namespace Microsoft.AspNet.PipelineCore { diff --git a/src/Microsoft.AspNet.PipelineCore/Security/AuthenticateContext.cs b/src/Microsoft.AspNet.PipelineCore/Security/AuthenticateContext.cs index d975333fdf..1e27a006c4 100644 --- a/src/Microsoft.AspNet.PipelineCore/Security/AuthenticateContext.cs +++ b/src/Microsoft.AspNet.PipelineCore/Security/AuthenticateContext.cs @@ -17,12 +17,8 @@ namespace Microsoft.AspNet.PipelineCore.Security private List _results; private List _accepted; - public AuthenticateContext(IEnumerable authenticationTypes) + public AuthenticateContext([NotNull] IEnumerable authenticationTypes) { - if (authenticationTypes == null) - { - throw new ArgumentNullException("authenticationType"); - } AuthenticationTypes = authenticationTypes; _results = new List(); _accepted = new List(); diff --git a/src/Microsoft.AspNet.PipelineCore/Security/ChallengeContext.cs b/src/Microsoft.AspNet.PipelineCore/Security/ChallengeContext.cs index 45afd34594..ea87d06599 100644 --- a/src/Microsoft.AspNet.PipelineCore/Security/ChallengeContext.cs +++ b/src/Microsoft.AspNet.PipelineCore/Security/ChallengeContext.cs @@ -14,12 +14,8 @@ namespace Microsoft.AspNet.PipelineCore.Security { private List _accepted; - public ChallengeContext(IEnumerable authenticationTypes, IDictionary properties) + public ChallengeContext([NotNull] IEnumerable authenticationTypes, IDictionary properties) { - if (authenticationTypes == null) - { - throw new ArgumentNullException(); - } AuthenticationTypes = authenticationTypes; Properties = properties ?? new Dictionary(StringComparer.Ordinal); _accepted = new List(); @@ -33,7 +29,7 @@ namespace Microsoft.AspNet.PipelineCore.Security { get { return _accepted; } } - + public void Accept(string authenticationType, IDictionary description) { _accepted.Add(authenticationType); diff --git a/src/Microsoft.AspNet.PipelineCore/Security/SignInContext.cs b/src/Microsoft.AspNet.PipelineCore/Security/SignInContext.cs index b24a14758a..43a88a82de 100644 --- a/src/Microsoft.AspNet.PipelineCore/Security/SignInContext.cs +++ b/src/Microsoft.AspNet.PipelineCore/Security/SignInContext.cs @@ -12,12 +12,8 @@ namespace Microsoft.AspNet.PipelineCore.Security { private List _accepted; - public SignInContext(IEnumerable identities, IDictionary dictionary) + public SignInContext([NotNull] IEnumerable identities, IDictionary dictionary) { - if (identities == null) - { - throw new ArgumentNullException("identities"); - } Identities = identities; Properties = dictionary ?? new Dictionary(StringComparer.Ordinal); _accepted = new List(); diff --git a/src/Microsoft.AspNet.PipelineCore/Security/SignOutContext.cs b/src/Microsoft.AspNet.PipelineCore/Security/SignOutContext.cs index 6ce2a64fa5..546c5bca83 100644 --- a/src/Microsoft.AspNet.PipelineCore/Security/SignOutContext.cs +++ b/src/Microsoft.AspNet.PipelineCore/Security/SignOutContext.cs @@ -11,12 +11,8 @@ namespace Microsoft.AspNet.PipelineCore.Security { private List _accepted; - public SignOutContext(IEnumerable authenticationTypes) + public SignOutContext([NotNull] IEnumerable authenticationTypes) { - if (authenticationTypes == null) - { - throw new ArgumentNullException("authenticationTypes"); - } AuthenticationTypes = authenticationTypes; _accepted = new List(); } diff --git a/src/Microsoft.AspNet.WebUtilities/BufferedReadStream.cs b/src/Microsoft.AspNet.WebUtilities/BufferedReadStream.cs new file mode 100644 index 0000000000..944f126b93 --- /dev/null +++ b/src/Microsoft.AspNet.WebUtilities/BufferedReadStream.cs @@ -0,0 +1,396 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebUtilities +{ + internal class BufferedReadStream : Stream + { + private const char CR = '\r'; + private const char LF = '\n'; + + private readonly Stream _inner; + private readonly byte[] _buffer; + private int _bufferOffset = 0; + private int _bufferCount = 0; + private bool _disposed; + + public BufferedReadStream([NotNull] Stream inner, int bufferSize) + { + _inner = inner; + _buffer = new byte[bufferSize]; + } + + public ArraySegment BufferedData + { + get { return new ArraySegment(_buffer, _bufferOffset, _bufferCount); } + } + + public override bool CanRead + { + get { return _inner.CanRead || _bufferCount > 0; } + } + + public override bool CanSeek + { + get { return _inner.CanSeek; } + } + + public override bool CanTimeout + { + get { return _inner.CanTimeout; } + } + + public override bool CanWrite + { + get { return _inner.CanWrite; } + } + + public override long Length + { + get { return _inner.Length; } + } + + public override long Position + { + get { return _inner.Position - _bufferCount; } + set + { + if (value < 0) + { + throw new ArgumentOutOfRangeException("value", value, "Position must be positive."); + } + if (value == Position) + { + return; + } + + // Backwards? + if (value <= _inner.Position) + { + // Forward within the buffer? + var innerOffset = (int)(_inner.Position - value); + if (innerOffset <= _bufferCount) + { + // Yes, just skip some of the buffered data + _bufferOffset += innerOffset; + _bufferCount -= innerOffset; + } + else + { + // No, reset the buffer + _bufferOffset = 0; + _bufferCount = 0; + _inner.Position = value; + } + } + else + { + // Forward, reset the buffer + _bufferOffset = 0; + _bufferCount = 0; + _inner.Position = value; + } + } + } + + public override long Seek(long offset, SeekOrigin origin) + { + if (origin == SeekOrigin.Begin) + { + Position = offset; + } + else if (origin == SeekOrigin.Current) + { + Position = Position + offset; + } + else // if (origin == SeekOrigin.End) + { + Position = Length + offset; + } + return Position; + } + + public override void SetLength(long value) + { + _inner.SetLength(value); + } + + protected override void Dispose(bool disposing) + { + _disposed = true; + if (disposing) + { + _inner.Dispose(); + } + } + + public override void Flush() + { + _inner.Flush(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _inner.FlushAsync(cancellationToken); + } + + public override void Write(byte[] buffer, int offset, int count) + { + _inner.Write(buffer, offset, count); + } +#if ASPNET50 + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _inner.BeginWrite(buffer, offset, count, callback, state); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + _inner.EndWrite(asyncResult); + } +#endif + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _inner.WriteAsync(buffer, offset, count, cancellationToken); + } + + public override int Read(byte[] buffer, int offset, int count) + { + ValidateBuffer(buffer, offset, count); + + // Drain buffer + if (_bufferCount > 0) + { + int toCopy = Math.Min(_bufferCount, count); + Buffer.BlockCopy(_buffer, _bufferOffset, buffer, offset, toCopy); + _bufferOffset += toCopy; + _bufferCount -= toCopy; + return toCopy; + } + + return _inner.Read(buffer, offset, count); + } + + public async override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ValidateBuffer(buffer, offset, count); + + // Drain buffer + if (_bufferCount > 0) + { + int toCopy = Math.Min(_bufferCount, count); + Buffer.BlockCopy(_buffer, _bufferOffset, buffer, offset, toCopy); + _bufferOffset += toCopy; + _bufferCount -= toCopy; + return toCopy; + } + + return await _inner.ReadAsync(buffer, offset, count, cancellationToken); + } +#if ASPNET50 + // We only anticipate using ReadAsync + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + ValidateBuffer(buffer, offset, count); + + // Drain buffer + if (_bufferCount > 0) + { + int toCopy = Math.Min(_bufferCount, count); + Buffer.BlockCopy(_buffer, _bufferOffset, buffer, offset, toCopy); + _bufferOffset += toCopy; + _bufferCount -= toCopy; + + TaskCompletionSource tcs = new TaskCompletionSource(state); + tcs.TrySetResult(toCopy); + if (callback != null) + { + callback(tcs.Task); + } + return tcs.Task; + } + + return _inner.BeginRead(buffer, offset, count, callback, state); + } + + public override int EndRead(IAsyncResult asyncResult) + { + Task task = asyncResult as Task; + if (task != null) + { + return task.GetAwaiter().GetResult(); + } + return _inner.EndRead(asyncResult); + } +#endif + public bool EnsureBuffered() + { + if (_bufferCount > 0) + { + return true; + } + // Downshift to make room + _bufferOffset = 0; + _bufferCount = _inner.Read(_buffer, 0, _buffer.Length); + return _bufferCount > 0; + } + + public async Task EnsureBufferedAsync(CancellationToken cancellationToken) + { + if (_bufferCount > 0) + { + return true; + } + // Downshift to make room + _bufferOffset = 0; + _bufferCount = await _inner.ReadAsync(_buffer, 0, _buffer.Length, cancellationToken); + return _bufferCount > 0; + } + + public bool EnsureBuffered(int minCount) + { + if (minCount > _buffer.Length) + { + throw new ArgumentOutOfRangeException(nameof(minCount), minCount, "The value must be smaller than the buffer size: " + _buffer.Length); + } + while (_bufferCount < minCount) + { + // Downshift to make room + if (_bufferOffset > 0) + { + if (_bufferCount > 0) + { + Buffer.BlockCopy(_buffer, _bufferOffset, _buffer, 0, _bufferCount); + } + _bufferOffset = 0; + } + int read = _inner.Read(_buffer, _bufferOffset + _bufferCount, _buffer.Length - _bufferCount - _bufferOffset); + _bufferCount += read; + if (read == 0) + { + return false; + } + } + return true; + } + + public async Task EnsureBufferedAsync(int minCount, CancellationToken cancellationToken) + { + if (minCount > _buffer.Length) + { + throw new ArgumentOutOfRangeException(nameof(minCount), minCount, "The value must be smaller than the buffer size: " + _buffer.Length); + } + while (_bufferCount < minCount) + { + // Downshift to make room + if (_bufferOffset > 0) + { + if (_bufferCount > 0) + { + Buffer.BlockCopy(_buffer, _bufferOffset, _buffer, 0, _bufferCount); + } + _bufferOffset = 0; + } + int read = await _inner.ReadAsync(_buffer, _bufferOffset + _bufferCount, _buffer.Length - _bufferCount - _bufferOffset, cancellationToken); + _bufferCount += read; + if (read == 0) + { + return false; + } + } + return true; + } + + public string ReadLine(int lengthLimit) + { + CheckDisposed(); + StringBuilder builder = new StringBuilder(); + bool foundCR = false, foundCRLF = false; + while (!foundCRLF && EnsureBuffered()) + { + if (builder.Length > lengthLimit) + { + throw new InvalidOperationException("Line length limit exceeded: " + lengthLimit); + } + ProcessLineChar(builder, ref foundCR, ref foundCRLF); + } + + if (foundCRLF) + { + return builder.ToString(0, builder.Length - 2); // Drop the CRLF + } + // Stream ended with no CRLF. + return builder.ToString(); + } + + public async Task ReadLineAsync(int lengthLimit, CancellationToken cancellationToken) + { + CheckDisposed(); + StringBuilder builder = new StringBuilder(); + bool foundCR = false, foundCRLF = false; + while (!foundCRLF && await EnsureBufferedAsync(cancellationToken)) + { + if (builder.Length > lengthLimit) + { + throw new InvalidOperationException("Line length limit exceeded: " + lengthLimit); + } + + ProcessLineChar(builder, ref foundCR, ref foundCRLF); + } + + if (foundCRLF) + { + return builder.ToString(0, builder.Length - 2); // Drop the CRLF + } + // Stream ended with no CRLF. + return builder.ToString(); + } + + private void ProcessLineChar(StringBuilder builder, ref bool foundCR, ref bool foundCRLF) + { + char ch = (char)_buffer[_bufferOffset]; // TODO: Encoding enforcement + builder.Append(ch); + _bufferOffset++; + _bufferCount--; + if (ch == CR) + { + foundCR = true; + } + else if (ch == LF) + { + if (foundCR) + { + foundCRLF = true; + } + else + { + foundCR = false; + } + } + } + + private void CheckDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(BufferedReadStream)); + } + } + + private void ValidateBuffer(byte[] buffer, int offset, int count) + { + // Delegate most of our validation. + var ignored = new ArraySegment(buffer, offset, count); + if (count == 0) + { + throw new ArgumentOutOfRangeException(nameof(count), "The value must be greater than zero."); + } + } + } +} diff --git a/src/Microsoft.AspNet.WebUtilities/FileBufferingReadStream.cs b/src/Microsoft.AspNet.WebUtilities/FileBufferingReadStream.cs new file mode 100644 index 0000000000..2f23ab5247 --- /dev/null +++ b/src/Microsoft.AspNet.WebUtilities/FileBufferingReadStream.cs @@ -0,0 +1,248 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebUtilities +{ + /// + /// A Stream that wraps another stream and enables rewinding by buffering the content as it is read. + /// The content is buffered in memory up to a certain size and then spooled to a temp file on disk. + /// The temp file will be deleted on Dispose. + /// + public class FileBufferingReadStream : Stream + { + private readonly Stream _inner; + private readonly int _memoryThreshold; + private readonly string _tempFileDirectory; + + private Stream _buffer = new MemoryStream(); // TODO: We could have a more efficiently expanding buffer stream. + private bool _inMemory = true; + private bool _completelyBuffered; + + private bool _disposed; + + // TODO: allow for an optional buffer size limit to prevent filling hard disks. 1gb? + public FileBufferingReadStream([NotNull] Stream inner, int memoryThreshold, [NotNull] string tempFileDirectory) + { + _inner = inner; + _memoryThreshold = memoryThreshold; + _tempFileDirectory = tempFileDirectory; + } + + public override bool CanRead + { + get { return true; } + } + + public override bool CanSeek + { + get { return true; } + } + + public override bool CanWrite + { + get { return false; } + } + + public override long Length + { + get { return _buffer.Length; } + } + + public override long Position + { + get { return _buffer.Position; } + // Note this will not allow seeking forward beyond the end of the buffer. + set + { + ThrowIfDisposed(); + _buffer.Position = value; + } + } + + public override long Seek(long offset, SeekOrigin origin) + { + ThrowIfDisposed(); + if (!_completelyBuffered && origin == SeekOrigin.End) + { + // Can't seek from the end until we've finished consuming the inner stream + throw new NotSupportedException("The content has not been fully buffered yet."); + } + else if (!_completelyBuffered && origin == SeekOrigin.Current && offset + Position > Length) + { + // Can't seek past the end of the buffer until we've finished consuming the inner stream + throw new NotSupportedException("The content has not been fully buffered yet."); + } + else if (!_completelyBuffered && origin == SeekOrigin.Begin && offset > Length) + { + // Can't seek past the end of the buffer until we've finished consuming the inner stream + throw new NotSupportedException("The content has not been fully buffered yet."); + } + return _buffer.Seek(offset, origin); + } + + private Stream CreateTempFile() + { + var fileName = Path.Combine(_tempFileDirectory, "ASPNET_" + Guid.NewGuid().ToString() + ".tmp"); + return new FileStream(fileName, FileMode.Create, FileAccess.ReadWrite, FileShare.Delete, 1024 * 16, + FileOptions.Asynchronous | FileOptions.DeleteOnClose | FileOptions.SequentialScan); + } + + public override int Read(byte[] buffer, int offset, int count) + { + ThrowIfDisposed(); + if (_buffer.Position < _buffer.Length || _completelyBuffered) + { + // Just read from the buffer + return _buffer.Read(buffer, offset, (int)Math.Min(count, _buffer.Length - _buffer.Position)); + } + + int read = _inner.Read(buffer, offset, count); + + if (_inMemory && _buffer.Length + read > _memoryThreshold) + { + var oldBuffer = _buffer; + _buffer = CreateTempFile(); + _inMemory = false; + oldBuffer.Position = 0; + oldBuffer.CopyTo(_buffer, 1024 * 16); + } + + if (read > 0) + { + _buffer.Write(buffer, offset, read); + } + else + { + _completelyBuffered = true; + } + + return read; + } +#if ASPNET50 + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + ThrowIfDisposed(); + var tcs = new TaskCompletionSource(state); + BeginRead(buffer, offset, count, callback, tcs); + return tcs.Task; + } + + private async void BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, TaskCompletionSource tcs) + { + try + { + var read = await ReadAsync(buffer, offset, count); + tcs.TrySetResult(read); + } + catch (Exception ex) + { + tcs.TrySetException(ex); + } + + if (callback != null) + { + try + { + callback(tcs.Task); + } + catch (Exception) + { + // Suppress exceptions on background threads. + } + } + } + + public override int EndRead(IAsyncResult asyncResult) + { + var task = (Task)asyncResult; + return task.GetAwaiter().GetResult(); + } +#endif + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + ThrowIfDisposed(); + if (_buffer.Position < _buffer.Length || _completelyBuffered) + { + // Just read from the buffer + return await _buffer.ReadAsync(buffer, offset, (int)Math.Min(count, _buffer.Length - _buffer.Position), cancellationToken); + } + + int read = await _inner.ReadAsync(buffer, offset, count, cancellationToken); + + if (_inMemory && _buffer.Length + read > _memoryThreshold) + { + var oldBuffer = _buffer; + _buffer = CreateTempFile(); + _inMemory = false; + oldBuffer.Position = 0; + await oldBuffer.CopyToAsync(_buffer, 1024 * 16, cancellationToken); + } + + if (read > 0) + { + await _buffer.WriteAsync(buffer, offset, read, cancellationToken); + } + else + { + _completelyBuffered = true; + } + + return read; + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } +#if ASPNET50 + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + throw new NotSupportedException(); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + throw new NotSupportedException(); + } +#endif + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override void Flush() + { + throw new NotSupportedException(); + } + + protected override void Dispose(bool disposing) + { + if (!_disposed) + { + _disposed = true; + if (disposing) + { + _buffer.Dispose(); + } + } + } + + private void ThrowIfDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(FileBufferingReadStream)); + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebUtilities/FormHelpers.cs b/src/Microsoft.AspNet.WebUtilities/FormHelpers.cs deleted file mode 100644 index 4c16485eb0..0000000000 --- a/src/Microsoft.AspNet.WebUtilities/FormHelpers.cs +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using Microsoft.AspNet.Http; - -namespace Microsoft.AspNet.WebUtilities -{ - public static class FormHelpers - { - /// - /// Parses an HTTP form body. - /// - /// The HTTP form body to parse. - /// The object containing the parsed HTTP form body. - public static IFormCollection ParseForm(string text) - { - return ParsingHelpers.GetForm(text); - } - } -} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebUtilities/FormReader.cs b/src/Microsoft.AspNet.WebUtilities/FormReader.cs new file mode 100644 index 0000000000..7c6a034f78 --- /dev/null +++ b/src/Microsoft.AspNet.WebUtilities/FormReader.cs @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebUtilities +{ + /// + /// Used to read an 'application/x-www-form-urlencoded' form. + /// + public class FormReader + { + private readonly TextReader _reader; + private readonly char[] _buffer = new char[1024]; + private readonly StringBuilder _builder = new StringBuilder(); + private int _bufferOffset; + private int _bufferCount; + + public FormReader([NotNull] string data) + { + _reader = new StringReader(data); + } + + // TODO: Encoding + public FormReader([NotNull] Stream stream) + { + _reader = new StreamReader(stream, Encoding.UTF8, detectEncodingFromByteOrderMarks: true, bufferSize: 1024 * 2, leaveOpen: true); + } + + // Format: key1=value1&key2=value2 + /// + /// Reads the next key value pair from the form. + /// For unbuffered data use the async overload instead. + /// + /// The next key value pair, or null when the end of the form is reached. + public KeyValuePair? ReadNextPair() + { + var key = ReadWord('='); + if (string.IsNullOrEmpty(key) && _bufferCount == 0) + { + return null; + } + var value = ReadWord('&'); + return new KeyValuePair(key, value); + } + + // Format: key1=value1&key2=value2 + /// + /// Asynchronously reads the next key value pair from the form. + /// + /// + /// The next key value pair, or null when the end of the form is reached. + public async Task?> ReadNextPairAsync(CancellationToken cancellationToken) + { + var key = await ReadWordAsync('=', cancellationToken); + if (string.IsNullOrEmpty(key) && _bufferCount == 0) + { + return null; + } + var value = await ReadWordAsync('&', cancellationToken); + return new KeyValuePair(key, value); + } + + private string ReadWord(char seperator) + { + // TODO: Configurable value size limit + while (true) + { + // Empty + if (_bufferCount == 0) + { + Buffer(); + } + + // End + if (_bufferCount == 0) + { + return BuildWord(); + } + + var c = _buffer[_bufferOffset++]; + _bufferCount--; + + if (c == seperator) + { + return BuildWord(); + } + _builder.Append(c); + } + } + + private async Task ReadWordAsync(char seperator, CancellationToken cancellationToken) + { + // TODO: Configurable value size limit + while (true) + { + // Empty + if (_bufferCount == 0) + { + await BufferAsync(cancellationToken); + } + + // End + if (_bufferCount == 0) + { + return BuildWord(); + } + + var c = _buffer[_bufferOffset++]; + _bufferCount--; + + if (c == seperator) + { + return BuildWord(); + } + _builder.Append(c); + } + } + + // '+' un-escapes to ' ', %HH un-escapes as ASCII (or utf-8?) + private string BuildWord() + { + _builder.Replace('+', ' '); + var result = _builder.ToString(); + _builder.Clear(); + return Uri.UnescapeDataString(result); // TODO: Replace this, it's not completely accurate. + } + + private void Buffer() + { + _bufferOffset = 0; + _bufferCount = _reader.Read(_buffer, 0, _buffer.Length); + } + + private async Task BufferAsync(CancellationToken cancellationToken) + { + // TODO: StreamReader doesn't support cancellation? + cancellationToken.ThrowIfCancellationRequested(); + _bufferOffset = 0; + _bufferCount = await _reader.ReadAsync(_buffer, 0, _buffer.Length); + } + + /// + /// Parses text from an HTTP form body. + /// + /// The HTTP form body to parse. + /// The collection containing the parsed HTTP form body. + public static IDictionary ReadForm(string text) + { + var reader = new FormReader(text); + + var accumulator = new KeyValueAccumulator(StringComparer.OrdinalIgnoreCase); + var pair = reader.ReadNextPair(); + while (pair.HasValue) + { + accumulator.Append(pair.Value.Key, pair.Value.Value); + pair = reader.ReadNextPair(); + } + + return accumulator.GetResults(); + } + + /// + /// Parses an HTTP form body. + /// + /// The HTTP form body to parse. + /// The collection containing the parsed HTTP form body. + public static async Task> ReadFormAsync(Stream stream, CancellationToken cancellationToken = new CancellationToken()) + { + var reader = new FormReader(stream); + + var accumulator = new KeyValueAccumulator(StringComparer.OrdinalIgnoreCase); + var pair = await reader.ReadNextPairAsync(cancellationToken); + while (pair.HasValue) + { + accumulator.Append(pair.Value.Key, pair.Value.Value); + pair = await reader.ReadNextPairAsync(cancellationToken); + } + + return accumulator.GetResults(); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebUtilities/KeyValueAccumulator.cs b/src/Microsoft.AspNet.WebUtilities/KeyValueAccumulator.cs new file mode 100644 index 0000000000..4c9d629a47 --- /dev/null +++ b/src/Microsoft.AspNet.WebUtilities/KeyValueAccumulator.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; + +namespace Microsoft.AspNet.WebUtilities +{ + public class KeyValueAccumulator + { + private Dictionary> _accumulator; + IEqualityComparer _comparer; + + public KeyValueAccumulator([NotNull] IEqualityComparer comparer) + { + _comparer = comparer; + _accumulator = new Dictionary>(comparer); + } + + public void Append(TKey key, TValue value) + { + List values; + if (_accumulator.TryGetValue(key, out values)) + { + values.Add(value); + } + else + { + _accumulator[key] = new List(1) { value }; + } + } + + public IDictionary GetResults() + { + var results = new Dictionary(_comparer); + foreach (var kv in _accumulator) + { + results.Add(kv.Key, kv.Value.ToArray()); + } + return results; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebUtilities/MultipartReader.cs b/src/Microsoft.AspNet.WebUtilities/MultipartReader.cs new file mode 100644 index 0000000000..766e5361dc --- /dev/null +++ b/src/Microsoft.AspNet.WebUtilities/MultipartReader.cs @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebUtilities +{ + // https://www.ietf.org/rfc/rfc2046.txt + public class MultipartReader + { + private const int DefaultBufferSize = 1024 * 4; + + private readonly BufferedReadStream _stream; + private readonly string _boundary; + private MultipartReaderStream _currentStream; + + public MultipartReader([NotNull] string boundary, [NotNull] Stream stream) + : this(boundary, stream, DefaultBufferSize) + { + } + + public MultipartReader([NotNull] string boundary, [NotNull] Stream stream, int bufferSize) + { + if (bufferSize < boundary.Length + 8) // Size of the boundary + leading and trailing CRLF + leading and trailing '--' markers. + { + throw new ArgumentOutOfRangeException(nameof(bufferSize), bufferSize, "Insufficient buffer space, the buffer must be larger than the boundary: " + boundary); + } + _stream = new BufferedReadStream(stream, bufferSize); + _boundary = boundary; + // This stream will drain any preamble data and remove the first boundary marker. + _currentStream = new MultipartReaderStream(_stream, _boundary, expectLeadingCrlf: false); + } + + /// + /// The limit for individual header lines inside a multipart section. + /// + public int HeaderLengthLimit { get; set; } = 1024 * 4; + + /// + /// The combined size limit for headers per multipart section. + /// + public int TotalHeaderSizeLimit { get; set; } = 1024 * 16; + + public async Task ReadNextSectionAsync(CancellationToken cancellationToken = new CancellationToken()) + { + // Drain the prior section. + await _currentStream.DrainAsync(cancellationToken); + // If we're at the end return null + if (_currentStream.FinalBoundaryFound) + { + // There may be trailer data after the last boundary. + await _stream.DrainAsync(cancellationToken); + return null; + } + var headers = await ReadHeadersAsync(cancellationToken); + _currentStream = new MultipartReaderStream(_stream, _boundary); + long? baseStreamOffset = _stream.CanSeek ? (long?)_stream.Position : null; + return new MultipartSection() { Headers = headers, Body = _currentStream, BaseStreamOffset = baseStreamOffset }; + } + + private async Task> ReadHeadersAsync(CancellationToken cancellationToken) + { + int totalSize = 0; + var accumulator = new KeyValueAccumulator(StringComparer.OrdinalIgnoreCase); + var line = await _stream.ReadLineAsync(HeaderLengthLimit, cancellationToken); + while (!string.IsNullOrEmpty(line)) + { + totalSize += line.Length; + if (totalSize > TotalHeaderSizeLimit) + { + throw new InvalidOperationException("Total header size limit exceeded: " + TotalHeaderSizeLimit); + } + int splitIndex = line.IndexOf(':'); + Debug.Assert(splitIndex > 0, "Invalid header line: " + line); + if (splitIndex >= 0) + { + var name = line.Substring(0, splitIndex); + var value = line.Substring(splitIndex + 1, line.Length - splitIndex - 1).Trim(); + accumulator.Append(name, value); + } + line = await _stream.ReadLineAsync(HeaderLengthLimit, cancellationToken); + } + + return accumulator.GetResults(); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebUtilities/MultipartReaderStream.cs b/src/Microsoft.AspNet.WebUtilities/MultipartReaderStream.cs new file mode 100644 index 0000000000..d72809ef81 --- /dev/null +++ b/src/Microsoft.AspNet.WebUtilities/MultipartReaderStream.cs @@ -0,0 +1,320 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebUtilities +{ + internal class MultipartReaderStream : Stream + { + private readonly BufferedReadStream _innerStream; + private readonly byte[] _boundaryBytes; + private readonly int _finalBoundaryLength; + private readonly long _innerOffset; + private long _position; + private long _observedLength; + private bool _finished; + + /// + /// Creates a stream that reads until it reaches the given boundary pattern. + /// + /// + /// + public MultipartReaderStream([NotNull] BufferedReadStream stream, [NotNull] string boundary, bool expectLeadingCrlf = true) + { + _innerStream = stream; + _innerOffset = _innerStream.CanSeek ? _innerStream.Position : 0; + if (expectLeadingCrlf) + { + _boundaryBytes = Encoding.UTF8.GetBytes("\r\n--" + boundary); + } + else + { + _boundaryBytes = Encoding.UTF8.GetBytes("--" + boundary); + } + _finalBoundaryLength = _boundaryBytes.Length + 2; // Include the final '--' terminator. + } + + public bool FinalBoundaryFound { get; private set; } + + public override bool CanRead + { + get { return true; } + } + + public override bool CanSeek + { + get { return _innerStream.CanSeek; } + } + + public override bool CanWrite + { + get { return false; } + } + + public override long Length + { + get { return _observedLength; } + } + + public override long Position + { + get { return _position; } + set + { + if (value < 0) + { + throw new ArgumentOutOfRangeException("value", value, "The Position must be positive."); + } + if (value > _observedLength) + { + throw new ArgumentOutOfRangeException("value", value, "The Position must be less than length."); + } + _position = value; + if (_position < _observedLength) + { + _finished = false; + } + } + } + + public override long Seek(long offset, SeekOrigin origin) + { + if (origin == SeekOrigin.Begin) + { + Position = offset; + } + else if (origin == SeekOrigin.Current) + { + Position = Position + offset; + } + else // if (origin == SeekOrigin.End) + { + Position = Length + offset; + } + return Position; + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } +#if ASPNET50 + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback callback, object state) + { + throw new NotSupportedException(); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + throw new NotSupportedException(); + } +#endif + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public override void Flush() + { + throw new NotSupportedException(); + } + + private void PositionInnerStream() + { + if (_innerStream.CanSeek && _innerStream.Position != (_innerOffset + _position)) + { + _innerStream.Position = _innerOffset + _position; + } + } + + private int UpdatePosition(int read) + { + _position += read; + if (_observedLength < _position) + { + _observedLength = _position; + } + return read; + } +#if ASPNET50 + public override IAsyncResult BeginRead(byte[] buffer, int offset, int size, AsyncCallback callback, object state) + { + var tcs = new TaskCompletionSource(state); + InternalReadAsync(buffer, offset, size, callback, tcs); + return tcs.Task; + } + + private async void InternalReadAsync(byte[] buffer, int offset, int size, AsyncCallback callback, TaskCompletionSource tcs) + { + try + { + int read = await ReadAsync(buffer, offset, size); + tcs.TrySetResult(read); + } + catch (Exception ex) + { + tcs.TrySetException(ex); + } + + if (callback != null) + { + try + { + callback(tcs.Task); + } + catch (Exception) + { + // Suppress exceptions on background threads. + } + } + } + + public override int EndRead(IAsyncResult asyncResult) + { + var task = (Task)asyncResult; + return task.GetAwaiter().GetResult(); + } +#endif + public override int Read(byte[] buffer, int offset, int count) + { + if (_finished) + { + return 0; + } + + PositionInnerStream(); + if (!_innerStream.EnsureBuffered(_finalBoundaryLength)) + { + throw new IOException("Unexpected end of stream."); + } + var bufferedData = _innerStream.BufferedData; + + // scan for a boundary match, full or partial. + int matchOffset; + int matchCount; + int read; + if (SubMatch(bufferedData, _boundaryBytes, out matchOffset, out matchCount)) + { + // We found a possible match, return any data before it. + if (matchOffset > bufferedData.Offset) + { + read = _innerStream.Read(buffer, offset, Math.Min(count, matchOffset - bufferedData.Offset)); + return UpdatePosition(read); + } + Debug.Assert(matchCount == _boundaryBytes.Length); + + // "The boundary may be followed by zero or more characters of + // linear whitespace. It is then terminated by either another CRLF" + // or -- for the final boundary. + byte[] boundary = new byte[_boundaryBytes.Length]; + read = _innerStream.Read(boundary, 0, boundary.Length); + Debug.Assert(read == boundary.Length); // It should have all been buffered + var remainder = _innerStream.ReadLine(lengthLimit: 100); // Whitespace may exceed the buffer. + remainder = remainder.Trim(); + if (string.Equals("--", remainder, StringComparison.Ordinal)) + { + FinalBoundaryFound = true; + } + Debug.Assert(FinalBoundaryFound || string.Equals(string.Empty, remainder, StringComparison.Ordinal), "Un-expected data found on the boundary line: " + remainder); + + _finished = true; + return 0; + } + + // No possible boundary match within the buffered data, return the data from the buffer. + read = _innerStream.Read(buffer, offset, Math.Min(count, bufferedData.Count)); + return UpdatePosition(read); + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (_finished) + { + return 0; + } + + PositionInnerStream(); + if (!await _innerStream.EnsureBufferedAsync(_finalBoundaryLength, cancellationToken)) + { + throw new IOException("Unexpected end of stream."); + } + var bufferedData = _innerStream.BufferedData; + + // scan for a boundary match, full or partial. + int matchOffset; + int matchCount; + int read; + if (SubMatch(bufferedData, _boundaryBytes, out matchOffset, out matchCount)) + { + // We found a possible match, return any data before it. + if (matchOffset > bufferedData.Offset) + { + // Sync, it's already buffered + read = _innerStream.Read(buffer, offset, Math.Min(count, matchOffset - bufferedData.Offset)); + return UpdatePosition(read); + } + Debug.Assert(matchCount == _boundaryBytes.Length); + + // "The boundary may be followed by zero or more characters of + // linear whitespace. It is then terminated by either another CRLF" + // or -- for the final boundary. + byte[] boundary = new byte[_boundaryBytes.Length]; + read = _innerStream.Read(boundary, 0, boundary.Length); + Debug.Assert(read == boundary.Length); // It should have all been buffered + var remainder = await _innerStream.ReadLineAsync(lengthLimit: 100, cancellationToken: cancellationToken); // Whitespace may exceed the buffer. + remainder = remainder.Trim(); + if (string.Equals("--", remainder, StringComparison.Ordinal)) + { + FinalBoundaryFound = true; + } + Debug.Assert(FinalBoundaryFound || string.Equals(string.Empty, remainder, StringComparison.Ordinal), "Un-expected data found on the boundary line: " + remainder); + + _finished = true; + return 0; + } + + // No possible boundary match within the buffered data, return the data from the buffer. + read = _innerStream.Read(buffer, offset, Math.Min(count, bufferedData.Count)); + return UpdatePosition(read); + } + + // Does Segment1 contain all of segment2, or does it end with the start of segment2? + // 1: AAAAABBBBBCCCCC + // 2: BBBBB + // Or: + // 1: AAAAABBB + // 2: BBBBB + private static bool SubMatch(ArraySegment segment1, byte[] matchBytes, out int matchOffset, out int matchCount) + { + matchCount = 0; + for (matchOffset = segment1.Offset; matchOffset < segment1.Offset + segment1.Count; matchOffset++) + { + int countLimit = segment1.Offset - matchOffset + segment1.Count; + for (matchCount = 0; matchCount < matchBytes.Length && matchCount < countLimit; matchCount++) + { + if (matchBytes[matchCount] != segment1.Array[matchOffset + matchCount]) + { + matchCount = 0; + break; + } + } + if (matchCount > 0) + { + break; + } + } + return matchCount > 0; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebUtilities/MultipartSection.cs b/src/Microsoft.AspNet.WebUtilities/MultipartSection.cs new file mode 100644 index 0000000000..b35bcfee2a --- /dev/null +++ b/src/Microsoft.AspNet.WebUtilities/MultipartSection.cs @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.IO; + +namespace Microsoft.AspNet.WebUtilities +{ + public class MultipartSection + { + public string ContentType + { + get + { + string[] values; + if (Headers.TryGetValue("Content-Type", out values)) + { + return string.Join(", ", values); + } + return null; + } + } + + public string ContentDisposition + { + get + { + string[] values; + if (Headers.TryGetValue("Content-Disposition", out values)) + { + return string.Join(", ", values); + } + return null; + } + } + + public IDictionary Headers { get; set; } + + public Stream Body { get; set; } + + /// + /// The position where the body starts in the total multipart body. + /// This may not be available if the total multipart body is not seekable. + /// + public long? BaseStreamOffset { get; set; } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebUtilities/ParsingHelpers.cs b/src/Microsoft.AspNet.WebUtilities/ParsingHelpers.cs deleted file mode 100644 index 257b206c6b..0000000000 --- a/src/Microsoft.AspNet.WebUtilities/ParsingHelpers.cs +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.AspNet.Http; -using Microsoft.AspNet.WebUtilities.Collections; - -namespace Microsoft.AspNet.WebUtilities -{ - internal static class ParsingHelpers - { - internal static void ParseDelimited(string text, char[] delimiters, Action callback, object state) - { - int textLength = text.Length; - int equalIndex = text.IndexOf('='); - if (equalIndex == -1) - { - equalIndex = textLength; - } - int scanIndex = 0; - while (scanIndex < textLength) - { - int delimiterIndex = text.IndexOfAny(delimiters, scanIndex); - if (delimiterIndex == -1) - { - delimiterIndex = textLength; - } - if (equalIndex < delimiterIndex) - { - while (scanIndex != equalIndex && char.IsWhiteSpace(text[scanIndex])) - { - ++scanIndex; - } - string name = text.Substring(scanIndex, equalIndex - scanIndex); - string value = text.Substring(equalIndex + 1, delimiterIndex - equalIndex - 1); - callback( - Uri.UnescapeDataString(name.Replace('+', ' ')), - Uri.UnescapeDataString(value.Replace('+', ' ')), - state); - equalIndex = text.IndexOf('=', delimiterIndex); - if (equalIndex == -1) - { - equalIndex = textLength; - } - } - scanIndex = delimiterIndex + 1; - } - } - - private static readonly Action AppendItemCallback = (name, value, state) => - { - var dictionary = (IDictionary>)state; - - List existing; - if (!dictionary.TryGetValue(name, out existing)) - { - dictionary.Add(name, new List(1) { value }); - } - else - { - existing.Add(value); - } - }; - - internal static IFormCollection GetForm(string text) - { - IDictionary form = new Dictionary(StringComparer.OrdinalIgnoreCase); - var accumulator = new Dictionary>(StringComparer.OrdinalIgnoreCase); - ParseDelimited(text, Ampersand, AppendItemCallback, accumulator); - foreach (var kv in accumulator) - { - form.Add(kv.Key, kv.Value.ToArray()); - } - return new FormCollection(form); - } - - internal static string GetJoinedValue(IDictionary store, string key) - { - string[] values = GetUnmodifiedValues(store, key); - return values == null ? null : string.Join(",", values); - } - - internal static string[] GetUnmodifiedValues(IDictionary store, string key) - { - if (store == null) - { - throw new ArgumentNullException("store"); - } - string[] values; - return store.TryGetValue(key, out values) ? values : null; - } - - private static readonly char[] Ampersand = new[] { '&' }; - - internal static IReadableStringCollection GetQuery(string queryString) - { - if (!string.IsNullOrEmpty(queryString) && queryString[0] == '?') - { - queryString = queryString.Substring(1); - } - var accumulator = new Dictionary>(StringComparer.OrdinalIgnoreCase); - ParseDelimited(queryString, Ampersand, AppendItemCallback, accumulator); - return new ReadableStringCollection(accumulator.ToDictionary( - item => item.Key, - item => item.Value.ToArray(), - StringComparer.OrdinalIgnoreCase)); - } - } -} diff --git a/src/Microsoft.AspNet.WebUtilities/QueryHelpers.cs b/src/Microsoft.AspNet.WebUtilities/QueryHelpers.cs index f54ff50dea..1414ef8f17 100644 --- a/src/Microsoft.AspNet.WebUtilities/QueryHelpers.cs +++ b/src/Microsoft.AspNet.WebUtilities/QueryHelpers.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Text; -using Microsoft.AspNet.Http; namespace Microsoft.AspNet.WebUtilities { @@ -50,9 +49,49 @@ namespace Microsoft.AspNet.WebUtilities /// /// The raw query string value, with or without the leading '?'. /// A collection of parsed keys and values. - public static IReadableStringCollection ParseQuery(string text) + public static IDictionary ParseQuery(string queryString) { - return ParsingHelpers.GetQuery(text); + if (!string.IsNullOrEmpty(queryString) && queryString[0] == '?') + { + queryString = queryString.Substring(1); + } + var accumulator = new KeyValueAccumulator(StringComparer.OrdinalIgnoreCase); + + int textLength = queryString.Length; + int equalIndex = queryString.IndexOf('='); + if (equalIndex == -1) + { + equalIndex = textLength; + } + int scanIndex = 0; + while (scanIndex < textLength) + { + int delimiterIndex = queryString.IndexOf('&', scanIndex); + if (delimiterIndex == -1) + { + delimiterIndex = textLength; + } + if (equalIndex < delimiterIndex) + { + while (scanIndex != equalIndex && char.IsWhiteSpace(queryString[scanIndex])) + { + ++scanIndex; + } + string name = queryString.Substring(scanIndex, equalIndex - scanIndex); + string value = queryString.Substring(equalIndex + 1, delimiterIndex - equalIndex - 1); + accumulator.Append( + Uri.UnescapeDataString(name.Replace('+', ' ')), + Uri.UnescapeDataString(value.Replace('+', ' '))); + equalIndex = queryString.IndexOf('=', delimiterIndex); + if (equalIndex == -1) + { + equalIndex = textLength; + } + } + scanIndex = delimiterIndex + 1; + } + + return accumulator.GetResults(); } } } \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebUtilities/StreamHelperExtensions.cs b/src/Microsoft.AspNet.WebUtilities/StreamHelperExtensions.cs new file mode 100644 index 0000000000..c5a2432db6 --- /dev/null +++ b/src/Microsoft.AspNet.WebUtilities/StreamHelperExtensions.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebUtilities +{ + public static class StreamHelperExtensions + { + public static async Task DrainAsync(this Stream stream, CancellationToken cancellationToken) + { + byte[] buffer = new byte[1024]; + cancellationToken.ThrowIfCancellationRequested(); + while (await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken) > 0) + { + // Not all streams support cancellation directly. + cancellationToken.ThrowIfCancellationRequested(); + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebUtilities/project.json b/src/Microsoft.AspNet.WebUtilities/project.json index 0fb71d0a56..963fcc3178 100644 --- a/src/Microsoft.AspNet.WebUtilities/project.json +++ b/src/Microsoft.AspNet.WebUtilities/project.json @@ -9,6 +9,7 @@ "aspnetcore50": { "dependencies": { "System.Diagnostics.Debug": "4.0.10-beta-*", + "System.IO.FileSystem": "4.0.0-beta-*", "System.Runtime": "4.0.20-beta-*" } } diff --git a/test/Microsoft.AspNet.PipelineCore.Tests/FormFeatureTests.cs b/test/Microsoft.AspNet.PipelineCore.Tests/FormFeatureTests.cs index b7fdc8f563..b129adc0e1 100644 --- a/test/Microsoft.AspNet.PipelineCore.Tests/FormFeatureTests.cs +++ b/test/Microsoft.AspNet.PipelineCore.Tests/FormFeatureTests.cs @@ -1,72 +1,266 @@ // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.IO; +using System.Linq; using System.Text; -using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNet.FeatureModel; -using Microsoft.AspNet.HttpFeature; -using Moq; +using Microsoft.AspNet.WebUtilities; using Xunit; -namespace Microsoft.AspNet.PipelineCore.Tests +namespace Microsoft.AspNet.PipelineCore { public class FormFeatureTests { [Fact] - public async Task GetFormAsync_ReturnsParsedFormCollection() + public async Task ReadFormAsync_SimpleData_ReturnsParsedFormCollection() { // Arrange var formContent = Encoding.UTF8.GetBytes("foo=bar&baz=2"); - var features = new Mock(); - var request = new Mock(); - request.SetupGet(r => r.Body).Returns(new MemoryStream(formContent)); + var context = new DefaultHttpContext(); + context.Request.ContentType = "application/x-www-form-urlencoded; charset=utf-8"; + context.Request.Body = new MemoryStream(formContent); - object value = request.Object; - features.Setup(f => f.TryGetValue(typeof(IHttpRequestFeature), out value)) - .Returns(true); - - var provider = new FormFeature(features.Object); + // Not cached yet + var formFeature = context.GetFeature(); + Assert.Null(formFeature); // Act - var formCollection = await provider.GetFormAsync(CancellationToken.None); + var formCollection = await context.Request.ReadFormAsync(); // Assert Assert.Equal("bar", formCollection["foo"]); Assert.Equal("2", formCollection["baz"]); + + // Cached + formFeature = context.GetFeature(); + Assert.NotNull(formFeature); + Assert.NotNull(formFeature.Form); + Assert.Same(formFeature.Form, formCollection); } [Fact] - public async Task GetFormAsync_CachesFormCollectionPerBodyStream() + public async Task ReadFormAsync_EmptyKeyAtEndAllowed() { // Arrange - var formContent1 = Encoding.UTF8.GetBytes("foo=bar&baz=2"); - var formContent2 = Encoding.UTF8.GetBytes("collection2=value"); - var features = new Mock(); - var request = new Mock(); - request.SetupGet(r => r.Body).Returns(new MemoryStream(formContent1)); + var formContent = Encoding.UTF8.GetBytes("=bar"); + var body = new MemoryStream(formContent); - object value = request.Object; - features.Setup(f => f.TryGetValue(typeof(IHttpRequestFeature), out value)) - .Returns(true); + var formCollection = await FormReader.ReadFormAsync(body); - var provider = new FormFeature(features.Object); + // Assert + Assert.Equal("bar", formCollection[""].FirstOrDefault()); + } - // Act - 1 - var formCollection = await provider.GetFormAsync(CancellationToken.None); + [Fact] + public async Task ReadFormAsync_EmptyKeyWithAdditionalEntryAllowed() + { + // Arrange + var formContent = Encoding.UTF8.GetBytes("=bar&baz=2"); + var body = new MemoryStream(formContent); - // Assert - 1 - Assert.Equal("bar", formCollection["foo"]); - Assert.Equal("2", formCollection["baz"]); - Assert.Same(formCollection, await provider.GetFormAsync(CancellationToken.None)); + var formCollection = await FormReader.ReadFormAsync(body); - // Act - 2 - request.SetupGet(r => r.Body).Returns(new MemoryStream(formContent2)); - formCollection = await provider.GetFormAsync(CancellationToken.None); + // Assert + Assert.Equal("bar", formCollection[""].FirstOrDefault()); + Assert.Equal("2", formCollection["baz"].FirstOrDefault()); + } - // Assert - 2 - Assert.Equal("value", formCollection["collection2"]); + [Fact] + public async Task ReadFormAsync_EmptyValuedAtEndAllowed() + { + // Arrange + var formContent = Encoding.UTF8.GetBytes("foo="); + var body = new MemoryStream(formContent); + + var formCollection = await FormReader.ReadFormAsync(body); + + // Assert + Assert.Equal("", formCollection["foo"].FirstOrDefault()); + } + + [Fact] + public async Task ReadFormAsync_EmptyValuedWithAdditionalEntryAllowed() + { + // Arrange + var formContent = Encoding.UTF8.GetBytes("foo=&baz=2"); + var body = new MemoryStream(formContent); + + var formCollection = await FormReader.ReadFormAsync(body); + + // Assert + Assert.Equal("", formCollection["foo"].FirstOrDefault()); + Assert.Equal("2", formCollection["baz"].FirstOrDefault()); + } + + private const string MultipartContentType = "multipart/form-data; boundary=WebKitFormBoundary5pDRpGheQXaM8k3T"; + private const string EmptyMultipartForm = +@"--WebKitFormBoundary5pDRpGheQXaM8k3T--"; + private const string MultipartFormWithField = +@"--WebKitFormBoundary5pDRpGheQXaM8k3T +Content-Disposition: form-data; name=""description"" + +Foo +--WebKitFormBoundary5pDRpGheQXaM8k3T--"; + private const string MultipartFormWithFile = +@"--WebKitFormBoundary5pDRpGheQXaM8k3T +Content-Disposition: form-data; name=""myfile1""; filename=""temp.html"" +Content-Type: text/html + +Hello World +--WebKitFormBoundary5pDRpGheQXaM8k3T--"; + private const string MultipartFormWithFieldAndFile = +@"--WebKitFormBoundary5pDRpGheQXaM8k3T +Content-Disposition: form-data; name=""description"" + +Foo +--WebKitFormBoundary5pDRpGheQXaM8k3T +Content-Disposition: form-data; name=""myfile1""; filename=""temp.html"" +Content-Type: text/html + +Hello World +--WebKitFormBoundary5pDRpGheQXaM8k3T--"; + + [Fact] + public async Task ReadForm_EmptyMultipart_ReturnsParsedFormCollection() + { + var formContent = Encoding.UTF8.GetBytes(EmptyMultipartForm); + var context = new DefaultHttpContext(); + context.Request.ContentType = MultipartContentType; + context.Request.Body = new MemoryStream(formContent); + + // Not cached yet + var formFeature = context.GetFeature(); + Assert.Null(formFeature); + + var formCollection = context.Request.Form; + + Assert.NotNull(formCollection); + + // Cached + formFeature = context.GetFeature(); + Assert.NotNull(formFeature); + Assert.NotNull(formFeature.Form); + Assert.Same(formCollection, formFeature.Form); + Assert.Same(formCollection, await context.Request.ReadFormAsync()); + + // Content + Assert.Equal(0, formCollection.Count); + Assert.NotNull(formCollection.Files); + Assert.Equal(0, formCollection.Files.Count); + } + + [Fact] + public async Task ReadForm_MultipartWithField_ReturnsParsedFormCollection() + { + var formContent = Encoding.UTF8.GetBytes(MultipartFormWithField); + var context = new DefaultHttpContext(); + context.Request.ContentType = MultipartContentType; + context.Request.Body = new MemoryStream(formContent); + + // Not cached yet + var formFeature = context.GetFeature(); + Assert.Null(formFeature); + + var formCollection = context.Request.Form; + + Assert.NotNull(formCollection); + + // Cached + formFeature = context.GetFeature(); + Assert.NotNull(formFeature); + Assert.NotNull(formFeature.Form); + Assert.Same(formCollection, formFeature.Form); + Assert.Same(formCollection, await context.Request.ReadFormAsync()); + + // Content + Assert.Equal(1, formCollection.Count); + Assert.Equal("Foo", formCollection["description"]); + + Assert.NotNull(formCollection.Files); + Assert.Equal(0, formCollection.Files.Count); + } + + [Fact] + public async Task ReadFormAsync_MultipartWithFile_ReturnsParsedFormCollection() + { + var formContent = Encoding.UTF8.GetBytes(MultipartFormWithFile); + var context = new DefaultHttpContext(); + context.Request.ContentType = MultipartContentType; + context.Request.Body = new MemoryStream(formContent); + + // Not cached yet + var formFeature = context.GetFeature(); + Assert.Null(formFeature); + + var formCollection = await context.Request.ReadFormAsync(); + + Assert.NotNull(formCollection); + + // Cached + formFeature = context.GetFeature(); + Assert.NotNull(formFeature); + Assert.NotNull(formFeature.Form); + Assert.Same(formFeature.Form, formCollection); + Assert.Same(formCollection, context.Request.Form); + + // Content + Assert.Equal(0, formCollection.Count); + + Assert.NotNull(formCollection.Files); + Assert.Equal(1, formCollection.Files.Count); + + var file = formCollection.Files["myfile1"]; + Assert.Equal("text/html", file.ContentType); + Assert.Equal(@"form-data; name=""myfile1""; filename=""temp.html""", file.ContentDisposition); + var body = file.OpenReadStream(); + using (var reader = new StreamReader(body)) + { + var content = reader.ReadToEnd(); + Assert.Equal(content, "Hello World"); + } + } + + [Fact] + public async Task ReadFormAsync_MultipartWithFieldAndFile_ReturnsParsedFormCollection() + { + var formContent = Encoding.UTF8.GetBytes(MultipartFormWithFieldAndFile); + var context = new DefaultHttpContext(); + context.Request.ContentType = MultipartContentType; + context.Request.Body = new MemoryStream(formContent); + + // Not cached yet + var formFeature = context.GetFeature(); + Assert.Null(formFeature); + + var formCollection = await context.Request.ReadFormAsync(); + + Assert.NotNull(formCollection); + + // Cached + formFeature = context.GetFeature(); + Assert.NotNull(formFeature); + Assert.NotNull(formFeature.Form); + Assert.Same(formFeature.Form, formCollection); + Assert.Same(formCollection, context.Request.Form); + + // Content + Assert.Equal(1, formCollection.Count); + Assert.Equal("Foo", formCollection["description"]); + + Assert.NotNull(formCollection.Files); + Assert.Equal(1, formCollection.Files.Count); + + var file = formCollection.Files["myfile1"]; + Assert.Equal("text/html", file.ContentType); + Assert.Equal(@"form-data; name=""myfile1""; filename=""temp.html""", file.ContentDisposition); + var body = file.OpenReadStream(); + using (var reader = new StreamReader(body)) + { + var content = reader.ReadToEnd(); + Assert.Equal(content, "Hello World"); + } } } } diff --git a/test/Microsoft.AspNet.WebUtilities.Tests/MultipartReaderTests.cs b/test/Microsoft.AspNet.WebUtilities.Tests/MultipartReaderTests.cs new file mode 100644 index 0000000000..a642511a86 --- /dev/null +++ b/test/Microsoft.AspNet.WebUtilities.Tests/MultipartReaderTests.cs @@ -0,0 +1,185 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.WebUtilities +{ + public class MultipartReaderTests + { + private const string Boundary = "9051914041544843365972754266"; + private const string OnePartBody = +@"--9051914041544843365972754266 +Content-Disposition: form-data; name=""text"" + +text default +--9051914041544843365972754266-- +"; + private const string OnePartBodyWithTrailingWhitespace = +@"--9051914041544843365972754266 +Content-Disposition: form-data; name=""text"" + +text default +--9051914041544843365972754266-- +"; + // It's non-compliant but common to leave off the last CRLF. + private const string OnePartBodyWithoutFinalCRLF = +@"--9051914041544843365972754266 +Content-Disposition: form-data; name=""text"" + +text default +--9051914041544843365972754266--"; + private const string TwoPartBody = +@"--9051914041544843365972754266 +Content-Disposition: form-data; name=""text"" + +text default +--9051914041544843365972754266 +Content-Disposition: form-data; name=""file1""; filename=""a.txt"" +Content-Type: text/plain + +Content of a.txt. + +--9051914041544843365972754266-- +"; + private const string ThreePartBody = +@"--9051914041544843365972754266 +Content-Disposition: form-data; name=""text"" + +text default +--9051914041544843365972754266 +Content-Disposition: form-data; name=""file1""; filename=""a.txt"" +Content-Type: text/plain + +Content of a.txt. + +--9051914041544843365972754266 +Content-Disposition: form-data; name=""file2""; filename=""a.html"" +Content-Type: text/html + +Content of a.html. + +--9051914041544843365972754266-- +"; + + private static MemoryStream MakeStream(string text) + { + return new MemoryStream(Encoding.UTF8.GetBytes(text)); + } + + [Fact] + public async Task MutipartReader_ReadSinglePartBody_Success() + { + var stream = MakeStream(OnePartBody); + var reader = new MultipartReader(Boundary, stream); + + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.Equal(1, section.Headers.Count); + Assert.Equal("form-data; name=\"text\"", section.Headers["Content-Disposition"][0]); + var buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.Equal("text default", Encoding.ASCII.GetString(buffer.ToArray())); + + Assert.Null(await reader.ReadNextSectionAsync()); + } + + [Fact] + public async Task MutipartReader_ReadSinglePartBodyWithTrailingWhitespace_Success() + { + var stream = MakeStream(OnePartBodyWithTrailingWhitespace); + var reader = new MultipartReader(Boundary, stream); + + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.Equal(1, section.Headers.Count); + Assert.Equal("form-data; name=\"text\"", section.Headers["Content-Disposition"][0]); + var buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.Equal("text default", Encoding.ASCII.GetString(buffer.ToArray())); + + Assert.Null(await reader.ReadNextSectionAsync()); + } + + [Fact] + public async Task MutipartReader_ReadSinglePartBodyWithoutLastCRLF_Success() + { + var stream = MakeStream(OnePartBodyWithoutFinalCRLF); + var reader = new MultipartReader(Boundary, stream); + + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.Equal(1, section.Headers.Count); + Assert.Equal("form-data; name=\"text\"", section.Headers["Content-Disposition"][0]); + var buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.Equal("text default", Encoding.ASCII.GetString(buffer.ToArray())); + + Assert.Null(await reader.ReadNextSectionAsync()); + } + + [Fact] + public async Task MutipartReader_ReadTwoPartBody_Success() + { + var stream = MakeStream(TwoPartBody); + var reader = new MultipartReader(Boundary, stream); + + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.Equal(1, section.Headers.Count); + Assert.Equal("form-data; name=\"text\"", section.Headers["Content-Disposition"][0]); + var buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.Equal("text default", Encoding.ASCII.GetString(buffer.ToArray())); + + section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.Equal(2, section.Headers.Count); + Assert.Equal("form-data; name=\"file1\"; filename=\"a.txt\"", section.Headers["Content-Disposition"][0]); + Assert.Equal("text/plain", section.Headers["Content-Type"][0]); + buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.Equal("Content of a.txt.\r\n", Encoding.ASCII.GetString(buffer.ToArray())); + + Assert.Null(await reader.ReadNextSectionAsync()); + } + + [Fact] + public async Task MutipartReader_ThreePartBody_Success() + { + var stream = MakeStream(ThreePartBody); + var reader = new MultipartReader(Boundary, stream); + + var section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.Equal(1, section.Headers.Count); + Assert.Equal("form-data; name=\"text\"", section.Headers["Content-Disposition"][0]); + var buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.Equal("text default", Encoding.ASCII.GetString(buffer.ToArray())); + + section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.Equal(2, section.Headers.Count); + Assert.Equal("form-data; name=\"file1\"; filename=\"a.txt\"", section.Headers["Content-Disposition"][0]); + Assert.Equal("text/plain", section.Headers["Content-Type"][0]); + buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.Equal("Content of a.txt.\r\n", Encoding.ASCII.GetString(buffer.ToArray())); + + section = await reader.ReadNextSectionAsync(); + Assert.NotNull(section); + Assert.Equal(2, section.Headers.Count); + Assert.Equal("form-data; name=\"file2\"; filename=\"a.html\"", section.Headers["Content-Disposition"][0]); + Assert.Equal("text/html", section.Headers["Content-Type"][0]); + buffer = new MemoryStream(); + await section.Body.CopyToAsync(buffer); + Assert.Equal("Content of a.html.\r\n", Encoding.ASCII.GetString(buffer.ToArray())); + + Assert.Null(await reader.ReadNextSectionAsync()); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNet.WebUtilities.Tests/QueryHelpersTests.cs b/test/Microsoft.AspNet.WebUtilities.Tests/QueryHelpersTests.cs index fea33d497a..e90c78305a 100644 --- a/test/Microsoft.AspNet.WebUtilities.Tests/QueryHelpersTests.cs +++ b/test/Microsoft.AspNet.WebUtilities.Tests/QueryHelpersTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Linq; using Xunit; namespace Microsoft.AspNet.WebUtilities @@ -13,8 +14,8 @@ namespace Microsoft.AspNet.WebUtilities { var collection = QueryHelpers.ParseQuery("?key1=value1&key2=value2"); Assert.Equal(2, collection.Count); - Assert.Equal("value1", collection["key1"]); - Assert.Equal("value2", collection["key2"]); + Assert.Equal("value1", collection["key1"].FirstOrDefault()); + Assert.Equal("value2", collection["key2"].FirstOrDefault()); } [Fact] @@ -22,8 +23,8 @@ namespace Microsoft.AspNet.WebUtilities { var collection = QueryHelpers.ParseQuery("key1=value1&key2=value2"); Assert.Equal(2, collection.Count); - Assert.Equal("value1", collection["key1"]); - Assert.Equal("value2", collection["key2"]); + Assert.Equal("value1", collection["key1"].FirstOrDefault()); + Assert.Equal("value2", collection["key2"].FirstOrDefault()); } [Fact] @@ -31,8 +32,8 @@ namespace Microsoft.AspNet.WebUtilities { var collection = QueryHelpers.ParseQuery("?key1=valueA&key2=valueB&key1=valueC"); Assert.Equal(2, collection.Count); - Assert.Equal("valueA,valueC", collection["key1"]); - Assert.Equal("valueB", collection["key2"]); + Assert.Equal(new[] { "valueA", "valueC" }, collection["key1"]); + Assert.Equal("valueB", collection["key2"].FirstOrDefault()); } [Fact] @@ -40,8 +41,8 @@ namespace Microsoft.AspNet.WebUtilities { var collection = QueryHelpers.ParseQuery("?key1=&key2="); Assert.Equal(2, collection.Count); - Assert.Equal(string.Empty, collection["key1"]); - Assert.Equal(string.Empty, collection["key2"]); + Assert.Equal(string.Empty, collection["key1"].FirstOrDefault()); + Assert.Equal(string.Empty, collection["key2"].FirstOrDefault()); } [Fact] @@ -49,7 +50,7 @@ namespace Microsoft.AspNet.WebUtilities { var collection = QueryHelpers.ParseQuery("?=value1&="); Assert.Equal(1, collection.Count); - Assert.Equal("value1,", collection[""]); + Assert.Equal(new[] { "value1", "" }, collection[""]); } } } \ No newline at end of file