Use MatcherPolicy for Consumes

This commit is contained in:
Ryan Nowak 2018-07-31 20:58:26 -07:00 committed by Ryan Nowak
parent 44f5b54f5f
commit 2b289d2f2c
9 changed files with 532 additions and 1 deletions

View File

@ -175,6 +175,7 @@ namespace Microsoft.Extensions.DependencyInjection
services.TryAddEnumerable(ServiceDescriptor.Transient<IActionConstraintProvider, DefaultActionConstraintProvider>());
// Policies for Endpoints
services.TryAddEnumerable(ServiceDescriptor.Singleton<MatcherPolicy, ConsumesMatcherPolicy>());
services.TryAddEnumerable(ServiceDescriptor.Singleton<MatcherPolicy, ActionConstraintMatcherPolicy>());
//

View File

@ -10,6 +10,7 @@ using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc.Abstractions;
using Microsoft.AspNetCore.Mvc.ActionConstraints;
using Microsoft.AspNetCore.Mvc.Infrastructure;
using Microsoft.AspNetCore.Mvc.Routing;
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.Routing.Matching;
using Microsoft.AspNetCore.Routing.Metadata;
@ -348,6 +349,11 @@ namespace Microsoft.AspNetCore.Mvc.Internal
{
metadata.Add(new HttpMethodMetadata(httpMethodActionConstraint.HttpMethods));
}
else if (actionConstraint is ConsumesAttribute consumesAttribute &&
!metadata.OfType<ConsumesMetadata>().Any())
{
metadata.Add(new ConsumesMetadata(consumesAttribute.ContentTypes.ToArray()));
}
else if (!metadata.Contains(actionConstraint))
{
// The constraint might have been added earlier, e.g. it is also a filter descriptor

View File

@ -0,0 +1,244 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc.Formatters;
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.Routing.Matching;
using Microsoft.AspNetCore.Routing.Patterns;
namespace Microsoft.AspNetCore.Mvc.Routing
{
internal class ConsumesMatcherPolicy : MatcherPolicy, IEndpointComparerPolicy, INodeBuilderPolicy
{
internal const string Http415EndpointDisplayName = "415 HTTP Unsupported Media Type";
internal const string AnyContentType = "*/*";
// Run after HTTP methods, but before 'default'.
public override int Order { get; } = -100;
public IComparer<Endpoint> Comparer { get; } = new ConsumesMetadataEndpointComparer();
public bool AppliesToNode(IReadOnlyList<Endpoint> endpoints)
{
if (endpoints == null)
{
throw new ArgumentNullException(nameof(endpoints));
}
return endpoints.Any(e => e.Metadata.GetMetadata<IConsumesMetadata>()?.ContentTypes.Count > 0);
}
public IReadOnlyList<PolicyNodeEdge> GetEdges(IReadOnlyList<Endpoint> endpoints)
{
if (endpoints == null)
{
throw new ArgumentNullException(nameof(endpoints));
}
// The algorithm here is designed to be preserve the order of the endpoints
// while also being relatively simple. Preserving order is important.
// First, build a dictionary of all of the content-type patterns that are included
// at this node.
//
// For now we're just building up the set of keys. We don't add any endpoints
// to lists now because we don't want ordering problems.
var edges = new Dictionary<string, List<Endpoint>>(StringComparer.OrdinalIgnoreCase);
for (var i = 0; i < endpoints.Count; i++)
{
var endpoint = endpoints[i];
var contentTypes = endpoint.Metadata.GetMetadata<IConsumesMetadata>()?.ContentTypes;
if (contentTypes == null || contentTypes.Count == 0)
{
contentTypes = new string[] { AnyContentType, };
}
for (var j = 0; j < contentTypes.Count; j++)
{
var contentType = contentTypes[j];
if (!edges.ContainsKey(contentType))
{
edges.Add(contentType, new List<Endpoint>());
}
}
}
// Now in a second loop, add endpoints to these lists. We've enumerated all of
// the states, so we want to see which states this endpoint matches.
for (var i = 0; i < endpoints.Count; i++)
{
var endpoint = endpoints[i];
var contentTypes = endpoint.Metadata.GetMetadata<IConsumesMetadata>()?.ContentTypes ?? Array.Empty<string>();
if (contentTypes.Count == 0)
{
// OK this means that this endpoint matches *all* content methods.
// So, loop and add it to all states.
foreach (var kvp in edges)
{
kvp.Value.Add(endpoint);
}
}
else
{
// OK this endpoint matches specific content types -- we have to loop through edges here
// because content types could either be exact (like 'application/json') or they
// could have wildcards (like 'text/*'). We don't expect wildcards to be especially common
// with consumes, but we need to support it.
foreach (var kvp in edges)
{
// The edgeKey maps to a possible request header value
var edgeKey = new MediaType(kvp.Key);
for (var j = 0; j < contentTypes.Count; j++)
{
var contentType = contentTypes[j];
var mediaType = new MediaType(contentType);
// Example: 'application/json' is subset of 'application/*'
//
// This means that when the request has content-type 'application/json' an endpoint
// what consumes 'application/*' should match.
if (edgeKey.IsSubsetOf(mediaType))
{
kvp.Value.Add(endpoint);
// It's possible that a ConsumesMetadata defines overlapping wildcards. Don't add an endpoint
// to any edge twice
break;
}
}
}
}
}
// If after we're done there isn't any endpoint that accepts */*, then we'll synthesize an
// endpoint that always returns a 415.
if (!edges.ContainsKey(AnyContentType))
{
edges.Add(AnyContentType, new List<Endpoint>()
{
CreateRejectionEndpoint(),
});
}
return edges
.Select(kvp => new PolicyNodeEdge(kvp.Key, kvp.Value))
.ToArray();
}
private Endpoint CreateRejectionEndpoint()
{
return new MatcherEndpoint(
(next) => (context) =>
{
context.Response.StatusCode = StatusCodes.Status415UnsupportedMediaType;
return Task.CompletedTask;
},
RoutePatternFactory.Parse("/"),
new RouteValueDictionary(),
0,
EndpointMetadataCollection.Empty,
Http415EndpointDisplayName);
}
public PolicyJumpTable BuildJumpTable(int exitDestination, IReadOnlyList<PolicyJumpTableEdge> edges)
{
if (edges == null)
{
throw new ArgumentNullException(nameof(edges));
}
// Since our 'edges' can have wildcards, we do a sort based on how wildcard-ey they
// are then then execute them in linear order.
var ordered = edges
.Select(e => (mediaType: new MediaType((string)e.State), destination: e.Destination))
.OrderBy(e => GetScore(e.mediaType))
.ToArray();
// If any edge matches all content types, then treat that as the 'exit'. This will
// always happen because we insert a 415 endpoint.
for (var i = 0; i < ordered.Length; i++)
{
if (ordered[i].mediaType.MatchesAllTypes)
{
exitDestination = ordered[i].destination;
break;
}
}
return new ConsumesPolicyJumpTable(exitDestination, ordered);
}
private int GetScore(MediaType mediaType)
{
// Higher score == lower priority - see comments on MediaType.
if (mediaType.MatchesAllTypes)
{
return 4;
}
else if (mediaType.MatchesAllSubTypes)
{
return 3;
}
else if (mediaType.MatchesAllSubTypesWithoutSuffix)
{
return 2;
}
else
{
return 1;
}
}
private class ConsumesMetadataEndpointComparer : EndpointMetadataComparer<IConsumesMetadata>
{
protected override int CompareMetadata(IConsumesMetadata x, IConsumesMetadata y)
{
// Ignore the metadata if it has an empty list of content types.
return base.CompareMetadata(
x?.ContentTypes.Count > 0 ? x : null,
y?.ContentTypes.Count > 0 ? y : null);
}
}
private class ConsumesPolicyJumpTable : PolicyJumpTable
{
private (MediaType mediaType, int destination)[] _destinations;
private int _exitDestination;
public ConsumesPolicyJumpTable(int exitDestination, (MediaType mediaType, int destination)[] destinations)
{
_exitDestination = exitDestination;
_destinations = destinations;
}
public override int GetDestination(HttpContext httpContext)
{
var contentType = httpContext.Request.ContentType;
if (string.IsNullOrEmpty(contentType))
{
return _exitDestination;
}
var requestMediaType = new MediaType(contentType);
var destinations = _destinations;
for (var i = 0; i < destinations.Length; i++)
{
if (requestMediaType.IsSubsetOf(destinations[i].mediaType))
{
return destinations[i].destination;
}
}
return _exitDestination;
}
}
}
}

View File

@ -0,0 +1,23 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
namespace Microsoft.AspNetCore.Mvc.Routing
{
internal class ConsumesMetadata : IConsumesMetadata
{
public ConsumesMetadata(string[] contentTypes)
{
if (contentTypes == null)
{
throw new ArgumentNullException(nameof(contentTypes));
}
ContentTypes = contentTypes;
}
public IReadOnlyList<string> ContentTypes { get; }
}
}

View File

@ -325,6 +325,7 @@ namespace Microsoft.AspNetCore.Mvc
typeof(MatcherPolicy),
new Type[]
{
typeof(ConsumesMatcherPolicy),
typeof(ActionConstraintMatcherPolicy),
}
},

