Use ReferenceEquals in HttpMethodMatcherPolicy (#21277)

This commit is contained in:
Kahbazi 2020-05-17 20:58:36 +04:30 committed by GitHub
parent 21e1020b88
commit 5cfebf260f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 5529 additions and 5429 deletions

View File

@ -248,6 +248,8 @@ namespace Microsoft.AspNetCore.Http
public static readonly string Post; public static readonly string Post;
public static readonly string Put; public static readonly string Put;
public static readonly string Trace; public static readonly string Trace;
public static bool Equals(string methodA, string methodB) { throw null; }
public static string GetCanonicalizedValue(string method) { throw null; }
public static bool IsConnect(string method) { throw null; } public static bool IsConnect(string method) { throw null; }
public static bool IsDelete(string method) { throw null; } public static bool IsDelete(string method) { throw null; }
public static bool IsGet(string method) { throw null; } public static bool IsGet(string method) { throw null; }

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved. // Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
@ -16,7 +16,7 @@ namespace Microsoft.AspNetCore.Http
// Using .'static readonly' means that all consumers get these exact same // Using .'static readonly' means that all consumers get these exact same
// 'string' instance, which means the 'ReferenceEquals' checks below work // 'string' instance, which means the 'ReferenceEquals' checks below work
// and allow us to optimize comparisons when these constants are used. // and allow us to optimize comparisons when these constants are used.
// Please do NOT change these to 'const' // Please do NOT change these to 'const'
public static readonly string Connect = "CONNECT"; public static readonly string Connect = "CONNECT";
public static readonly string Delete = "DELETE"; public static readonly string Delete = "DELETE";
@ -37,7 +37,7 @@ namespace Microsoft.AspNetCore.Http
/// </returns> /// </returns>
public static bool IsConnect(string method) public static bool IsConnect(string method)
{ {
return object.ReferenceEquals(Connect, method) || StringComparer.OrdinalIgnoreCase.Equals(Connect, method); return Equals(Connect, method);
} }
/// <summary> /// <summary>
@ -49,7 +49,7 @@ namespace Microsoft.AspNetCore.Http
/// </returns> /// </returns>
public static bool IsDelete(string method) public static bool IsDelete(string method)
{ {
return object.ReferenceEquals(Delete, method) || StringComparer.OrdinalIgnoreCase.Equals(Delete, method); return Equals(Delete, method);
} }
/// <summary> /// <summary>
@ -57,11 +57,11 @@ namespace Microsoft.AspNetCore.Http
/// </summary> /// </summary>
/// <param name="method">The HTTP request method.</param> /// <param name="method">The HTTP request method.</param>
/// <returns> /// <returns>
/// <see langword="true" /> if the method is GET; otherwise, <see langword="false" />. /// <see langword="true" /> if the method is GET; otherwise, <see langword="false" />.
/// </returns> /// </returns>
public static bool IsGet(string method) public static bool IsGet(string method)
{ {
return object.ReferenceEquals(Get, method) || StringComparer.OrdinalIgnoreCase.Equals(Get, method); return Equals(Get, method);
} }
/// <summary> /// <summary>
@ -69,11 +69,11 @@ namespace Microsoft.AspNetCore.Http
/// </summary> /// </summary>
/// <param name="method">The HTTP request method.</param> /// <param name="method">The HTTP request method.</param>
/// <returns> /// <returns>
/// <see langword="true" /> if the method is HEAD; otherwise, <see langword="false" />. /// <see langword="true" /> if the method is HEAD; otherwise, <see langword="false" />.
/// </returns> /// </returns>
public static bool IsHead(string method) public static bool IsHead(string method)
{ {
return object.ReferenceEquals(Head, method) || StringComparer.OrdinalIgnoreCase.Equals(Head, method); return Equals(Head, method);
} }
/// <summary> /// <summary>
@ -81,11 +81,11 @@ namespace Microsoft.AspNetCore.Http
/// </summary> /// </summary>
/// <param name="method">The HTTP request method.</param> /// <param name="method">The HTTP request method.</param>
/// <returns> /// <returns>
/// <see langword="true" /> if the method is OPTIONS; otherwise, <see langword="false" />. /// <see langword="true" /> if the method is OPTIONS; otherwise, <see langword="false" />.
/// </returns> /// </returns>
public static bool IsOptions(string method) public static bool IsOptions(string method)
{ {
return object.ReferenceEquals(Options, method) || StringComparer.OrdinalIgnoreCase.Equals(Options, method); return Equals(Options, method);
} }
/// <summary> /// <summary>
@ -97,7 +97,7 @@ namespace Microsoft.AspNetCore.Http
/// </returns> /// </returns>
public static bool IsPatch(string method) public static bool IsPatch(string method)
{ {
return object.ReferenceEquals(Patch, method) || StringComparer.OrdinalIgnoreCase.Equals(Patch, method); return Equals(Patch, method);
} }
/// <summary> /// <summary>
@ -109,7 +109,7 @@ namespace Microsoft.AspNetCore.Http
/// </returns> /// </returns>
public static bool IsPost(string method) public static bool IsPost(string method)
{ {
return object.ReferenceEquals(Post, method) || StringComparer.OrdinalIgnoreCase.Equals(Post, method); return Equals(Post, method);
} }
/// <summary> /// <summary>
@ -121,7 +121,7 @@ namespace Microsoft.AspNetCore.Http
/// </returns> /// </returns>
public static bool IsPut(string method) public static bool IsPut(string method)
{ {
return object.ReferenceEquals(Put, method) || StringComparer.OrdinalIgnoreCase.Equals(Put, method); return Equals(Put, method);
} }
/// <summary> /// <summary>
@ -133,7 +133,39 @@ namespace Microsoft.AspNetCore.Http
/// </returns> /// </returns>
public static bool IsTrace(string method) public static bool IsTrace(string method)
{ {
return object.ReferenceEquals(Trace, method) || StringComparer.OrdinalIgnoreCase.Equals(Trace, method); return Equals(Trace, method);
}
/// <summary>
/// Returns the equivalent static instance, or the original instance if none match. This conversion is optional but allows for performance optimizations when comparing method values elsewhere.
/// </summary>
/// <param name="method"></param>
/// <returns></returns>
public static string GetCanonicalizedValue(string method) => method switch
{
string _ when IsGet(method) => Get,
string _ when IsPost(method) => Post,
string _ when IsPut(method) => Put,
string _ when IsDelete(method) => Delete,
string _ when IsOptions(method) => Options,
string _ when IsHead(method) => Head,
string _ when IsPatch(method) => Patch,
string _ when IsTrace(method) => Trace,
string _ when IsConnect(method) => Connect,
string _ => method
};
/// <summary>
/// Returns a value that indicates if the HTTP methods are the same.
/// </summary>
/// <param name="methodA">The first HTTP request method to compare.</param>
/// <param name="methodB">The second HTTP request method to compare.</param>
/// <returns>
/// <see langword="true" /> if the methods are the same; otherwise, <see langword="false" />.
/// </returns>
public static bool Equals(string methodA, string methodB)
{
return object.ReferenceEquals(methodA, methodB) || StringComparer.OrdinalIgnoreCase.Equals(methodA, methodB);
} }
} }
} }

