aspnetcore/src/Microsoft.AspNet.Server.Kes.../Http/Frame.cs

712 lines
24 KiB
C#

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNet.Server.Kestrel.Infrastructure;
using Microsoft.Framework.Logging;
using Microsoft.Framework.Primitives;
using System.Numerics;
using Microsoft.AspNet.Hosting.Builder;
// ReSharper disable AccessToModifiedClosure
namespace Microsoft.AspNet.Server.Kestrel.Http
{
public class Frame : FrameContext, IFrameControl
{
private static readonly Encoding _ascii = Encoding.ASCII;
private static readonly ArraySegment<byte> _endChunkBytes = CreateAsciiByteArraySegment("\r\n");
private static readonly ArraySegment<byte> _endChunkedResponseBytes = CreateAsciiByteArraySegment("0\r\n\r\n");
private static readonly ArraySegment<byte> _continueBytes = CreateAsciiByteArraySegment("HTTP/1.1 100 Continue\r\n\r\n");
private static readonly ArraySegment<byte> _emptyData = new ArraySegment<byte>(new byte[0]);
private static readonly byte[] _hex = Encoding.ASCII.GetBytes("0123456789abcdef");
private readonly object _onStartingSync = new Object();
private readonly object _onCompletedSync = new Object();
private readonly FrameRequestHeaders _requestHeaders = new FrameRequestHeaders();
private readonly FrameResponseHeaders _responseHeaders = new FrameResponseHeaders();
private List<KeyValuePair<Func<object, Task>, object>> _onStarting;
private List<KeyValuePair<Func<object, Task>, object>> _onCompleted;
private bool _responseStarted;
private bool _keepAlive;
private bool _autoChunk;
public Frame(ConnectionContext context) : base(context)
{
FrameControl = this;
Reset();
}
public string Method { get; set; }
public string RequestUri { get; set; }
public string Path { get; set; }
public string QueryString { get; set; }
public string HttpVersion { get; set; }
public IDictionary<string, StringValues> RequestHeaders { get; set; }
public MessageBody MessageBody { get; set; }
public Stream RequestBody { get; set; }
public int StatusCode { get; set; }
public string ReasonPhrase { get; set; }
public IDictionary<string, StringValues> ResponseHeaders { get; set; }
public Stream ResponseBody { get; set; }
public Stream DuplexStream { get; set; }
public bool HasResponseStarted
{
get { return _responseStarted; }
}
public void Reset()
{
_onStarting = null;
_onCompleted = null;
_responseStarted = false;
_keepAlive = false;
_autoChunk = false;
_requestHeaders.Reset();
ResetResponseHeaders();
Method = null;
RequestUri = null;
Path = null;
QueryString = null;
HttpVersion = null;
RequestHeaders = _requestHeaders;
MessageBody = null;
RequestBody = null;
StatusCode = 200;
ReasonPhrase = null;
ResponseHeaders = _responseHeaders;
ResponseBody = null;
DuplexStream = null;
}
public void ResetResponseHeaders()
{
_responseHeaders.Reset();
_responseHeaders.HeaderServer = "Kestrel";
_responseHeaders.HeaderDate = DateTime.UtcNow.ToString("r");
}
public async Task ProcessFraming()
{
var terminated = false;
while (!terminated)
{
while (!terminated && !TakeStartLine(SocketInput))
{
terminated = SocketInput.RemoteIntakeFin;
if (!terminated)
{
await SocketInput;
}
else
{
var x = 5;
}
}
while (!terminated && !TakeMessageHeader2(SocketInput))
{
terminated = SocketInput.RemoteIntakeFin;
if (!terminated)
{
await SocketInput;
}
else
{
var x = 5;
}
}
if (!terminated)
{
MessageBody = MessageBody.For(HttpVersion, _requestHeaders, this);
_keepAlive = MessageBody.RequestKeepAlive;
RequestBody = new FrameRequestStream(MessageBody);
ResponseBody = new FrameResponseStream(this);
DuplexStream = new FrameDuplexStream(RequestBody, ResponseBody);
Exception error = null;
try
{
await Application.Invoke(this).ConfigureAwait(false);
// Trigger FireOnStarting if ProduceStart hasn't been called yet.
// We call it here, so it can go through our normal error handling
// and respond with a 500 if an OnStarting callback throws.
if (!_responseStarted)
{
FireOnStarting();
}
}
catch (Exception ex)
{
error = ex;
}
finally
{
FireOnCompleted();
ProduceEnd(error);
}
terminated = !_keepAlive;
}
Reset();
}
// Connection Terminated!
ConnectionControl.End(ProduceEndType.SocketShutdownSend);
ConnectionControl.End(ProduceEndType.SocketDisconnect);
}
public void OnStarting(Func<object, Task> callback, object state)
{
lock (_onStartingSync)
{
if (_onStarting == null)
{
_onStarting = new List<KeyValuePair<Func<object, Task>, object>>();
}
_onStarting.Add(new KeyValuePair<Func<object, Task>, object>(callback, state));
}
}
public void OnCompleted(Func<object, Task> callback, object state)
{
lock (_onCompletedSync)
{
if (_onCompleted == null)
{
_onCompleted = new List<KeyValuePair<Func<object, Task>, object>>();
}
_onCompleted.Add(new KeyValuePair<Func<object, Task>, object>(callback, state));
}
}
private void FireOnStarting()
{
List<KeyValuePair<Func<object, Task>, object>> onStarting = null;
lock (_onStartingSync)
{
onStarting = _onStarting;
_onStarting = null;
}
if (onStarting != null)
{
foreach (var entry in onStarting)
{
entry.Key.Invoke(entry.Value).Wait();
}
}
}
private void FireOnCompleted()
{
List<KeyValuePair<Func<object, Task>, object>> onCompleted = null;
lock (_onCompletedSync)
{
onCompleted = _onCompleted;
_onCompleted = null;
}
if (onCompleted != null)
{
foreach (var entry in onCompleted)
{
try
{
entry.Key.Invoke(entry.Value).Wait();
}
catch
{
// Ignore exceptions
}
}
}
}
public void Flush()
{
ProduceStart(immediate: false);
SocketOutput.Write(_emptyData, immediate: true);
}
public Task FlushAsync(CancellationToken cancellationToken)
{
ProduceStart(immediate: false);
return SocketOutput.WriteAsync(_emptyData, immediate: true);
}
public void Write(ArraySegment<byte> data)
{
ProduceStart(immediate: false);
if (_autoChunk)
{
if (data.Count == 0)
{
return;
}
WriteChunked(data);
}
else
{
SocketOutput.Write(data, immediate: true);
}
}
public Task WriteAsync(ArraySegment<byte> data, CancellationToken cancellationToken)
{
ProduceStart(immediate: false);
if (_autoChunk)
{
if (data.Count == 0)
{
return TaskUtilities.CompletedTask;
}
return WriteChunkedAsync(data, cancellationToken);
}
else
{
return SocketOutput.WriteAsync(data, immediate: true, cancellationToken: cancellationToken);
}
}
private void WriteChunked(ArraySegment<byte> data)
{
SocketOutput.Write(BeginChunkBytes(data.Count), immediate: false);
SocketOutput.Write(data, immediate: false);
SocketOutput.Write(_endChunkBytes, immediate: true);
}
private async Task WriteChunkedAsync(ArraySegment<byte> data, CancellationToken cancellationToken)
{
await SocketOutput.WriteAsync(BeginChunkBytes(data.Count), immediate: false, cancellationToken: cancellationToken);
await SocketOutput.WriteAsync(data, immediate: false, cancellationToken: cancellationToken);
await SocketOutput.WriteAsync(_endChunkBytes, immediate: true, cancellationToken: cancellationToken);
}
public static ArraySegment<byte> BeginChunkBytes(int dataCount)
{
var bytes = new byte[10]
{
_hex[((dataCount >> 0x1c) & 0x0f)],
_hex[((dataCount >> 0x18) & 0x0f)],
_hex[((dataCount >> 0x14) & 0x0f)],
_hex[((dataCount >> 0x10) & 0x0f)],
_hex[((dataCount >> 0x0c) & 0x0f)],
_hex[((dataCount >> 0x08) & 0x0f)],
_hex[((dataCount >> 0x04) & 0x0f)],
_hex[((dataCount >> 0x00) & 0x0f)],
(byte)'\r',
(byte)'\n',
};
// Determine the most-significant non-zero nibble
int total, shift;
total = (dataCount > 0xffff) ? 0x10 : 0x00;
dataCount >>= total;
shift = (dataCount > 0x00ff) ? 0x08 : 0x00;
dataCount >>= shift;
total |= shift;
total |= (dataCount > 0x000f) ? 0x04 : 0x00;
var offset = 7 - (total >> 2);
return new ArraySegment<byte>(bytes, offset, 10 - offset);
}
private void WriteChunkedResponseSuffix()
{
SocketOutput.Write(_endChunkedResponseBytes, immediate: true);
}
public void Upgrade(IDictionary<string, object> options, Func<object, Task> callback)
{
_keepAlive = false;
ProduceStart();
}
private static ArraySegment<byte> CreateAsciiByteArraySegment(string text)
{
var bytes = Encoding.ASCII.GetBytes(text);
return new ArraySegment<byte>(bytes);
}
public void ProduceContinue()
{
if (_responseStarted) return;
StringValues expect;
if (HttpVersion.Equals("HTTP/1.1") &&
RequestHeaders.TryGetValue("Expect", out expect) &&
(expect.FirstOrDefault() ?? "").Equals("100-continue", StringComparison.OrdinalIgnoreCase))
{
SocketOutput.Write(_continueBytes);
}
}
public void ProduceStart(bool immediate = true, bool appCompleted = false)
{
// ProduceStart shouldn't no-op in the future just b/c FireOnStarting throws.
if (_responseStarted) return;
FireOnStarting();
_responseStarted = true;
var status = ReasonPhrases.ToStatus(StatusCode, ReasonPhrase);
var responseHeader = CreateResponseHeader(status, appCompleted, ResponseHeaders);
SocketOutput.Write(responseHeader.Item1, immediate: immediate);
responseHeader.Item2.Dispose();
}
public void ProduceEnd(Exception ex)
{
if (ex != null)
{
if (_responseStarted)
{
// We can no longer respond with a 500, so we simply close the connection.
ConnectionControl.End(ProduceEndType.SocketDisconnect);
return;
}
else
{
StatusCode = 500;
ReasonPhrase = null;
// If OnStarting hasn't been triggered yet, we don't want to trigger it now that
// the app func has failed. https://github.com/aspnet/KestrelHttpServer/issues/43
_onStarting = null;
ResetResponseHeaders();
_responseHeaders.HeaderContentLength = "0";
}
}
ProduceStart(immediate: true, appCompleted: true);
// _autoChunk should be checked after we are sure ProduceStart() has been called
// since ProduceStart() may set _autoChunk to true.
if (_autoChunk)
{
WriteChunkedResponseSuffix();
}
if (!_keepAlive)
{
ConnectionControl.End(ProduceEndType.SocketShutdownSend);
}
//NOTE: must finish reading request body
ConnectionControl.End(_keepAlive ? ProduceEndType.ConnectionKeepAlive : ProduceEndType.SocketDisconnect);
}
private Tuple<ArraySegment<byte>, IDisposable> CreateResponseHeader(
string status,
bool appCompleted,
IEnumerable<KeyValuePair<string, StringValues>> headers)
{
var writer = new MemoryPoolTextWriter(Memory);
writer.Write(HttpVersion);
writer.Write(' ');
writer.Write(status);
writer.Write('\r');
writer.Write('\n');
var hasConnection = false;
var hasTransferEncoding = false;
var hasContentLength = false;
if (headers != null)
{
foreach (var header in headers)
{
var isConnection = false;
if (!hasConnection &&
string.Equals(header.Key, "Connection", StringComparison.OrdinalIgnoreCase))
{
hasConnection = isConnection = true;
}
else if (!hasTransferEncoding &&
string.Equals(header.Key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase))
{
hasTransferEncoding = true;
}
else if (!hasContentLength &&
string.Equals(header.Key, "Content-Length", StringComparison.OrdinalIgnoreCase))
{
hasContentLength = true;
}
foreach (var value in header.Value)
{
writer.Write(header.Key);
writer.Write(':');
writer.Write(' ');
writer.Write(value);
writer.Write('\r');
writer.Write('\n');
if (isConnection && value.IndexOf("close", StringComparison.OrdinalIgnoreCase) != -1)
{
_keepAlive = false;
}
}
}
}
if (_keepAlive && !hasTransferEncoding && !hasContentLength)
{
if (appCompleted)
{
// Don't set the Content-Length or Transfer-Encoding headers
// automatically for HEAD requests or 101, 204, 205, 304 responses.
if (Method != "HEAD" && StatusCanHaveBody(StatusCode))
{
// Since the app has completed and we are only now generating
// the headers we can safely set the Content-Length to 0.
writer.Write("Content-Length: 0\r\n");
}
}
else
{
if (HttpVersion == "HTTP/1.1")
{
_autoChunk = true;
writer.Write("Transfer-Encoding: chunked\r\n");
}
else
{
_keepAlive = false;
}
}
}
if (_keepAlive == false && hasConnection == false && HttpVersion == "HTTP/1.1")
{
writer.Write("Connection: close\r\n\r\n");
}
else if (_keepAlive && hasConnection == false && HttpVersion == "HTTP/1.0")
{
writer.Write("Connection: keep-alive\r\n\r\n");
}
else
{
writer.Write('\r');
writer.Write('\n');
}
writer.Flush();
return new Tuple<ArraySegment<byte>, IDisposable>(writer.Buffer, writer);
}
private bool TakeStartLine(SocketInput input)
{
var begin = input.GetIterator();
if (begin.IsDefault) return false;
var end = begin.IndexOf(' ');
var method = begin.GetString(end);
char chFound;
begin = end.Add(1);
end = begin.IndexOfAny(' ', '?', out chFound);
var requestUri = begin.GetString(end);
begin = end;
end = chFound == '?' ? begin.IndexOf(' ') : begin;
var queryString = begin.GetString(end);
begin = end.Add(1);
end = begin.IndexOf('\r');
var httpVersion = begin.GetString(end);
end = end.Add(1);
if (end.Peek() != '\n')
{
return false;
}
Method = method;
RequestUri = requestUri;
QueryString = queryString;
HttpVersion = httpVersion;
input.JumpTo(end.Add(1));
return true;
}
static string GetString(ArraySegment<byte> range, int startIndex, int endIndex)
{
return Encoding.UTF8.GetString(range.Array, range.Offset + startIndex, endIndex - startIndex);
}
private bool TakeMessageHeader2(SocketInput baton)
{
char chFirst;
char chSecond;
var scan = baton.GetIterator();
while (!scan.IsDefault)
{
var beginName = scan;
scan = scan.IndexOfAny(':', '\r', out chFirst);
var endName = scan;
chSecond = scan.MoveNext();
if (chFirst == '\r' && chSecond == '\n')
{
baton.JumpTo(scan.Add(1));
return true;
}
if (chFirst == char.MinValue)
{
return false;
}
while (
chSecond == ' ' ||
chSecond == '\t' ||
chSecond == '\r' ||
chSecond == '\n')
{
chSecond = scan.MoveNext();
}
var beginValue = scan;
var wrapping = false;
while (!scan.IsDefault)
{
var endValue = scan = scan.IndexOf('\r');
chFirst = scan.MoveNext();
if (chFirst != '\n')
{
continue;
}
chSecond = scan.MoveNext();
if (chSecond == ' ' || chSecond == '\t')
{
wrapping = true;
continue;
}
var name = beginName.GetArraySegment(endName);
var value = beginValue.GetString(endValue);
if (wrapping)
{
value = value.Replace("\r\n", " ");
}
_requestHeaders.Append(name.Array, name.Offset, name.Count, value);
break;
}
}
return false;
}
private bool TakeMessageHeader(SocketInput baton, out bool endOfHeaders)
{
var remaining = baton.Buffer;
endOfHeaders = false;
if (remaining.Count < 2)
{
return false;
}
var ch0 = remaining.Array[remaining.Offset];
var ch1 = remaining.Array[remaining.Offset + 1];
if (ch0 == '\r' && ch1 == '\n')
{
endOfHeaders = true;
baton.Skip(2);
return true;
}
if (remaining.Count < 3)
{
return false;
}
var wrappedHeaders = false;
var colonIndex = -1;
var valueStartIndex = -1;
var valueEndIndex = -1;
for (var index = 0; index != remaining.Count - 2; ++index)
{
var ch2 = remaining.Array[remaining.Offset + index + 2];
if (ch0 == '\r' &&
ch1 == '\n' &&
ch2 != ' ' &&
ch2 != '\t')
{
var value = "";
if (valueEndIndex != -1)
{
value = _ascii.GetString(
remaining.Array, remaining.Offset + valueStartIndex, valueEndIndex - valueStartIndex);
}
if (wrappedHeaders)
{
value = value.Replace("\r\n", " ");
}
AddRequestHeader(remaining.Array, remaining.Offset, colonIndex, value);
baton.Skip(index + 2);
return true;
}
if (colonIndex == -1 && ch0 == ':')
{
colonIndex = index;
}
else if (colonIndex != -1 &&
ch0 != ' ' &&
ch0 != '\t' &&
ch0 != '\r' &&
ch0 != '\n')
{
if (valueStartIndex == -1)
{
valueStartIndex = index;
}
valueEndIndex = index + 1;
}
else if (!wrappedHeaders &&
ch0 == '\r' &&
ch1 == '\n' &&
(ch2 == ' ' ||
ch2 == '\t'))
{
wrappedHeaders = true;
}
ch0 = ch1;
ch1 = ch2;
}
return false;
}
private void AddRequestHeader(byte[] keyBytes, int keyOffset, int keyLength, string value)
{
_requestHeaders.Append(keyBytes, keyOffset, keyLength, value);
}
public bool StatusCanHaveBody(int statusCode)
{
// List of status codes taken from Microsoft.Net.Http.Server.Response
return statusCode != 101 &&
statusCode != 204 &&
statusCode != 205 &&
statusCode != 304;
}
}
}