SocketOutput Cancellation handling

This commit is contained in:
Ben Adams 2016-01-15 13:08:40 +00:00
parent b062f851dc
commit 73bb0ab5b8
16 changed files with 421 additions and 108 deletions

View File

@ -56,7 +56,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
private bool _requestProcessingStarted;
private Task _requestProcessingTask;
protected volatile bool _requestProcessingStopping; // volatile, see: https://msdn.microsoft.com/en-us/library/x13ttww7.aspx
protected volatile bool _requestAborted;
protected int _requestAborted;
protected CancellationTokenSource _abortedCts;
protected CancellationToken? _manuallySetRequestAbortToken;
@ -167,7 +167,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
var cts = _abortedCts;
return
cts != null ? cts.Token :
_requestAborted ? new CancellationToken(true) :
(Volatile.Read(ref _requestAborted) == 1) ? new CancellationToken(true) :
RequestAbortedSource.Token;
}
set
@ -185,7 +185,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
// Get the abort token, lazily-initializing it if necessary.
// Make sure it's canceled if an abort request already came in.
var cts = LazyInitializer.EnsureInitialized(ref _abortedCts, () => new CancellationTokenSource());
if (_requestAborted)
if (Volatile.Read(ref _requestAborted) == 1)
{
cts.Cancel();
}
@ -288,24 +288,31 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
/// </summary>
public void Abort()
{
_requestProcessingStopping = true;
_requestAborted = true;
if (Interlocked.CompareExchange(ref _requestAborted, 1, 0) == 0)
{
_requestProcessingStopping = true;
_requestBody?.Abort();
_responseBody?.Abort();
_requestBody?.Abort();
_responseBody?.Abort();
try
{
ConnectionControl.End(ProduceEndType.SocketDisconnect);
SocketInput.AbortAwaiting();
RequestAbortedSource.Cancel();
}
catch (Exception ex)
{
Log.LogError("Abort", ex);
}
finally
{
try
{
ConnectionControl.End(ProduceEndType.SocketDisconnect);
SocketInput.AbortAwaiting();
}
catch (Exception ex)
{
Log.LogError("Abort", ex);
}
try
{
RequestAbortedSource.Cancel();
}
catch (Exception ex)
{
Log.LogError("Abort", ex);
}
_abortedCts = null;
}
}

View File

@ -3,6 +3,7 @@
using System;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Hosting.Server;
using Microsoft.AspNetCore.Http.Features;
@ -111,7 +112,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
_application.DisposeContext(context, _applicationException);
// If _requestAbort is set, the connection has already been closed.
if (!_requestAborted)
if (Volatile.Read(ref _requestAborted) == 0)
{
_responseBody.ResumeAcceptingWrites();
await ProduceEnd();
@ -148,7 +149,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
_abortedCts = null;
// If _requestAborted is set, the connection has already been closed.
if (!_requestAborted)
if (Volatile.Read(ref _requestAborted) == 0)
{
// Inform client no more data will ever arrive
ConnectionControl.End(ProduceEndType.SocketShutdownSend);

View File

@ -5,6 +5,7 @@ using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNet.Server.Kestrel.Infrastructure;
namespace Microsoft.AspNetCore.Server.Kestrel.Http
{
@ -51,8 +52,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
public override int Read(byte[] buffer, int offset, int count)
{
ValidateState();
// ValueTask uses .GetAwaiter().GetResult() if necessary
return ReadAsync(buffer, offset, count).Result;
}
@ -60,7 +59,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
#if NET451
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
ValidateState();
ValidateState(CancellationToken.None);
var task = ReadAsync(buffer, offset, count, CancellationToken.None, state);
if (callback != null)
@ -77,7 +76,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
private Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state)
{
ValidateState();
ValidateState(cancellationToken);
var tcs = new TaskCompletionSource<int>(state);
var task = _body.ReadAsync(new ArraySegment<byte>(buffer, offset, count), cancellationToken);
@ -103,10 +102,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValidateState();
// Needs .AsTask to match Stream's Async method return types
return _body.ReadAsync(new ArraySegment<byte>(buffer, offset, count), cancellationToken).AsTask();
var task = ValidateState(cancellationToken);
if (task == null)
{
// Needs .AsTask to match Stream's Async method return types
return _body.ReadAsync(new ArraySegment<byte>(buffer, offset, count), cancellationToken).AsTask();
}
return task;
}
public override void Write(byte[] buffer, int offset, int count)
@ -149,24 +151,29 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
public void Abort()
{
// We don't want to throw an ODE until the app func actually completes.
// If the request is aborted, we throw an IOException instead.
// If the request is aborted, we throw an TaskCanceledException instead.
if (_state != FrameStreamState.Closed)
{
_state = FrameStreamState.Aborted;
}
}
private void ValidateState()
private Task<int> ValidateState(CancellationToken cancellationToken)
{
switch (_state)
{
case FrameStreamState.Open:
return;
if (cancellationToken.IsCancellationRequested)
{
return TaskUtilities.GetCancelledZeroTask();
}
break;
case FrameStreamState.Closed:
throw new ObjectDisposedException(nameof(FrameRequestStream));
case FrameStreamState.Aborted:
throw new IOException("The request has been aborted.");
return TaskUtilities.GetCancelledZeroTask();
}
return null;
}
}
}

View File

@ -5,6 +5,7 @@ using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNet.Server.Kestrel.Infrastructure;
namespace Microsoft.AspNetCore.Server.Kestrel.Http
{
@ -37,16 +38,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
public override void Flush()
{
ValidateState();
ValidateState(CancellationToken.None);
_context.FrameControl.Flush();
}
public override Task FlushAsync(CancellationToken cancellationToken)
{
ValidateState();
return _context.FrameControl.FlushAsync(cancellationToken);
var task = ValidateState(cancellationToken);
if (task == null)
{
return _context.FrameControl.FlushAsync(cancellationToken);
}
return task;
}
public override long Seek(long offset, SeekOrigin origin)
@ -66,16 +70,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
public override void Write(byte[] buffer, int offset, int count)
{
ValidateState();
ValidateState(CancellationToken.None);
_context.FrameControl.Write(new ArraySegment<byte>(buffer, offset, count));
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
ValidateState();
return _context.FrameControl.WriteAsync(new ArraySegment<byte>(buffer, offset, count), cancellationToken);
var task = ValidateState(cancellationToken);
if (task == null)
{
return _context.FrameControl.WriteAsync(new ArraySegment<byte>(buffer, offset, count), cancellationToken);
}
return task;
}
public Stream StartAcceptingWrites()
@ -112,24 +119,36 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
public void Abort()
{
// We don't want to throw an ODE until the app func actually completes.
// If the request is aborted, we throw an IOException instead.
if (_state != FrameStreamState.Closed)
{
_state = FrameStreamState.Aborted;
}
}
private void ValidateState()
private Task ValidateState(CancellationToken cancellationToken)
{
switch (_state)
{
case FrameStreamState.Open:
return;
if (cancellationToken.IsCancellationRequested)
{
return TaskUtilities.GetCancelledTask(cancellationToken);
}
break;
case FrameStreamState.Closed:
throw new ObjectDisposedException(nameof(FrameResponseStream));
case FrameStreamState.Aborted:
throw new IOException("The request has been aborted.");
if (cancellationToken.CanBeCanceled)
{
// Aborted state only throws on write if cancellationToken requests it
return TaskUtilities.GetCancelledTask(
cancellationToken.IsCancellationRequested ?
cancellationToken :
new CancellationToken(true));
}
break;
}
return null;
}
}
}

