Extract a layer of IIS Async IO handling (#818)

This commit is contained in:
Pavel Krymets 2018-05-11 13:02:12 -07:00 committed by GitHub
parent fbf727e073
commit 22a865b832
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1684 additions and 870 deletions

View File

@ -9,8 +9,9 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal static class NativeMethods
{
private const int HR_NOT_FOUND = unchecked((int)0x80070490);
private const int HR_OK = 0;
internal const int HR_OK = 0;
internal const int ERROR_NOT_FOUND = unchecked((int)0x80070490);
internal const int ERROR_OPERATION_ABORTED = unchecked((int)0x800703E3);
private const string KERNEL32 = "kernel32.dll";
@ -238,7 +239,7 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
var hr = http_cancel_io(pInProcessHandler);
// Async operation finished
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa363792(v=vs.85).aspx
if (hr == HR_NOT_FOUND)
if (hr == ERROR_NOT_FOUND)
{
return false;
}

View File

@ -1,119 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
// Primarily copied from https://github.com/aspnet/KestrelHttpServer/blob/dev/src/Kestrel.Transport.Libuv/Internal/LibuvAwaitable.cs
internal class IISAwaitable : ICriticalNotifyCompletion
{
private readonly static Action _callbackCompleted = () => { };
private Action _callback;
private Exception _exception;
private int _cbBytes;
private int _hr;
public static readonly NativeMethods.PFN_WEBSOCKET_ASYNC_COMPLETION ReadCallback = (IntPtr pHttpContext, IntPtr pCompletionInfo, IntPtr pvCompletionContext) =>
{
var context = (IISHttpContext)GCHandle.FromIntPtr(pvCompletionContext).Target;
NativeMethods.HttpGetCompletionInfo(pCompletionInfo, out int cbBytes, out int hr);
context.CompleteReadWebSockets(hr, cbBytes);
return NativeMethods.REQUEST_NOTIFICATION_STATUS.RQ_NOTIFICATION_PENDING;
};
public static readonly NativeMethods.PFN_WEBSOCKET_ASYNC_COMPLETION WriteCallback = (IntPtr pHttpContext, IntPtr pCompletionInfo, IntPtr pvCompletionContext) =>
{
var context = (IISHttpContext)GCHandle.FromIntPtr(pvCompletionContext).Target;
NativeMethods.HttpGetCompletionInfo(pCompletionInfo, out int cbBytes, out int hr);
context.CompleteWriteWebSockets(hr, cbBytes);
return NativeMethods.REQUEST_NOTIFICATION_STATUS.RQ_NOTIFICATION_PENDING;
};
public IISAwaitable GetAwaiter() => this;
public bool IsCompleted => _callback == _callbackCompleted;
public bool HasContinuation => _callback != null && !IsCompleted;
public int GetResult()
{
var exception = _exception;
var cbBytes = _cbBytes;
var hr = _hr;
// Reset the awaitable state
_exception = null;
_cbBytes = 0;
_callback = null;
_hr = 0;
if (exception != null)
{
// If the exception was an aborted read operation,
// return -1 to notify NativeReadAsync that the write was cancelled.
// E_OPERATIONABORTED == 0x800703e3 == -2147023901
// We also don't throw the exception here as this is expected behavior
// and can negatively impact perf if we catch an exception for each
// cann
if (hr != IISServerConstants.HResultCancelIO)
{
throw exception;
}
else
{
cbBytes = -1;
}
}
return cbBytes;
}
public void OnCompleted(Action continuation)
{
// There should never be a race between IsCompleted and OnCompleted since both operations
// should always be on the libuv thread
if (_callback == _callbackCompleted ||
Interlocked.CompareExchange(ref _callback, continuation, null) == _callbackCompleted)
{
// Just run it inline
Task.Run(continuation);
}
}
public void UnsafeOnCompleted(Action continuation)
{
OnCompleted(continuation);
}
public void Complete(int hr, int cbBytes)
{
_hr = hr;
_exception = Marshal.GetExceptionForHR(hr);
_cbBytes = cbBytes;
var continuation = Interlocked.Exchange(ref _callback, _callbackCompleted);
continuation?.Invoke();
}
public Action GetCompletion(int hr, int cbBytes)
{
_hr = hr;
_exception = Marshal.GetExceptionForHR(hr);
_cbBytes = cbBytes;
return Interlocked.Exchange(ref _callback, _callbackCompleted);
}
}
}

View File

@ -4,6 +4,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Security.Claims;
@ -250,16 +251,23 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
{
throw new InvalidOperationException("CoreStrings.UpgradeCannotBeCalledMultipleTimes");
}
_wasUpgraded = true;
StatusCode = StatusCodes.Status101SwitchingProtocols;
ReasonPhrase = ReasonPhrases.GetReasonPhrase(StatusCodes.Status101SwitchingProtocols);
_readWebSocketsOperation = new IISAwaitable();
_writeWebSocketsOperation = new IISAwaitable();
NativeMethods.HttpEnableWebsockets(_pInProcessHandler);
// If we started reading before calling Upgrade Task should be completed at this point
// because read would return 0 syncronosly
Debug.Assert(_readBodyTask == null || _readBodyTask.IsCompleted);
// Reset reading status to allow restarting with new IO
_hasRequestReadingStarted = false;
// Upgrade async will cause the stream processing to go into duplex mode
await UpgradeAsync();
AsyncIO = new WebSocketsAsyncIOEngine(_pInProcessHandler);
await InitializeResponse(flushHeaders: true);
return new DuplexStream(RequestBody, ResponseBody);
}

View File

