233 lines
7.6 KiB
C#
233 lines
7.6 KiB
C#
// Copyright (c) .NET Foundation. All rights reserved.
|
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
|
|
|
using System;
|
|
using System.Collections.Generic;
|
|
using System.IO;
|
|
using System.Linq;
|
|
using System.Runtime.InteropServices;
|
|
using System.Text;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
|
|
namespace Microsoft.AspNetCore.Server.Kestrel.Tls
|
|
{
|
|
public class TlsStream : Stream
|
|
{
|
|
private static unsafe OpenSsl.alpn_select_cb_t _alpnSelectCallback = AlpnSelectCallback;
|
|
|
|
private readonly Stream _innerStream;
|
|
private readonly byte[] _protocols;
|
|
private readonly GCHandle _protocolsHandle;
|
|
|
|
private IntPtr _ctx;
|
|
private IntPtr _ssl;
|
|
private IntPtr _inputBio;
|
|
private IntPtr _outputBio;
|
|
|
|
private readonly byte[] _inputBuffer = new byte[1024 * 1024];
|
|
private readonly byte[] _outputBuffer = new byte[1024 * 1024];
|
|
|
|
static TlsStream()
|
|
{
|
|
OpenSsl.SSL_library_init();
|
|
OpenSsl.SSL_load_error_strings();
|
|
OpenSsl.ERR_load_BIO_strings();
|
|
OpenSsl.OpenSSL_add_all_algorithms();
|
|
}
|
|
|
|
public TlsStream(Stream innerStream, string certificatePath, string password, IEnumerable<string> protocols)
|
|
{
|
|
_innerStream = innerStream;
|
|
_protocols = ToWireFormat(protocols);
|
|
_protocolsHandle = GCHandle.Alloc(_protocols);
|
|
|
|
_ctx = OpenSsl.SSL_CTX_new(OpenSsl.TLSv1_2_method());
|
|
|
|
if (_ctx == IntPtr.Zero)
|
|
{
|
|
throw new Exception("Unable to create SSL context.");
|
|
}
|
|
|
|
if(OpenSsl.SSL_CTX_Set_Pfx(_ctx, certificatePath, password) != 1)
|
|
{
|
|
throw new InvalidOperationException("Unable to load PFX");
|
|
}
|
|
|
|
OpenSsl.SSL_CTX_set_ecdh_auto(_ctx, 1);
|
|
|
|
OpenSsl.SSL_CTX_set_alpn_select_cb(_ctx, _alpnSelectCallback, GCHandle.ToIntPtr(_protocolsHandle));
|
|
|
|
_ssl = OpenSsl.SSL_new(_ctx);
|
|
|
|
_inputBio = OpenSsl.BIO_new(OpenSsl.BIO_s_mem());
|
|
OpenSsl.BIO_set_mem_eof_return(_inputBio, -1);
|
|
|
|
_outputBio = OpenSsl.BIO_new(OpenSsl.BIO_s_mem());
|
|
OpenSsl.BIO_set_mem_eof_return(_outputBio, -1);
|
|
|
|
OpenSsl.SSL_set_bio(_ssl, _inputBio, _outputBio);
|
|
}
|
|
|
|
~TlsStream()
|
|
{
|
|
if (_ssl != IntPtr.Zero)
|
|
{
|
|
OpenSsl.SSL_free(_ssl);
|
|
}
|
|
|
|
if (_ctx != IntPtr.Zero)
|
|
{
|
|
// This frees the BIOs.
|
|
OpenSsl.SSL_CTX_free(_ctx);
|
|
}
|
|
|
|
if (_protocolsHandle.IsAllocated)
|
|
{
|
|
_protocolsHandle.Free();
|
|
}
|
|
}
|
|
|
|
public override bool CanRead => true;
|
|
public override bool CanWrite => true;
|
|
|
|
public override bool CanSeek => false;
|
|
public override long Length => throw new NotSupportedException();
|
|
public override long Position
|
|
{
|
|
get => throw new NotSupportedException();
|
|
set => throw new NotSupportedException();
|
|
}
|
|
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
|
|
public override void SetLength(long value) => throw new NotSupportedException();
|
|
|
|
public override void Flush()
|
|
{
|
|
FlushAsync(default(CancellationToken)).GetAwaiter().GetResult();
|
|
}
|
|
|
|
public override int Read(byte[] buffer, int offset, int count)
|
|
{
|
|
return ReadAsync(buffer, offset, count).GetAwaiter().GetResult();
|
|
}
|
|
|
|
public override void Write(byte[] buffer, int offset, int count)
|
|
{
|
|
WriteAsync(buffer, offset, count).GetAwaiter().GetResult();
|
|
}
|
|
|
|
public override async Task FlushAsync(CancellationToken cancellationToken)
|
|
{
|
|
var pending = OpenSsl.BIO_ctrl_pending(_outputBio);
|
|
|
|
while (pending > 0)
|
|
{
|
|
var count = OpenSsl.BIO_read(_outputBio, _outputBuffer, 0, _outputBuffer.Length);
|
|
await _innerStream.WriteAsync(_outputBuffer, 0, count, cancellationToken);
|
|
|
|
pending = OpenSsl.BIO_ctrl_pending(_outputBio);
|
|
}
|
|
}
|
|
|
|
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
|
{
|
|
if (OpenSsl.BIO_ctrl_pending(_inputBio) == 0)
|
|
{
|
|
var bytesRead = await _innerStream.ReadAsync(_inputBuffer, 0, _inputBuffer.Length, cancellationToken);
|
|
|
|
if (bytesRead == 0)
|
|
{
|
|
return 0;
|
|
}
|
|
|
|
OpenSsl.BIO_write(_inputBio, _inputBuffer, 0, bytesRead);
|
|
}
|
|
|
|
return OpenSsl.SSL_read(_ssl, buffer, offset, count);
|
|
}
|
|
|
|
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
|
{
|
|
OpenSsl.SSL_write(_ssl, buffer, offset, count);
|
|
|
|
return FlushAsync(cancellationToken);
|
|
}
|
|
|
|
public async Task DoHandshakeAsync(CancellationToken cancellationToken = default(CancellationToken))
|
|
{
|
|
OpenSsl.SSL_set_accept_state(_ssl);
|
|
|
|
var count = 0;
|
|
|
|
try
|
|
{
|
|
while ((count = await _innerStream.ReadAsync(_inputBuffer, 0, _inputBuffer.Length, cancellationToken)) > 0)
|
|
{
|
|
if (count == 0)
|
|
{
|
|
throw new IOException("TLS handshake failed: the inner stream was closed.");
|
|
}
|
|
|
|
OpenSsl.BIO_write(_inputBio, _inputBuffer, 0, count);
|
|
|
|
var ret = OpenSsl.SSL_do_handshake(_ssl);
|
|
|
|
if (ret != 1)
|
|
{
|
|
var error = OpenSsl.SSL_get_error(_ssl, ret);
|
|
|
|
if (error != 2)
|
|
{
|
|
throw new IOException($"TLS handshake failed: {nameof(OpenSsl.SSL_do_handshake)} error {error}.");
|
|
}
|
|
}
|
|
|
|
await FlushAsync(cancellationToken);
|
|
|
|
if (ret == 1)
|
|
{
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
finally
|
|
{
|
|
_protocolsHandle.Free();
|
|
}
|
|
}
|
|
|
|
public string GetNegotiatedApplicationProtocol()
|
|
{
|
|
OpenSsl.SSL_get0_alpn_selected(_ssl, out var protocol);
|
|
return protocol;
|
|
}
|
|
|
|
private static unsafe int AlpnSelectCallback(IntPtr ssl, out byte* @out, out byte outlen, byte* @in, uint inlen, IntPtr arg)
|
|
{
|
|
var protocols = GCHandle.FromIntPtr(arg);
|
|
var server = (byte[])protocols.Target;
|
|
|
|
fixed (byte* serverPtr = server)
|
|
{
|
|
return OpenSsl.SSL_select_next_proto(out @out, out outlen, serverPtr, (uint)server.Length, @in, (uint)inlen) == OpenSsl.OPENSSL_NPN_NEGOTIATED
|
|
? OpenSsl.SSL_TLSEXT_ERR_OK
|
|
: OpenSsl.SSL_TLSEXT_ERR_NOACK;
|
|
}
|
|
}
|
|
|
|
private static byte[] ToWireFormat(IEnumerable<string> protocols)
|
|
{
|
|
var buffer = new byte[protocols.Count() + protocols.Sum(protocol => protocol.Length)];
|
|
|
|
var offset = 0;
|
|
foreach (var protocol in protocols)
|
|
{
|
|
buffer[offset++] = (byte)protocol.Length;
|
|
offset += Encoding.ASCII.GetBytes(protocol, 0, protocol.Length, buffer, offset);
|
|
}
|
|
|
|
return buffer;
|
|
}
|
|
}
|
|
}
|