// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.IO; using System.Threading; using System.Threading.Tasks; namespace Microsoft.AspNetCore.Server.HttpSys { internal class ResponseStream : Stream { private readonly Stream _innerStream; private readonly Func _onStart; internal ResponseStream(Stream innerStream, Func onStart) { _innerStream = innerStream; _onStart = onStart; } public override bool CanRead => _innerStream.CanRead; public override bool CanSeek => _innerStream.CanSeek; public override bool CanWrite => _innerStream.CanWrite; public override long Length => _innerStream.Length; public override long Position { get { return _innerStream.Position; } set { _innerStream.Position = value; } } public override long Seek(long offset, SeekOrigin origin) => _innerStream.Seek(offset, origin); public override void SetLength(long value) => _innerStream.SetLength(value); public override int Read(byte[] buffer, int offset, int count) => _innerStream.Read(buffer, offset, count); public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) { return _innerStream.BeginRead(buffer, offset, count, callback, state); } public override int EndRead(IAsyncResult asyncResult) { return _innerStream.EndRead(asyncResult); } public override void Flush() { _onStart().GetAwaiter().GetResult(); _innerStream.Flush(); } public override async Task FlushAsync(CancellationToken cancellationToken) { await _onStart(); await _innerStream.FlushAsync(cancellationToken); } public override void Write(byte[] buffer, int offset, int count) { _onStart().GetAwaiter().GetResult(); _innerStream.Write(buffer, offset, count); } public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { await _onStart(); await _innerStream.WriteAsync(buffer, offset, count, cancellationToken); } public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) { return ToIAsyncResult(WriteAsync(buffer, offset, count), callback, state); } public override void EndWrite(IAsyncResult asyncResult) { if (asyncResult == null) { throw new ArgumentNullException(nameof(asyncResult)); } ((Task)asyncResult).GetAwaiter().GetResult(); } private static IAsyncResult ToIAsyncResult(Task task, AsyncCallback callback, object state) { var tcs = new TaskCompletionSource(state); task.ContinueWith(t => { if (t.IsFaulted) { tcs.TrySetException(t.Exception.InnerExceptions); } else if (t.IsCanceled) { tcs.TrySetCanceled(); } else { tcs.TrySetResult(0); } if (callback != null) { callback(tcs.Task); } }, CancellationToken.None, TaskContinuationOptions.None, TaskScheduler.Default); return tcs.Task; } } }