@ -3,17 +3,13 @@
using System;
using System.Buffers;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.HttpSys.Internal;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal partial class IISHttpContext
{
private const int HttpDataChunkStackLimit = 128; // 16 bytes per HTTP_DATA_CHUNK
/// <summary>
/// Reads data from the Input pipe to the user.
/// </summary>
@ -22,11 +18,14 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
/// <returns></returns>
internal async Task<int> ReadAsync(Memory<byte> memory, CancellationToken cancellationToken)
{
StartProcessingRequestAndResponseBody();
if (!_hasRequestReadingStarted)
{
InitializeRequestIO();
}
while (true)
{
var result = await Input.Reader.ReadAsync();
var result = await _bodyInputPipe.Reader.ReadAsync();
var readableBuffer = result.Buffer;
try
{
@ -44,7 +43,7 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
}
finally
{
Input.Reader.AdvanceTo(readableBuffer.End, readableBuffer.End);
_bodyInputPipe.Reader.AdvanceTo(readableBuffer.End, readableBuffer.End);
}
}
}
@ -57,18 +56,13 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
/// <returns></returns>
internal Task WriteAsync(ReadOnlyMemory<byte> memory, CancellationToken cancellationToken = default(CancellationToken))
{
// Want to keep exceptions consistent,
if (!_hasResponseStarted)
async Task WriteFirstAsync()
{
return WriteAsyncAwaited(memory, cancellationToken);
await InitializeResponse(flushHeaders: false);
await _bodyOutput.WriteAsync(memory, cancellationToken);
}
lock (_stateSync)
{
DisableReads();
return Output.WriteAsync(memory, cancellationToken);
}
return !HasResponseStarted ? WriteFirstAsync() : _bodyOutput.WriteAsync(memory, cancellationToken);
}
/// <summary>
@ -78,397 +72,99 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
/// <returns></returns>
internal Task FlushAsync(CancellationToken cancellationToken = default(CancellationToken))
{
if (!_hasResponseStarted)
async Task FlushFirstAsync()
{
return FlushAsyncAwaited(cancellationToken);
}
lock (_stateSync)
{
DisableReads();
return Output.FlushAsync(cancellationToken);
await InitializeResponse(flushHeaders: true);
await _bodyOutput.FlushAsync(cancellationToken);
}
return !HasResponseStarted ? FlushFirstAsync() : _bodyOutput.FlushAsync(cancellationToken);
}
private void StartProcessingRequestAndResponseBody()
{
if (_processBodiesTask == null)
{
lock (_createReadWriteBodySync)
{
if (_processBodiesTask == null)
{
_processBodiesTask = ConsumeAsync();
}
}
}
}
private async Task FlushAsyncAwaited(CancellationToken cancellationToken)
{
await InitializeResponseAwaited();
Task flushTask;
lock (_stateSync)
{
DisableReads();
// Want to guarantee that data has been written to the pipe before releasing the lock.
flushTask = Output.FlushAsync(cancellationToken: cancellationToken);
}
await flushTask;
}
private async Task WriteAsyncAwaited(ReadOnlyMemory<byte> data, CancellationToken cancellationToken)
{
// WriteAsyncAwaited is only called for the first write to the body.
// Ensure headers are flushed if Write(Chunked)Async isn't called.
await InitializeResponseAwaited();
Task writeTask;
lock (_stateSync)
{
DisableReads();
// Want to guarantee that data has been written to the pipe before releasing the lock.
writeTask = Output.WriteAsync(data, cancellationToken: cancellationToken);
}
await writeTask;
}
// ConsumeAsync is called when either the first read or first write is done.
// There are two modes for reading and writing to the request/response bodies without upgrade.
// 1. Await all reads and try to read from the Output pipe
// 2. Done reading and await all writes.
// If the request is upgraded, we will start bidirectional streams for the input and output.
private async Task ConsumeAsync()
{
await ReadAndWriteLoopAsync();
// The ReadAndWriteLoop can return due to being upgraded. Check if _wasUpgraded is true to determine
// whether we go to a bidirectional stream or only write.
if (_wasUpgraded)
{
await StartBidirectionalStream();
}
}
private unsafe IISAwaitable ReadFromIISAsync(int length)
{
Action completion = null;
lock (_stateSync)
{
// We don't want to read if there is data available in the output pipe
// Therefore, we mark the current operation as cancelled to allow for the read
// to be requeued.
if (Output.Reader.TryRead(out var result))
{
// If the buffer is empty, it is considered a write of zero.
// we still want to cancel and allow the write to occur.
completion = _operation.GetCompletion(hr: IISServerConstants.HResultCancelIO, cbBytes: 0);
Output.Reader.AdvanceTo(result.Buffer.Start);
}
else
{
var hr = NativeMethods.HttpReadRequestBytes(
_pInProcessHandler,
(byte*)_inputHandle.Pointer,
length,
out var dwReceivedBytes,
out bool fCompletionExpected);
// if we complete the read synchronously, there is no need to set the reading flag
// as there is no cancelable operation.
if (!fCompletionExpected)
{
completion = _operation.GetCompletion(hr, dwReceivedBytes);
}
else
{
_reading = true;
}
}
}
// Invoke the completion outside of the lock if the reead finished synchronously.
completion?.Invoke();
return _operation;
}
private unsafe IISAwaitable WriteToIISAsync(ReadOnlySequence<byte> buffer)
{
var fCompletionExpected = false;
var hr = 0;
var nChunks = 0;
// Count the number of chunks in memory.
if (buffer.IsSingleSegment)
{
nChunks = 1;
}
else
{
foreach (var memory in buffer)
{
nChunks++;
}
}
if (nChunks == 1)
{
// If there is only a single chunk, use fixed to get a pointer to the buffer
var pDataChunks = stackalloc HttpApiTypes.HTTP_DATA_CHUNK[1];
fixed (byte* pBuffer = &MemoryMarshal.GetReference(buffer.First.Span))
{
ref var chunk = ref pDataChunks[0];
chunk.DataChunkType = HttpApiTypes.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory;
chunk.fromMemory.pBuffer = (IntPtr)pBuffer;
chunk.fromMemory.BufferLength = (uint)buffer.Length;
hr = NativeMethods.HttpWriteResponseBytes(_pInProcessHandler, pDataChunks, nChunks, out fCompletionExpected);
}
}
else if (nChunks < HttpDataChunkStackLimit)
{
// To avoid stackoverflows, we will only stackalloc if the write size is less than the StackChunkLimit
// The stack size is IIS is by default 128/256 KB, so we are generous with this threshold.
var pDataChunks = stackalloc HttpApiTypes.HTTP_DATA_CHUNK[nChunks];
hr = WriteSequenceToIIS(nChunks, buffer, pDataChunks, out fCompletionExpected);
}
else
{
// Otherwise allocate the chunks on the heap.
var chunks = new HttpApiTypes.HTTP_DATA_CHUNK[nChunks];
fixed (HttpApiTypes.HTTP_DATA_CHUNK* pDataChunks = chunks)
{
hr = WriteSequenceToIIS(nChunks, buffer, pDataChunks, out fCompletionExpected);
}
}
if (!fCompletionExpected)
{
_operation.Complete(hr, 0);
}
return _operation;
}
private unsafe int WriteSequenceToIIS(int nChunks, ReadOnlySequence<byte> buffer, HttpApiTypes.HTTP_DATA_CHUNK* pDataChunks, out bool fCompletionExpected)
{
var currentChunk = 0;
var hr = 0;
// REVIEW: We don't really need this list since the memory is already pinned with the default pool,
// but shouldn't assume the pool implementation right now. Unfortunately, this causes a heap allocation...
var handles = new MemoryHandle[nChunks];
foreach (var b in buffer)
{
ref var handle = ref handles[currentChunk];
ref var chunk = ref pDataChunks[currentChunk];
handle = b.Pin();
chunk.DataChunkType = HttpApiTypes.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory;
chunk.fromMemory.BufferLength = (uint)b.Length;
chunk.fromMemory.pBuffer = (IntPtr)handle.Pointer;
currentChunk++;
}
hr = NativeMethods.HttpWriteResponseBytes(_pInProcessHandler, pDataChunks, nChunks, out fCompletionExpected);
// Free the handles
foreach (var handle in handles)
{
handle.Dispose();
}
return hr;
}
private unsafe IISAwaitable FlushToIISAsync()
{
// Calls flush
var hr = 0;
hr = NativeMethods.HttpFlushResponseBytes(_pInProcessHandler, out var fCompletionExpected);
if (!fCompletionExpected)
{
_operation.Complete(hr, 0);
}
return _operation;
}
/// <summary>
/// Main function for control flow with IIS.
/// Uses two Pipes (Input and Output) between application calls to Read/Write/FlushAsync
/// Control Flow:
/// Try to see if there is data written by the application code (using TryRead)
/// and write it to IIS.
/// Check if the connection has been upgraded and call StartBidirectionalStreams
/// if it has.
/// Await reading from IIS, which will be cancelled if application code calls Write/FlushAsync.
/// </summary>
/// <returns>The Reading and Writing task.</returns>
private async Task ReadAndWriteLoopAsync()
private async Task ReadBody()
{
try
{
while (true)
{
// First we check if there is anything to write from the Output pipe
// If there is, we call WriteToIISAsync
// Check if Output pipe has anything to write to IIS.
if (Output.Reader.TryRead(out var readResult))
{
var buffer = readResult.Buffer;
var memory = _bodyInputPipe.Writer.GetMemory();
try
{
if (!buffer.IsEmpty)
{
// Write to IIS buffers
// Guaranteed to write the entire buffer to IIS
await WriteToIISAsync(buffer);
}
else if (readResult.IsCompleted)
{
break;
}
else
{
// Flush of zero bytes
await FlushToIISAsync();
}
}
finally
{
// Always Advance the data pointer to the end of the buffer.
Output.Reader.AdvanceTo(buffer.End);
}
var read = await AsyncIO.ReadAsync(memory);
// End of body
if (read == 0)
{
break;
}
// Check if there was an upgrade. If there is, we will replace the request and response bodies with
// two seperate loops. These will still be using the same Input and Output pipes here.
if (_upgradeTcs?.TrySetResult(null) == true)
// Read was not canceled because of incoming write or IO stopping
if (read != -1)
{
// _wasUpgraded will be set at this point, exit the loop and we will check if we upgraded or not
// when going to next read/write type.
return;
_bodyInputPipe.Writer.Advance(read);
}
// Now we handle the read.
var memory = Input.Writer.GetMemory();
_inputHandle = memory.Pin();
var result = await _bodyInputPipe.Writer.FlushAsync();
try
{
// Lock around invoking ReadFromIISAsync as we don't want to call CancelIo
// when calling read
var read = await ReadFromIISAsync(memory.Length);
// read value of 0 == done reading
// read value of -1 == read cancelled, still allowed to read but we
// need a write to occur first.
if (read == 0)
{
break;
}
else if (read == -1)
{
continue;
}
Input.Writer.Advance(read);
}
finally
{
// Always commit any changes to the Input pipe
_inputHandle.Dispose();
}
// Flush the read data for the Input Pipe writer
var flushResult = await Input.Writer.FlushAsync();
// If the pipe was closed, we are done reading,
if (flushResult.IsCompleted || flushResult.IsCanceled)
if (result.IsCompleted || result.IsCanceled)
{
break;
}
}
// Complete the input writer as we are done reading the request body.
Input.Writer.Complete();
}
catch (Exception ex)
{
Input.Writer.Complete(ex);
_bodyInputPipe.Writer.Complete(ex);
}
finally
{
_bodyInputPipe.Writer.Complete();
}
await WriteLoopAsync();
}
/// <summary>
/// Secondary function for control flow with IIS. This is only called once we are done
/// reading the request body. We now await reading from the Output pipe.
/// </summary>
/// <returns></returns>
private async Task WriteLoopAsync()
private async Task WriteBody(bool flush = false)
{
try
{
while (true)
{
// Reading is done, so we will await all reads from the output pipe
var readResult = await Output.Reader.ReadAsync();
var result = await _bodyOutput.Reader.ReadAsync();
// Get data from pipe
var buffer = readResult.Buffer;
var buffer = result.Buffer;
try
{
if (!buffer.IsEmpty)
{
// Write to IIS buffers
// Guaranteed to write the entire buffer to IIS
await WriteToIISAsync(buffer);
await AsyncIO.WriteAsync(buffer);
}
else if (readResult.IsCompleted)
// if request is done no need to flush, http.sys would do it for us
if (result.IsCompleted)
{
break;
}
else
flush = flush | result.IsCanceled;
if (flush)
{
// Flush of zero bytes will
await FlushToIISAsync();
await AsyncIO.FlushAsync();
flush = false;
}
}
finally
{
// Always Advance the data pointer to the end of the buffer.
Output.Reader.AdvanceTo(buffer.End);
_bodyOutput.Reader.AdvanceTo(buffer.End);
}
}
// Close the output pipe as we are done reading from it.
Output.Reader.Complete();
}
catch (Exception ex)
{
Output.Reader.Complete(ex);
_bodyOutput.Reader.Complete(ex);
}
}
// Always called from within a lock
private void DisableReads()
{
// To avoid concurrent reading and writing, if we have a pending read,
// we must cancel it.
// _reading will always be false if we upgrade to websockets, so we don't need to check wasUpgrade
// Also, we set _reading to false after cancelling to detect redundant calls
if (_reading)
finally
{
_reading = false;
// Calls IHttpContext->CancelIo(), which will cause the OnAsyncCompletion handler to fire.
NativeMethods.HttpTryCancelIO(_pInProcessHandler);
_bodyOutput.Reader.Complete();
}
}
}

View File

@ -1,225 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using Microsoft.AspNetCore.HttpSys.Internal;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
/// <summary>
/// Represents the websocket portion of the <see cref="IISHttpContext"/>
/// </summary>
internal partial class IISHttpContext
{
private bool _wasUpgraded; // Used for detecting repeated upgrades in IISHttpContext
private IISAwaitable _readWebSocketsOperation;
private IISAwaitable _writeWebSocketsOperation;
private TaskCompletionSource<object> _upgradeTcs;
private Task StartBidirectionalStream()
{
// IIS allows for websocket support and duplex channels only on Win8 and above
// This allows us to have two tasks for reading the request and writing the response
var readWebsocketTask = ReadWebSockets();
var writeWebsocketTask = WriteWebSockets();
return Task.WhenAll(readWebsocketTask, writeWebsocketTask);
}
public async Task UpgradeAsync()
{
if (_upgradeTcs == null)
{
_upgradeTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
// Flush any contents of the OutputPipe before upgrading to websockets.
await FlushAsync();
await _upgradeTcs.Task;
}
}
private unsafe IISAwaitable ReadWebSocketsFromIISAsync(int length)
{
var hr = 0;
int dwReceivedBytes;
bool fCompletionExpected;
// For websocket calls, we can directly provide a callback function to be called once the websocket operation completes.
hr = NativeMethods.HttpWebsocketsReadBytes(
_pInProcessHandler,
(byte*)_inputHandle.Pointer,
length,
IISAwaitable.ReadCallback,
(IntPtr)_thisHandle,
out dwReceivedBytes,
out fCompletionExpected);
if (!fCompletionExpected)
{
CompleteReadWebSockets(hr, dwReceivedBytes);
}
return _readWebSocketsOperation;
}
private unsafe IISAwaitable WriteWebSocketsFromIISAsync(ReadOnlySequence<byte> buffer)
{
var fCompletionExpected = false;
var hr = 0;
var nChunks = 0;
if (buffer.IsSingleSegment)
{
nChunks = 1;
}
else
{
foreach (var memory in buffer)
{
nChunks++;
}
}
if (buffer.IsSingleSegment)
{
var pDataChunks = stackalloc HttpApiTypes.HTTP_DATA_CHUNK[1];
fixed (byte* pBuffer = &MemoryMarshal.GetReference(buffer.First.Span))
{
ref var chunk = ref pDataChunks[0];
chunk.DataChunkType = HttpApiTypes.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory;
chunk.fromMemory.pBuffer = (IntPtr)pBuffer;
chunk.fromMemory.BufferLength = (uint)buffer.Length;
hr = NativeMethods.HttpWebsocketsWriteBytes(_pInProcessHandler, pDataChunks, nChunks, IISAwaitable.WriteCallback, (IntPtr)_thisHandle, out fCompletionExpected);
}
}
else
{
// REVIEW: Do we need to guard against this getting too big? It seems unlikely that we'd have more than say 10 chunks in real life
var pDataChunks = stackalloc HttpApiTypes.HTTP_DATA_CHUNK[nChunks];
var currentChunk = 0;
// REVIEW: We don't really need this list since the memory is already pinned with the default pool,
// but shouldn't assume the pool implementation right now. Unfortunately, this causes a heap allocation...
var handles = new MemoryHandle[nChunks];
foreach (var b in buffer)
{
ref var handle = ref handles[currentChunk];
ref var chunk = ref pDataChunks[currentChunk];
handle = b.Pin();
chunk.DataChunkType = HttpApiTypes.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory;
chunk.fromMemory.BufferLength = (uint)b.Length;
chunk.fromMemory.pBuffer = (IntPtr)handle.Pointer;
currentChunk++;
}
hr = NativeMethods.HttpWebsocketsWriteBytes(_pInProcessHandler, pDataChunks, nChunks, IISAwaitable.WriteCallback, (IntPtr)_thisHandle, out fCompletionExpected);
foreach (var handle in handles)
{
handle.Dispose();
}
}
if (!fCompletionExpected)
{
CompleteWriteWebSockets(hr, 0);
}
return _writeWebSocketsOperation;
}
internal void CompleteWriteWebSockets(int hr, int cbBytes)
{
_writeWebSocketsOperation.Complete(hr, cbBytes);
}
internal void CompleteReadWebSockets(int hr, int cbBytes)
{
_readWebSocketsOperation.Complete(hr, cbBytes);
}
private async Task ReadWebSockets()
{
try
{
while (true)
{
var memory = Input.Writer.GetMemory();
_inputHandle = memory.Pin();
try
{
int read = 0;
read = await ReadWebSocketsFromIISAsync(memory.Length);
if (read == 0)
{
break;
}
Input.Writer.Advance(read);
}
finally
{
_inputHandle.Dispose();
}
var result = await Input.Writer.FlushAsync();
if (result.IsCompleted || result.IsCanceled)
{
break;
}
}
Input.Writer.Complete();
}
catch (Exception ex)
{
Input.Writer.Complete(ex);
}
}
private async Task WriteWebSockets()
{
try
{
while (true)
{
var result = await Output.Reader.ReadAsync();
var buffer = result.Buffer;
var consumed = buffer.End;
try
{
if (!buffer.IsEmpty)
{
await WriteWebSocketsFromIISAsync(buffer);
}
else if (result.IsCompleted)
{
break;
}
}
finally
{
Output.Reader.AdvanceTo(consumed);
}
}
Output.Reader.Complete();
}
catch (Exception ex)
{
Output.Reader.Complete(ex);
}
}
}
}

