// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNet.Server.Kestrel.Infrastructure; using Microsoft.AspNet.Server.Kestrel.Networking; namespace Microsoft.AspNet.Server.Kestrel.Http { public class SocketOutput : ISocketOutput { private const int _maxPendingWrites = 3; private const int _maxBytesPreCompleted = 65536; private readonly KestrelThread _thread; private readonly UvStreamHandle _socket; private readonly Connection _connection; private readonly long _connectionId; private readonly IKestrelTrace _log; // This locks all access to _tail, _isProducing and _returnFromOnProducingComplete. // _head does not require a lock, since it is only used in the ctor and uv thread. private readonly object _returnLock = new object(); private MemoryPoolBlock2 _head; private MemoryPoolBlock2 _tail; private bool _isProducing; private MemoryPoolBlock2 _returnFromOnProducingComplete; // This locks access to to all of the below fields private readonly object _contextLock = new object(); // The number of write operations that have been scheduled so far // but have not completed. private int _writesPending = 0; private int _numBytesPreCompleted = 0; private Exception _lastWriteError; private WriteContext _nextWriteContext; private readonly Queue> _tasksPending; public SocketOutput( KestrelThread thread, UvStreamHandle socket, MemoryPool2 memory, Connection connection, long connectionId, IKestrelTrace log) { _thread = thread; _socket = socket; _connection = connection; _connectionId = connectionId; _log = log; _tasksPending = new Queue>(); _head = memory.Lease(); _tail = _head; } public Task WriteAsync( ArraySegment buffer, bool immediate = true, bool socketShutdownSend = false, bool socketDisconnect = false) { var tail = ProducingStart(); tail = tail.CopyFrom(buffer); // We do our own accounting below ProducingComplete(tail, count: 0); TaskCompletionSource tcs = null; lock (_contextLock) { if (_nextWriteContext == null) { _nextWriteContext = new WriteContext(this); } if (socketShutdownSend) { _nextWriteContext.SocketShutdownSend = true; } if (socketDisconnect) { _nextWriteContext.SocketDisconnect = true; } if (!immediate) { // immediate==false calls always return complete tasks, because there is guaranteed // to be a subsequent immediate==true call which will go down one of the previous code-paths _numBytesPreCompleted += buffer.Count; } else if (_lastWriteError == null && _tasksPending.Count == 0 && _numBytesPreCompleted + buffer.Count <= _maxBytesPreCompleted) { // Complete the write task immediately if all previous write tasks have been completed, // the buffers haven't grown too large, and the last write to the socket succeeded. _numBytesPreCompleted += buffer.Count; } else { // immediate write, which is not eligable for instant completion above tcs = new TaskCompletionSource(buffer.Count); _tasksPending.Enqueue(tcs); } if (_writesPending < _maxPendingWrites && immediate) { ScheduleWrite(); _writesPending++; } } // Return TaskCompletionSource's Task if set, otherwise completed Task return tcs?.Task ?? TaskUtilities.CompletedTask; } public void End(ProduceEndType endType) { switch (endType) { case ProduceEndType.SocketShutdownSend: WriteAsync(default(ArraySegment), immediate: true, socketShutdownSend: true, socketDisconnect: false); break; case ProduceEndType.SocketDisconnect: WriteAsync(default(ArraySegment), immediate: true, socketShutdownSend: false, socketDisconnect: true); break; } } public MemoryPoolIterator2 ProducingStart() { lock (_returnLock) { Debug.Assert(!_isProducing); _isProducing = true; if (_tail == null) { throw new IOException("The socket has been closed."); } return new MemoryPoolIterator2(_tail, _tail.End); } } public void ProducingComplete(MemoryPoolIterator2 end, int count) { lock (_returnLock) { Debug.Assert(_isProducing); _isProducing = false; if (_returnFromOnProducingComplete == null) { _tail = end.Block; _tail.End = end.Index; if (count != 0) { lock (_contextLock) { _numBytesPreCompleted += count; } } } else { var block = _returnFromOnProducingComplete; while (block != null) { var returnBlock = block; block = block.Next; returnBlock.Pool?.Return(returnBlock); } _returnFromOnProducingComplete = null; } } } private void ScheduleWrite() { _thread.Post(_this => _this.WriteAllPending(), this); } // This is called on the libuv event loop private void WriteAllPending() { WriteContext writingContext; lock (_contextLock) { if (_nextWriteContext != null) { writingContext = _nextWriteContext; _nextWriteContext = null; } else { _writesPending--; return; } } try { writingContext.DoWriteIfNeeded(); } catch { lock (_contextLock) { // Lock instead of using Interlocked.Decrement so _writesSending // doesn't change in the middle of executing other synchronized code. _writesPending--; } throw; } } // This is called on the libuv event loop private void OnWriteCompleted(int bytesWritten, int status, Exception error) { _log.ConnectionWriteCallback(_connectionId, status); if (error != null) { _lastWriteError = new IOException(error.Message, error); // Abort the connection for any failed write. _connection.Abort(); } lock (_contextLock) { if (_nextWriteContext != null) { ScheduleWrite(); } else { _writesPending--; } // _numBytesPreCompleted can temporarily go negative in the event there are // completed writes that we haven't triggered callbacks for yet. _numBytesPreCompleted -= bytesWritten; // bytesLeftToBuffer can be greater than _maxBytesPreCompleted // This allows large writes to complete once they've actually finished. var bytesLeftToBuffer = _maxBytesPreCompleted - _numBytesPreCompleted; while (_tasksPending.Count > 0 && (int)(_tasksPending.Peek().Task.AsyncState) <= bytesLeftToBuffer) { var tcs = _tasksPending.Dequeue(); var bytesToWrite = (int)tcs.Task.AsyncState; _numBytesPreCompleted += bytesToWrite; bytesLeftToBuffer -= bytesToWrite; if (_lastWriteError == null) { ThreadPool.QueueUserWorkItem( (o) => ((TaskCompletionSource)o).SetResult(null), tcs); } else { // error is closure captured ThreadPool.QueueUserWorkItem( (o) => ((TaskCompletionSource)o).SetException(_lastWriteError), tcs); } } } } // This is called on the libuv event loop private void ReturnAllBlocks() { lock (_returnLock) { var block = _head; while (block != _tail) { var returnBlock = block; block = block.Next; returnBlock.Unpin(); returnBlock.Pool?.Return(returnBlock); } _tail.Unpin(); if (_isProducing) { _returnFromOnProducingComplete = _tail; } else { _tail.Pool?.Return(_tail); } _head = null; _tail = null; } } void ISocketOutput.Write(ArraySegment buffer, bool immediate) { var task = WriteAsync(buffer, immediate); if (task.Status == TaskStatus.RanToCompletion) { return; } else { task.GetAwaiter().GetResult(); } } Task ISocketOutput.WriteAsync(ArraySegment buffer, bool immediate, CancellationToken cancellationToken) { return WriteAsync(buffer, immediate); } private class WriteContext { private MemoryPoolIterator2 _lockedStart; private MemoryPoolIterator2 _lockedEnd; private int _bufferCount; private int _byteCount; public SocketOutput Self; public bool SocketShutdownSend; public bool SocketDisconnect; public int WriteStatus; public Exception WriteError; public int ShutdownSendStatus; public WriteContext(SocketOutput self) { Self = self; } /// /// First step: initiate async write if needed, otherwise go to next step /// public void DoWriteIfNeeded() { LockWrite(); if (_byteCount == 0 || Self._socket.IsClosed) { DoShutdownIfNeeded(); return; } var writeReq = new UvWriteReq(Self._log); writeReq.Init(Self._thread.Loop); writeReq.Write(Self._socket, _lockedStart, _lockedEnd, _bufferCount, (_writeReq, status, error, state) => { _writeReq.Dispose(); var _this = (WriteContext)state; _this.ReturnFullyWrittenBlocks(); _this.WriteStatus = status; _this.WriteError = error; _this.DoShutdownIfNeeded(); }, this); Self._head = _lockedEnd.Block; Self._head.Start = _lockedEnd.Index; } /// /// Second step: initiate async shutdown if needed, otherwise go to next step /// public void DoShutdownIfNeeded() { if (SocketShutdownSend == false || Self._socket.IsClosed) { DoDisconnectIfNeeded(); return; } var shutdownReq = new UvShutdownReq(Self._log); shutdownReq.Init(Self._thread.Loop); shutdownReq.Shutdown(Self._socket, (_shutdownReq, status, state) => { _shutdownReq.Dispose(); var _this = (WriteContext)state; _this.ShutdownSendStatus = status; _this.Self._log.ConnectionWroteFin(Self._connectionId, status); _this.DoDisconnectIfNeeded(); }, this); } /// /// Third step: disconnect socket if needed, otherwise this work item is complete /// public void DoDisconnectIfNeeded() { if (SocketDisconnect == false || Self._socket.IsClosed) { Complete(); return; } Self._socket.Dispose(); Self.ReturnAllBlocks(); Self._log.ConnectionStop(Self._connectionId); Complete(); } public void Complete() { Self.OnWriteCompleted(_byteCount, WriteStatus, WriteError); } private void ReturnFullyWrittenBlocks() { var block = _lockedStart.Block; while (block != _lockedEnd.Block) { var returnBlock = block; block = block.Next; returnBlock.Unpin(); returnBlock.Pool?.Return(returnBlock); } } private void LockWrite() { var head = Self._head; var tail = Self._tail; if (head == null || tail == null) { // ReturnAllBlocks has already bee called. Nothing to do here. // Write will no-op since _byteCount will remain 0. return; } _lockedStart = new MemoryPoolIterator2(head, head.Start); _lockedEnd = new MemoryPoolIterator2(tail, tail.End); if (_lockedStart.Block == _lockedEnd.Block) { _byteCount = _lockedEnd.Index - _lockedStart.Index; _bufferCount = 1; return; } _byteCount = _lockedStart.Block.Data.Offset + _lockedStart.Block.Data.Count - _lockedStart.Index; _bufferCount = 1; for (var block = _lockedStart.Block.Next; block != _lockedEnd.Block; block = block.Next) { _byteCount += block.Data.Count; _bufferCount++; } _byteCount += _lockedEnd.Index - _lockedEnd.Block.Data.Offset; _bufferCount++; } } } }