aspnetcore/src/Microsoft.AspNet.Server.Kes.../Http/SocketOutput.cs

346 lines
12 KiB
C#

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNet.Server.Kestrel.Infrastructure;
using Microsoft.AspNet.Server.Kestrel.Networking;
namespace Microsoft.AspNet.Server.Kestrel.Http
{
public class SocketOutput : ISocketOutput
{
private const int _maxPendingWrites = 3;
private const int _maxBytesPreCompleted = 65536;
private readonly KestrelThread _thread;
private readonly UvStreamHandle _socket;
private readonly long _connectionId;
private readonly IKestrelTrace _log;
// This locks access to to all of the below fields
private readonly object _lockObj = new object();
// The number of write operations that have been scheduled so far
// but have not completed.
private int _writesPending = 0;
private int _numBytesPreCompleted = 0;
private Exception _lastWriteError;
private WriteContext _nextWriteContext;
private readonly Queue<TaskCompletionSource<object>> _tasksPending;
public SocketOutput(KestrelThread thread, UvStreamHandle socket, long connectionId, IKestrelTrace log)
{
_thread = thread;
_socket = socket;
_connectionId = connectionId;
_log = log;
_tasksPending = new Queue<TaskCompletionSource<object>>();
}
public Task WriteAsync(
ArraySegment<byte> buffer,
bool immediate = true,
bool socketShutdownSend = false,
bool socketDisconnect = false)
{
//TODO: need buffering that works
if (buffer.Array != null)
{
var copy = new byte[buffer.Count];
Array.Copy(buffer.Array, buffer.Offset, copy, 0, buffer.Count);
buffer = new ArraySegment<byte>(copy);
_log.ConnectionWrite(_connectionId, buffer.Count);
}
TaskCompletionSource<object> tcs = null;
lock (_lockObj)
{
if (_nextWriteContext == null)
{
_nextWriteContext = new WriteContext(this);
}
if (buffer.Array != null)
{
_nextWriteContext.Buffers.Enqueue(buffer);
}
if (socketShutdownSend)
{
_nextWriteContext.SocketShutdownSend = true;
}
if (socketDisconnect)
{
_nextWriteContext.SocketDisconnect = true;
}
// Complete the write task immediately if all previous write tasks have been completed,
// the buffers haven't grown too large, and the last write to the socket succeeded.
if (_lastWriteError == null &&
_tasksPending.Count == 0 &&
_numBytesPreCompleted + buffer.Count <= _maxBytesPreCompleted)
{
_numBytesPreCompleted += buffer.Count;
}
else if (immediate)
{
// immediate write, which is not eligable for instant completion above
tcs = new TaskCompletionSource<object>(buffer.Count);
_tasksPending.Enqueue(tcs);
}
else
{
// immediate==false calls always return complete tasks, because there is guaranteed
// to be a subsequent immediate==true call which will go down one of the previous code-paths
}
if (_writesPending < _maxPendingWrites && immediate)
{
ScheduleWrite();
_writesPending++;
}
}
// Return TaskCompletionSource's Task if set, otherwise completed Task
return tcs?.Task ?? TaskUtilities.CompletedTask;
}
public void End(ProduceEndType endType)
{
switch (endType)
{
case ProduceEndType.SocketShutdownSend:
WriteAsync(default(ArraySegment<byte>),
immediate: true,
socketShutdownSend: true,
socketDisconnect: false);
break;
case ProduceEndType.SocketDisconnect:
WriteAsync(default(ArraySegment<byte>),
immediate: true,
socketShutdownSend: false,
socketDisconnect: true);
break;
}
}
private void ScheduleWrite()
{
_thread.Post(_this => _this.WriteAllPending(), this);
}
// This is called on the libuv event loop
private void WriteAllPending()
{
WriteContext writingContext;
lock (_lockObj)
{
if (_nextWriteContext != null)
{
writingContext = _nextWriteContext;
_nextWriteContext = null;
}
else
{
_writesPending--;
return;
}
}
try
{
writingContext.DoWriteIfNeeded();
}
catch
{
lock (_lockObj)
{
// Lock instead of using Interlocked.Decrement so _writesSending
// doesn't change in the middle of executing other synchronized code.
_writesPending--;
}
throw;
}
}
// This is called on the libuv event loop
private void OnWriteCompleted(Queue<ArraySegment<byte>> writtenBuffers, int status, Exception error)
{
_log.ConnectionWriteCallback(_connectionId, status);
lock (_lockObj)
{
_lastWriteError = error;
if (_nextWriteContext != null)
{
ScheduleWrite();
}
else
{
_writesPending--;
}
foreach (var writeBuffer in writtenBuffers)
{
// _numBytesPreCompleted can temporarily go negative in the event there are
// completed writes that we haven't triggered callbacks for yet.
_numBytesPreCompleted -= writeBuffer.Count;
}
// bytesLeftToBuffer can be greater than _maxBytesPreCompleted
// This allows large writes to complete once they've actually finished.
var bytesLeftToBuffer = _maxBytesPreCompleted - _numBytesPreCompleted;
while (_tasksPending.Count > 0 &&
(int)(_tasksPending.Peek().Task.AsyncState) <= bytesLeftToBuffer)
{
var tcs = _tasksPending.Dequeue();
_numBytesPreCompleted += (int)(tcs.Task.AsyncState);
if (error == null)
{
ThreadPool.QueueUserWorkItem(
(o) => ((TaskCompletionSource<object>)o).SetResult(null),
tcs);
}
else
{
// error is closure captured
ThreadPool.QueueUserWorkItem(
(o) => ((TaskCompletionSource<object>)o).SetException(error),
tcs);
}
}
// Now that the while loop has completed the following invariants should hold true:
Debug.Assert(_numBytesPreCompleted >= 0);
Debug.Assert(_numBytesPreCompleted <= _maxBytesPreCompleted);
}
}
void ISocketOutput.Write(ArraySegment<byte> buffer, bool immediate)
{
var task = WriteAsync(buffer, immediate);
if (task.Status == TaskStatus.RanToCompletion)
{
return;
}
else
{
task.GetAwaiter().GetResult();
}
}
Task ISocketOutput.WriteAsync(ArraySegment<byte> buffer, bool immediate, CancellationToken cancellationToken)
{
return WriteAsync(buffer, immediate);
}
private class WriteContext
{
public SocketOutput Self;
public Queue<ArraySegment<byte>> Buffers;
public bool SocketShutdownSend;
public bool SocketDisconnect;
public int WriteStatus;
public Exception WriteError;
public int ShutdownSendStatus;
public WriteContext(SocketOutput self)
{
Self = self;
Buffers = new Queue<ArraySegment<byte>>();
}
/// <summary>
/// First step: initiate async write if needed, otherwise go to next step
/// </summary>
public void DoWriteIfNeeded()
{
if (Buffers.Count == 0 || Self._socket.IsClosed)
{
DoShutdownIfNeeded();
return;
}
var buffers = new ArraySegment<byte>[Buffers.Count];
var i = 0;
foreach (var buffer in Buffers)
{
buffers[i++] = buffer;
}
var writeReq = new UvWriteReq(Self._log);
writeReq.Init(Self._thread.Loop);
writeReq.Write(Self._socket, new ArraySegment<ArraySegment<byte>>(buffers), (_writeReq, status, error, state) =>
{
_writeReq.Dispose();
var _this = (WriteContext)state;
_this.WriteStatus = status;
_this.WriteError = error;
DoShutdownIfNeeded();
}, this);
}
/// <summary>
/// Second step: initiate async shutdown if needed, otherwise go to next step
/// </summary>
public void DoShutdownIfNeeded()
{
if (SocketShutdownSend == false || Self._socket.IsClosed)
{
DoDisconnectIfNeeded();
return;
}
var shutdownReq = new UvShutdownReq(Self._log);
shutdownReq.Init(Self._thread.Loop);
shutdownReq.Shutdown(Self._socket, (_shutdownReq, status, state) =>
{
_shutdownReq.Dispose();
var _this = (WriteContext)state;
_this.ShutdownSendStatus = status;
Self._log.ConnectionWroteFin(Self._connectionId, status);
DoDisconnectIfNeeded();
}, this);
}
/// <summary>
/// Third step: disconnect socket if needed, otherwise this work item is complete
/// </summary>
public void DoDisconnectIfNeeded()
{
if (SocketDisconnect == false || Self._socket.IsClosed)
{
Complete();
return;
}
Self._socket.Dispose();
Self._log.ConnectionStop(Self._connectionId);
Complete();
}
public void Complete()
{
Self.OnWriteCompleted(Buffers, WriteStatus, WriteError);
}
}
}
}