View File

@ -2,11 +2,9 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Primitives;
namespace Microsoft.AspNetCore.Server.Kestrel.Http
{

View File

@ -5,6 +5,7 @@ using System;
using System.IO;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Infrastructure;
namespace Microsoft.AspNetCore.Server.Kestrel.Http
@ -184,7 +185,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
public void AbortAwaiting()
{
_awaitableError = new ObjectDisposedException(nameof(SocketInput), "The request was aborted");
_awaitableError = new TaskCanceledException("The request was aborted");
Complete();
}
@ -238,6 +239,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
var error = _awaitableError;
if (error != null)
{
if (error is TaskCanceledException || error is InvalidOperationException)
{
throw error;
}
throw new IOException(error.Message, error);
}
}

View File

@ -4,7 +4,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Infrastructure;
@ -22,6 +21,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
private const int _maxPooledWriteContexts = 32;
private static readonly WaitCallback _returnBlocks = (state) => ReturnBlocks((MemoryPoolBlock2)state);
private static readonly Action<object> _connectionCancellation = (state) => ((SocketOutput)state).CancellationTriggered();
private readonly KestrelThread _thread;
private readonly UvStreamHandle _socket;
@ -78,6 +78,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
public Task WriteAsync(
ArraySegment<byte> buffer,
CancellationToken cancellationToken,
bool immediate = true,
bool chunk = false,
bool socketShutdownSend = false,
@ -89,9 +90,21 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
lock (_contextLock)
{
if (_lastWriteError != null || _socket.IsClosed)
{
_log.ConnectionDisconnectedWrite(_connectionId, buffer.Count, _lastWriteError);
return TaskUtilities.CompletedTask;
}
if (buffer.Count > 0)
{
var tail = ProducingStart();
if (tail.IsDefault)
{
return TaskUtilities.CompletedTask;
}
if (chunk)
{
_numBytesPreCompleted += ChunkWriter.WriteBeginChunkBytes(ref tail, buffer.Count);
@ -146,13 +159,36 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
}
else
{
// immediate write, which is not eligable for instant completion above
tcs = new TaskCompletionSource<object>(buffer.Count);
_tasksPending.Enqueue(new WaitingTask() {
CompletionSource = tcs,
BytesToWrite = buffer.Count,
IsSync = isSync
});
if (cancellationToken.CanBeCanceled)
{
if (cancellationToken.IsCancellationRequested)
{
_connection.Abort();
return TaskUtilities.GetCancelledTask(cancellationToken);
}
else
{
// immediate write, which is not eligable for instant completion above
tcs = new TaskCompletionSource<object>();
_tasksPending.Enqueue(new WaitingTask()
{
CancellationToken = cancellationToken,
CancellationRegistration = cancellationToken.Register(_connectionCancellation, this),
BytesToWrite = buffer.Count,
CompletionSource = tcs
});
}
}
else
{
tcs = new TaskCompletionSource<object>();
_tasksPending.Enqueue(new WaitingTask() {
IsSync = isSync,
BytesToWrite = buffer.Count,
CompletionSource = tcs
});
}
}
if (!_writePending && immediate)
@ -177,12 +213,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
{
case ProduceEndType.SocketShutdownSend:
WriteAsync(default(ArraySegment<byte>),
default(CancellationToken),
immediate: true,
socketShutdownSend: true,
socketDisconnect: false);
break;
case ProduceEndType.SocketDisconnect:
WriteAsync(default(ArraySegment<byte>),
default(CancellationToken),
immediate: true,
socketShutdownSend: false,
socketDisconnect: true);
@ -198,7 +236,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
if (_tail == null)
{
throw new IOException("The socket has been closed.");
return default(MemoryPoolIterator2);
}
_lastStart = new MemoryPoolIterator2(_tail, _tail.End);
@ -251,6 +289,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
}
}
private void CancellationTriggered()
{
lock (_contextLock)
{
// Abort the connection for any failed write
// Queued on threadpool so get it in as first op.
_connection?.Abort();
CompleteAllWrites();
}
}
private static void ReturnBlocks(MemoryPoolBlock2 block)
{
while (block != null)
@ -305,10 +355,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
if (error != null)
{
_lastWriteError = new IOException(error.Message, error);
// Abort the connection for any failed write.
// Abort the connection for any failed write
// Queued on threadpool so get it in as first op.
_connection.Abort();
_lastWriteError = error;
}
PoolWriteContext(writeContext);
@ -317,43 +367,78 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
// completed writes that we haven't triggered callbacks for yet.
_numBytesPreCompleted -= bytesWritten;
CompleteFinishedWrites(status);
if (error != null)
{
_log.ConnectionError(_connectionId, error);
}
else
{
_log.ConnectionWriteCallback(_connectionId, status);
}
}
private void CompleteNextWrite(ref int bytesLeftToBuffer)
{
var waitingTask = _tasksPending.Dequeue();
var bytesToWrite = waitingTask.BytesToWrite;
_numBytesPreCompleted += bytesToWrite;
bytesLeftToBuffer -= bytesToWrite;
// Dispose registration if there is one
waitingTask.CancellationRegistration?.Dispose();
if (waitingTask.CancellationToken.IsCancellationRequested)
{
if (waitingTask.IsSync)
{
waitingTask.CompletionSource.TrySetCanceled();
}
else
{
_threadPool.Cancel(waitingTask.CompletionSource);
}
}
else
{
if (waitingTask.IsSync)
{
waitingTask.CompletionSource.TrySetResult(null);
}
else
{
_threadPool.Complete(waitingTask.CompletionSource);
}
}
}
private void CompleteFinishedWrites(int status)
{
// bytesLeftToBuffer can be greater than _maxBytesPreCompleted
// This allows large writes to complete once they've actually finished.
var bytesLeftToBuffer = _maxBytesPreCompleted - _numBytesPreCompleted;
while (_tasksPending.Count > 0 &&
(_tasksPending.Peek().BytesToWrite) <= bytesLeftToBuffer)
{
var waitingTask = _tasksPending.Dequeue();
var bytesToWrite = waitingTask.BytesToWrite;
CompleteNextWrite(ref bytesLeftToBuffer);
}
}
_numBytesPreCompleted += bytesToWrite;
bytesLeftToBuffer -= bytesToWrite;
if (_lastWriteError == null)
{
if (waitingTask.IsSync)
{
waitingTask.CompletionSource.TrySetResult(null);
}
else
{
_threadPool.Complete(waitingTask.CompletionSource);
}
}
else
{
if (waitingTask.IsSync)
{
waitingTask.CompletionSource.TrySetException(_lastWriteError);
}
else
{
_threadPool.Error(waitingTask.CompletionSource, _lastWriteError);
}
}
private void CompleteAllWrites()
{
var writesToComplete = _tasksPending.Count > 0;
var bytesLeftToBuffer = _maxBytesPreCompleted - _numBytesPreCompleted;
while (_tasksPending.Count > 0)
{
CompleteNextWrite(ref bytesLeftToBuffer);
}
_log.ConnectionWriteCallback(_connectionId, status);
if (writesToComplete)
{
_log.ConnectionError(_connectionId, new TaskCanceledException("Connetcion"));
}
}
// This is called on the libuv event loop
@ -393,12 +478,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
void ISocketOutput.Write(ArraySegment<byte> buffer, bool immediate, bool chunk)
{
WriteAsync(buffer, immediate, chunk, isSync: true).GetAwaiter().GetResult();
WriteAsync(buffer, CancellationToken.None, immediate, chunk, isSync: true).GetAwaiter().GetResult();
}
Task ISocketOutput.WriteAsync(ArraySegment<byte> buffer, bool immediate, bool chunk, CancellationToken cancellationToken)
{
return WriteAsync(buffer, immediate, chunk);
if (cancellationToken.IsCancellationRequested)
{
_connection?.Abort();
return TaskUtilities.GetCancelledTask(cancellationToken);
}
return WriteAsync(buffer, cancellationToken, immediate, chunk);
}
private static void BytesBetween(MemoryPoolIterator2 start, MemoryPoolIterator2 end, out int bytes, out int buffers)
@ -649,6 +740,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
{
public bool IsSync;
public int BytesToWrite;
public CancellationToken CancellationToken;
public IDisposable CancellationRegistration;
public TaskCompletionSource<object> CompletionSource;
}
}

View File

@ -29,6 +29,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure
void ConnectionWriteCallback(long connectionId, int status);
void ConnectionError(long connectionId, Exception ex);
void ConnectionDisconnectedWrite(long connectionId, int count, Exception ex);
void ApplicationError(Exception ex);
}
}