View File

@ -19,6 +19,7 @@ using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.HttpSys.Internal;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
@ -33,15 +34,13 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
private readonly IISOptions _options;
private bool _reading; // To know whether we are currently in a read operation.
private volatile bool _hasResponseStarted;
private volatile bool _hasRequestReadingStarted;
private int _statusCode;
private string _reasonPhrase;
private readonly object _onStartingSync = new object();
private readonly object _onCompletedSync = new object();
private readonly object _stateSync = new object();
protected readonly object _createReadWriteBodySync = new object();
protected Stack<KeyValuePair<Func<object, Task>, object>> _onStarting;
protected Stack<KeyValuePair<Func<object, Task>, object>> _onCompleted;
@ -51,16 +50,20 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
private readonly IISHttpServer _server;
private GCHandle _thisHandle;
private MemoryHandle _inputHandle;
private IISAwaitable _operation = new IISAwaitable();
protected Task _processBodiesTask;
protected Task _readBodyTask;
protected Task _writeBodyTask;
private bool _wasUpgraded;
protected int _requestAborted;
protected Pipe _bodyInputPipe;
protected OutputProducer _bodyOutput;
private const string NtlmString = "NTLM";
private const string NegotiateString = "Negotiate";
private const string BasicString = "Basic";
internal unsafe IISHttpContext(MemoryPool<byte> memoryPool, IntPtr pInProcessHandler, IISOptions options, IISHttpServer server)
: base((HttpApiTypes.HTTP_REQUEST*)NativeMethods.HttpGetRawRequest(pInProcessHandler))
{
@ -89,8 +92,8 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
internal WindowsPrincipal WindowsUser { get; set; }
public Stream RequestBody { get; set; }
public Stream ResponseBody { get; set; }
public Pipe Input { get; set; }
public OutputProducer Output { get; set; }
protected IAsyncIOEngine AsyncIO { get; set; }
public IHeaderDictionary RequestHeaders { get; set; }
public IHeaderDictionary ResponseHeaders { get; set; }
@ -153,7 +156,7 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
RequestBody = new IISHttpRequestBody(this);
ResponseBody = new IISHttpResponseBody(this);
Input = new Pipe(new PipeOptions(_memoryPool, readerScheduler: PipeScheduler.ThreadPool, minimumSegmentSize: MinAllocBufferSize));
var pipe = new Pipe(
new PipeOptions(
_memoryPool,
@ -161,7 +164,7 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
pauseWriterThreshold: PauseWriterThreshold,
resumeWriterThreshold: ResumeWriterTheshold,
minimumSegmentSize: MinAllocBufferSize));
Output = new OutputProducer(pipe);
_bodyOutput = new OutputProducer(pipe);
}
public int StatusCode
@ -169,7 +172,7 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
get { return _statusCode; }
set
{
if (_hasResponseStarted)
if (HasResponseStarted)
{
ThrowResponseAlreadyStartedException(nameof(StatusCode));
}
@ -182,7 +185,7 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
get { return _reasonPhrase; }
set
{
if (_hasResponseStarted)
if (HasResponseStarted)
{
ThrowResponseAlreadyStartedException(nameof(ReasonPhrase));
}
@ -190,12 +193,9 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
}
}
internal IISHttpServer Server
{
get { return _server; }
}
internal IISHttpServer Server => _server;
private async Task InitializeResponseAwaited()
private async Task InitializeResponse(bool flushHeaders)
{
await FireOnStarting();
@ -204,7 +204,46 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
ThrowResponseAbortedException();
}
await ProduceStart(appCompleted: false);
await ProduceStart(flushHeaders);
}
private async Task ProduceStart(bool flushHeaders)
{
Debug.Assert(_hasResponseStarted == false);
_hasResponseStarted = true;
SetResponseHeaders();
EnsureIOInitialized();
if (flushHeaders)
{
await AsyncIO.FlushAsync();
}
_writeBodyTask = WriteBody(!flushHeaders);
}
private void InitializeRequestIO()
{
Debug.Assert(!_hasRequestReadingStarted);
_hasRequestReadingStarted = true;
EnsureIOInitialized();
_bodyInputPipe = new Pipe(new PipeOptions(_memoryPool, readerScheduler: PipeScheduler.ThreadPool, minimumSegmentSize: MinAllocBufferSize));
_readBodyTask = ReadBody();
}
private void EnsureIOInitialized()
{
// If at this point request was not upgraded just start a normal IO engine
if (AsyncIO == null)
{
AsyncIO = new AsyncIOEngine(_pInProcessHandler);
}
}
private void ThrowResponseAbortedException()
@ -212,38 +251,11 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
throw new ObjectDisposedException("Unhandled application exception", _applicationException);
}
private async Task ProduceStart(bool appCompleted)
{
if (_hasResponseStarted)
{
return;
}
_hasResponseStarted = true;
SendResponseHeaders(appCompleted);
// On first flush for websockets, we need to flush the headers such that
// IIS will know that an upgrade occured.
// If we don't have anything on the Output pipe, the TryRead in ReadAndWriteLoopAsync
// will fail and we will signal the upgradeTcs that we are upgrading. However, we still
// didn't flush. To fix this, we flush 0 bytes right after writing the headers.
Task flushTask;
lock (_stateSync)
{
DisableReads();
flushTask = Output.FlushAsync();
}
await flushTask;
StartProcessingRequestAndResponseBody();
}
protected Task ProduceEnd()
{
if (_applicationException != null)
{
if (_hasResponseStarted)
if (HasResponseStarted)
{
// We can no longer change the response, so we simply close the connection.
return Task.CompletedTask;
@ -258,7 +270,7 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
}
}
if (!_hasResponseStarted)
if (!HasResponseStarted)
{
return ProduceEndAwaited();
}
@ -275,27 +287,11 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
private async Task ProduceEndAwaited()
{
if (_hasResponseStarted)
{
return;
}
_hasResponseStarted = true;
SendResponseHeaders(appCompleted: true);
StartProcessingRequestAndResponseBody();
Task flushAsync;
lock (_stateSync)
{
DisableReads();
flushAsync = Output.FlushAsync();
}
await flushAsync;
await ProduceStart(flushHeaders: true);
await _bodyOutput.FlushAsync(default);
}
public unsafe void SendResponseHeaders(bool appCompleted)
public unsafe void SetResponseHeaders()
{
// Verifies we have sent the statuscode before writing a header
var reasonPhrase = string.IsNullOrEmpty(ReasonPhrase) ? ReasonPhrases.GetReasonPhrase(StatusCode) : ReasonPhrase;
@ -348,7 +344,7 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
{
lock (_onStartingSync)
{
if (_hasResponseStarted)
if (HasResponseStarted)
{
throw new InvalidOperationException("Response already started");
}
@ -439,8 +435,6 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
public void PostCompletion(NativeMethods.REQUEST_NOTIFICATION_STATUS requestNotificationStatus)
{
Debug.Assert(!_operation.HasContinuation, "Pending async operation!");
NativeMethods.HttpSetCompletionStatus(_pInProcessHandler, requestNotificationStatus);
NativeMethods.HttpPostCompletion(_pInProcessHandler, 0);
}
@ -450,18 +444,9 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
NativeMethods.HttpIndicateCompletion(_pInProcessHandler, notificationStatus);
}
internal void OnAsyncCompletion(int hr, int cbBytes)
internal void OnAsyncCompletion(int hr, int bytes)
{
// Must acquire the _stateSync here as anytime we call complete, we need to hold the lock
// to avoid races with cancellation.
Action continuation;
lock (_stateSync)
{
_reading = false;
continuation = _operation.GetCompletion(hr, cbBytes);
}
continuation?.Invoke();
AsyncIO.NotifyCompletion(hr, bytes);
}
private bool disposedValue = false; // To detect redundant calls

View File

@ -80,18 +80,24 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
}
finally
{
// The app is finished and there should be nobody writing to the response pipe
Output.Dispose();
// Complete response writer and request reader pipe sides
_bodyOutput.Dispose();
_bodyInputPipe?.Reader.Complete();
// The app is finished and there should be nobody reading from the request pipe
Input.Reader.Complete();
Task processBodiesTask;
lock (_createReadWriteBodySync)
// Allow writes to drain
if (_writeBodyTask != null)
{
processBodiesTask = _processBodiesTask;
await _writeBodyTask;
}
// Cancell all remaining IO, thre might be reads pending if not entire request body was sent
// by client
AsyncIO.Dispose();
if (_readBodyTask != null)
{
await _readBodyTask;
}
await processBodiesTask;
}
return success;
}

View File

@ -37,7 +37,7 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
return ReadAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult();
}
public override unsafe Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
var memory = new Memory<byte>(buffer, offset, count);
@ -54,7 +54,7 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
throw new NotSupportedException();
}
public unsafe override void Write(byte[] buffer, int offset, int count)
public override void Write(byte[] buffer, int offset, int count)
{
throw new NotSupportedException();
}

