diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketOutput.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketOutput.cs index d7ac287ec6..c06bd4b236 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketOutput.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketOutput.cs @@ -169,7 +169,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http _tasksPending.Enqueue(new WaitingTask() { CancellationToken = cancellationToken, - CancellationRegistration = cancellationToken.Register(_connectionCancellation, this), + CancellationRegistration = cancellationToken.SafeRegister(_connectionCancellation, this), BytesToWrite = buffer.Count, CompletionSource = tcs }); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/CancellationTokenExtensions.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/CancellationTokenExtensions.cs new file mode 100644 index 0000000000..1d4275a314 --- /dev/null +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/CancellationTokenExtensions.cs @@ -0,0 +1,76 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure +{ + internal static class CancellationTokenExtensions + { + public static IDisposable SafeRegister(this CancellationToken cancellationToken, Action callback, object state) + { + var callbackWrapper = new CancellationCallbackWrapper(callback, state); + var registration = cancellationToken.Register(s => InvokeCallback(s), callbackWrapper); + var disposeCancellationState = new DisposeCancellationState(callbackWrapper, registration); + + return new DisposableAction(s => Dispose(s), disposeCancellationState); + } + + private static void InvokeCallback(object state) + { + ((CancellationCallbackWrapper)state).TryInvoke(); + } + + private static void Dispose(object state) + { + ((DisposeCancellationState)state).TryDispose(); + } + + private class DisposeCancellationState + { + private readonly CancellationCallbackWrapper _callbackWrapper; + private readonly CancellationTokenRegistration _registration; + + public DisposeCancellationState(CancellationCallbackWrapper callbackWrapper, CancellationTokenRegistration registration) + { + _callbackWrapper = callbackWrapper; + _registration = registration; + } + + public void TryDispose() + { + if (_callbackWrapper.TrySetInvoked()) + { + _registration.Dispose(); + } + } + } + + private class CancellationCallbackWrapper + { + private readonly Action _callback; + private readonly object _state; + private int _callbackInvoked; + + public CancellationCallbackWrapper(Action callback, object state) + { + _callback = callback; + _state = state; + } + + public bool TrySetInvoked() + { + return Interlocked.Exchange(ref _callbackInvoked, 1) == 0; + } + + public void TryInvoke() + { + if (TrySetInvoked()) + { + _callback(_state); + } + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/DisposableAction.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/DisposableAction.cs new file mode 100644 index 0000000000..4336437ae9 --- /dev/null +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/DisposableAction.cs @@ -0,0 +1,40 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure +{ + internal class DisposableAction : IDisposable + { + public static readonly DisposableAction Empty = new DisposableAction(() => { }); + + private Action _action; + private readonly object _state; + + public DisposableAction(Action action) + : this(state => ((Action)state).Invoke(), state: action) + { + } + + public DisposableAction(Action action, object state) + { + _action = action; + _state = state; + } + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + Interlocked.Exchange(ref _action, (state) => { }).Invoke(_state); + } + } + + public void Dispose() + { + Dispose(true); + } + } +} diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs index bcd15a643c..f28b1ba8c5 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs @@ -753,7 +753,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Action triggerNextCompleted; Assert.True(completeQueue.TryDequeue(out triggerNextCompleted)); triggerNextCompleted(0); - await mockLibuv.OnPostTask; + await mockLibuv.OnPostTask; Assert.True(writeCalled); // Cleanup @@ -862,4 +862,4 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } } -} +} \ No newline at end of file