diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs index c16ef2e652..f2c0233744 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs @@ -35,13 +35,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http private FilteredStreamAdapter _filteredStreamAdapter; private Task _readInputTask; - private readonly SocketInput _rawSocketInput; - private readonly SocketOutput _rawSocketOutput; - - private readonly object _stateLock = new object(); - private ConnectionState _connectionState; private TaskCompletionSource _socketClosedTcs; - private BufferSizeControl _bufferSizeControl; public Connection(ListenerContext context, UvStreamHandle socket) : base(context) @@ -57,8 +51,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http _bufferSizeControl = new BufferSizeControl(ServerOptions.MaxRequestBufferSize.Value, this, Thread); } - _rawSocketInput = new SocketInput(Memory, ThreadPool, _bufferSizeControl); - _rawSocketOutput = new SocketOutput(Thread, _socket, Memory, this, ConnectionId, Log, ThreadPool, WriteReqPool); + SocketInput = new SocketInput(Memory, ThreadPool, _bufferSizeControl); + SocketOutput = new SocketOutput(Thread, _socket, Memory, this, ConnectionId, Log, ThreadPool, WriteReqPool); + + var tcpHandle = _socket as UvTcpHandle; + if (tcpHandle != null) + { + RemoteEndPoint = tcpHandle.GetPeerIPEndPoint(); + LocalEndPoint = tcpHandle.GetSockIPEndPoint(); + } + + _frame = FrameFactory(this); } // Internal for testing @@ -73,35 +76,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http // Start socket prior to applying the ConnectionFilter _socket.ReadStart(_allocCallback, _readCallback, this); - var tcpHandle = _socket as UvTcpHandle; - if (tcpHandle != null) - { - RemoteEndPoint = tcpHandle.GetPeerIPEndPoint(); - LocalEndPoint = tcpHandle.GetSockIPEndPoint(); - } - // Don't initialize _frame until SocketInput and SocketOutput are set to their final values. if (ServerOptions.ConnectionFilter == null) { - lock (_stateLock) - { - if (_connectionState != ConnectionState.CreatingFrame) - { - throw new InvalidOperationException("Invalid connection state: " + _connectionState); - } + _frame.SocketInput = SocketInput; + _frame.SocketOutput = SocketOutput; - _connectionState = ConnectionState.Open; + _frame.Start(); - SocketInput = _rawSocketInput; - SocketOutput = _rawSocketOutput; - - _frame = CreateFrame(); - _frame.Start(); - } } else { - _libuvStream = new LibuvStream(_rawSocketInput, _rawSocketOutput); + _libuvStream = new LibuvStream(SocketInput, SocketOutput); _filterContext = new ConnectionFilterContext { @@ -141,24 +127,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public Task StopAsync() { - lock (_stateLock) + if (_socketClosedTcs == null) { - switch (_connectionState) - { - case ConnectionState.SocketClosed: - return TaskUtilities.CompletedTask; - case ConnectionState.CreatingFrame: - _connectionState = ConnectionState.ToDisconnect; - break; - case ConnectionState.Open: - _frame.Stop(); - SocketInput.CompleteAwaiting(); - break; - } - _socketClosedTcs = new TaskCompletionSource(); - return _socketClosedTcs.Task; + + _frame.Stop(); + _frame.SocketInput.CompleteAwaiting(); } + + return _socketClosedTcs.Task; } public virtual void Abort() @@ -168,18 +145,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http ThreadPool.Run(() => { var connection = this; - - lock (connection._stateLock) - { - if (connection._connectionState == ConnectionState.CreatingFrame) - { - connection._connectionState = ConnectionState.ToDisconnect; - } - else - { - connection._frame?.Abort(); - } - } + connection._frame.Abort(); }); } @@ -189,65 +155,45 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http if (_filteredStreamAdapter != null) { _filteredStreamAdapter.Abort(); - _rawSocketInput.IncomingFin(); + SocketInput.IncomingFin(); _readInputTask.ContinueWith((task, state) => { ((Connection)state)._filterContext.Connection.Dispose(); ((Connection)state)._filteredStreamAdapter.Dispose(); - ((Connection)state)._rawSocketInput.Dispose(); + ((Connection)state).SocketInput.Dispose(); }, this); } else { - _rawSocketInput.Dispose(); + SocketInput.Dispose(); } - lock (_stateLock) - { - _connectionState = ConnectionState.SocketClosed; - - if (_socketClosedTcs != null) - { - // This is always waited on synchronously, so it's safe to - // call on the libuv thread. - _socketClosedTcs.TrySetResult(null); - } - } + _socketClosedTcs?.TrySetResult(null); } private void ApplyConnectionFilter() { - lock (_stateLock) + if (_filterContext.Connection != _libuvStream) { - if (_connectionState == ConnectionState.CreatingFrame) - { - _connectionState = ConnectionState.Open; + _filteredStreamAdapter = new FilteredStreamAdapter(ConnectionId, _filterContext.Connection, Memory, Log, ThreadPool, _bufferSizeControl); - if (_filterContext.Connection != _libuvStream) - { - _filteredStreamAdapter = new FilteredStreamAdapter(ConnectionId, _filterContext.Connection, Memory, Log, ThreadPool, _bufferSizeControl); + _frame.SocketInput = _filteredStreamAdapter.SocketInput; + _frame.SocketOutput = _filteredStreamAdapter.SocketOutput; - SocketInput = _filteredStreamAdapter.SocketInput; - SocketOutput = _filteredStreamAdapter.SocketOutput; - - _readInputTask = _filteredStreamAdapter.ReadInputAsync(); - } - else - { - SocketInput = _rawSocketInput; - SocketOutput = _rawSocketOutput; - } - - PrepareRequest = _filterContext.PrepareRequest; - - _frame = CreateFrame(); - _frame.Start(); - } - else - { - ConnectionControl.End(ProduceEndType.SocketDisconnect); - } + _readInputTask = _filteredStreamAdapter.ReadInputAsync(); } + else + { + _frame.SocketInput = SocketInput; + _frame.SocketOutput = SocketOutput; + } + + _frame.PrepareRequest = _filterContext.PrepareRequest; + + // Reset needs to be called here so prepare request gets applied + _frame.Reset(); + + _frame.Start(); } private static Libuv.uv_buf_t AllocCallback(UvStreamHandle handle, int suggestedSize, object state) @@ -257,7 +203,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http private Libuv.uv_buf_t OnAlloc(UvStreamHandle handle, int suggestedSize) { - var result = _rawSocketInput.IncomingStart(); + var result = SocketInput.IncomingStart(); return handle.Libuv.buf_init( result.DataArrayPtr + result.End, @@ -277,7 +223,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http // there is no data to be read right now. // See the note at http://docs.libuv.org/en/v1.x/stream.html#c.uv_read_cb. // We need to clean up whatever was allocated by OnAlloc. - _rawSocketInput.IncomingDeferred(); + SocketInput.IncomingDeferred(); return; } @@ -307,7 +253,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http Log.ConnectionError(ConnectionId, error); } - _rawSocketInput.IncomingComplete(readCount, error); + SocketInput.IncomingComplete(readCount, error); if (errorDone) { @@ -338,7 +284,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http // ReadStart() can throw a UvException in some cases (e.g. socket is no longer connected). // This should be treated the same as OnRead() seeing a "normalDone" condition. Log.ConnectionReadFin(ConnectionId); - _rawSocketInput.IncomingComplete(0, null); + SocketInput.IncomingComplete(0, null); } } @@ -347,28 +293,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http switch (endType) { case ProduceEndType.ConnectionKeepAlive: - if (_connectionState != ConnectionState.Open) - { - return; - } - Log.ConnectionKeepAlive(ConnectionId); break; case ProduceEndType.SocketShutdown: case ProduceEndType.SocketDisconnect: - lock (_stateLock) - { - if (_connectionState == ConnectionState.Disconnecting || - _connectionState == ConnectionState.SocketClosed) - { - return; - } - _connectionState = ConnectionState.Disconnecting; - - Log.ConnectionDisconnect(ConnectionId); - _rawSocketOutput.End(endType); - break; - } + Log.ConnectionDisconnect(ConnectionId); + ((SocketOutput)SocketOutput).End(endType); + break; } } @@ -398,14 +329,5 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http // string ctor overload that takes char* return new string(charBuffer, 0, 13); } - - private enum ConnectionState - { - CreatingFrame, - ToDisconnect, - Open, - Disconnecting, - SocketClosed - } } }