Remove custom routing and path matching from HttpConnectionDispatcher (#471)

- We're now using the routing system in a very vanilla way now that
we're not using the URL space as part of the protocol.
- Removed the path argument from the HttpConnectionDispatcher (simplifies code and removes duplication from tests)
This commit is contained in:
David Fowler 2017-05-20 00:34:29 -07:00 committed by GitHub
parent 240a88f7af
commit 2aabce48b4
11 changed files with 57 additions and 151 deletions

View File

@ -25,10 +25,10 @@ namespace Microsoft.AspNetCore.SignalR.Test.Server
}
app.UseFileServer();
app.UseSockets(options => options.MapEndpoint<EchoEndPoint>("/echo"));
app.UseSockets(options => options.MapEndpoint<EchoEndPoint>("echo"));
app.UseSignalR(routes =>
{
routes.MapHub<TestHub>("/testhub");
routes.MapHub<TestHub>("testhub");
});
}
}

View File

@ -92,7 +92,7 @@ namespace ChatSample
app.UseSignalR(routes =>
{
routes.MapHub<Chat>("/chat");
routes.MapHub<Chat>("chat");
});
app.UseMvc(routes =>

View File

@ -32,7 +32,7 @@ namespace SocialWeather
app.UseDeveloperExceptionPage();
}
app.UseSockets(o => { o.MapEndpoint<SocialWeatherEndPoint>("/weather"); });
app.UseSockets(o => { o.MapEndpoint<SocialWeatherEndPoint>("weather"); });
app.UseFileServer();
var formatterResolver = app.ApplicationServices.GetRequiredService<FormatterResolver>();

View File

@ -48,12 +48,12 @@ namespace SocketsSample
app.UseSignalR(routes =>
{
routes.MapHub<Chat>("/hubs");
routes.MapHub<Chat>("hubs");
});
app.UseSockets(routes =>
{
routes.MapEndpoint<MessagesEndPoint>("/chat");
routes.MapEndpoint<MessagesEndPoint>("chat");
});
}
}

View File

@ -34,7 +34,7 @@ namespace Microsoft.AspNetCore.Sockets
_logger = _loggerFactory.CreateLogger<HttpConnectionDispatcher>();
}
public async Task ExecuteAsync<TEndPoint>(string path, HttpContext context) where TEndPoint : EndPoint
public async Task ExecuteAsync<TEndPoint>(HttpContext context) where TEndPoint : EndPoint
{
var options = context.RequestServices.GetRequiredService<IOptions<EndPointOptions<TEndPoint>>>().Value;
// TODO: Authorize attribute on EndPoint
@ -43,38 +43,31 @@ namespace Microsoft.AspNetCore.Sockets
return;
}
if (context.Request.Path.Equals(path, StringComparison.OrdinalIgnoreCase))
if (HttpMethods.IsOptions(context.Request.Method))
{
if (HttpMethods.IsOptions(context.Request.Method))
{
// OPTIONS /{path}
await ProcessNegotiate(context, options);
}
else if (HttpMethods.IsPost(context.Request.Method))
{
// POST /{path}
await ProcessSend(context);
}
else if (HttpMethods.IsGet(context.Request.Method))
{
// GET /{path}
// OPTIONS /{path}
await ProcessNegotiate(context, options);
}
else if (HttpMethods.IsPost(context.Request.Method))
{
// POST /{path}
await ProcessSend(context);
}
else if (HttpMethods.IsGet(context.Request.Method))
{
// GET /{path}
// Get the end point mapped to this http connection
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
await ExecuteEndpointAsync(path, context, endpoint, options);
}
else
{
context.Response.StatusCode = StatusCodes.Status405MethodNotAllowed;
}
// Get the end point mapped to this http connection
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
await ExecuteEndpointAsync(context, endpoint, options);
}
else
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
context.Response.StatusCode = StatusCodes.Status405MethodNotAllowed;
}
}
private async Task ExecuteEndpointAsync<TEndPoint>(string path, HttpContext context, EndPoint endpoint, EndPointOptions<TEndPoint> options) where TEndPoint : EndPoint
private async Task ExecuteEndpointAsync<TEndPoint>(HttpContext context, EndPoint endpoint, EndPointOptions<TEndPoint> options) where TEndPoint : EndPoint
{
var supportedTransports = options.Transports;

View File

@ -2,9 +2,9 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Routing;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.Builder
@ -38,7 +38,7 @@ namespace Microsoft.AspNetCore.Builder
public void MapEndpoint<TEndPoint>(string path) where TEndPoint : EndPoint
{
_routes.AddPrefixRoute(path, new RouteHandler(c => _dispatcher.ExecuteAsync<TEndPoint>(path, c)));
_routes.MapRoute(path, _dispatcher.ExecuteAsync<TEndPoint>);
}
}
}