View File

@ -1,10 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal static class IISServerConstants
{
internal const int HResultCancelIO = -2147023901;
}
}

View File

@ -0,0 +1,43 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal partial class AsyncIOEngine
{
internal class AsyncFlushOperation : AsyncIOOperation
{
private readonly AsyncIOEngine _engine;
private IntPtr _requestHandler;
public AsyncFlushOperation(AsyncIOEngine engine)
{
_engine = engine;
}
public void Initialize(IntPtr requestHandler)
{
_requestHandler = requestHandler;
}
protected override bool InvokeOperation(out int hr, out int bytes)
{
bytes = 0;
hr = NativeMethods.HttpFlushResponseBytes(_requestHandler, out var fCompletionExpected);
return !fCompletionExpected;
}
protected override void ResetOperation()
{
base.ResetOperation();
_requestHandler = default;
_engine.ReturnOperation(this);
}
}
}
}

View File

@ -0,0 +1,63 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal partial class AsyncIOEngine
{
internal class AsyncReadOperation : AsyncIOOperation
{
private readonly AsyncIOEngine _engine;
private MemoryHandle _inputHandle;
private IntPtr _requestHandler;
private Memory<byte> _memory;
public AsyncReadOperation(AsyncIOEngine engine)
{
_engine = engine;
}
public void Initialize(IntPtr requestHandler, Memory<byte> memory)
{
_requestHandler = requestHandler;
_memory = memory;
}
protected override unsafe bool InvokeOperation(out int hr, out int bytes)
{
_inputHandle = _memory.Pin();
hr = NativeMethods.HttpReadRequestBytes(
_requestHandler,
(byte*)_inputHandle.Pointer,
_memory.Length,
out bytes,
out bool completionExpected);
return !completionExpected;
}
protected override void ResetOperation()
{
base.ResetOperation();
_memory = default;
_inputHandle.Dispose();
_inputHandle = default;
_requestHandler = default;
_engine.ReturnOperation(this);
}
public override void FreeOperationResources(int hr, int bytes)
{
_inputHandle.Dispose();
}
}
}
}

View File

@ -0,0 +1,34 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNetCore.HttpSys.Internal;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal partial class AsyncIOEngine
{
private class AsyncWriteOperation : AsyncWriteOperationBase
{
private readonly AsyncIOEngine _engine;
public AsyncWriteOperation(AsyncIOEngine engine)
{
_engine = engine;
}
protected override unsafe int WriteChunks(IntPtr requestHandler, int chunkCount, HttpApiTypes.HTTP_DATA_CHUNK* dataChunks,
out bool completionExpected)
{
return NativeMethods.HttpWriteResponseBytes(requestHandler, dataChunks, chunkCount, out completionExpected);
}
protected override void ResetOperation()
{
base.ResetOperation();
_engine.ReturnOperation(this);
}
}
}
}

View File

@ -0,0 +1,172 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal partial class AsyncIOEngine : IAsyncIOEngine
{
private readonly IntPtr _handler;
private bool _stopped;
private AsyncIOOperation _nextOperation;
private AsyncIOOperation _runningOperation;
private AsyncReadOperation _cachedAsyncReadOperation;
private AsyncWriteOperation _cachedAsyncWriteOperation;
private AsyncFlushOperation _cachedAsyncFlushOperation;
public AsyncIOEngine(IntPtr handler)
{
_handler = handler;
}
public ValueTask<int> ReadAsync(Memory<byte> memory)
{
var read = GetReadOperation();
read.Initialize(_handler, memory);
Run(read);
return new ValueTask<int>(read, 0);
}
public ValueTask<int> WriteAsync(ReadOnlySequence<byte> data)
{
var write = GetWriteOperation();
write.Initialize(_handler, data);
Run(write);
return new ValueTask<int>(write, 0);
}
private void Run(AsyncIOOperation ioOperation)
{
lock (this)
{
if (_stopped)
{
throw new IOException("IO stopped", NativeMethods.ERROR_OPERATION_ABORTED);
}
if (_runningOperation != null)
{
if (_nextOperation == null)
{
_nextOperation = ioOperation;
// If there is an active read cancel it
if (_runningOperation is AsyncReadOperation)
{
NativeMethods.HttpTryCancelIO(_handler);
}
}
else
{
throw new InvalidOperationException("Only one queued operation is allowed");
}
}
else
{
// we are just starting operation so there would be no
// continuation registered
var completed = ioOperation.Invoke() != null;
// operation went async
if (!completed)
{
_runningOperation = ioOperation;
}
}
}
}
public ValueTask FlushAsync()
{
var flush = GetFlushOperation();
flush.Initialize(_handler);
Run(flush);
return new ValueTask(flush, 0);
}
public void NotifyCompletion(int hr, int bytes)
{
AsyncIOOperation.AsyncContinuation continuation;
AsyncIOOperation.AsyncContinuation? nextContinuation = null;
lock (this)
{
Debug.Assert(_runningOperation != null);
continuation = _runningOperation.Complete(hr, bytes);
var next = _nextOperation;
_nextOperation = null;
_runningOperation = null;
if (next != null)
{
if (_stopped)
{
// Abort next operation if IO is stopped
nextContinuation = next.Complete(NativeMethods.ERROR_OPERATION_ABORTED, 0);
}
else
{
nextContinuation = next.Invoke();
// operation went async
if (nextContinuation == null)
{
_runningOperation = next;
}
}
}
}
continuation.Invoke();
nextContinuation?.Invoke();
}
public void Dispose()
{
lock (this)
{
_stopped = true;
NativeMethods.HttpTryCancelIO(_handler);
}
}
private AsyncReadOperation GetReadOperation() =>
Interlocked.Exchange(ref _cachedAsyncReadOperation, null) ??
new AsyncReadOperation(this);
private AsyncWriteOperation GetWriteOperation() =>
Interlocked.Exchange(ref _cachedAsyncWriteOperation, null) ??
new AsyncWriteOperation(this);
private AsyncFlushOperation GetFlushOperation() =>
Interlocked.Exchange(ref _cachedAsyncFlushOperation, null) ??
new AsyncFlushOperation(this);
private void ReturnOperation(AsyncReadOperation operation)
{
Volatile.Write(ref _cachedAsyncReadOperation, operation);
}
private void ReturnOperation(AsyncWriteOperation operation)
{
Volatile.Write(ref _cachedAsyncWriteOperation, operation);
}
private void ReturnOperation(AsyncFlushOperation operation)
{
Volatile.Write(ref _cachedAsyncFlushOperation, operation);
}
}
}