View File

@ -9,6 +9,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure
public interface IThreadPool
{
void Complete(TaskCompletionSource<object> tcs);
void Cancel(TaskCompletionSource<object> tcs);
void Error(TaskCompletionSource<object> tcs, Exception ex);
void Run(Action action);
}

View File

@ -21,6 +21,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel
private static readonly Action<ILogger, long, int, Exception> _connectionWroteFin;
private static readonly Action<ILogger, long, Exception> _connectionKeepAlive;
private static readonly Action<ILogger, long, Exception> _connectionDisconnect;
private static readonly Action<ILogger, long, Exception> _connectionError;
private static readonly Action<ILogger, long, int, Exception> _connectionDisconnectedWrite;
protected readonly ILogger _logger;
@ -39,6 +41,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel
// ConnectionWrite: Reserved: 11
// ConnectionWriteCallback: Reserved: 12
// ApplicationError: Reserved: 13 - LoggerMessage.Define overload not present
_connectionError = LoggerMessage.Define<long>(LogLevel.Information, 14, @"Connection id ""{ConnectionId}"" communication error");
_connectionDisconnectedWrite = LoggerMessage.Define<long, int>(LogLevel.Debug, 15, @"Connection id ""{ConnectionId}"" write of ""{count}"" bytes to disconnected client.");
}
public KestrelTrace(ILogger logger)
@ -114,6 +118,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel
_logger.LogError(13, "An unhandled exception was thrown by the application.", ex);
}
public virtual void ConnectionError(long connectionId, Exception ex)
{
_connectionError(_logger, connectionId, ex);
}
public virtual void ConnectionDisconnectedWrite(long connectionId, int count, Exception ex)
{
_connectionDisconnectedWrite(_logger, connectionId, count, ex);
}
public virtual void Log(LogLevel logLevel, int eventId, object state, Exception exception, Func<object, Exception, string> formatter)
{
_logger.Log(logLevel, eventId, state, exception, formatter);

View File

@ -12,6 +12,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure
private readonly IKestrelTrace _log;
private readonly WaitCallback _runAction;
private readonly WaitCallback _cancelTcs;
private readonly WaitCallback _completeTcs;
public LoggingThreadPool(IKestrelTrace log)
@ -42,6 +43,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure
_log.ApplicationError(e);
}
};
_cancelTcs = (o) =>
{
try
{
((TaskCompletionSource<object>)o).TrySetCanceled();
}
catch (Exception e)
{
_log.ApplicationError(e);
}
};
}
public void Run(Action action)
@ -54,6 +67,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure
ThreadPool.QueueUserWorkItem(_completeTcs, tcs);
}
public void Cancel(TaskCompletionSource<object> tcs)
{
ThreadPool.QueueUserWorkItem(_cancelTcs, tcs);
}
public void Error(TaskCompletionSource<object> tcs, Exception ex)
{
// ex ang _log are closure captured

View File

@ -724,6 +724,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure
public void CopyFrom(byte[] data, int offset, int count)
{
if (IsDefault)
{
return;
}
Debug.Assert(_block != null);
Debug.Assert(_block.Next == null);
Debug.Assert(_block.End == _index);
@ -766,6 +771,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure
public unsafe void CopyFromAscii(string data)
{
if (IsDefault)
{
return;
}
Debug.Assert(_block != null);
Debug.Assert(_block.Next == null);
Debug.Assert(_block.End == _index);

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure
@ -13,5 +14,24 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure
public static Task CompletedTask = Task.FromResult<object>(null);
#endif
public static Task<int> ZeroTask = Task.FromResult(0);
public static Task GetCancelledTask(CancellationToken cancellationToken)
{
#if DOTNET5_4
return Task.FromCanceled(cancellationToken);
#else
var tcs = new TaskCompletionSource<object>();
tcs.TrySetCanceled();
return tcs.Task;
#endif
}
public static Task<int> GetCancelledZeroTask()
{
// Task<int>.FromCanceled doesn't return Task<int>
var tcs = new TaskCompletionSource<int>();
tcs.TrySetCanceled();
return tcs.Task;
}
}
}

View File

@ -1084,7 +1084,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
await Assert.ThrowsAsync<IOException>(async () => await readTcs.Task);
await Assert.ThrowsAsync<TaskCanceledException>(async () => await readTcs.Task);
// The cancellation token for only the last request should be triggered.
var abortedRequestId = await registrationTcs.Task;
@ -1096,6 +1096,12 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[FrameworkSkipCondition(RuntimeFrameworks.Mono, SkipReason = "Test hangs after execution on Mono.")]
public async Task FailedWritesResultInAbortedRequest(ServiceContext testContext)
{
const int resetEventTimeout = 2000;
// This should match _maxBytesPreCompleted in SocketOutput
const int maxBytesPreCompleted = 65536;
// Ensure string is long enough to disable write-behind buffering
var largeString = new string('a', maxBytesPreCompleted + 1);
var writeTcs = new TaskCompletionSource<object>();
var registrationWh = new ManualResetEventSlim();
var connectionCloseWh = new ManualResetEventSlim();
@ -1119,7 +1125,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
// Ensure write is long enough to disable write-behind buffering
for (int i = 0; i < 10; i++)
{
await response.WriteAsync(new string('a', 65537));
await response.WriteAsync(largeString).ConfigureAwait(false);
}
}
catch (Exception ex)
@ -1127,7 +1133,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
writeTcs.SetException(ex);
// Give a chance for RequestAborted to trip before the app completes
registrationWh.Wait(1000);
registrationWh.Wait(resetEventTimeout);
throw;
}
@ -1141,16 +1147,16 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
"POST / HTTP/1.1",
"Content-Length: 5",
"",
"Hello");
"Hello").ConfigureAwait(false);
// Don't wait to receive the response. Just close the socket.
}
connectionCloseWh.Set();
// Write failed
await Assert.ThrowsAsync<IOException>(async () => await writeTcs.Task);
await Assert.ThrowsAsync<TaskCanceledException>(async () => await writeTcs.Task);
// RequestAborted tripped
Assert.True(registrationWh.Wait(200));
Assert.True(registrationWh.Wait(resetEventTimeout));
}
}

