[Blazor] Graceful disconnection

* Adds a new API endpoint to trigger graceful disconnection from blazor clients.
* Uses the sendBeacon API on the Blazor client to trigger graceful disconnections on the client when the document gets unloaded, which happens when closing the window, navigating away from the page or refreshing the page.
This commit is contained in:
Javier Calvarro Nelson 2019-08-07 17:40:02 +02:00 committed by GitHub
parent 92869c677f
commit 25c240bef5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 404 additions and 46 deletions

View File

@ -10,11 +10,13 @@ namespace Microsoft.AspNetCore.Builder
/// </summary>
public sealed class ComponentEndpointConventionBuilder : IHubEndpointConventionBuilder
{
private readonly IEndpointConventionBuilder _endpointConventionBuilder;
private readonly IEndpointConventionBuilder _hubEndpoint;
private readonly IEndpointConventionBuilder _disconnectEndpoint;
internal ComponentEndpointConventionBuilder(IEndpointConventionBuilder endpointConventionBuilder)
internal ComponentEndpointConventionBuilder(IEndpointConventionBuilder hubEndpoint, IEndpointConventionBuilder disconnectEndpoint)
{
_endpointConventionBuilder = endpointConventionBuilder;
_hubEndpoint = hubEndpoint;
_disconnectEndpoint = disconnectEndpoint;
}
/// <summary>
@ -23,7 +25,8 @@ namespace Microsoft.AspNetCore.Builder
/// <param name="convention">The convention to add to the builder.</param>
public void Add(Action<EndpointBuilder> convention)
{
_endpointConventionBuilder.Add(convention);
_hubEndpoint.Add(convention);
_disconnectEndpoint.Add(convention);
}
}
}

View File

@ -292,7 +292,17 @@ namespace Microsoft.AspNetCore.Builder
throw new ArgumentNullException(nameof(configureOptions));
}
return new ComponentEndpointConventionBuilder(endpoints.MapHub<ComponentHub>(path, configureOptions)).AddComponent(componentType, selector);
var hubEndpoint = endpoints.MapHub<ComponentHub>(path, configureOptions);
var disconnectEndpoint = endpoints.Map(
(path.EndsWith("/") ? path : path + "/") + "disconnect/",
endpoints.CreateApplicationBuilder().UseMiddleware<CircuitDisconnectMiddleware>().Build())
.WithDisplayName("Blazor disconnect");
return new ComponentEndpointConventionBuilder(
hubEndpoint,
disconnectEndpoint)
.AddComponent(componentType, selector);
}
}
}

View File

@ -0,0 +1,103 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Components.Server.Circuits;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Components.Server
{
// We use a middlware so that we can use DI.
internal class CircuitDisconnectMiddleware
{
private const string CircuitIdKey = "circuitId";
public CircuitDisconnectMiddleware(
ILogger<CircuitDisconnectMiddleware> logger,
CircuitRegistry registry,
CircuitIdFactory circuitIdFactory,
RequestDelegate next)
{
Logger = logger;
Registry = registry;
CircuitIdFactory = circuitIdFactory;
Next = next;
}
public ILogger<CircuitDisconnectMiddleware> Logger { get; }
public CircuitRegistry Registry { get; }
public CircuitIdFactory CircuitIdFactory { get; }
public RequestDelegate Next { get; }
public async Task Invoke(HttpContext context)
{
if (!HttpMethods.IsPost(context.Request.Method))
{
context.Response.StatusCode = StatusCodes.Status405MethodNotAllowed;
return;
}
var (hasCircuitId, circuitId) = await TryGetCircuitIdAsync(context);
if (!hasCircuitId)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
return;
}
await TerminateCircuitGracefully(circuitId);
context.Response.StatusCode = StatusCodes.Status200OK;
}
private async Task<(bool, string)> TryGetCircuitIdAsync(HttpContext context)
{
try
{
if (!context.Request.HasFormContentType)
{
return (false, null);
}
var form = await context.Request.ReadFormAsync();
if (!form.TryGetValue(CircuitIdKey, out var circuitId) || !CircuitIdFactory.ValidateCircuitId(circuitId))
{
return (false, null);
}
return (true, circuitId);
}
catch
{
return (false, null);
}
}
private async Task TerminateCircuitGracefully(string circuitId)
{
try
{
await Registry.Terminate(circuitId);
Log.CircuitTerminatedGracefully(Logger, circuitId);
}
catch (Exception e)
{
Log.UnhandledExceptionInCircuit(Logger, circuitId, e);
}
}
private class Log
{
private static readonly Action<ILogger, string, Exception> _circuitTerminatedGracefully =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(1, "CircuitTerminatedGracefully"), "Circuit '{CircuitId}' terminated gracefully");
private static readonly Action<ILogger, string, Exception> _unhandledExceptionInCircuit =
LoggerMessage.Define<string>(LogLevel.Warning, new EventId(2, "UnhandledExceptionInCircuit"), "Unhandled exception in circuit {CircuitId} while terminating gracefully.");
public static void CircuitTerminatedGracefully(ILogger logger, string circuitId) => _circuitTerminatedGracefully(logger, circuitId, null);
public static void UnhandledExceptionInCircuit(ILogger logger, string circuitId, Exception exception) => _unhandledExceptionInCircuit(logger, circuitId, exception);
}
}
}