View File

@ -0,0 +1,158 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks.Sources;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal abstract class AsyncIOOperation: IValueTaskSource<int>, IValueTaskSource
{
private static readonly Action<object> CallbackCompleted = _ => { Debug.Assert(false, "Should not be invoked"); };
private Action<object> _continuation;
private object _state;
private int _result;
private Exception _exception;
public ValueTaskSourceStatus GetStatus(short token)
{
if (ReferenceEquals(Volatile.Read(ref _continuation), null))
{
return ValueTaskSourceStatus.Pending;
}
return _exception != null ? ValueTaskSourceStatus.Succeeded : ValueTaskSourceStatus.Faulted;
}
public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags)
{
if (_state != null)
{
ThrowMultipleContinuations();
}
_state = state;
var previousContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null);
if (previousContinuation != null)
{
if (!ReferenceEquals(previousContinuation, CallbackCompleted))
{
ThrowMultipleContinuations();
}
new AsyncContinuation(continuation, state).Invoke();
}
}
private static void ThrowMultipleContinuations()
{
throw new InvalidOperationException("Multiple awaiters are not allowed");
}
void IValueTaskSource.GetResult(short token)
{
var exception = _exception;
ResetOperation();
if (exception != null)
{
throw exception;
}
}
public int GetResult(short token)
{
var exception = _exception;
var result = _result;
ResetOperation();
if (exception != null)
{
throw exception;
}
return result;
}
public AsyncContinuation? Invoke()
{
if (InvokeOperation(out var hr, out var bytes))
{
return Complete(hr, bytes);
}
return null;
}
protected abstract bool InvokeOperation(out int hr, out int bytes);
public AsyncContinuation Complete(int hr, int bytes)
{
if (hr != NativeMethods.ERROR_OPERATION_ABORTED)
{
_result = bytes;
if (hr != NativeMethods.HR_OK)
{
_exception = new IOException("IO exception occurred", hr);
}
}
else
{
_result = -1;
_exception = null;
}
AsyncContinuation asyncContinuation = default;
var continuation = Interlocked.CompareExchange(ref _continuation, CallbackCompleted, null);
if (continuation != null)
{
asyncContinuation = new AsyncContinuation(continuation, _state);
}
FreeOperationResources(hr, bytes);
return asyncContinuation;
}
public virtual void FreeOperationResources(int hr, int bytes) { }
protected virtual void ResetOperation()
{
_exception = null;
_result = int.MinValue;
_state = null;
_continuation = null;
}
public readonly struct AsyncContinuation
{
public Action<object> Continuation { get; }
public object State { get; }
public AsyncContinuation(Action<object> continuation, object state)
{
Continuation = continuation;
State = state;
}
public void Invoke()
{
if (Continuation != null)
{
// TODO: use generic overload when code moved to be netcoreapp only
var continuation = Continuation;
var state = State;
ThreadPool.QueueUserWorkItem(_ => continuation(state));
}
}
}
}
}

View File

@ -0,0 +1,119 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using Microsoft.AspNetCore.HttpSys.Internal;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal abstract class AsyncWriteOperationBase : AsyncIOOperation
{
private const int HttpDataChunkStackLimit = 128; // 16 bytes per HTTP_DATA_CHUNK
private IntPtr _requestHandler;
private ReadOnlySequence<byte> _buffer;
private MemoryHandle[] _handles;
public void Initialize(IntPtr requestHandler, ReadOnlySequence<byte> buffer)
{
_requestHandler = requestHandler;
_buffer = buffer;
}
protected override unsafe bool InvokeOperation(out int hr, out int bytes)
{
if (_buffer.Length > int.MaxValue)
{
throw new InvalidOperationException($"Writes larger then {int.MaxValue} are not supported.");
}
bool completionExpected;
var chunkCount = GetChunkCount();
var bufferLength = (int)_buffer.Length;
if (chunkCount < HttpDataChunkStackLimit)
{
// To avoid stackoverflows, we will only stackalloc if the write size is less than the StackChunkLimit
// The stack size is IIS is by default 128/256 KB, so we are generous with this threshold.
var chunks = stackalloc HttpApiTypes.HTTP_DATA_CHUNK[chunkCount];
hr = WriteSequence(chunkCount, _buffer, chunks, out completionExpected);
}
else
{
// Otherwise allocate the chunks on the heap.
var chunks = new HttpApiTypes.HTTP_DATA_CHUNK[chunkCount];
fixed (HttpApiTypes.HTTP_DATA_CHUNK* pDataChunks = chunks)
{
hr = WriteSequence(chunkCount, _buffer, pDataChunks, out completionExpected);
}
}
bytes = bufferLength;
return !completionExpected;
}
public override void FreeOperationResources(int hr, int bytes)
{
// Free the handles
foreach (var handle in _handles)
{
handle.Dispose();
}
}
protected override void ResetOperation()
{
base.ResetOperation();
_requestHandler = default;
_buffer = default;
_handles.AsSpan().Clear();
}
private int GetChunkCount()
{
if (_buffer.IsSingleSegment)
{
return 1;
}
var count = 0;
foreach (var _ in _buffer)
{
count++;
}
return count;
}
private unsafe int WriteSequence(int nChunks, ReadOnlySequence<byte> buffer, HttpApiTypes.HTTP_DATA_CHUNK* pDataChunks, out bool fCompletionExpected)
{
var currentChunk = 0;
if (_handles == null || _handles.Length < nChunks)
{
_handles = new MemoryHandle[nChunks];
}
foreach (var readOnlyMemory in buffer)
{
ref var handle = ref _handles[currentChunk];
ref var chunk = ref pDataChunks[currentChunk];
handle = readOnlyMemory.Pin();
chunk.DataChunkType = HttpApiTypes.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory;
chunk.fromMemory.BufferLength = (uint)readOnlyMemory.Length;
chunk.fromMemory.pBuffer = (IntPtr)handle.Pointer;
currentChunk++;
}
return WriteChunks(_requestHandler, nChunks, pDataChunks, out fCompletionExpected);
}
protected abstract unsafe int WriteChunks(IntPtr requestHandler, int chunkCount, HttpApiTypes.HTTP_DATA_CHUNK* dataChunks, out bool completionExpected);
}
}

View File

@ -0,0 +1,17 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal interface IAsyncIOEngine: IDisposable
{
ValueTask<int> ReadAsync(Memory<byte> memory);
ValueTask<int> WriteAsync(ReadOnlySequence<byte> data);
ValueTask FlushAsync();
void NotifyCompletion(int hr, int bytes);
}
}

View File

@ -0,0 +1,42 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal partial class WebSocketsAsyncIOEngine
{
internal class AsyncInitializeOperation : AsyncIOOperation
{
private readonly WebSocketsAsyncIOEngine _engine;
private IntPtr _requestHandler;
public AsyncInitializeOperation(WebSocketsAsyncIOEngine engine)
{
_engine = engine;
}
public void Initialize(IntPtr requestHandler)
{
_requestHandler = requestHandler;
}
protected override bool InvokeOperation(out int hr, out int bytes)
{
hr = NativeMethods.HttpFlushResponseBytes(_requestHandler, out var completionExpected);
bytes = 0;
return !completionExpected;
}
protected override void ResetOperation()
{
base.ResetOperation();
_requestHandler = default;
_engine.ReturnOperation(this);
}
}
}
}

View File

@ -0,0 +1,79 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Runtime.InteropServices;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal partial class WebSocketsAsyncIOEngine
{
internal class WebSocketReadOperation : AsyncIOOperation
{
public static readonly NativeMethods.PFN_WEBSOCKET_ASYNC_COMPLETION ReadCallback = (httpContext, completionInfo, completionContext) =>
{
var context = (WebSocketReadOperation)GCHandle.FromIntPtr(completionContext).Target;
NativeMethods.HttpGetCompletionInfo(completionInfo, out var cbBytes, out var hr);
var continuation = context.Complete(hr, cbBytes);
continuation.Invoke();
return NativeMethods.REQUEST_NOTIFICATION_STATUS.RQ_NOTIFICATION_PENDING;
};
private readonly WebSocketsAsyncIOEngine _engine;
private readonly GCHandle _thisHandle;
private MemoryHandle _inputHandle;
private IntPtr _requestHandler;
private Memory<byte> _memory;
public WebSocketReadOperation(WebSocketsAsyncIOEngine engine)
{
_engine = engine;
_thisHandle = GCHandle.Alloc(this);
}
protected override unsafe bool InvokeOperation(out int hr, out int bytes)
{
_inputHandle = _memory.Pin();
hr = NativeMethods.HttpWebsocketsReadBytes(
_requestHandler,
(byte*)_inputHandle.Pointer,
_memory.Length,
ReadCallback,
(IntPtr)_thisHandle,
out bytes,
out var completionExpected);
return !completionExpected;
}
public void Initialize(IntPtr requestHandler, Memory<byte> memory)
{
_requestHandler = requestHandler;
_memory = memory;
}
public override void FreeOperationResources(int hr, int bytes)
{
_inputHandle.Dispose();
}
protected override void ResetOperation()
{
base.ResetOperation();
_memory = default;
_inputHandle.Dispose();
_inputHandle = default;
_requestHandler = default;
_engine.ReturnOperation(this);
}
}
}
}

