diff --git a/src/Kestrel.Transport.Libuv/Internal/LibuvAwaitable.cs b/src/Kestrel.Transport.Libuv/Internal/LibuvAwaitable.cs index 8ee11ff42e..49ca4f8e6b 100644 --- a/src/Kestrel.Transport.Libuv/Internal/LibuvAwaitable.cs +++ b/src/Kestrel.Transport.Libuv/Internal/LibuvAwaitable.cs @@ -36,6 +36,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal public UvWriteResult GetResult() { + Debug.Assert(_callback == _callbackCompleted); + var exception = _exception; var status = _status; diff --git a/src/Kestrel.Transport.Sockets/Internal/BufferExtensions.cs b/src/Kestrel.Transport.Sockets/Internal/BufferExtensions.cs new file mode 100644 index 0000000000..cadf97f0d0 --- /dev/null +++ b/src/Kestrel.Transport.Sockets/Internal/BufferExtensions.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + public static class BufferExtensions + { + public static ArraySegment GetArray(this Buffer buffer) + { + ArraySegment result; + if (!buffer.TryGetArray(out result)) + { + throw new InvalidOperationException("Buffer backed by array was expected"); + } + return result; + } + } +} \ No newline at end of file diff --git a/src/Kestrel.Transport.Sockets/Internal/SocketAwaitable.cs b/src/Kestrel.Transport.Sockets/Internal/SocketAwaitable.cs new file mode 100644 index 0000000000..fc68b3c08f --- /dev/null +++ b/src/Kestrel.Transport.Sockets/Internal/SocketAwaitable.cs @@ -0,0 +1,58 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Net.Sockets; +using System.Runtime.CompilerServices; +using System.Threading; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + public class SocketAwaitable : ICriticalNotifyCompletion + { + private static readonly Action _callbackCompleted = () => { }; + + private Action _callback; + private int _bytesTransfered; + private SocketError _error; + + public SocketAwaitable GetAwaiter() => this; + public bool IsCompleted => _callback == _callbackCompleted; + + public int GetResult() + { + Debug.Assert(_callback == _callbackCompleted); + + _callback = null; + + if (_error != SocketError.Success) + { + throw new SocketException((int)_error); + } + + return _bytesTransfered; + } + + public void OnCompleted(Action continuation) + { + if (_callback == _callbackCompleted || + Interlocked.CompareExchange(ref _callback, continuation, null) == _callbackCompleted) + { + continuation(); + } + } + + public void UnsafeOnCompleted(Action continuation) + { + OnCompleted(continuation); + } + + public void Complete(int bytesTransferred, SocketError socketError) + { + _error = socketError; + _bytesTransfered = bytesTransferred; + Interlocked.Exchange(ref _callback, _callbackCompleted)?.Invoke(); + } + } +} \ No newline at end of file diff --git a/src/Kestrel.Transport.Sockets/SocketConnection.cs b/src/Kestrel.Transport.Sockets/Internal/SocketConnection.cs similarity index 80% rename from src/Kestrel.Transport.Sockets/SocketConnection.cs rename to src/Kestrel.Transport.Sockets/Internal/SocketConnection.cs index 6bca6c7731..ac9ef9a8e2 100644 --- a/src/Kestrel.Transport.Sockets/SocketConnection.cs +++ b/src/Kestrel.Transport.Sockets/Internal/SocketConnection.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.IO.Pipelines; @@ -11,10 +10,9 @@ using System.Net.Sockets; using System.Threading.Tasks; using Microsoft.AspNetCore.Protocols; using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; -using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal; using Microsoft.Extensions.Logging; -namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal { internal sealed class SocketConnection : TransportConnection { @@ -22,8 +20,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets private readonly Socket _socket; private readonly ISocketsTrace _trace; + private readonly SocketReceiver _receiver; + private readonly SocketSender _sender; - private IList> _sendBufferList; private volatile bool _aborted; internal SocketConnection(Socket socket, PipeFactory pipeFactory, ISocketsTrace trace) @@ -44,6 +43,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets RemoteAddress = remoteEndPoint.Address; RemotePort = remoteEndPoint.Port; + + _receiver = new SocketReceiver(_socket); + _sender = new SocketSender(_socket); } public override PipeFactory PipeFactory { get; } @@ -95,7 +97,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets try { - var bytesReceived = await _socket.ReceiveAsync(GetArraySegment(buffer.Buffer), SocketFlags.None); + var bytesReceived = await _receiver.ReceiveAsync(buffer.Buffer); if (bytesReceived == 0) { @@ -176,25 +178,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets } } - private void SetupSendBuffers(ReadableBuffer buffer) - { - Debug.Assert(!buffer.IsEmpty); - Debug.Assert(!buffer.IsSingleSpan); - - if (_sendBufferList == null) - { - _sendBufferList = new List>(); - } - - // We should always clear the list after the send - Debug.Assert(_sendBufferList.Count == 0); - - foreach (var b in buffer) - { - _sendBufferList.Add(GetArraySegment(b)); - } - } - private async Task DoSend() { Exception error = null; @@ -216,23 +199,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets { if (!buffer.IsEmpty) { - if (buffer.IsSingleSpan) - { - await _socket.SendAsync(GetArraySegment(buffer.First), SocketFlags.None); - } - else - { - SetupSendBuffers(buffer); - - try - { - await _socket.SendAsync(_sendBufferList, SocketFlags.None); - } - finally - { - _sendBufferList.Clear(); - } - } + await _sender.SendAsync(buffer); } else if (result.IsCompleted) { @@ -273,16 +240,5 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets _socket.Shutdown(SocketShutdown.Both); } } - - private static ArraySegment GetArraySegment(Buffer buffer) - { - if (!buffer.TryGetArray(out var segment)) - { - throw new InvalidOperationException(); - } - - return segment; - } - } } diff --git a/src/Kestrel.Transport.Sockets/Internal/SocketReceiver.cs b/src/Kestrel.Transport.Sockets/Internal/SocketReceiver.cs new file mode 100644 index 0000000000..bb9f85a56b --- /dev/null +++ b/src/Kestrel.Transport.Sockets/Internal/SocketReceiver.cs @@ -0,0 +1,36 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net.Sockets; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + public class SocketReceiver + { + private readonly Socket _socket; + private readonly SocketAsyncEventArgs _eventArgs = new SocketAsyncEventArgs(); + private readonly SocketAwaitable _awaitable = new SocketAwaitable(); + + public SocketReceiver(Socket socket) + { + _socket = socket; + _eventArgs.UserToken = _awaitable; + _eventArgs.Completed += (_, e) => ((SocketAwaitable)e.UserToken).Complete(e.BytesTransferred, e.SocketError); + } + + public SocketAwaitable ReceiveAsync(Buffer buffer) + { + var segment = buffer.GetArray(); + + _eventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); + + if (!_socket.ReceiveAsync(_eventArgs)) + { + _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); + } + + return _awaitable; + } + } +} \ No newline at end of file diff --git a/src/Kestrel.Transport.Sockets/Internal/SocketSender.cs b/src/Kestrel.Transport.Sockets/Internal/SocketSender.cs new file mode 100644 index 0000000000..04f680f42d --- /dev/null +++ b/src/Kestrel.Transport.Sockets/Internal/SocketSender.cs @@ -0,0 +1,98 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO.Pipelines; +using System.Net.Sockets; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + public class SocketSender + { + private readonly Socket _socket; + private readonly SocketAsyncEventArgs _eventArgs = new SocketAsyncEventArgs(); + private readonly SocketAwaitable _awaitable = new SocketAwaitable(); + + private List> _bufferList; + + public SocketSender(Socket socket) + { + _socket = socket; + _eventArgs.UserToken = _awaitable; + _eventArgs.Completed += (_, e) => SendCompleted(e, (SocketAwaitable)e.UserToken); + } + + public SocketAwaitable SendAsync(ReadableBuffer buffers) + { + if (buffers.IsSingleSpan) + { + return SendAsync(buffers.First); + } + + _eventArgs.BufferList = GetBufferList(buffers); + + if (!_socket.SendAsync(_eventArgs)) + { + SendCompleted(_eventArgs, _awaitable); + } + + return _awaitable; + } + + private SocketAwaitable SendAsync(Buffer buffer) + { + var segment = buffer.GetArray(); + + _eventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); + + if (!_socket.SendAsync(_eventArgs)) + { + SendCompleted(_eventArgs, _awaitable); + } + + return _awaitable; + } + + private List> GetBufferList(ReadableBuffer buffer) + { + Debug.Assert(!buffer.IsEmpty); + Debug.Assert(!buffer.IsSingleSpan); + + if (_bufferList == null) + { + _bufferList = new List>(); + } + + // We should always clear the list after the send + Debug.Assert(_bufferList.Count == 0); + + foreach (var b in buffer) + { + _bufferList.Add(b.GetArray()); + } + + return _bufferList; + } + + private static void SendCompleted(SocketAsyncEventArgs e, SocketAwaitable awaitable) + { + // Clear buffer(s) to prevent the SetBuffer buffer and BufferList from both being + // set for the next write operation. This is unnecessary for reads since they never + // set BufferList. + + if (e.BufferList != null) + { + e.BufferList.Clear(); + e.BufferList = null; + } + else + { + e.SetBuffer(null, 0, 0); + } + + awaitable.Complete(e.BytesTransferred, e.SocketError); + } + } +} \ No newline at end of file