View File

@ -0,0 +1,53 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
using Xunit;
namespace Microsoft.AspNetCore.Http.Abstractions
{
public class HttpMethodslTests
{
[Fact]
public void CanonicalizedValue_Success()
{
var testCases = new List<(string[] methods, string expectedMethod)>
{
(new string[] { "GET", "Get", "get" }, HttpMethods.Get),
(new string[] { "POST", "Post", "post" }, HttpMethods.Post),
(new string[] { "PUT", "Put", "put" }, HttpMethods.Put),
(new string[] { "DELETE", "Delete", "delete" }, HttpMethods.Delete),
(new string[] { "HEAD", "Head", "head" }, HttpMethods.Head),
(new string[] { "CONNECT", "Connect", "connect" }, HttpMethods.Connect),
(new string[] { "OPTIONS", "Options", "options" }, HttpMethods.Options),
(new string[] { "PATCH", "Patch", "patch" }, HttpMethods.Patch),
(new string[] { "TRACE", "Trace", "trace" }, HttpMethods.Trace)
};
for (int i = 0; i < testCases.Count; i++)
{
var testCase = testCases[i];
for (int j = 0; j < testCase.methods.Length; j++)
{
CanonicalizedValueTest(testCase.methods[j], testCase.expectedMethod);
}
}
}
private void CanonicalizedValueTest(string method, string expectedMethod)
{
string inputMethod = CreateStringAtRuntime(method);
var canonicalizedValue = HttpMethods.GetCanonicalizedValue(inputMethod);
Assert.Same(expectedMethod, canonicalizedValue);
}
private string CreateStringAtRuntime(string input)
{
return new StringBuilder(input).ToString();
}
}
}

View File

@ -5,6 +5,7 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using static Microsoft.AspNetCore.Http.HttpMethods;
namespace Microsoft.AspNetCore.Routing namespace Microsoft.AspNetCore.Routing
{ {
@ -41,7 +42,7 @@ namespace Microsoft.AspNetCore.Routing
throw new ArgumentNullException(nameof(httpMethods)); throw new ArgumentNullException(nameof(httpMethods));
} }
HttpMethods = httpMethods.ToArray(); HttpMethods = httpMethods.Select(GetCanonicalizedValue).ToArray();
AcceptCorsPreflight = acceptCorsPreflight; AcceptCorsPreflight = acceptCorsPreflight;
} }

View File