View File

@ -102,7 +102,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
private static void TestConcurrentFaultedTask(Task t)
{
Assert.True(t.IsFaulted);
Assert.IsType(typeof(System.IO.IOException), t.Exception.InnerException);
Assert.IsType(typeof(System.InvalidOperationException), t.Exception.InnerException);
Assert.Equal(t.Exception.InnerException.Message, "Concurrent reads are not supported.");
}

View File

@ -50,7 +50,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
var completedWh = new ManualResetEventSlim();
// Act
socketOutput.WriteAsync(buffer).ContinueWith(
socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(
(t) =>
{
Assert.Null(t.Exception);
@ -101,14 +101,14 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
};
// Act
socketOutput.WriteAsync(buffer).ContinueWith(onCompleted);
socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted);
// Assert
// The first write should pre-complete since it is <= _maxBytesPreCompleted.
Assert.True(completedWh.Wait(1000));
// Arrange
completedWh.Reset();
// Act
socketOutput.WriteAsync(buffer).ContinueWith(onCompleted);
socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted);
// Assert
// Too many bytes are already pre-completed for the second write to pre-complete.
Assert.False(completedWh.Wait(1000));
@ -162,28 +162,28 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
};
// Act
socketOutput.WriteAsync(halfBuffer, false).ContinueWith(onCompleted);
socketOutput.WriteAsync(halfBuffer, default(CancellationToken), false).ContinueWith(onCompleted);
// Assert
// The first write should pre-complete since it is not immediate.
Assert.True(completedWh.Wait(1000));
// Arrange
completedWh.Reset();
// Act
socketOutput.WriteAsync(halfBuffer).ContinueWith(onCompleted);
socketOutput.WriteAsync(halfBuffer, default(CancellationToken)).ContinueWith(onCompleted);
// Assert
// The second write should pre-complete since it is <= _maxBytesPreCompleted.
Assert.True(completedWh.Wait(1000));
// Arrange
completedWh.Reset();
// Act
socketOutput.WriteAsync(halfBuffer, false).ContinueWith(onCompleted);
socketOutput.WriteAsync(halfBuffer, default(CancellationToken), false).ContinueWith(onCompleted);
// Assert
// The third write should pre-complete since it is not immediate, even though too many.
Assert.True(completedWh.Wait(1000));
// Arrange
completedWh.Reset();
// Act
socketOutput.WriteAsync(halfBuffer).ContinueWith(onCompleted);
socketOutput.WriteAsync(halfBuffer, default(CancellationToken)).ContinueWith(onCompleted);
// Assert
// Too many bytes are already pre-completed for the fourth write to pre-complete.
Assert.False(completedWh.Wait(1000));
@ -198,6 +198,116 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
[Fact]
public async Task OnlyWritesRequestingCancellationAreErroredOnCancellation()
{
// This should match _maxBytesPreCompleted in SocketOutput
var maxBytesPreCompleted = 65536;
var completeQueue = new Queue<Action<int>>();
// Arrange
var mockLibuv = new MockLibuv
{
OnWrite = (socket, buffers, triggerCompleted) =>
{
completeQueue.Enqueue(triggerCompleted);
return 0;
}
};
using (var kestrelEngine = new KestrelEngine(mockLibuv, new TestServiceContext()))
using (var memory = new MemoryPool2())
{
kestrelEngine.Start(count: 1);
var kestrelThread = kestrelEngine.Threads[0];
var socket = new MockSocket(kestrelThread.Loop.ThreadId, new TestKestrelTrace());
var trace = new KestrelTrace(new TestKestrelTrace());
var ltp = new LoggingThreadPool(trace);
ISocketOutput socketOutput = new SocketOutput(kestrelThread, socket, memory, null, 0, trace, ltp, new Queue<UvWriteReq>());
var bufferSize = maxBytesPreCompleted;
var data = new byte[bufferSize];
var fullBuffer = new ArraySegment<byte>(data, 0, bufferSize);
var cts = new CancellationTokenSource();
// Act
var task1Success = socketOutput.WriteAsync(fullBuffer, cancellationToken: cts.Token);
// task1 should complete sucessfully as < _maxBytesPreCompleted
// First task is completed and sucessful
Assert.True(task1Success.IsCompleted);
Assert.False(task1Success.IsCanceled);
Assert.False(task1Success.IsFaulted);
task1Success.GetAwaiter().GetResult();
// following tasks should wait.
var task2Throw = socketOutput.WriteAsync(fullBuffer, cancellationToken: cts.Token);
var task3Success = socketOutput.WriteAsync(fullBuffer, cancellationToken: default(CancellationToken));
// Give time for tasks to perculate
await Task.Delay(2000).ConfigureAwait(false);
// Second task is not completed
Assert.False(task2Throw.IsCompleted);
Assert.False(task2Throw.IsCanceled);
Assert.False(task2Throw.IsFaulted);
// Third task is not completed
Assert.False(task3Success.IsCompleted);
Assert.False(task3Success.IsCanceled);
Assert.False(task3Success.IsFaulted);
cts.Cancel();
// Give time for tasks to perculate
await Task.Delay(2000).ConfigureAwait(false);
// Second task is now cancelled
Assert.True(task2Throw.IsCompleted);
Assert.True(task2Throw.IsCanceled);
Assert.False(task2Throw.IsFaulted);
// Third task is now completed
Assert.True(task3Success.IsCompleted);
Assert.False(task3Success.IsCanceled);
Assert.False(task3Success.IsFaulted);
// Fourth task immediately cancels as the token is cancelled
var task4Throw = socketOutput.WriteAsync(fullBuffer, cancellationToken: cts.Token);
Assert.True(task4Throw.IsCompleted);
Assert.True(task4Throw.IsCanceled);
Assert.False(task4Throw.IsFaulted);
Assert.Throws<TaskCanceledException>(() => task4Throw.GetAwaiter().GetResult());
var task5Success = socketOutput.WriteAsync(fullBuffer, cancellationToken: default(CancellationToken));
// task5 should complete immedately
Assert.True(task5Success.IsCompleted);
Assert.False(task5Success.IsCanceled);
Assert.False(task5Success.IsFaulted);
cts = new CancellationTokenSource();
var task6Throw = socketOutput.WriteAsync(fullBuffer, cancellationToken: cts.Token);
// task6 should complete immedately but not cancel as its cancelation token isn't set
Assert.True(task6Throw.IsCompleted);
Assert.False(task6Throw.IsCanceled);
Assert.False(task6Throw.IsFaulted);
Assert.Throws<TaskCanceledException>(() => task6Throw.GetAwaiter().GetResult());
Assert.True(true);
}
}
[Fact]
public void WritesDontGetCompletedTooQuickly()
{
@ -247,7 +357,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
};
// Act (Pre-complete the maximum number of bytes in preparation for the rest of the test)
socketOutput.WriteAsync(buffer).ContinueWith(onCompleted);
socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted);
// Assert
// The first write should pre-complete since it is <= _maxBytesPreCompleted.
Assert.True(completedWh.Wait(1000));
@ -257,8 +367,8 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
onWriteWh.Reset();
// Act
socketOutput.WriteAsync(buffer).ContinueWith(onCompleted);
socketOutput.WriteAsync(buffer).ContinueWith(onCompleted2);
socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted);
socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted2);
Assert.True(onWriteWh.Wait(1000));
completeQueue.Dequeue()(0);
@ -320,7 +430,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
socketOutput.ProducingComplete(end);
// A call to Write is required to ensure a write is scheduled
socketOutput.WriteAsync(default(ArraySegment<byte>));
socketOutput.WriteAsync(default(ArraySegment<byte>), default(CancellationToken));
Assert.True(nBufferWh.Wait(1000));
Assert.Equal(2, nBuffers);