View File

@ -0,0 +1,49 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Runtime.InteropServices;
using Microsoft.AspNetCore.HttpSys.Internal;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal partial class WebSocketsAsyncIOEngine
{
internal sealed class WebSocketWriteOperation : AsyncWriteOperationBase
{
private static readonly NativeMethods.PFN_WEBSOCKET_ASYNC_COMPLETION WriteCallback = (httpContext, completionInfo, completionContext) =>
{
var context = (WebSocketWriteOperation)GCHandle.FromIntPtr(completionContext).Target;
NativeMethods.HttpGetCompletionInfo(completionInfo, out var cbBytes, out var hr);
var continuation = context.Complete(hr, cbBytes);
continuation.Invoke();
return NativeMethods.REQUEST_NOTIFICATION_STATUS.RQ_NOTIFICATION_PENDING;
};
private readonly WebSocketsAsyncIOEngine _engine;
private readonly GCHandle _thisHandle;
public WebSocketWriteOperation(WebSocketsAsyncIOEngine engine)
{
_engine = engine;
_thisHandle = GCHandle.Alloc(this);
}
protected override unsafe int WriteChunks(IntPtr requestHandler, int chunkCount, HttpApiTypes.HTTP_DATA_CHUNK* dataChunks, out bool completionExpected)
{
return NativeMethods.HttpWebsocketsWriteBytes(requestHandler, dataChunks, chunkCount, WriteCallback, (IntPtr)_thisHandle, out completionExpected);
}
protected override void ResetOperation()
{
base.ResetOperation();
_engine.ReturnOperation(this);
}
}
}
}

View File

@ -0,0 +1,127 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal partial class WebSocketsAsyncIOEngine: IAsyncIOEngine
{
private readonly IntPtr _handler;
private bool _isInitialized;
private AsyncInitializeOperation _initializationFlush;
private WebSocketWriteOperation _cachedWebSocketWriteOperation;
private WebSocketReadOperation _cachedWebSocketReadOperation;
private AsyncInitializeOperation _cachedAsyncInitializeOperation;
public WebSocketsAsyncIOEngine(IntPtr handler)
{
_handler = handler;
}
public ValueTask<int> ReadAsync(Memory<byte> memory)
{
ThrowIfNotInitialized();
var read = GetReadOperation();
read.Initialize(_handler, memory);
read.Invoke();
return new ValueTask<int>(read, 0);
}
public ValueTask<int> WriteAsync(ReadOnlySequence<byte> data)
{
ThrowIfNotInitialized();
var write = GetWriteOperation();
write.Initialize(_handler, data);
write.Invoke();
return new ValueTask<int>(write, 0);
}
public ValueTask FlushAsync()
{
if (_isInitialized)
{
return new ValueTask(Task.CompletedTask);
}
NativeMethods.HttpEnableWebsockets(_handler);
var init = GetInitializeOperation();
init.Initialize(_handler);
var continuation = init.Invoke();
if (continuation != null)
{
_isInitialized = true;
}
else
{
_initializationFlush = init;
}
return new ValueTask(init, 0);
}
public void NotifyCompletion(int hr, int bytes)
{
_isInitialized = true;
var init = _initializationFlush;
if (init == null)
{
throw new InvalidOperationException("Unexpected completion for WebSocket operation");
}
var continuation = init.Complete(hr, bytes);
_initializationFlush = null;
continuation.Invoke();
}
private void ThrowIfNotInitialized()
{
if (!_isInitialized)
{
throw new InvalidOperationException("Socket IO not initialized yet");
}
}
public void Dispose()
{
NativeMethods.HttpTryCancelIO(_handler);
}
private WebSocketReadOperation GetReadOperation() =>
Interlocked.Exchange(ref _cachedWebSocketReadOperation, null) ??
new WebSocketReadOperation(this);
private WebSocketWriteOperation GetWriteOperation() =>
Interlocked.Exchange(ref _cachedWebSocketWriteOperation, null) ??
new WebSocketWriteOperation(this);
private AsyncInitializeOperation GetInitializeOperation() =>
Interlocked.Exchange(ref _cachedAsyncInitializeOperation, null) ??
new AsyncInitializeOperation(this);
private void ReturnOperation(AsyncInitializeOperation operation) =>
Volatile.Write(ref _cachedAsyncInitializeOperation, operation);
private void ReturnOperation(WebSocketWriteOperation operation) =>
Volatile.Write(ref _cachedWebSocketWriteOperation, operation);
private void ReturnOperation(WebSocketReadOperation operation) =>
Volatile.Write(ref _cachedWebSocketReadOperation, operation);
}
}

View File

@ -11,8 +11,6 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
{
internal class OutputProducer
{
private static readonly ArraySegment<byte> _emptyData = new ArraySegment<byte>(new byte[0]);
// This locks access to to all of the below fields
private readonly object _contextLock = new object();
@ -36,9 +34,11 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
public PipeReader Reader => _pipe.Reader;
public Task FlushAsync(CancellationToken cancellationToken = default(CancellationToken))
public Task FlushAsync(CancellationToken cancellationToken)
{
return WriteAsync(_emptyData, cancellationToken);
_pipe.Reader.CancelPendingRead();
// Await backpressure
return FlushAsync(_pipe.Writer, cancellationToken);
}
public void Dispose()
@ -71,9 +71,7 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
}
}
public Task WriteAsync(
ReadOnlyMemory<byte> buffer,
CancellationToken cancellationToken)
public Task WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
{
lock (_contextLock)
{
@ -88,8 +86,7 @@ namespace Microsoft.AspNetCore.Server.IISIntegration
return FlushAsync(_pipe.Writer, cancellationToken);
}
private Task FlushAsync(PipeWriter pipeWriter,
CancellationToken cancellationToken)
private Task FlushAsync(PipeWriter pipeWriter, CancellationToken cancellationToken)
{
var awaitable = pipeWriter.FlushAsync(cancellationToken);
if (awaitable.IsCompleted)

View File

@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>$(StandardTestTfms)</TargetFrameworks>
@ -23,6 +23,7 @@
<PackageReference Include="Microsoft.Extensions.Logging.Debug" Version="$(MicrosoftExtensionsLoggingDebugPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Testing" Version="$(MicrosoftExtensionsLoggingTestingPackageVersion)" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="$(MicrosoftNETTestSdkPackageVersion)" />
<PackageReference Include="System.Net.WebSockets.WebSocketProtocol" Version="$(SystemNetWebSocketsWebSocketProtocolPackageVersion)" />
<PackageReference Include="xunit" Version="$(XunitPackageVersion)" />
<PackageReference Include="xunit.runner.visualstudio" Version="$(XunitRunnerVisualStudioPackageVersion)" />
</ItemGroup>

View File

@ -2,11 +2,9 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Net;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Testing.xunit;
using Xunit;
@ -70,55 +68,129 @@ namespace Microsoft.AspNetCore.Server.IISIntegration.FunctionalTests
}
[ConditionalFact]
public void ReadAndWriteSlowConnection()
public async Task ReadSetHeaderWrite()
{
var ipHostEntry = Dns.GetHostEntry(_fixture.Client.BaseAddress.Host);
var body = "Body text";
var content = new StringContent(body);
var response = await _fixture.Client.PostAsync("SetHeaderFromBody", content);
var responseText = await response.Content.ReadAsStringAsync();
using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
Assert.Equal(body, response.Headers.GetValues("BodyAsString").Single());
Assert.Equal(body, responseText);
}
[ConditionalFact]
public async Task ReadAndWriteSlowConnection()
{
using (var connection = _fixture.CreateTestConnection())
{
foreach (var hostEntry in ipHostEntry.AddressList)
{
try
{
socket.Connect(hostEntry, _fixture.Client.BaseAddress.Port);
break;
}
catch (Exception)
{
// Exceptions can be thrown based on ipv6 support
}
}
Assert.True(socket.Connected);
var testString = "hello world";
var request = $"POST /ReadAndWriteSlowConnection HTTP/1.0\r\n" +
$"Content-Length: {testString.Length}\r\n" +
"Host: " + "localhost\r\n" +
"\r\n";
var bytes = 0;
var requestStringBytes = Encoding.ASCII.GetBytes(request);
var testStringBytes = Encoding.ASCII.GetBytes(testString);
"\r\n" + testString;
while ((bytes += socket.Send(requestStringBytes, bytes, 1, SocketFlags.None)) < requestStringBytes.Length)
foreach (var c in request)
{
await connection.Send(c.ToString());
await Task.Delay(10);
}
bytes = 0;
while ((bytes += socket.Send(testStringBytes, bytes, 1, SocketFlags.None)) < testStringBytes.Length)
await connection.Receive(
"HTTP/1.1 200 OK",
"");
await connection.ReceiveHeaders();
for (int i = 0; i < 100; i++)
{
Thread.Sleep(100);
foreach (var c in testString)
{
await connection.Receive(c.ToString());
}
await Task.Delay(10);
}
await connection.WaitForConnectionClose();
}
}
[ConditionalFact]
public async Task ReadAndWriteInterleaved()
{
using (var connection = _fixture.CreateTestConnection())
{
var requestLength = 0;
var messages = new List<string>();
for (var i = 1; i < 100; i++)
{
var message = i + Environment.NewLine;
requestLength += message.Length;
messages.Add(message);
}
var stringBuilder = new StringBuilder();
var buffer = new byte[4096];
int size;
while ((size = socket.Receive(buffer, buffer.Length, SocketFlags.None)) != 0)
await connection.Send(
"POST /ReadAndWriteEchoLines HTTP/1.0",
$"Content-Length: {requestLength}",
"Host: localhost",
"",
"");
await connection.Receive(
"HTTP/1.1 200 OK",
"");
await connection.ReceiveHeaders();
foreach (var message in messages)
{
stringBuilder.Append(Encoding.ASCII.GetString(buffer, 0, size));
await connection.Send(message);
await connection.Receive(message);
}
Assert.Contains(new StringBuilder().Insert(0, "hello world", 100).ToString(), stringBuilder.ToString());
await connection.Send("\r\n");
await connection.WaitForConnectionClose();
}
}
[ConditionalFact]
public async Task ConsumePartialBody()
{
using (var connection = _fixture.CreateTestConnection())
{
var message = "Hello";
await connection.Send(
"POST /ReadPartialBody HTTP/1.1",
$"Content-Length: {100}",
"Host: localhost",
"Connection: close",
"",
"");
await connection.Send(message);
await connection.Receive(
"HTTP/1.1 200 OK",
"");
// This test can return both content length or chunked response
// depending on if appfunc managed to complete before write was
// issued
var headers = await connection.ReceiveHeaders();
if (headers.Contains("Content-Length: 5"))
{
await connection.Receive("Hello");
}
else
{
await connection.Receive(
"5",
message,
"");
await connection.Receive(
"0",
"",
"");
}
await connection.WaitForConnectionClose();
}
}
}