@ -4,6 +4,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Runtime.InteropServices;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Internal; using Microsoft.Extensions.Internal;
@ -21,7 +22,7 @@ namespace Microsoft.AspNetCore.Routing.Matching
// Used in tests // Used in tests
internal static readonly string OriginHeader = "Origin"; internal static readonly string OriginHeader = "Origin";
internal static readonly string AccessControlRequestMethod = "Access-Control-Request-Method"; internal static readonly string AccessControlRequestMethod = "Access-Control-Request-Method";
internal static readonly string PreflightHttpMethod = "OPTIONS"; internal static readonly string PreflightHttpMethod = HttpMethods.Options;
// Used in tests // Used in tests
internal const string Http405EndpointDisplayName = "405 HTTP Method Not Supported"; internal const string Http405EndpointDisplayName = "405 HTTP Method Not Supported";
@ -133,7 +134,7 @@ namespace Microsoft.AspNetCore.Routing.Matching
var httpMethod = httpContext.Request.Method; var httpMethod = httpContext.Request.Method;
var headers = httpContext.Request.Headers; var headers = httpContext.Request.Headers;
if (metadata.AcceptCorsPreflight && if (metadata.AcceptCorsPreflight &&
string.Equals(httpMethod, PreflightHttpMethod, StringComparison.OrdinalIgnoreCase) && HttpMethods.Equals(httpMethod, PreflightHttpMethod) &&
headers.ContainsKey(HeaderNames.Origin) && headers.ContainsKey(HeaderNames.Origin) &&
headers.TryGetValue(HeaderNames.AccessControlRequestMethod, out var accessControlRequestMethod) && headers.TryGetValue(HeaderNames.AccessControlRequestMethod, out var accessControlRequestMethod) &&
!StringValues.IsNullOrEmpty(accessControlRequestMethod)) !StringValues.IsNullOrEmpty(accessControlRequestMethod))
@ -146,7 +147,7 @@ namespace Microsoft.AspNetCore.Routing.Matching
for (var j = 0; j < metadata.HttpMethods.Count; j++) for (var j = 0; j < metadata.HttpMethods.Count; j++)
{ {
var candidateMethod = metadata.HttpMethods[j]; var candidateMethod = metadata.HttpMethods[j];
if (!string.Equals(httpMethod, candidateMethod, StringComparison.OrdinalIgnoreCase)) if (!HttpMethods.Equals(httpMethod, candidateMethod))
{ {
methods = methods ?? new HashSet<string>(StringComparer.OrdinalIgnoreCase); methods = methods ?? new HashSet<string>(StringComparer.OrdinalIgnoreCase);
methods.Add(candidateMethod); methods.Add(candidateMethod);
@ -396,9 +397,19 @@ namespace Microsoft.AspNetCore.Routing.Matching
private static bool ContainsHttpMethod(List<string> httpMethods, string httpMethod) private static bool ContainsHttpMethod(List<string> httpMethods, string httpMethod)
{ {
for (var i = 0; i < httpMethods.Count; i++) var methods = CollectionsMarshal.AsSpan(httpMethods);
for (var i = 0; i < methods.Length; i++)
{ {
if (string.Equals(httpMethods[i], httpMethod, StringComparison.OrdinalIgnoreCase)) // This is a fast path for when everything is using static HttpMethods instances.
if (object.ReferenceEquals(methods[i], httpMethod))
{
return true;
}
}
for (var i = 0; i < methods.Length; i++)
{
if (HttpMethods.Equals(methods[i], httpMethod))
{ {
return true; return true;
} }
@ -437,7 +448,7 @@ namespace Microsoft.AspNetCore.Routing.Matching
var httpMethod = httpContext.Request.Method; var httpMethod = httpContext.Request.Method;
var headers = httpContext.Request.Headers; var headers = httpContext.Request.Headers;
if (_supportsCorsPreflight && if (_supportsCorsPreflight &&
string.Equals(httpMethod, PreflightHttpMethod, StringComparison.OrdinalIgnoreCase) && HttpMethods.Equals(httpMethod, PreflightHttpMethod) &&
headers.ContainsKey(HeaderNames.Origin) && headers.ContainsKey(HeaderNames.Origin) &&
headers.TryGetValue(HeaderNames.AccessControlRequestMethod, out var accessControlRequestMethod) && headers.TryGetValue(HeaderNames.AccessControlRequestMethod, out var accessControlRequestMethod) &&
!StringValues.IsNullOrEmpty(accessControlRequestMethod)) !StringValues.IsNullOrEmpty(accessControlRequestMethod))
@ -499,7 +510,7 @@ namespace Microsoft.AspNetCore.Routing.Matching
{ {
return return
IsCorsPreflightRequest == other.IsCorsPreflightRequest && IsCorsPreflightRequest == other.IsCorsPreflightRequest &&
string.Equals(HttpMethod, other.HttpMethod, StringComparison.OrdinalIgnoreCase); HttpMethods.Equals(HttpMethod, other.HttpMethod);
} }
public override bool Equals(object obj) public override bool Equals(object obj)

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved. // Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
@ -49,7 +49,7 @@ namespace Swaggatherer
if (entry.Method != null) if (entry.Method != null)
{ {
setupRequestsLines.Add($" Requests[{i}].Request.Method = \"{entries[i].Method.ToUpperInvariant()}\";"); setupRequestsLines.Add($" Requests[{i}].Request.Method = HttpMethods.GetCanonicalizedValue({entries[i].Method});");
} }
setupRequestsLines.Add($" Requests[{i}].Request.Path = \"{entries[i].RequestUrl}\";"); setupRequestsLines.Add($" Requests[{i}].Request.Path = \"{entries[i].RequestUrl}\";");