View File

@ -81,15 +81,6 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
}
}
public void PermanentDisconnect(CircuitHost circuitHost)
{
if (ConnectedCircuits.TryRemove(circuitHost.CircuitId, out _))
{
Log.CircuitDisconnectedPermanently(_logger, circuitHost.CircuitId);
circuitHost.Client.SetDisconnected();
}
}
public virtual Task DisconnectAsync(CircuitHost circuitHost, string connectionId)
{
Log.CircuitDisconnectStarted(_logger, circuitHost.CircuitId, connectionId);
@ -297,6 +288,29 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
}
}
public ValueTask Terminate(string circuitId)
{
CircuitHost circuitHost;
DisconnectedCircuitEntry entry = default;
lock (CircuitRegistryLock)
{
if (ConnectedCircuits.TryGetValue(circuitId, out circuitHost) || DisconnectedCircuits.TryGetValue(circuitId, out entry))
{
circuitHost ??= entry.CircuitHost;
DisconnectedCircuits.Remove(circuitHost.CircuitId);
ConnectedCircuits.TryRemove(circuitHost.CircuitId, out _);
Log.CircuitDisconnectedPermanently(_logger, circuitHost.CircuitId);
circuitHost.Client.SetDisconnected();
}
else
{
return default;
}
}
return circuitHost?.DisposeAsync() ?? default;
}
private readonly struct DisconnectedCircuitEntry
{
public DisconnectedCircuitEntry(CircuitHost circuitHost, CancellationTokenSource tokenSource)

View File

@ -68,34 +68,7 @@ namespace Microsoft.AspNetCore.Components.Server
return Task.CompletedTask;
}
if (exception != null)
{
return _circuitRegistry.DisconnectAsync(circuitHost, Context.ConnectionId);
}
else
{
// The client will gracefully disconnect when using websockets by correctly closing the TCP connection.
// This happens when the user closes a tab, navigates away from the page or reloads the page.
// In these situations we know the user is done with the circuit, so we can get rid of it at that point.
// This is important to be able to more efficiently manage resources, specially memory.
return TerminateCircuitGracefully(circuitHost);
}
}
private async Task TerminateCircuitGracefully(CircuitHost circuitHost)
{
try
{
Log.CircuitTerminatedGracefully(_logger, circuitHost.CircuitId);
_circuitRegistry.PermanentDisconnect(circuitHost);
await circuitHost.DisposeAsync();
}
catch (Exception e)
{
Log.UnhandledExceptionInCircuit(_logger, circuitHost.CircuitId, e);
}
await _circuitRegistry.DisconnectAsync(circuitHost, Context.ConnectionId);
return _circuitRegistry.DisconnectAsync(circuitHost, Context.ConnectionId);
}
/// <summary>

View File