View File

@ -0,0 +1,78 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Linq;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Testing.xunit;
using Xunit;
namespace Microsoft.AspNetCore.Server.IISIntegration.FunctionalTests
{
[Collection(IISTestSiteCollection.Name)]
public class WebSocketsTests
{
private readonly string _webSocketUri;
public WebSocketsTests(IISTestSiteFixture fixture)
{
_webSocketUri = fixture.BaseUri.Replace("http:", "ws:");
}
[ConditionalFact]
public async Task OnStartedCalledForWebSocket()
{
var cws = new ClientWebSocket();
await cws.ConnectAsync(new Uri(_webSocketUri + "WebSocketLifetimeEvents"), default);
await ReceiveMessage(cws, "OnStarting");
await ReceiveMessage(cws, "Upgraded");
}
[ConditionalFact]
public async Task WebReadBeforeUpgrade()
{
var cws = new ClientWebSocket();
await cws.ConnectAsync(new Uri(_webSocketUri + "WebReadBeforeUpgrade"), default);
await ReceiveMessage(cws, "Yay");
}
[ConditionalFact]
public async Task CanSendAndReceieveData()
{
var cws = new ClientWebSocket();
await cws.ConnectAsync(new Uri(_webSocketUri + "WebSocketEcho"), default);
for (int i = 0; i < 1000; i++)
{
var mesage = i.ToString();
await SendMessage(cws, mesage);
await ReceiveMessage(cws, mesage);
}
}
private async Task SendMessage(ClientWebSocket webSocket, string message)
{
await webSocket.SendAsync(new ArraySegment<byte>(Encoding.ASCII.GetBytes(message)), WebSocketMessageType.Text, true, default);
}
private async Task ReceiveMessage(ClientWebSocket webSocket, string expectedMessage)
{
var received = new byte[expectedMessage.Length];
var offset = 0;
WebSocketReceiveResult result;
do
{
result = await webSocket.ReceiveAsync(new ArraySegment<byte>(received, offset, received.Length - offset), default);
offset += result.Count;
} while (!result.EndOfMessage);
Assert.Equal(expectedMessage, Encoding.ASCII.GetString(received));
}
}
}

View File

@ -40,6 +40,11 @@ namespace Microsoft.AspNetCore.Server.IISIntegration.FunctionalTests
public CancellationToken ShutdownToken { get; }
public DeploymentResult DeploymentResult { get; }
public TestConnection CreateTestConnection()
{
return new TestConnection(Client.BaseAddress.Port);
}
public void Dispose()
{
_deployer.Dispose();

View File

@ -0,0 +1,212 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Testing;
using Xunit;
namespace Microsoft.AspNetCore.Server.IISIntegration.FunctionalTests
{
/// <summary>
/// Summary description for TestConnection
/// </summary>
public class TestConnection : IDisposable
{
private static readonly TimeSpan Timeout = TimeSpan.FromMinutes(1);
private readonly bool _ownsSocket;
private readonly Socket _socket;
private readonly NetworkStream _stream;
private readonly StreamReader _reader;
public TestConnection(int port)
: this(port, AddressFamily.InterNetwork)
{
}
public TestConnection(int port, AddressFamily addressFamily)
: this(CreateConnectedLoopbackSocket(port, addressFamily), ownsSocket: true)
{
}
public TestConnection(Socket socket)
: this(socket, ownsSocket: false)
{
}
private TestConnection(Socket socket, bool ownsSocket)
{
_ownsSocket = ownsSocket;
_socket = socket;
_stream = new NetworkStream(_socket, ownsSocket: false);
_reader = new StreamReader(_stream, Encoding.ASCII);
}
public Socket Socket => _socket;
public StreamReader Reader => _reader;
public void Dispose()
{
_stream.Dispose();
if (_ownsSocket)
{
_socket.Dispose();
}
}
public async Task Send(params string[] lines)
{
var text = string.Join("\r\n", lines);
var writer = new StreamWriter(_stream, Encoding.GetEncoding("iso-8859-1"));
for (var index = 0; index < text.Length; index++)
{
var ch = text[index];
writer.Write(ch);
await writer.FlushAsync().ConfigureAwait(false);
// Re-add delay to help find socket input consumption bugs more consistently
//await Task.Delay(TimeSpan.FromMilliseconds(5));
}
await writer.FlushAsync().ConfigureAwait(false);
await _stream.FlushAsync().ConfigureAwait(false);
}
public async Task Receive(params string[] lines)
{
var expected = string.Join("\r\n", lines);
var actual = new char[expected.Length];
var offset = 0;
try
{
while (offset < expected.Length)
{
var data = new byte[expected.Length];
var task = _reader.ReadAsync(actual, offset, actual.Length - offset);
if (!Debugger.IsAttached)
{
task = task.TimeoutAfter(Timeout);
}
var count = await task.ConfigureAwait(false);
if (count == 0)
{
break;
}
offset += count;
}
}
catch (TimeoutException ex) when (offset != 0)
{
throw new TimeoutException($"Did not receive a complete response within {Timeout}.{Environment.NewLine}{Environment.NewLine}" +
$"Expected:{Environment.NewLine}{expected}{Environment.NewLine}{Environment.NewLine}" +
$"Actual:{Environment.NewLine}{new string(actual, 0, offset)}{Environment.NewLine}",
ex);
}
Assert.Equal(expected, new string(actual, 0, offset));
}
public async Task ReceiveStartsWith(string prefix, int maxLineLength = 1024)
{
var actual = new char[maxLineLength];
var offset = 0;
while (offset < maxLineLength)
{
// Read one char at a time so we don't read past the end of the line.
var task = _reader.ReadAsync(actual, offset, 1);
if (!Debugger.IsAttached)
{
Assert.True(task.Wait(4000), "timeout");
}
var count = await task.ConfigureAwait(false);
if (count == 0)
{
break;
}
Assert.True(count == 1);
offset++;
if (actual[offset - 1] == '\n')
{
break;
}
}
var actualLine = new string(actual, 0, offset);
Assert.StartsWith(prefix, actualLine);
}
public async Task<string[]> ReceiveHeaders(params string[] lines)
{
List<string> headers = new List<string>();
string line;
do
{
line = await _reader.ReadLineAsync();
headers.Add(line);
} while (line != "");
foreach (var s in lines)
{
Assert.Contains(s, headers);
}
return headers.ToArray();
}
public Task WaitForConnectionClose()
{
var tcs = new TaskCompletionSource<object>();
var eventArgs = new SocketAsyncEventArgs();
eventArgs.SetBuffer(new byte[128], 0, 128);
eventArgs.Completed += ReceiveAsyncCompleted;
eventArgs.UserToken = tcs;
if (!_socket.ReceiveAsync(eventArgs))
{
ReceiveAsyncCompleted(this, eventArgs);
}
return tcs.Task;
}
private void ReceiveAsyncCompleted(object sender, SocketAsyncEventArgs e)
{
var tcs = (TaskCompletionSource<object>)e.UserToken;
if (e.BytesTransferred == 0)
{
tcs.SetResult(null);
}
else
{
tcs.SetException(new IOException(
$"Expected connection close, received data instead: \"{_reader.CurrentEncoding.GetString(e.Buffer, 0, e.BytesTransferred)}\""));
}
}
public static Socket CreateConnectedLoopbackSocket(int port, AddressFamily addressFamily)
{
if (addressFamily != AddressFamily.InterNetwork && addressFamily != AddressFamily.InterNetworkV6)
{
throw new ArgumentException($"TestConnection does not support address family of type {addressFamily}", nameof(addressFamily));
}
var socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp);
var address = addressFamily == AddressFamily.InterNetworkV6
? IPAddress.IPv6Loopback
: IPAddress.Loopback;
socket.Connect(new IPEndPoint(address, port));
return socket;
}
}
}

