Support Authorize attribute on EndPoints (#606)

This commit is contained in:
BrennanConroy 2017-06-28 14:52:52 -07:00 committed by GitHub
parent ccf6cd415e
commit dc29e98032
5 changed files with 139 additions and 17 deletions

View File

@ -10,7 +10,7 @@
<PackageTags>aspnetcore;signalr</PackageTags>
<EnableApiCheck>false</EnableApiCheck>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Abstractions\Microsoft.AspNetCore.Sockets.Abstractions.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.SignalR.Common\Microsoft.AspNetCore.SignalR.Common.csproj" />

View File

@ -1,6 +1,6 @@
using System;
using System.Collections.Generic;
using System.Text;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
@ -11,10 +11,7 @@ namespace Microsoft.AspNetCore.SignalR
public static ISocketBuilder UseHub<THub>(this ISocketBuilder socketBuilder) where THub : Hub<IClientProxy>
{
var endpoint = socketBuilder.ApplicationServices.GetRequiredService<HubEndPoint<THub>>();
return socketBuilder.Run(connection =>
{
return endpoint.OnConnectedAsync(connection);
});
return socketBuilder.Run(connection => endpoint.OnConnectedAsync(connection));
}
}
}

View File

@ -2,6 +2,8 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Reflection;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Routing;
namespace Microsoft.AspNetCore.Sockets
@ -30,7 +32,20 @@ namespace Microsoft.AspNetCore.Sockets
public void MapEndPoint<TEndPoint>(string path) where TEndPoint : EndPoint
{
MapSocket(path, builder =>
MapEndPoint<TEndPoint>(path, socketOptions: null);
}
public void MapEndPoint<TEndPoint>(string path, Action<HttpSocketOptions> socketOptions) where TEndPoint : EndPoint
{
var authorizeAttributes = typeof(TEndPoint).GetCustomAttributes<AuthorizeAttribute>(inherit: true);
var options = new HttpSocketOptions();
foreach (var attribute in authorizeAttributes)
{
options.AuthorizationData.Add(attribute);
}
socketOptions?.Invoke(options);
MapSocket(path, options, builder =>
{
builder.UseEndPoint<TEndPoint>();
});

View File

@ -1,9 +1,6 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.Sockets
@ -12,12 +9,9 @@ namespace Microsoft.AspNetCore.Sockets
{
public static ISocketBuilder UseEndPoint<TEndPoint>(this ISocketBuilder socketBuilder) where TEndPoint : EndPoint
{
var endpoint = socketBuilder.ApplicationServices.GetRequiredService<TEndPoint>();
// This is a terminal middleware, so there's no need to use the 'next' parameter
return socketBuilder.Run(connection =>
{
var endpoint = socketBuilder.ApplicationServices.GetRequiredService<TEndPoint>();
return endpoint.OnConnectedAsync(connection);
});
return socketBuilder.Run(connection => endpoint.OnConnectedAsync(connection));
}
}
}

View File

@ -0,0 +1,116 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.DependencyInjection;
using Xunit;
namespace Microsoft.AspNetCore.Sockets.Tests
{
public class MapEndPointTests
{
[Fact]
public void MapEndPointFindsAuthAttributeOnEndPoint()
{
var authCount = 0;
var builder = new WebHostBuilder()
.UseKestrel()
.ConfigureServices(services =>
{
services.AddSockets();
services.AddEndPoint<AuthEndPoint>();
})
.Configure(app =>
{
app.UseSockets(routes =>
{
routes.MapEndPoint<AuthEndPoint>("auth", httpSocketOptions =>
{
authCount += httpSocketOptions.AuthorizationData.Count;
});
});
})
.Build();
Assert.Equal(1, authCount);
}
[Fact]
public void MapEndPointFindsAuthAttributeOnInheritedEndPoint()
{
var authCount = 0;
var builder = new WebHostBuilder()
.UseKestrel()
.ConfigureServices(services =>
{
services.AddSockets();
services.AddEndPoint<InheritedAuthEndPoint>();
})
.Configure(app =>
{
app.UseSockets(routes =>
{
routes.MapEndPoint<InheritedAuthEndPoint>("auth", httpSocketOptions =>
{
authCount += httpSocketOptions.AuthorizationData.Count;
});
});
})
.Build();
Assert.Equal(1, authCount);
}
[Fact]
public void MapEndPointFindsAuthAttributesOnDoubleAuthEndPoint()
{
var authCount = 0;
var builder = new WebHostBuilder()
.UseKestrel()
.ConfigureServices(services =>
{
services.AddSockets();
services.AddEndPoint<DoubleAuthEndPoint>();
})
.Configure(app =>
{
app.UseSockets(routes =>
{
routes.MapEndPoint<DoubleAuthEndPoint>("auth", httpSocketOptions =>
{
authCount += httpSocketOptions.AuthorizationData.Count;
});
});
})
.Build();
Assert.Equal(2, authCount);
}
private class InheritedAuthEndPoint : AuthEndPoint
{
public override Task OnConnectedAsync(ConnectionContext connection)
{
throw new NotImplementedException();
}
}
[Authorize]
private class DoubleAuthEndPoint : AuthEndPoint
{
}
[Authorize]
private class AuthEndPoint : EndPoint
{
public override Task OnConnectedAsync(ConnectionContext connection)
{
throw new NotImplementedException();
}
}
}
}