// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Networking; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Server.Kestrel.Http { public class SocketOutput : ISocketOutput { public const int MaxPooledWriteReqs = 1024; private const int _maxBytesPreCompleted = 65536; private const int _initialTaskQueues = 64; private const int _maxPooledWriteContexts = 32; private static readonly WaitCallback _returnBlocks = (state) => ReturnBlocks((MemoryPoolBlock2)state); private static readonly Action _connectionCancellation = (state) => ((SocketOutput)state).CancellationTriggered(); private readonly KestrelThread _thread; private readonly UvStreamHandle _socket; private readonly Connection _connection; private readonly long _connectionId; private readonly IKestrelTrace _log; private readonly IThreadPool _threadPool; // This locks all access to _tail and _lastStart. // _head does not require a lock, since it is only used in the ctor and uv thread. private readonly object _returnLock = new object(); private MemoryPoolBlock2 _head; private MemoryPoolBlock2 _tail; private MemoryPoolIterator2 _lastStart; // This locks access to to all of the below fields private readonly object _contextLock = new object(); // The number of write operations that have been scheduled so far // but have not completed. private bool _writePending = false; private bool _cancelled = false; private int _numBytesPreCompleted = 0; private Exception _lastWriteError; private WriteContext _nextWriteContext; private readonly Queue _tasksPending; private readonly Queue _writeContextPool; private readonly Queue _writeReqPool; public SocketOutput( KestrelThread thread, UvStreamHandle socket, MemoryPool2 memory, Connection connection, long connectionId, IKestrelTrace log, IThreadPool threadPool, Queue writeReqPool) { _thread = thread; _socket = socket; _connection = connection; _connectionId = connectionId; _log = log; _threadPool = threadPool; _tasksPending = new Queue(_initialTaskQueues); _writeContextPool = new Queue(_maxPooledWriteContexts); _writeReqPool = writeReqPool; _head = memory.Lease(); _tail = _head; } public Task WriteAsync( ArraySegment buffer, CancellationToken cancellationToken, bool chunk = false, bool socketShutdownSend = false, bool socketDisconnect = false, bool isSync = false) { TaskCompletionSource tcs = null; var scheduleWrite = false; lock (_contextLock) { if (_socket.IsClosed) { _log.ConnectionDisconnectedWrite(_connectionId, buffer.Count, _lastWriteError); return TaskUtilities.CompletedTask; } if (buffer.Count > 0) { var tail = ProducingStart(); if (tail.IsDefault) { return TaskUtilities.CompletedTask; } if (chunk) { _numBytesPreCompleted += ChunkWriter.WriteBeginChunkBytes(ref tail, buffer.Count); } tail.CopyFrom(buffer); if (chunk) { ChunkWriter.WriteEndChunkBytes(ref tail); _numBytesPreCompleted += 2; } // We do our own accounting below ProducingCompleteNoPreComplete(tail); } if (_nextWriteContext == null) { if (_writeContextPool.Count > 0) { _nextWriteContext = _writeContextPool.Dequeue(); } else { _nextWriteContext = new WriteContext(this); } } if (socketShutdownSend) { _nextWriteContext.SocketShutdownSend = true; } if (socketDisconnect) { _nextWriteContext.SocketDisconnect = true; } if (_lastWriteError == null && _tasksPending.Count == 0 && _numBytesPreCompleted + buffer.Count <= _maxBytesPreCompleted) { // Complete the write task immediately if all previous write tasks have been completed, // the buffers haven't grown too large, and the last write to the socket succeeded. _numBytesPreCompleted += buffer.Count; } else { if (cancellationToken.CanBeCanceled) { if (cancellationToken.IsCancellationRequested) { _connection.Abort(); _cancelled = true; return TaskUtilities.GetCancelledTask(cancellationToken); } else { // immediate write, which is not eligable for instant completion above tcs = new TaskCompletionSource(); _tasksPending.Enqueue(new WaitingTask() { CancellationToken = cancellationToken, CancellationRegistration = cancellationToken.Register(_connectionCancellation, this), BytesToWrite = buffer.Count, CompletionSource = tcs }); } } else { tcs = new TaskCompletionSource(); _tasksPending.Enqueue(new WaitingTask() { IsSync = isSync, BytesToWrite = buffer.Count, CompletionSource = tcs }); } } if (!_writePending) { _writePending = true; scheduleWrite = true; } } if (scheduleWrite) { ScheduleWrite(); } // Return TaskCompletionSource's Task if set, otherwise completed Task return tcs?.Task ?? TaskUtilities.CompletedTask; } public void End(ProduceEndType endType) { switch (endType) { case ProduceEndType.SocketShutdownSend: WriteAsync(default(ArraySegment), default(CancellationToken), socketShutdownSend: true, socketDisconnect: false); break; case ProduceEndType.SocketDisconnect: WriteAsync(default(ArraySegment), default(CancellationToken), socketShutdownSend: false, socketDisconnect: true); break; } } public MemoryPoolIterator2 ProducingStart() { lock (_returnLock) { Debug.Assert(_lastStart.IsDefault); if (_tail == null) { return default(MemoryPoolIterator2); } _lastStart = new MemoryPoolIterator2(_tail, _tail.End); return _lastStart; } } public void ProducingComplete(MemoryPoolIterator2 end) { Debug.Assert(!_lastStart.IsDefault); int bytesProduced, buffersIncluded; BytesBetween(_lastStart, end, out bytesProduced, out buffersIncluded); lock (_contextLock) { _numBytesPreCompleted += bytesProduced; } ProducingCompleteNoPreComplete(end); } private void ProducingCompleteNoPreComplete(MemoryPoolIterator2 end) { MemoryPoolBlock2 blockToReturn = null; lock (_returnLock) { Debug.Assert(!_lastStart.IsDefault); // If the socket has been closed, return the produced blocks // instead of advancing the now non-existent tail. if (_tail != null) { _tail = end.Block; _tail.End = end.Index; } else { blockToReturn = _lastStart.Block; } _lastStart = default(MemoryPoolIterator2); } if (blockToReturn != null) { ThreadPool.QueueUserWorkItem(_returnBlocks, blockToReturn); } } private void CancellationTriggered() { lock (_contextLock) { if (!_cancelled) { // Abort the connection for any failed write // Queued on threadpool so get it in as first op. _connection.Abort(); _cancelled = true; CompleteAllWrites(); _log.ConnectionError(_connectionId, new TaskCanceledException("Write operation canceled. Aborting connection.")); } } } private static void ReturnBlocks(MemoryPoolBlock2 block) { while (block != null) { var returningBlock = block; block = returningBlock.Next; returningBlock.Pool.Return(returningBlock); } } private void ScheduleWrite() { _thread.Post(_this => _this.WriteAllPending(), this); } // This is called on the libuv event loop private void WriteAllPending() { WriteContext writingContext = null; if (Monitor.TryEnter(_contextLock)) { _writePending = false; if (_nextWriteContext != null) { writingContext = _nextWriteContext; _nextWriteContext = null; } Monitor.Exit(_contextLock); } else { ScheduleWrite(); } if (writingContext != null) { writingContext.DoWriteIfNeeded(); } } // This may called on the libuv event loop private void OnWriteCompleted(WriteContext writeContext) { // Called inside _contextLock var bytesWritten = writeContext.ByteCount; var status = writeContext.WriteStatus; var error = writeContext.WriteError; if (error != null) { // Abort the connection for any failed write // Queued on threadpool so get it in as first op. _connection.Abort(); _cancelled = true; _lastWriteError = error; } PoolWriteContext(writeContext); // _numBytesPreCompleted can temporarily go negative in the event there are // completed writes that we haven't triggered callbacks for yet. _numBytesPreCompleted -= bytesWritten; if (error == null) { CompleteFinishedWrites(status); _log.ConnectionWriteCallback(_connectionId, status); } else { CompleteAllWrites(); _log.ConnectionError(_connectionId, error); } } private void CompleteNextWrite(ref int bytesLeftToBuffer) { // Called inside _contextLock var waitingTask = _tasksPending.Dequeue(); var bytesToWrite = waitingTask.BytesToWrite; _numBytesPreCompleted += bytesToWrite; bytesLeftToBuffer -= bytesToWrite; // Dispose registration if there is one waitingTask.CancellationRegistration?.Dispose(); if (waitingTask.CancellationToken.IsCancellationRequested) { if (waitingTask.IsSync) { waitingTask.CompletionSource.TrySetCanceled(); } else { _threadPool.Cancel(waitingTask.CompletionSource); } } else { if (waitingTask.IsSync) { waitingTask.CompletionSource.TrySetResult(null); } else { _threadPool.Complete(waitingTask.CompletionSource); } } } private void CompleteFinishedWrites(int status) { // Called inside _contextLock // bytesLeftToBuffer can be greater than _maxBytesPreCompleted // This allows large writes to complete once they've actually finished. var bytesLeftToBuffer = _maxBytesPreCompleted - _numBytesPreCompleted; while (_tasksPending.Count > 0 && (_tasksPending.Peek().BytesToWrite) <= bytesLeftToBuffer) { CompleteNextWrite(ref bytesLeftToBuffer); } } private void CompleteAllWrites() { // Called inside _contextLock var bytesLeftToBuffer = _maxBytesPreCompleted - _numBytesPreCompleted; while (_tasksPending.Count > 0) { CompleteNextWrite(ref bytesLeftToBuffer); } } // This is called on the libuv event loop private void ReturnAllBlocks() { lock (_returnLock) { var block = _head; while (block != _tail) { var returnBlock = block; block = block.Next; returnBlock.Pool.Return(returnBlock); } // Only return the _tail if we aren't between ProducingStart/Complete calls if (_lastStart.IsDefault) { _tail.Pool.Return(_tail); } _head = null; _tail = null; } } private void PoolWriteContext(WriteContext writeContext) { // Called inside _contextLock if (_writeContextPool.Count < _maxPooledWriteContexts) { writeContext.Reset(); _writeContextPool.Enqueue(writeContext); } } void ISocketOutput.Write(ArraySegment buffer, bool chunk) { WriteAsync(buffer, default(CancellationToken), chunk, isSync: true).GetAwaiter().GetResult(); } Task ISocketOutput.WriteAsync(ArraySegment buffer, bool chunk, CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) { _connection.Abort(); _cancelled = true; return TaskUtilities.GetCancelledTask(cancellationToken); } else if (_cancelled) { return TaskUtilities.CompletedTask; } return WriteAsync(buffer, cancellationToken, chunk); } private static void BytesBetween(MemoryPoolIterator2 start, MemoryPoolIterator2 end, out int bytes, out int buffers) { if (start.Block == end.Block) { bytes = end.Index - start.Index; buffers = 1; return; } bytes = start.Block.Data.Offset + start.Block.Data.Count - start.Index; buffers = 1; for (var block = start.Block.Next; block != end.Block; block = block.Next) { bytes += block.Data.Count; buffers++; } bytes += end.Index - end.Block.Data.Offset; buffers++; } private class WriteContext { private static WaitCallback _returnWrittenBlocks = (state) => ReturnWrittenBlocks((MemoryPoolBlock2)state); private static WaitCallback _completeWrite = (state) => ((WriteContext)state).CompleteOnThreadPool(); private SocketOutput Self; private UvWriteReq _writeReq; private MemoryPoolIterator2 _lockedStart; private MemoryPoolIterator2 _lockedEnd; private int _bufferCount; public int ByteCount; public bool SocketShutdownSend; public bool SocketDisconnect; public int WriteStatus; public Exception WriteError; public int ShutdownSendStatus; public WriteContext(SocketOutput self) { Self = self; } /// /// First step: initiate async write if needed, otherwise go to next step /// public void DoWriteIfNeeded() { LockWrite(); if (ByteCount == 0 || Self._socket.IsClosed) { DoShutdownIfNeeded(); return; } // Sample values locally in case write completes inline // to allow block to be Reset and still complete this function var lockedEndBlock = _lockedEnd.Block; var lockedEndIndex = _lockedEnd.Index; if (Self._writeReqPool.Count > 0) { _writeReq = Self._writeReqPool.Dequeue(); } else { _writeReq = new UvWriteReq(Self._log); _writeReq.Init(Self._thread.Loop); } _writeReq.Write(Self._socket, _lockedStart, _lockedEnd, _bufferCount, (_writeReq, status, error, state) => { var writeContext = (WriteContext)state; writeContext.PoolWriteReq(writeContext._writeReq); writeContext._writeReq = null; writeContext.ScheduleReturnFullyWrittenBlocks(); writeContext.WriteStatus = status; writeContext.WriteError = error; writeContext.DoShutdownIfNeeded(); }, this); Self._head = lockedEndBlock; Self._head.Start = lockedEndIndex; } /// /// Second step: initiate async shutdown if needed, otherwise go to next step /// public void DoShutdownIfNeeded() { if (SocketShutdownSend == false || Self._socket.IsClosed) { DoDisconnectIfNeeded(); return; } var shutdownReq = new UvShutdownReq(Self._log); shutdownReq.Init(Self._thread.Loop); shutdownReq.Shutdown(Self._socket, (_shutdownReq, status, state) => { _shutdownReq.Dispose(); var _this = (WriteContext)state; _this.ShutdownSendStatus = status; _this.Self._log.ConnectionWroteFin(_this.Self._connectionId, status); _this.DoDisconnectIfNeeded(); }, this); } /// /// Third step: disconnect socket if needed, otherwise this work item is complete /// public void DoDisconnectIfNeeded() { if (SocketDisconnect == false || Self._socket.IsClosed) { CompleteWithContextLock(); return; } Self._socket.Dispose(); Self._connection.OnSocketClosed(); Self.ReturnAllBlocks(); Self._log.ConnectionStop(Self._connectionId); CompleteWithContextLock(); } public void CompleteWithContextLock() { if (Monitor.TryEnter(Self._contextLock)) { try { Self.OnWriteCompleted(this); } finally { Monitor.Exit(Self._contextLock); } } else { ThreadPool.QueueUserWorkItem(_completeWrite, this); } } public void CompleteOnThreadPool() { lock (Self._contextLock) { try { Self.OnWriteCompleted(this); } catch (Exception ex) { Self._log.LogError(0, ex, "SocketOutput.OnWriteCompleted"); } } } private void PoolWriteReq(UvWriteReq writeReq) { if (Self._writeReqPool.Count < MaxPooledWriteReqs) { Self._writeReqPool.Enqueue(writeReq); } else { writeReq.Dispose(); } } private void ScheduleReturnFullyWrittenBlocks() { var block = _lockedStart.Block; var end = _lockedEnd.Block; if (block == end) { end.Unpin(); return; } while (block.Next != end) { block = block.Next; block.Unpin(); } block.Next = null; ThreadPool.QueueUserWorkItem(_returnWrittenBlocks, _lockedStart.Block); } private static void ReturnWrittenBlocks(MemoryPoolBlock2 block) { while (block != null) { var returnBlock = block; block = block.Next; returnBlock.Unpin(); returnBlock.Pool.Return(returnBlock); } } private void LockWrite() { var head = Self._head; var tail = Self._tail; if (head == null || tail == null) { // ReturnAllBlocks has already bee called. Nothing to do here. // Write will no-op since _byteCount will remain 0. return; } _lockedStart = new MemoryPoolIterator2(head, head.Start); _lockedEnd = new MemoryPoolIterator2(tail, tail.End); BytesBetween(_lockedStart, _lockedEnd, out ByteCount, out _bufferCount); } public void Reset() { _lockedStart = default(MemoryPoolIterator2); _lockedEnd = default(MemoryPoolIterator2); _bufferCount = 0; ByteCount = 0; SocketShutdownSend = false; SocketDisconnect = false; WriteStatus = 0; WriteError = null; ShutdownSendStatus = 0; } } private struct WaitingTask { public bool IsSync; public int BytesToWrite; public CancellationToken CancellationToken; public IDisposable CancellationRegistration; public TaskCompletionSource CompletionSource; } } }