diff --git a/test/Microsoft.AspNet.Server.KestrelTests/StreamSocketOutputTests.cs b/test/Microsoft.AspNet.Server.KestrelTests/StreamSocketOutputTests.cs new file mode 100644 index 0000000000..7b19bf6173 --- /dev/null +++ b/test/Microsoft.AspNet.Server.KestrelTests/StreamSocketOutputTests.cs @@ -0,0 +1,106 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using Microsoft.AspNet.Server.Kestrel.Filter; +using Microsoft.AspNet.Server.Kestrel.Http; +using Xunit; + +namespace Microsoft.AspNet.Server.KestrelTests +{ + public class StreamSocketOutputTests + { + [Fact] + public void DoesNotThrowForNullBuffers() + { + // This test was added because SslStream throws if passed null buffers with (count == 0) + // Which happens if ProduceEnd is called in Frame without _responseStarted == true + // As it calls ProduceStart with write immediate == true + // This happens in WebSocket Upgrade over SSL + + ISocketOutput socketOutput = new StreamSocketOutput(new ThrowsOnNullWriteStream(), null); + + // Should not throw + socketOutput.Write(default(ArraySegment), true); + + Assert.True(true); + } + + private class ThrowsOnNullWriteStream : Stream + { + public override bool CanRead + { + get + { + throw new NotImplementedException(); + } + } + + public override bool CanSeek + { + get + { + throw new NotImplementedException(); + } + } + + public override bool CanWrite + { + get + { + throw new NotImplementedException(); + } + } + + public override long Length + { + get + { + throw new NotImplementedException(); + } + } + + public override long Position + { + get + { + throw new NotImplementedException(); + } + + set + { + throw new NotImplementedException(); + } + } + + public override void Flush() + { + throw new NotImplementedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + if (buffer == null) + { + throw new ArgumentNullException(nameof(buffer)); + } + } + } + } +}