View File

@ -1,64 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Routing;
namespace Microsoft.AspNetCore.Sockets.Routing
{
internal class PrefixRoute : IRouter
{
private readonly IRouteHandler _target;
private readonly string _prefix;
public PrefixRoute(IRouteHandler target, string prefix)
{
_target = target;
if (prefix == null)
{
prefix = "/";
}
else if (prefix.Length > 0 && prefix[0] != '/')
{
// owin.RequestPath starts with a /
prefix = "/" + prefix;
}
if (prefix.Length > 1 && prefix[prefix.Length - 1] == '/')
{
prefix = prefix.Substring(0, prefix.Length - 1);
}
_prefix = prefix;
}
public Task RouteAsync(RouteContext context)
{
var requestPath = context.HttpContext.Request.Path.Value ?? string.Empty;
if (requestPath.StartsWith(_prefix, StringComparison.OrdinalIgnoreCase))
{
if (requestPath.Length > _prefix.Length)
{
var lastCharacter = requestPath[_prefix.Length];
if (lastCharacter != '/' && lastCharacter != '#' && lastCharacter != '?')
{
return Task.FromResult(0);
}
}
context.Handler = _target.GetRequestHandler(context.HttpContext, context.RouteData);
}
return Task.FromResult(0);
}
public VirtualPathData GetVirtualPath(VirtualPathContext context)
{
return null;
}
}
}

View File

@ -1,23 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Routing;
namespace Microsoft.AspNetCore.Sockets.Routing
{
internal static class RouteBuilderExtensions
{
public static IRouteBuilder AddPrefixRoute(
this IRouteBuilder routeBuilder,
string prefix,
IRouteHandler handler)
{
routeBuilder.Routes.Add(new PrefixRoute(handler, prefix));
return routeBuilder;
}
}
}

View File

@ -37,7 +37,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
{
app.UseSignalR(routes =>
{
routes.MapHub<TestHub>("/hubs");
routes.MapHub<TestHub>("hubs");
});
});
_testServer = new TestServer(webHostBuilder);

View File

@ -46,7 +46,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public void Configure(IApplicationBuilder app, IHostingEnvironment env)
{
app.UseSockets(options => options.MapEndpoint<EchoEndPoint>("/echo"));
app.UseSockets(options => options.MapEndpoint<EchoEndPoint>("echo"));
}
}

View File