View File

@ -0,0 +1,237 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.Routing.Matching;
using Microsoft.AspNetCore.Routing.Patterns;
using Xunit;
namespace Microsoft.AspNetCore.Mvc.Routing
{
public class ConsumesMatcherPolicyTest
{
[Fact]
public void AppliesToNode_EndpointWithoutMetadata_ReturnsFalse()
{
// Arrange
var endpoints = new[] { CreateEndpoint("/", null), };
var policy = CreatePolicy();
// Act
var result = policy.AppliesToNode(endpoints);
// Assert
Assert.False(result);
}
[Fact]
public void AppliesToNode_EndpointWithoutContentTypes_ReturnsFalse()
{
// Arrange
var endpoints = new[]
{
CreateEndpoint("/", new ConsumesMetadata(Array.Empty<string>())),
};
var policy = CreatePolicy();
// Act
var result = policy.AppliesToNode(endpoints);
// Assert
Assert.False(result);
}
[Fact]
public void AppliesToNode_EndpointHasContentTypes_ReturnsTrue()
{
// Arrange
var endpoints = new[]
{
CreateEndpoint("/", new ConsumesMetadata(Array.Empty<string>())),
CreateEndpoint("/", new ConsumesMetadata(new[] { "application/json", })),
};
var policy = CreatePolicy();
// Act
var result = policy.AppliesToNode(endpoints);
// Assert
Assert.True(result);
}
[Fact]
public void GetEdges_GroupsByContentType()
{
// Arrange
var endpoints = new[]
{
// These are arrange in an order that we won't actually see in a product scenario. It's done
// this way so we can verify that ordering is preserved by GetEdges.
CreateEndpoint("/", new ConsumesMetadata(new[] { "application/json", "application/*+json", })),
CreateEndpoint("/", new ConsumesMetadata(Array.Empty<string>())),
CreateEndpoint("/", new ConsumesMetadata(new[] { "application/xml", "application/*+xml", })),
CreateEndpoint("/", new ConsumesMetadata(new[] { "application/*", })),
CreateEndpoint("/", new ConsumesMetadata(new[]{ "*/*", })),
};
var policy = CreatePolicy();
// Act
var edges = policy.GetEdges(endpoints);
// Assert
Assert.Collection(
edges.OrderBy(e => e.State),
e =>
{
Assert.Equal("*/*", e.State);
Assert.Equal(new[] { endpoints[1], endpoints[4], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("application/*", e.State);
Assert.Equal(new[] { endpoints[1], endpoints[3], endpoints[4], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("application/*+json", e.State);
Assert.Equal(new[] { endpoints[0], endpoints[1], endpoints[3], endpoints[4], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("application/*+xml", e.State);
Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[3], endpoints[4], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("application/json", e.State);
Assert.Equal(new[] { endpoints[0], endpoints[1], endpoints[3], endpoints[4], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("application/xml", e.State);
Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[3], endpoints[4], }, e.Endpoints.ToArray());
});
}
[Fact] // See explanation in GetEdges for how this case is different
public void GetEdges_GroupsByContentType_CreatesHttp405Endpoint()
{
// Arrange
var endpoints = new[]
{
// These are arrange in an order that we won't actually see in a product scenario. It's done
// this way so we can verify that ordering is preserved by GetEdges.
CreateEndpoint("/", new ConsumesMetadata(new[] { "application/json", "application/*+json", })),
CreateEndpoint("/", new ConsumesMetadata(new[] { "application/xml", "application/*+xml", })),
CreateEndpoint("/", new ConsumesMetadata(new[] { "application/*", })),
};
var policy = CreatePolicy();
// Act
var edges = policy.GetEdges(endpoints);
// Assert
Assert.Collection(
edges.OrderBy(e => e.State),
e =>
{
Assert.Equal("*/*", e.State);
Assert.Equal(ConsumesMatcherPolicy.Http415EndpointDisplayName, Assert.Single(e.Endpoints).DisplayName);
},
e =>
{
Assert.Equal("application/*", e.State);
Assert.Equal(new[] { endpoints[2], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("application/*+json", e.State);
Assert.Equal(new[] { endpoints[0], endpoints[2], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("application/*+xml", e.State);
Assert.Equal(new[] { endpoints[1], endpoints[2], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("application/json", e.State);
Assert.Equal(new[] { endpoints[0], endpoints[2], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("application/xml", e.State);
Assert.Equal(new[] { endpoints[1], endpoints[2], }, e.Endpoints.ToArray());
});
}
[Theory]
[InlineData("image/png", 1)]
[InlineData("application/foo", 2)]
[InlineData("text/xml", 3)]
[InlineData("application/product+json", 6)] // application/json will match this
[InlineData("application/product+xml", 7)] // application/xml will match this
[InlineData("application/json", 6)]
[InlineData("application/xml", 7)]
public void BuildJumpTable_SortsEdgesByPriority(string contentType, int expected)
{
// Arrange
var edges = new PolicyJumpTableEdge[]
{
// In reverse order of how they should be processed
new PolicyJumpTableEdge("*/*", 1),
new PolicyJumpTableEdge("application/*", 2),
new PolicyJumpTableEdge("text/*", 3),
new PolicyJumpTableEdge("application/*+xml", 4),
new PolicyJumpTableEdge("application/*+json", 5),
new PolicyJumpTableEdge("application/json", 6),
new PolicyJumpTableEdge("application/xml", 7),
};
var policy = CreatePolicy();
var jumpTable = policy.BuildJumpTable(-1, edges);
var httpContext = new DefaultHttpContext();
httpContext.Request.ContentType = contentType;
// Act
var actual = jumpTable.GetDestination(httpContext);
// Assert
Assert.Equal(expected, actual);
}
private static MatcherEndpoint CreateEndpoint(string template, ConsumesMetadata consumesMetadata)
{
var metadata = new List<object>();
if (consumesMetadata != null)
{
metadata.Add(consumesMetadata);
}
return new MatcherEndpoint(
(next) => null,
RoutePatternFactory.Parse(template),
new RouteValueDictionary(),
0,
new EndpointMetadataCollection(metadata),
$"test: {template} - {string.Join(", ", consumesMetadata?.ContentTypes ?? Array.Empty<string>())}");
}
private static ConsumesMatcherPolicy CreatePolicy()
{
return new ConsumesMatcherPolicy();
}
}
}

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Net;
using System.Net.Http;
using System.Threading.Tasks;
using Newtonsoft.Json;
using Xunit;
@ -29,5 +30,22 @@ namespace Microsoft.AspNetCore.Mvc.FunctionalTests
Assert.True(result);
}
// The endpoint routing version of this feature has fixed https://github.com/aspnet/Mvc/issues/8174
[Fact]
public override async Task NoRequestContentType_Selects_IfASingleActionWithConstraintIsPresent()
{
// Arrange
var request = new HttpRequestMessage(
HttpMethod.Post,
"http://localhost/ConsumesAttribute_PassThrough/CreateProduct");
// Act
var response = await Client.SendAsync(request);
var body = await response.Content.ReadAsStringAsync();
// Assert
Assert.Equal(HttpStatusCode.UnsupportedMediaType, response.StatusCode);
}
}
}

View File

@ -49,7 +49,7 @@ namespace Microsoft.AspNetCore.Mvc.FunctionalTests
}
[Fact]
public async Task NoRequestContentType_Selects_IfASingleActionWithConstraintIsPresent()
public virtual async Task NoRequestContentType_Selects_IfASingleActionWithConstraintIsPresent()
{
// Arrange
var request = new HttpRequestMessage(

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Net;
using System.Net.Http;
using System.Threading.Tasks;
using Newtonsoft.Json;
using Xunit;