View File

@ -4,17 +4,22 @@
<PropertyGroup>
<TargetFrameworks>$(StandardTestTfms)</TargetFrameworks>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\..\src\Microsoft.AspNetCore.Server.IISIntegration\Microsoft.AspNetCore.Server.IISIntegration.csproj" />
</ItemGroup>
<ItemGroup>
<Compile Include="..\shared\**\*.cs" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Hosting" Version="$(MicrosoftAspNetCoreHostingPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.WebUtilities" Version="$(MicrosoftAspNetCoreWebUtilitiesPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.Configuration.EnvironmentVariables" Version="$(MicrosoftExtensionsConfigurationEnvironmentVariablesPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.Configuration.Json" Version="$(MicrosoftExtensionsConfigurationJsonPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="$(MicrosoftExtensionsLoggingConsolePackageVersion)" />
<PackageReference Include="System.Net.WebSockets.WebSocketProtocol" Version="$(SystemNetWebSocketsWebSocketProtocolPackageVersion)" />
<PackageReference Include="xunit" Version="$(XunitPackageVersion)" />
</ItemGroup>

View File

@ -0,0 +1,154 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Net.WebSockets;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using IISIntegration.FunctionalTests;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.IIS;
using Microsoft.AspNetCore.Server.IISIntegration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Primitives;
using Xunit;
namespace IISTestSite
{
public partial class Startup
{
private void WebsocketRequest(IApplicationBuilder app)
{
app.Run(async context =>
{
await context.Response.WriteAsync("test");
});
}
private void WebReadBeforeUpgrade(IApplicationBuilder app)
{
app.Run(async context => {
var singleByteArray = new byte[1];
Assert.Equal(0, await context.Request.Body.ReadAsync(singleByteArray, 0, 1));
var ws = await Upgrade(context);
await SendMessages(ws, "Yay");
});
}
private void WebSocketEcho(IApplicationBuilder app)
{
app.Run(async context =>
{
var ws = await Upgrade(context);
var appLifetime = app.ApplicationServices.GetRequiredService<IApplicationLifetime>();
await Echo(ws, appLifetime.ApplicationStopping);
});
}
private void WebSocketLifetimeEvents(IApplicationBuilder app)
{
app.Run(async context => {
var messages = new List<string>();
context.Response.OnStarting(() => {
context.Response.Headers["custom-header"] = "value";
messages.Add("OnStarting");
return Task.CompletedTask;
});
var ws = await Upgrade(context);
messages.Add("Upgraded");
await SendMessages(ws, messages.ToArray());
});
}
private static async Task SendMessages(WebSocket webSocket, params string[] messages)
{
foreach (var message in messages)
{
await webSocket.SendAsync(new ArraySegment<byte>(Encoding.ASCII.GetBytes(message)), WebSocketMessageType.Text, true, CancellationToken.None);
}
}
private static async Task<WebSocket> Upgrade(HttpContext context)
{
var upgradeFeature = context.Features.Get<IHttpUpgradeFeature>();
// Generate WebSocket response headers
string key = context.Request.Headers[Constants.Headers.SecWebSocketKey].ToString();
var responseHeaders = HandshakeHelpers.GenerateResponseHeaders(key);
foreach (var headerPair in responseHeaders)
{
context.Response.Headers[headerPair.Key] = headerPair.Value;
}
// Upgrade the connection
Stream opaqueTransport = await upgradeFeature.UpgradeAsync();
// Get the WebSocket object
var ws = WebSocketProtocol.CreateFromStream(opaqueTransport, isServer: true, subProtocol: null, keepAliveInterval: TimeSpan.FromMinutes(2));
return ws;
}
private async Task Echo(WebSocket webSocket, CancellationToken token)
{
var buffer = new byte[1024 * 4];
var result = await webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), token);
bool closeFromServer = false;
string closeFromServerCmd = "CloseFromServer";
int closeFromServerLength = closeFromServerCmd.Length;
while (!result.CloseStatus.HasValue && !token.IsCancellationRequested && !closeFromServer)
{
if (result.Count == closeFromServerLength &&
Encoding.ASCII.GetString(buffer).Substring(0, result.Count) == closeFromServerCmd)
{
// The client sent "CloseFromServer" text message to request the server to close (a test scenario).
closeFromServer = true;
}
else
{
await webSocket.SendAsync(new ArraySegment<byte>(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, token);
result = await webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), token);
}
}
if (result.CloseStatus.HasValue)
{
// Client-initiated close handshake
await webSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None);
}
else
{
// Server-initiated close handshake due to either of the two conditions:
// (1) The applicaton host is performing a graceful shutdown.
// (2) The client sent "CloseFromServer" text message to request the server to close (a test scenario).
await webSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, closeFromServerCmd, CancellationToken.None);
// The server has sent the Close frame.
// Stop sending but keep receiving until we get the Close frame from the client.
while (!result.CloseStatus.HasValue)
{
result = await webSocket.ReceiveAsync(new ArraySegment<byte>(buffer), CancellationToken.None);
}
}
}
}
}

View File

@ -2,25 +2,30 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Net.WebSockets;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using IISIntegration.FunctionalTests;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.IIS;
using Microsoft.AspNetCore.Server.IISIntegration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Primitives;
using Xunit;
namespace IISTestSite
{
public class Startup
public partial class Startup
{
public void Configure(IApplicationBuilder app)
{
@ -369,8 +374,7 @@ namespace IISTestSite
private void ReadAndWriteEcho(IApplicationBuilder app)
{
app.Run(async context =>
{
app.Run(async context => {
var readBuffer = new byte[4096];
var result = await context.Request.Body.ReadAsync(readBuffer, 0, readBuffer.Length);
while (result != 0)
@ -381,6 +385,51 @@ namespace IISTestSite
});
}
private void ReadAndWriteEchoLines(IApplicationBuilder app)
{
app.Run(async context => {
//Send headers
await context.Response.Body.FlushAsync();
var reader = new StreamReader(context.Request.Body);
while (!reader.EndOfStream)
{
var line = await reader.ReadLineAsync();
if (line == "")
{
return;
}
await context.Response.WriteAsync(line + Environment.NewLine);
await context.Response.Body.FlushAsync();
}
});
}
private void ReadPartialBody(IApplicationBuilder app)
{
app.Run(async context => {
var data = new byte[5];
var count = 0;
do
{
count += await context.Request.Body.ReadAsync(data, count, data.Length - count);
} while (count != data.Length);
await context.Response.Body.WriteAsync(data, 0, data.Length);
});
}
private void SetHeaderFromBody(IApplicationBuilder app)
{
app.Run(async context => {
using (var reader = new StreamReader(context.Request.Body))
{
var value = await reader.ReadToEndAsync();
context.Response.Headers["BodyAsString"] = value;
await context.Response.WriteAsync(value);
}
});
}
private void ReadAndWriteEchoTwice(IApplicationBuilder app)
{
app.Run(async context =>
@ -416,14 +465,6 @@ namespace IISTestSite
}
}
private void WebsocketRequest(IApplicationBuilder app)
{
app.Run(async context =>
{
await context.Response.WriteAsync("test");
});
}
private void ReadAndWriteCopyToAsync(IApplicationBuilder app)
{
app.Run(async context =>

View File

@ -15,7 +15,7 @@
"nativeDebugging": true,
"environmentVariables": {
"IIS_SITE_PATH": "$(MSBuildThisFileDirectory)",
"ANCM_PATH": "$(TargetDir)$(AncmPath)",
"ANCM_PATH": "$(TargetDir)$(AncmV2Path)",
"LAUNCHER_ARGS": "$(TargetPath)",
"ASPNETCORE_ENVIRONMENT": "Development",
"LAUNCHER_PATH": "$(DotNetPath)"
@ -27,7 +27,7 @@
"commandLineArgs": "$(IISArguments)",
"environmentVariables": {
"IIS_SITE_PATH": "$(MSBuildThisFileDirectory)",
"ANCM_PATH": "$(TargetDir)$(AncmPath)",
"ANCM_PATH": "$(TargetDir)$(AncmV2Path)",
"LAUNCHER_ARGS": "$(TargetPath)",
"ASPNETCORE_ENVIRONMENT": "Development",
"LAUNCHER_PATH": "$(DotNetPath)"

View File

@ -8,6 +8,7 @@ using System.Threading.Tasks;
using System.Threading;
using System.Text;
using System.Net.WebSockets;
using IISIntegration.FunctionalTests;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;

View File

@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<Project Sdk="Microsoft.NET.Sdk.Web">
<Import Project="..\..\..\build\testsite.props" />
@ -10,6 +10,10 @@
<ProjectReference Include="..\..\..\src\Microsoft.AspNetCore.Server.IISIntegration\Microsoft.AspNetCore.Server.IISIntegration.csproj" />
</ItemGroup>
<ItemGroup>
<Compile Include="..\shared\**\*.cs" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Server.Kestrel" Version="$(MicrosoftAspNetCoreServerKestrelPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="$(MicrosoftExtensionsLoggingConsolePackageVersion)" />

View File

@ -1,7 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace ANCMStressTestApp
namespace IISIntegration.FunctionalTests
{
public static class Constants
{

View File

@ -6,9 +6,8 @@ using System.Collections.Generic;
using System.Security.Cryptography;
using System.Text;
namespace ANCMStressTestApp
namespace IISIntegration.FunctionalTests
{
// Removed all the
internal static class HandshakeHelpers
{
public static IEnumerable<KeyValuePair<string, string>> GenerateResponseHeaders(string key)