@ -42,7 +42,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Request.Path = "/foo";
context.Request.Method = "OPTIONS";
context.Response.Body = ms;
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
await dispatcher.ExecuteAsync<TestEndPoint>(context);
var id = Encoding.UTF8.GetString(ms.ToArray());
@ -77,7 +77,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Request.Query = qs;
SetTransport(context, transportType);
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
await dispatcher.ExecuteAsync<TestEndPoint>(context);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
await strm.FlushAsync();
@ -108,7 +108,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var qs = new QueryCollection(values);
context.Request.Query = qs;
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
await dispatcher.ExecuteAsync<TestEndPoint>(context);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
await strm.FlushAsync();
@ -136,7 +136,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
SetTransport(context, transportType);
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
await dispatcher.ExecuteAsync<TestEndPoint>(context);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await strm.FlushAsync();
@ -160,7 +160,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Request.Path = "/foo";
context.Request.Method = "POST";
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
await dispatcher.ExecuteAsync<TestEndPoint>(context);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await strm.FlushAsync();
@ -187,7 +187,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Request.ContentType = "text/plain";
context.Response.Body = strm;
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
await dispatcher.ExecuteAsync<TestEndPoint>(context);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await strm.FlushAsync();
@ -240,7 +240,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/foo", state);
SetTransport(context, TransportType.ServerSentEvents);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("/foo", context);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>(context);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
@ -260,7 +260,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = MakeRequest<SynchronusExceptionEndPoint>("/foo", state);
SetTransport(context, TransportType.ServerSentEvents);
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>("/foo", context);
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>(context);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
@ -279,7 +279,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = MakeRequest<SynchronusExceptionEndPoint>("/foo", state);
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>("/foo", context);
await dispatcher.ExecuteAsync<SynchronusExceptionEndPoint>(context);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
@ -298,7 +298,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/foo", state);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("/foo", context);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>(context);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
@ -318,7 +318,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = MakeRequest<ImmediatelyCompleteEndPoint>("/foo", state);
SetTransport(context, TransportType.WebSockets);
var task = dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("/foo", context);
var task = dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>(context);
await task.OrTimeout();
}
@ -339,9 +339,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
SetTransport(context1, transportType);
SetTransport(context2, transportType);
var request1 = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context1);
var request1 = dispatcher.ExecuteAsync<TestEndPoint>(context1);
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context2);
await dispatcher.ExecuteAsync<TestEndPoint>(context2);
Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode);
@ -372,8 +372,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context1 = MakeRequest<TestEndPoint>("/foo", state);
var context2 = MakeRequest<TestEndPoint>("/foo", state);
var request1 = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context1);
var request2 = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context2);
var request1 = dispatcher.ExecuteAsync<TestEndPoint>(context1);
var request2 = dispatcher.ExecuteAsync<TestEndPoint>(context2);
await request1;
@ -401,7 +401,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = MakeRequest<TestEndPoint>("/foo", state);
SetTransport(context, transportType);
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
await dispatcher.ExecuteAsync<TestEndPoint>(context);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
}
@ -416,7 +416,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = MakeRequest<TestEndPoint>("/foo", state);
var task = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
var task = dispatcher.ExecuteAsync<TestEndPoint>(context);
var buffer = Encoding.UTF8.GetBytes("Hello World");
@ -442,7 +442,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = MakeRequest<BlockingEndPoint>("/foo", state);
SetTransport(context, TransportType.ServerSentEvents);
var task = dispatcher.ExecuteAsync<BlockingEndPoint>("/foo", context);
var task = dispatcher.ExecuteAsync<BlockingEndPoint>(context);
var buffer = Encoding.UTF8.GetBytes("Hello World");
@ -467,7 +467,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var context = MakeRequest<BlockingEndPoint>("/foo", state);
var task = dispatcher.ExecuteAsync<BlockingEndPoint>("/foo", context);
var task = dispatcher.ExecuteAsync<BlockingEndPoint>(context);
var buffer = Encoding.UTF8.GetBytes("Hello World");
@ -491,9 +491,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context1 = MakeRequest<TestEndPoint>("/foo", state);
var task1 = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context1);
var task1 = dispatcher.ExecuteAsync<TestEndPoint>(context1);
var context2 = MakeRequest<TestEndPoint>("/foo", state);
var task2 = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context2);
var task2 = dispatcher.ExecuteAsync<TestEndPoint>(context2);
// Task 1 should finish when request 2 arrives
await task1.OrTimeout();
@ -577,7 +577,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
// would hang if EndPoint was running
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context).OrTimeout();
await dispatcher.ExecuteAsync<TestEndPoint>(context).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
}
@ -619,7 +619,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>(context);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
@ -665,14 +665,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// would hang if EndPoint was running
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context).OrTimeout();
await dispatcher.ExecuteAsync<TestEndPoint>(context).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
// fully "authorize" user
context.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") }));
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>(context);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
@ -718,7 +718,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>("/foo", context);
var endPointTask = dispatcher.ExecuteAsync<TestEndPoint>(context);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
@ -766,7 +766,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// would block if EndPoint was executed
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context).OrTimeout();
await dispatcher.ExecuteAsync<TestEndPoint>(context).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
}
@ -842,7 +842,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>("/foo", context);
await dispatcher.ExecuteAsync<ImmediatelyCompleteEndPoint>(context);
Assert.Equal(status, context.Response.StatusCode);
await strm.FlushAsync();
@ -872,7 +872,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var messages = new List<Message>();
using (context.Request.Body = new MemoryStream(buffer, writable: false))
{
await dispatcher.ExecuteAsync<TestEndPoint>("/foo", context).OrTimeout();
await dispatcher.ExecuteAsync<TestEndPoint>(context).OrTimeout();
}
while (state.Connection.Transport.Input.TryRead(out var message))