@ -0,0 +1,244 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.IO;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Components.Server.Circuits;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Xunit;
namespace Microsoft.AspNetCore.Components.Server
{
public class CircuitDisconnectMiddlewareTest
{
[Theory]
[InlineData("GET")]
[InlineData("PUT")]
[InlineData("DELETE")]
[InlineData("HEAD")]
public async Task DisconnectMiddleware_OnlyAccepts_PostRequests(string httpMethod)
{
// Arrange
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var registry = new CircuitRegistry(
Options.Create(new CircuitOptions()),
NullLogger<CircuitRegistry>.Instance,
circuitIdFactory);
var middleware = new CircuitDisconnectMiddleware(
NullLogger<CircuitDisconnectMiddleware>.Instance,
registry,
circuitIdFactory,
(ctx) => Task.CompletedTask);
var context = new DefaultHttpContext();
context.Request.Method = httpMethod;
// Act
await middleware.Invoke(context);
// Assert
Assert.Equal(StatusCodes.Status405MethodNotAllowed, context.Response.StatusCode);
}
[Theory]
[InlineData(null)]
[InlineData("application/json")]
public async Task Returns400BadRequest_ForInvalidContentTypes(string contentType)
{
// Arrange
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var registry = new CircuitRegistry(
Options.Create(new CircuitOptions()),
NullLogger<CircuitRegistry>.Instance,
circuitIdFactory);
var middleware = new CircuitDisconnectMiddleware(
NullLogger<CircuitDisconnectMiddleware>.Instance,
registry,
circuitIdFactory,
(ctx) => Task.CompletedTask);
var context = new DefaultHttpContext();
context.Request.Method = HttpMethods.Post;
context.Request.ContentType = contentType;
// Act
await middleware.Invoke(context);
// Assert
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
}
[Fact]
public async Task Returns400BadRequest_IfNoCircuitIdOnForm()
{
// Arrange
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var registry = new CircuitRegistry(
Options.Create(new CircuitOptions()),
NullLogger<CircuitRegistry>.Instance,
circuitIdFactory);
var middleware = new CircuitDisconnectMiddleware(
NullLogger<CircuitDisconnectMiddleware>.Instance,
registry,
circuitIdFactory,
(ctx) => Task.CompletedTask);
var context = new DefaultHttpContext();
context.Request.Method = HttpMethods.Post;
context.Request.ContentType = "application/x-www-form-urlencoded";
// Act
await middleware.Invoke(context);
// Assert
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
}
[Fact]
public async Task Returns400BadRequest_InvalidCircuitId()
{
// Arrange
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var registry = new CircuitRegistry(
Options.Create(new CircuitOptions()),
NullLogger<CircuitRegistry>.Instance,
circuitIdFactory);
var middleware = new CircuitDisconnectMiddleware(
NullLogger<CircuitDisconnectMiddleware>.Instance,
registry,
circuitIdFactory,
(ctx) => Task.CompletedTask);
using var memory = new MemoryStream();
await new FormUrlEncodedContent(new Dictionary<string, string> { ["circuitId"] = "1234" }).CopyToAsync(memory);
memory.Seek(0, SeekOrigin.Begin);
var context = new DefaultHttpContext();
context.Request.Method = HttpMethods.Post;
context.Request.ContentType = "application/x-www-form-urlencoded";
context.Request.Body = memory;
// Act
await middleware.Invoke(context);
// Assert
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
}
[Fact]
public async Task Returns200OK_NonExistingCircuit()
{
// Arrange
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var id = circuitIdFactory.CreateCircuitId();
var registry = new CircuitRegistry(
Options.Create(new CircuitOptions()),
NullLogger<CircuitRegistry>.Instance,
circuitIdFactory);
var middleware = new CircuitDisconnectMiddleware(
NullLogger<CircuitDisconnectMiddleware>.Instance,
registry,
circuitIdFactory,
(ctx) => Task.CompletedTask);
using var memory = new MemoryStream();
await new FormUrlEncodedContent(new Dictionary<string, string> { ["circuitId"] = id }).CopyToAsync(memory);
memory.Seek(0, SeekOrigin.Begin);
var context = new DefaultHttpContext();
context.Request.Method = HttpMethods.Post;
context.Request.ContentType = "application/x-www-form-urlencoded";
context.Request.Body = memory;
// Act
await middleware.Invoke(context);
// Assert
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
}
[Fact]
public async Task GracefullyTerminates_ConnectedCircuit()
{
// Arrange
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var id = circuitIdFactory.CreateCircuitId();
var testCircuitHost = TestCircuitHost.Create(id);
var registry = new CircuitRegistry(
Options.Create(new CircuitOptions()),
NullLogger<CircuitRegistry>.Instance,
circuitIdFactory);
registry.Register(testCircuitHost);
var middleware = new CircuitDisconnectMiddleware(
NullLogger<CircuitDisconnectMiddleware>.Instance,
registry,
circuitIdFactory,
(ctx) => Task.CompletedTask);
using var memory = new MemoryStream();
await new FormUrlEncodedContent(new Dictionary<string, string> { ["circuitId"] = id }).CopyToAsync(memory);
memory.Seek(0, SeekOrigin.Begin);
var context = new DefaultHttpContext();
context.Request.Method = HttpMethods.Post;
context.Request.ContentType = "application/x-www-form-urlencoded";
context.Request.Body = memory;
// Act
await middleware.Invoke(context);
// Assert
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
}
[Fact]
public async Task GracefullyTerminates_DisconnectedCircuit()
{
// Arrange
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var id = circuitIdFactory.CreateCircuitId();
var circuitHost = TestCircuitHost.Create(id);
var registry = new CircuitRegistry(
Options.Create(new CircuitOptions()),
NullLogger<CircuitRegistry>.Instance,
circuitIdFactory);
registry.Register(circuitHost);
await registry.DisconnectAsync(circuitHost, "1234");
var middleware = new CircuitDisconnectMiddleware(
NullLogger<CircuitDisconnectMiddleware>.Instance,
registry,
circuitIdFactory,
(ctx) => Task.CompletedTask);
using var memory = new MemoryStream();
await new FormUrlEncodedContent(new Dictionary<string, string> { ["circuitId"] = id }).CopyToAsync(memory);
memory.Seek(0, SeekOrigin.Begin);
var context = new DefaultHttpContext();
context.Request.Method = HttpMethods.Post;
context.Request.ContentType = "application/x-www-form-urlencoded";
context.Request.Body = memory;
// Act
await middleware.Invoke(context);
// Assert
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
}
}
}

View File

@ -4,6 +4,7 @@
using System.Diagnostics;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Moq;
@ -63,6 +64,7 @@ namespace Microsoft.AspNetCore.Components.Server.Tests
services.AddRouting();
services.AddSignalR();
services.AddServerSideBlazor();
services.AddSingleton<IConfiguration>(new ConfigurationBuilder().Build());
var serviceProvder = services.BuildServiceProvider();

File diff suppressed because one or more lines are too long

View File

@ -47,6 +47,16 @@ async function boot(userOptions?: Partial<BlazorOptions>): Promise<void> {
return true;
};
window.addEventListener(
'unload',
() => {
const data = new FormData();
data.set('circuitId', circuit.circuitId);
navigator.sendBeacon('_blazor/disconnect', data);
},
false
);
window['Blazor'].reconnect = reconnect;
logger.log(LogLevel.Information, 'Blazor server-side application started.');

View File

@ -55,7 +55,7 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests
public async Task ReloadingThePage_GracefullyDisconnects_TheCurrentCircuit()
{
// Arrange & Act
_ = ((IJavaScriptExecutor)Browser).ExecuteScript("location.reload()");
Browser.Navigate().Refresh();
await Task.WhenAny(Task.Delay(10000), GracefulDisconnectCompletionSource.Task);
// Assert
@ -70,7 +70,6 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests
Browser.Close();
// Set to null so that other tests in this class can create a new browser if necessary so
// that tests don't fail when running together.
Browser = null;
await Task.WhenAny(Task.Delay(10000), GracefulDisconnectCompletionSource.Task);