commit 251630590d1e666da57443bf12fdcbba6c8e653f Author: Chris Ross Date: Fri Feb 7 17:01:08 2014 -0800 Initial port. diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..bdaa5ba982 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,50 @@ +*.doc diff=astextplain +*.DOC diff=astextplain +*.docx diff=astextplain +*.DOCX diff=astextplain +*.dot diff=astextplain +*.DOT diff=astextplain +*.pdf diff=astextplain +*.PDF diff=astextplain +*.rtf diff=astextplain +*.RTF diff=astextplain + +*.jpg binary +*.png binary +*.gif binary + +*.cs text=auto diff=csharp +*.vb text=auto +*.resx text=auto +*.c text=auto +*.cpp text=auto +*.cxx text=auto +*.h text=auto +*.hxx text=auto +*.py text=auto +*.rb text=auto +*.java text=auto +*.html text=auto +*.htm text=auto +*.css text=auto +*.scss text=auto +*.sass text=auto +*.less text=auto +*.js text=auto +*.lisp text=auto +*.clj text=auto +*.sql text=auto +*.php text=auto +*.lua text=auto +*.m text=auto +*.asm text=auto +*.erl text=auto +*.fs text=auto +*.fsx text=auto +*.hs text=auto + +*.csproj text=auto +*.vbproj text=auto +*.fsproj text=auto +*.dbproj text=auto +*.sln text=auto eol=crlf diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..2554a1fc23 --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +[Oo]bj/ +[Bb]in/ +TestResults/ +.nuget/ +_ReSharper.*/ +packages/ +artifacts/ +PublishProfiles/ +*.user +*.suo +*.cache +*.docstates +_ReSharper.* +nuget.exe +*net45.csproj +*k10.csproj +*.psess +*.vsp +*.pidb +*.userprefs +*DS_Store +*.ncrunchsolution diff --git a/KatanaInternal.sln b/KatanaInternal.sln new file mode 100644 index 0000000000..9f18ac98c5 --- /dev/null +++ b/KatanaInternal.sln @@ -0,0 +1,94 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio 2013 +VisualStudioVersion = 12.0.30110.0 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TestClient", "samples\TestClient\TestClient.csproj", "{8B828433-B333-4C19-96AE-00BFFF9D8841}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNet.Server.WebListener.k10", "src\Microsoft.AspNet.Server.WebListener\Microsoft.AspNet.Server.WebListener.k10.csproj", "{6D9D3023-3ED7-4C95-80F0-347843ABD759}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNet.Server.WebListener.net45", "src\Microsoft.AspNet.Server.WebListener\Microsoft.AspNet.Server.WebListener.net45.csproj", "{253B9134-B6EB-4E59-8725-D983FD941A21}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNet.WebSockets.net45", "src\Microsoft.AspNet.WebSockets\Microsoft.AspNet.WebSockets.net45.csproj", "{00C6A882-1FE2-4769-901C-023D8DC175C4}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{99D5E5F3-88F5-4CCF-8D8C-717C8925DF09}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{E183C826-1360-4DFF-9994-F33CED5C8525}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "samples", "samples", "{3A1E31E3-2794-4CA3-B8E2-253E96BDE514}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HelloWorld.net45", "samples\HelloWorld\HelloWorld.net45.csproj", "{BF335732-BB09-49A1-8676-F074047E7DB2}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SelfHostServer.net45", "samples\SelfHostServer\SelfHostServer.net45.csproj", "{96C67B2F-9913-4E8D-B2E8-969BE66B71B6}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNet.Server.WebListener.Test.net45", "test\Microsoft.AspNet.Server.WebListener.Test\Microsoft.AspNet.Server.WebListener.Test.net45.csproj", "{485DAC59-A1F1-4D47-98BF-B448C994E05B}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HelloWorld.k10", "samples\HelloWorld\HelloWorld.k10.csproj", "{A1F2CA12-3F08-4DE2-B3D9-52DBE267936B}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNet.Security.Windows.net45", "src\Microsoft.AspNet.Security.Windows\Microsoft.AspNet.Security.Windows.net45.csproj", "{8B4EF749-251D-4222-AD18-DE5A1E7D321A}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNet.Security.Windows.Test.net45", "test\Microsoft.AspNet.Security.Windows.Test\Microsoft.AspNet.Security.Windows.Test.net45.csproj", "{3EC418D5-C8FD-47AA-BFED-F524358EC3DD}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {8B828433-B333-4C19-96AE-00BFFF9D8841}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8B828433-B333-4C19-96AE-00BFFF9D8841}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8B828433-B333-4C19-96AE-00BFFF9D8841}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8B828433-B333-4C19-96AE-00BFFF9D8841}.Release|Any CPU.Build.0 = Release|Any CPU + {6D9D3023-3ED7-4C95-80F0-347843ABD759}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {6D9D3023-3ED7-4C95-80F0-347843ABD759}.Debug|Any CPU.Build.0 = Debug|Any CPU + {6D9D3023-3ED7-4C95-80F0-347843ABD759}.Release|Any CPU.ActiveCfg = Release|Any CPU + {6D9D3023-3ED7-4C95-80F0-347843ABD759}.Release|Any CPU.Build.0 = Release|Any CPU + {253B9134-B6EB-4E59-8725-D983FD941A21}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {253B9134-B6EB-4E59-8725-D983FD941A21}.Debug|Any CPU.Build.0 = Debug|Any CPU + {253B9134-B6EB-4E59-8725-D983FD941A21}.Release|Any CPU.ActiveCfg = Release|Any CPU + {253B9134-B6EB-4E59-8725-D983FD941A21}.Release|Any CPU.Build.0 = Release|Any CPU + {00C6A882-1FE2-4769-901C-023D8DC175C4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {00C6A882-1FE2-4769-901C-023D8DC175C4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {00C6A882-1FE2-4769-901C-023D8DC175C4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {00C6A882-1FE2-4769-901C-023D8DC175C4}.Release|Any CPU.Build.0 = Release|Any CPU + {BF335732-BB09-49A1-8676-F074047E7DB2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BF335732-BB09-49A1-8676-F074047E7DB2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BF335732-BB09-49A1-8676-F074047E7DB2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BF335732-BB09-49A1-8676-F074047E7DB2}.Release|Any CPU.Build.0 = Release|Any CPU + {96C67B2F-9913-4E8D-B2E8-969BE66B71B6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {96C67B2F-9913-4E8D-B2E8-969BE66B71B6}.Debug|Any CPU.Build.0 = Debug|Any CPU + {96C67B2F-9913-4E8D-B2E8-969BE66B71B6}.Release|Any CPU.ActiveCfg = Release|Any CPU + {96C67B2F-9913-4E8D-B2E8-969BE66B71B6}.Release|Any CPU.Build.0 = Release|Any CPU + {485DAC59-A1F1-4D47-98BF-B448C994E05B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {485DAC59-A1F1-4D47-98BF-B448C994E05B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {485DAC59-A1F1-4D47-98BF-B448C994E05B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {485DAC59-A1F1-4D47-98BF-B448C994E05B}.Release|Any CPU.Build.0 = Release|Any CPU + {A1F2CA12-3F08-4DE2-B3D9-52DBE267936B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A1F2CA12-3F08-4DE2-B3D9-52DBE267936B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A1F2CA12-3F08-4DE2-B3D9-52DBE267936B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A1F2CA12-3F08-4DE2-B3D9-52DBE267936B}.Release|Any CPU.Build.0 = Release|Any CPU + {8B4EF749-251D-4222-AD18-DE5A1E7D321A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8B4EF749-251D-4222-AD18-DE5A1E7D321A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8B4EF749-251D-4222-AD18-DE5A1E7D321A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8B4EF749-251D-4222-AD18-DE5A1E7D321A}.Release|Any CPU.Build.0 = Release|Any CPU + {3EC418D5-C8FD-47AA-BFED-F524358EC3DD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3EC418D5-C8FD-47AA-BFED-F524358EC3DD}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3EC418D5-C8FD-47AA-BFED-F524358EC3DD}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3EC418D5-C8FD-47AA-BFED-F524358EC3DD}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {485DAC59-A1F1-4D47-98BF-B448C994E05B} = {E183C826-1360-4DFF-9994-F33CED5C8525} + {3EC418D5-C8FD-47AA-BFED-F524358EC3DD} = {E183C826-1360-4DFF-9994-F33CED5C8525} + {8B828433-B333-4C19-96AE-00BFFF9D8841} = {3A1E31E3-2794-4CA3-B8E2-253E96BDE514} + {BF335732-BB09-49A1-8676-F074047E7DB2} = {3A1E31E3-2794-4CA3-B8E2-253E96BDE514} + {96C67B2F-9913-4E8D-B2E8-969BE66B71B6} = {3A1E31E3-2794-4CA3-B8E2-253E96BDE514} + {A1F2CA12-3F08-4DE2-B3D9-52DBE267936B} = {3A1E31E3-2794-4CA3-B8E2-253E96BDE514} + {6D9D3023-3ED7-4C95-80F0-347843ABD759} = {99D5E5F3-88F5-4CCF-8D8C-717C8925DF09} + {253B9134-B6EB-4E59-8725-D983FD941A21} = {99D5E5F3-88F5-4CCF-8D8C-717C8925DF09} + {00C6A882-1FE2-4769-901C-023D8DC175C4} = {99D5E5F3-88F5-4CCF-8D8C-717C8925DF09} + {8B4EF749-251D-4222-AD18-DE5A1E7D321A} = {99D5E5F3-88F5-4CCF-8D8C-717C8925DF09} + EndGlobalSection +EndGlobal diff --git a/NuGet.Config b/NuGet.Config new file mode 100644 index 0000000000..ab583b0ff7 --- /dev/null +++ b/NuGet.Config @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/build.cmd b/build.cmd new file mode 100644 index 0000000000..7045ee1f84 --- /dev/null +++ b/build.cmd @@ -0,0 +1,23 @@ +@echo off +cd %~dp0 + +SETLOCAL +SET CACHED_NUGET=%LocalAppData%\NuGet\NuGet.exe + +IF EXIST %CACHED_NUGET% goto copynuget +echo Downloading latest version of NuGet.exe... +IF NOT EXIST %LocalAppData%\NuGet md %LocalAppData%\NuGet +@powershell -NoProfile -ExecutionPolicy unrestricted -Command "$ProgressPreference = 'SilentlyContinue'; Invoke-WebRequest 'https://www.nuget.org/nuget.exe' -OutFile '%CACHED_NUGET%'" + +:copynuget +IF EXIST .nuget\nuget.exe goto restore +md .nuget +copy %CACHED_NUGET% .nuget\nuget.exe > nul + +:restore +IF EXIST packages\KoreBuild goto run +.nuget\NuGet.exe install KoreBuild -ExcludeVersion -o packages -nocache -pre +.nuget\NuGet.exe install Sake -version 0.2 -o packages -ExcludeVersion + +:run +packages\Sake\tools\Sake.exe -I packages\KoreBuild\build -f makefile.shade %* diff --git a/global.json b/global.json new file mode 100644 index 0000000000..840c36f6ad --- /dev/null +++ b/global.json @@ -0,0 +1,3 @@ +{ + "sources": ["src"] +} \ No newline at end of file diff --git a/makefile.shade b/makefile.shade new file mode 100644 index 0000000000..6357ea2841 --- /dev/null +++ b/makefile.shade @@ -0,0 +1,7 @@ + +var VERSION='0.1' +var FULL_VERSION='0.1' +var AUTHORS='Microsoft' + +use-standard-lifecycle +k-standard-goals diff --git a/samples/HelloWorld/Program.cs b/samples/HelloWorld/Program.cs new file mode 100644 index 0000000000..5e176ff7ee --- /dev/null +++ b/samples/HelloWorld/Program.cs @@ -0,0 +1,55 @@ + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNet.Server.WebListener; + +using AppFunc = System.Func, System.Threading.Tasks.Task>; + +public class Program +{ + public static void Main(string[] args) + { + using (CreateServer(new AppFunc(HelloWorldApp))) + { + Console.WriteLine("Running, press enter to exit..."); + Console.ReadLine(); + } + } + + private static IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + public static Task HelloWorldApp(IDictionary environment) + { + string responseText = "Hello World"; + byte[] responseBytes = Encoding.UTF8.GetBytes(responseText); + + // See http://owin.org/spec/owin-1.0.0.html for standard environment keys. + Stream responseStream = (Stream)environment["owin.ResponseBody"]; + IDictionary responseHeaders = + (IDictionary)environment["owin.ResponseHeaders"]; + + responseHeaders["Content-Length"] = new string[] { responseBytes.Length.ToString(CultureInfo.InvariantCulture) }; + responseHeaders["Content-Type"] = new string[] { "text/plain" }; + + return responseStream.WriteAsync(responseBytes, 0, responseBytes.Length); + } +} diff --git a/samples/HelloWorld/project.json b/samples/HelloWorld/project.json new file mode 100644 index 0000000000..63b0cd1595 --- /dev/null +++ b/samples/HelloWorld/project.json @@ -0,0 +1,10 @@ +{ + "version" : "0.1-alpha-*", + "dependencies": { + "Microsoft.AspNet.Server.WebListener" : "" + }, + "configurations": { + "net45": { }, + "k10" : { } + } +} diff --git a/samples/SelfHostServer/App.config b/samples/SelfHostServer/App.config new file mode 100644 index 0000000000..73c9881fb6 --- /dev/null +++ b/samples/SelfHostServer/App.config @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/samples/SelfHostServer/Program.cs b/samples/SelfHostServer/Program.cs new file mode 100644 index 0000000000..f01efdd682 --- /dev/null +++ b/samples/SelfHostServer/Program.cs @@ -0,0 +1,138 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Owin; +using Microsoft.AspNet.Server.WebListener; +using Microsoft.Owin.Hosting; +using Owin; + +namespace SelfHostServer +{ + // http://owin.org/extensions/owin-WebSocket-Extension-v0.4.0.htm + using WebSocketAccept = Action, // options + Func, Task>>; // callback + using WebSocketCloseAsync = + Func; + using WebSocketReceiveAsync = + Func /* data */, + CancellationToken /* cancel */, + Task>>; + using WebSocketReceiveResult = Tuple; // count + using WebSocketSendAsync = + Func /* data */, + int /* messageType */, + bool /* endOfMessage */, + CancellationToken /* cancel */, + Task>; + + public class Program + { + private static byte[] Data = new byte[1024]; + + public static void Main(string[] args) + { + using (WebApp.Start(new StartOptions( + // "http://localhost:5000/" + "https://localhost:9090/" + ) + { + ServerFactory = "Microsoft.AspNet.Server.WebListener" + })) + { + Console.WriteLine("Running, press any key to exit"); + // System.Diagnostics.Process.Start("http://localhost:5000/"); + Console.ReadKey(); + } + } + + public void Configuration(IAppBuilder app) + { + OwinWebListener listener = (OwinWebListener)app.Properties["Microsoft.AspNet.Server.WebListener.OwinWebListener"]; + listener.AuthenticationManager.AuthenticationTypes = + AuthenticationType.Basic | + AuthenticationType.Digest | + AuthenticationType.Negotiate | + AuthenticationType.Ntlm | + AuthenticationType.Kerberos; + + app.Use((context, next) => + { + Console.WriteLine("Request: " + context.Request.Uri); + return next(); + }); + app.Use((context, next) => + { + if (context.Request.User == null) + { + context.Response.StatusCode = 401; + return Task.FromResult(0); + } + else + { + Console.WriteLine(context.Request.User.Identity.AuthenticationType); + } + return next(); + }); + app.UseWebSockets(); + app.Use(UpgradeToWebSockets); + app.Run(Invoke); + } + + public Task Invoke(IOwinContext context) + { + context.Response.ContentLength = Data.Length; + return context.Response.WriteAsync(Data); + } + + // Run once per request + private Task UpgradeToWebSockets(IOwinContext context, Func next) + { + WebSocketAccept accept = context.Get("websocket.Accept"); + if (accept == null) + { + // Not a websocket request + return next(); + } + + accept(null, WebSocketEcho); + + return Task.FromResult(null); + } + + private async Task WebSocketEcho(IDictionary websocketContext) + { + var sendAsync = (WebSocketSendAsync)websocketContext["websocket.SendAsync"]; + var receiveAsync = (WebSocketReceiveAsync)websocketContext["websocket.ReceiveAsync"]; + var closeAsync = (WebSocketCloseAsync)websocketContext["websocket.CloseAsync"]; + var callCancelled = (CancellationToken)websocketContext["websocket.CallCancelled"]; + + byte[] buffer = new byte[1024]; + WebSocketReceiveResult received = await receiveAsync(new ArraySegment(buffer), callCancelled); + + object status; + while (!websocketContext.TryGetValue("websocket.ClientCloseStatus", out status) || (int)status == 0) + { + // Echo anything we receive + await sendAsync(new ArraySegment(buffer, 0, received.Item3), received.Item1, received.Item2, callCancelled); + + received = await receiveAsync(new ArraySegment(buffer), callCancelled); + } + + await closeAsync((int)websocketContext["websocket.ClientCloseStatus"], (string)websocketContext["websocket.ClientCloseDescription"], callCancelled); + } + } +} diff --git a/samples/SelfHostServer/Properties/AssemblyInfo.cs b/samples/SelfHostServer/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..86b415a1b3 --- /dev/null +++ b/samples/SelfHostServer/Properties/AssemblyInfo.cs @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("SelfHostServer")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("SelfHostServer")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("ec50ddb4-9ec6-4cbd-96ac-15de948247cc")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] diff --git a/samples/SelfHostServer/Public/1kb.txt b/samples/SelfHostServer/Public/1kb.txt new file mode 100644 index 0000000000..1d43866603 --- /dev/null +++ b/samples/SelfHostServer/Public/1kb.txt @@ -0,0 +1 @@ +asdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfasdfasdfasdfasdfasd \ No newline at end of file diff --git a/samples/SelfHostServer/packages.config b/samples/SelfHostServer/packages.config new file mode 100644 index 0000000000..b452be1749 --- /dev/null +++ b/samples/SelfHostServer/packages.config @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/samples/SelfHostServer/project.json b/samples/SelfHostServer/project.json new file mode 100644 index 0000000000..ed203d2aeb --- /dev/null +++ b/samples/SelfHostServer/project.json @@ -0,0 +1,19 @@ +{ + "version" : "0.1-alpha-*", + "dependencies": { + "Microsoft.AspNet.Server.WebListener" : "", + "Microsoft.AspNet.WebSockets" : "" + }, + "configurations": { + "net45": { + "dependencies": { + "Owin": "1.0", + "Microsoft.Owin": "2.1.0", + "Microsoft.Owin.Diagnostics": "2.1.0", + "Microsoft.Owin.Hosting": "2.1.0", + "Microsoft.Owin.Host.HttpListener": "2.1.0", + "Microsoft.AspNet.AppBuilderSupport": "0.1-alpha-*" + } + } + } +} diff --git a/samples/TestClient/App.config b/samples/TestClient/App.config new file mode 100644 index 0000000000..8e15646352 --- /dev/null +++ b/samples/TestClient/App.config @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/samples/TestClient/Program.cs b/samples/TestClient/Program.cs new file mode 100644 index 0000000000..b1f93759aa --- /dev/null +++ b/samples/TestClient/Program.cs @@ -0,0 +1,78 @@ +using System; +using System.Net; +using System.Net.Http; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace TestClient +{ + public class Program + { + private const string Address = + // "http://localhost:5000/public/1kb.txt"; + "https://localhost:9090/public/1kb.txt"; + + public static void Main(string[] args) + { + WebRequestHandler handler = new WebRequestHandler(); + handler.ServerCertificateValidationCallback = (_, __, ___, ____) => true; + // handler.UseDefaultCredentials = true; + handler.Credentials = new NetworkCredential(@"redmond\chrross", "passwird"); + HttpClient client = new HttpClient(handler); + + /* + int completionCount = 0; + int itterations = 30000; + for (int i = 0; i < itterations; i++) + { + client.GetAsync(Address) + .ContinueWith(t => Interlocked.Increment(ref completionCount)); + } + + while (completionCount < itterations) + { + Thread.Sleep(10); + }*/ + + while (true) + { + Console.WriteLine("Press any key to send request"); + Console.ReadKey(); + var result = client.GetAsync(Address).Result; + Console.WriteLine(result); + } + + // RunWebSocketClient().Wait(); + Console.WriteLine("Done"); + Console.ReadKey(); + } + + public static async Task RunWebSocketClient() + { + ClientWebSocket websocket = new ClientWebSocket(); + + string url = "ws://localhost:5000/"; + Console.WriteLine("Connecting to: " + url); + await websocket.ConnectAsync(new Uri(url), CancellationToken.None); + + string message = "Hello World"; + Console.WriteLine("Sending message: " + message); + byte[] messageBytes = Encoding.UTF8.GetBytes(message); + await websocket.SendAsync(new ArraySegment(messageBytes), WebSocketMessageType.Text, true, CancellationToken.None); + + byte[] incomingData = new byte[1024]; + WebSocketReceiveResult result = await websocket.ReceiveAsync(new ArraySegment(incomingData), CancellationToken.None); + + if (result.CloseStatus.HasValue) + { + Console.WriteLine("Closed; Status: " + result.CloseStatus + ", " + result.CloseStatusDescription); + } + else + { + Console.WriteLine("Received message: " + Encoding.UTF8.GetString(incomingData, 0, result.Count)); + } + } + } +} diff --git a/samples/TestClient/Properties/AssemblyInfo.cs b/samples/TestClient/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..6919b6d3ce --- /dev/null +++ b/samples/TestClient/Properties/AssemblyInfo.cs @@ -0,0 +1,36 @@ +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("TestClient")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("TestClient")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("8db62eb3-48c0-4049-b33e-271c738140a0")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] diff --git a/samples/TestClient/TestClient.csproj b/samples/TestClient/TestClient.csproj new file mode 100644 index 0000000000..eb38a21555 --- /dev/null +++ b/samples/TestClient/TestClient.csproj @@ -0,0 +1,62 @@ + + + + + Debug + AnyCPU + {8B828433-B333-4C19-96AE-00BFFF9D8841} + Exe + Properties + TestClient + TestClient + v4.5 + 512 + ..\..\ + true + + + AnyCPU + true + full + false + bin\Debug\ + DEBUG;TRACE + prompt + 4 + + + AnyCPU + pdbonly + true + bin\Release\ + TRACE + prompt + 4 + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/Microsoft.AspNet.Security.Windows/AuthTypes.cs b/src/Microsoft.AspNet.Security.Windows/AuthTypes.cs new file mode 100644 index 0000000000..90a063b5fa --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/AuthTypes.cs @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; + +namespace Microsoft.AspNet.Security.Windows +{ + /// + /// Types of Windows Authentication supported. + /// + [Flags] + public enum AuthTypes + { + /// + /// Default + /// + None = 0, + + /// + /// Digest authentication using Windows credentials + /// + Digest = 1, + + /// + /// Negotiates Kerberos or NTLM + /// + Negotiate = 2, + + /// + /// NTLM Windows authentication + /// + Ntlm = 4, + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/ComNetOS.cs b/src/Microsoft.AspNet.Security.Windows/ComNetOS.cs new file mode 100644 index 0000000000..ca5902dbd6 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/ComNetOS.cs @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Diagnostics; +using System.Runtime.Versioning; +using System.Security.Permissions; + +namespace Microsoft.AspNet.Security.Windows +{ + internal static class ComNetOS + { + // Minimum support for Windows 2008 is assumed. + internal static readonly bool IsWin7orLater; // Is Windows 7 or later + internal static readonly bool IsWin8orLater; // Is Windows 8 or later + + // We use it safe so assert + [EnvironmentPermission(SecurityAction.Assert, Unrestricted = true)] + [ResourceExposure(ResourceScope.None)] + [ResourceConsumption(ResourceScope.AppDomain, ResourceScope.AppDomain)] + static ComNetOS() + { + OperatingSystem operatingSystem = Environment.OSVersion; + + GlobalLog.Print("ComNetOS::.ctor(): " + operatingSystem.ToString()); + + Debug.Assert(operatingSystem.Platform != PlatformID.Win32Windows, "Windows 9x is not supported"); + + var Win7Version = new Version(6, 1); + var Win8Version = new Version(6, 2); + IsWin7orLater = (operatingSystem.Version >= Win7Version); + IsWin8orLater = (operatingSystem.Version >= Win8Version); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Constants.cs b/src/Microsoft.AspNet.Security.Windows/Constants.cs new file mode 100644 index 0000000000..28e47e83fc --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Constants.cs @@ -0,0 +1,44 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + internal static class Constants + { + internal const string VersionKey = "owin.Version"; + internal const string OwinVersion = "1.0"; + internal const string CallCancelledKey = "owin.CallCancelled"; + + internal const string RequestBodyKey = "owin.RequestBody"; + internal const string RequestHeadersKey = "owin.RequestHeaders"; + internal const string RequestSchemeKey = "owin.RequestScheme"; + internal const string RequestMethodKey = "owin.RequestMethod"; + internal const string RequestPathBaseKey = "owin.RequestPathBase"; + internal const string RequestPathKey = "owin.RequestPath"; + internal const string RequestQueryStringKey = "owin.RequestQueryString"; + internal const string HttpRequestProtocolKey = "owin.RequestProtocol"; + + internal const string HttpResponseProtocolKey = "owin.ResponseProtocol"; + internal const string ResponseStatusCodeKey = "owin.ResponseStatusCode"; + internal const string ResponseReasonPhraseKey = "owin.ResponseReasonPhrase"; + internal const string ResponseHeadersKey = "owin.ResponseHeaders"; + internal const string ResponseBodyKey = "owin.ResponseBody"; + + internal const string ClientCertifiateKey = "ssl.ClientCertificate"; + internal const string SslSpnKey = "ssl.Spn"; + internal const string SslChannelBindingKey = "ssl.ChannelBinding"; + + internal const string RemoteIpAddressKey = "server.RemoteIpAddress"; + internal const string RemotePortKey = "server.RemotePort"; + internal const string LocalIpAddressKey = "server.LocalIpAddress"; + internal const string LocalPortKey = "server.LocalPort"; + internal const string IsLocalKey = "server.IsLocal"; + internal const string ServerOnSendingHeadersKey = "server.OnSendingHeaders"; + internal const string ServerUserKey = "server.User"; + internal const string ServerConnectionIdKey = "server.ConnectionId"; + internal const string ServerConnectionDisconnectKey = "server.ConnectionDisconnect"; + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/DictionaryExtensions.cs b/src/Microsoft.AspNet.Security.Windows/DictionaryExtensions.cs new file mode 100644 index 0000000000..2d1b86ad04 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/DictionaryExtensions.cs @@ -0,0 +1,67 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Linq; +using System.Text; + +namespace System.Collections.Generic +{ + internal static class DictionaryExtensions + { + internal static void Append(this IDictionary dictionary, string key, string value) + { + string[] orriginalValues; + if (dictionary.TryGetValue(key, out orriginalValues)) + { + string[] newValues = new string[orriginalValues.Length + 1]; + orriginalValues.CopyTo(newValues, 0); + newValues[newValues.Length - 1] = value; + dictionary[key] = newValues; + } + else + { + dictionary[key] = new string[] { value }; + } + } + + internal static void Append(this IDictionary dictionary, string key, IList values) + { + string[] orriginalValues; + if (dictionary.TryGetValue(key, out orriginalValues)) + { + string[] newValues = new string[orriginalValues.Length + values.Count]; + orriginalValues.CopyTo(newValues, 0); + values.CopyTo(newValues, orriginalValues.Length); + dictionary[key] = newValues; + } + else + { + dictionary[key] = values.ToArray(); + } + } + + internal static string Get(this IDictionary dictionary, string key) + { + string[] values; + if (dictionary.TryGetValue(key, out values)) + { + return string.Join(", ", values); + } + return null; + } + + internal static T Get(this IDictionary dictionary, string key, T fallback = default(T)) + { + object values; + if (dictionary.TryGetValue(key, out values)) + { + return (T)values; + } + return fallback; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/DigestCache.cs b/src/Microsoft.AspNet.Security.Windows/DigestCache.cs new file mode 100644 index 0000000000..0de3d84475 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/DigestCache.cs @@ -0,0 +1,153 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections; +using System.Threading; + +namespace Microsoft.AspNet.Security.Windows +{ + // Saves generated digest challenges so that they are still valid when the authenticated request arrives. + internal class DigestCache : IDisposable + { + private const int DigestLifetimeSeconds = 300; + private const int MaximumDigests = 1024; // Must be a power of two. + private const int MinimumDigestLifetimeSeconds = 10; + + private DigestContext[] _savedDigests; + private ArrayList _extraSavedDigests; + private ArrayList _extraSavedDigestsBaking; + private int _extraSavedDigestsTimestamp; + private int _newestContext; + private int _oldestContext; + + internal DigestCache() + { + } + + internal void SaveDigestContext(NTAuthentication digestContext) + { + if (_savedDigests == null) + { + Interlocked.CompareExchange(ref _savedDigests, new DigestContext[MaximumDigests], null); + } + + // We want to actually close the contexts outside the lock. + NTAuthentication oldContext = null; + ArrayList digestsToClose = null; + lock (_savedDigests) + { + int now = ((now = Environment.TickCount) == 0 ? 1 : now); + + _newestContext = (_newestContext + 1) & (MaximumDigests - 1); + + int oldTimestamp = _savedDigests[_newestContext].timestamp; + oldContext = _savedDigests[_newestContext].context; + _savedDigests[_newestContext].timestamp = now; + _savedDigests[_newestContext].context = digestContext; + + // May need to move this up. + if (_oldestContext == _newestContext) + { + _oldestContext = (_newestContext + 1) & (MaximumDigests - 1); + } + + // Delete additional contexts older than five minutes. + while (unchecked(now - _savedDigests[_oldestContext].timestamp) >= DigestLifetimeSeconds && _savedDigests[_oldestContext].context != null) + { + if (digestsToClose == null) + { + digestsToClose = new ArrayList(); + } + digestsToClose.Add(_savedDigests[_oldestContext].context); + _savedDigests[_oldestContext].context = null; + _oldestContext = (_oldestContext + 1) & (MaximumDigests - 1); + } + + // If the old context is younger than 10 seconds, put it in the backup pile. + if (oldContext != null && unchecked(now - oldTimestamp) <= MinimumDigestLifetimeSeconds * 1000) + { + // Use a two-tier ArrayList system to guarantee each entry lives at least 10 seconds. + if (_extraSavedDigests == null || + unchecked(now - _extraSavedDigestsTimestamp) > MinimumDigestLifetimeSeconds * 1000) + { + digestsToClose = _extraSavedDigestsBaking; + _extraSavedDigestsBaking = _extraSavedDigests; + _extraSavedDigestsTimestamp = now; + _extraSavedDigests = new ArrayList(); + } + _extraSavedDigests.Add(oldContext); + oldContext = null; + } + } + + if (oldContext != null) + { + oldContext.CloseContext(); + } + if (digestsToClose != null) + { + for (int i = 0; i < digestsToClose.Count; i++) + { + ((NTAuthentication)digestsToClose[i]).CloseContext(); + } + } + } + + private void ClearDigestCache() + { + if (_savedDigests == null) + { + return; + } + + ArrayList[] toClose = new ArrayList[3]; + lock (_savedDigests) + { + toClose[0] = _extraSavedDigestsBaking; + _extraSavedDigestsBaking = null; + toClose[1] = _extraSavedDigests; + _extraSavedDigests = null; + + _newestContext = 0; + _oldestContext = 0; + + toClose[2] = new ArrayList(); + for (int i = 0; i < MaximumDigests; i++) + { + if (_savedDigests[i].context != null) + { + toClose[2].Add(_savedDigests[i].context); + _savedDigests[i].context = null; + } + _savedDigests[i].timestamp = 0; + } + } + + for (int j = 0; j < toClose.Length; j++) + { + if (toClose[j] != null) + { + for (int k = 0; k < toClose[j].Count; k++) + { + ((NTAuthentication)toClose[j][k]).CloseContext(); + } + } + } + } + + public void Dispose() + { + ClearDigestCache(); + } + + private struct DigestContext + { + internal NTAuthentication context; + internal int timestamp; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/DisconnectAsyncResult.cs b/src/Microsoft.AspNet.Security.Windows/DisconnectAsyncResult.cs new file mode 100644 index 0000000000..6064872060 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/DisconnectAsyncResult.cs @@ -0,0 +1,93 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Security.Principal; +using System.Threading; + +namespace Microsoft.AspNet.Security.Windows +{ + // Keeps NTLM/Negotiate auth contexts alive until the connection is broken. + internal class DisconnectAsyncResult + { + private const string NTLM = "NTLM"; + + private object _connectionId; + private WindowsAuthMiddleware _winAuth; + private CancellationTokenRegistration _disconnectRegistration; + + private WindowsPrincipal _authenticatedUser; + private NTAuthentication _session; + + internal DisconnectAsyncResult(WindowsAuthMiddleware winAuth, object connectionId, CancellationToken connectionDisconnect) + { + GlobalLog.Print("DisconnectAsyncResult#" + ValidationHelper.HashString(this) + "::.ctor() httpListener#" + ValidationHelper.HashString(winAuth) + " connectionId:" + connectionId); + _winAuth = winAuth; + _connectionId = connectionId; + _winAuth.DisconnectResults[_connectionId] = this; + + // Register with a connection specific CancellationToken. Without this notification, the contexts will leak indefinitely. + // Alternatively we could attempt some kind of LRU storage, but this will either have to be larger than your expected connection limit, + // or will fail at unexpected moments under stress. + try + { + _disconnectRegistration = connectionDisconnect.Register(HandleDisconnect); + } + catch (ObjectDisposedException) + { + _winAuth.DisconnectResults.Remove(_connectionId); + } + } + + internal WindowsPrincipal AuthenticatedUser + { + get + { + return _authenticatedUser; + } + set + { + // The previous value can't be disposed because it may be in use by the app. + _authenticatedUser = value; + } + } + + internal NTAuthentication Session + { + get + { + return _session; + } + set + { + _session = value; + } + } + + private void HandleDisconnect() + { + GlobalLog.Print("DisconnectAsyncResult#" + ValidationHelper.HashString(this) + "::HandleDisconnect() DisconnectResults#" + ValidationHelper.HashString(_winAuth.DisconnectResults) + " removing for m_ConnectionId:" + _connectionId); + _winAuth.DisconnectResults.Remove(_connectionId); + if (_session != null) + { + _session.CloseContext(); + } + + // Clean up the identity. This is for scenarios where identity was not cleaned up before due to + // identity caching for unsafe ntlm authentication + + IDisposable identity = _authenticatedUser == null ? null : _authenticatedUser.Identity as IDisposable; + if ((identity != null) && + (NTLM.Equals(_authenticatedUser.Identity.AuthenticationType, StringComparison.OrdinalIgnoreCase)) && + (_winAuth.UnsafeConnectionNtlmAuthentication)) + { + identity.Dispose(); + } + + _disconnectRegistration.Dispose(); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/HeaderEncoding.cs b/src/Microsoft.AspNet.Security.Windows/HeaderEncoding.cs new file mode 100644 index 0000000000..7c41535d5e --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/HeaderEncoding.cs @@ -0,0 +1,139 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Text; + +namespace Microsoft.AspNet.Security.Windows +{ + // we use this static class as a helper class to encode/decode HTTP headers. + // what we need is a 1-1 correspondence between a char in the range U+0000-U+00FF + // and a byte in the range 0x00-0xFF (which is the range that can hit the network). + // The Latin-1 encoding (ISO-88591-1) (GetEncoding(28591)) works for byte[] to string, but is a little slow. + // It doesn't work for string -> byte[] because of best-fit-mapping problems. + internal static class HeaderEncoding + { + internal static unsafe string GetString(byte[] bytes, int byteIndex, int byteCount) + { + fixed (byte* pBytes = bytes) + return GetString(pBytes + byteIndex, byteCount); + } + + internal static unsafe string GetString(byte* pBytes, int byteCount) + { + if (byteCount < 1) + { + return string.Empty; + } + + string s = new String('\0', byteCount); + + fixed (char* pStr = s) + { + char* pString = pStr; + while (byteCount >= 8) + { + pString[0] = (char)pBytes[0]; + pString[1] = (char)pBytes[1]; + pString[2] = (char)pBytes[2]; + pString[3] = (char)pBytes[3]; + pString[4] = (char)pBytes[4]; + pString[5] = (char)pBytes[5]; + pString[6] = (char)pBytes[6]; + pString[7] = (char)pBytes[7]; + pString += 8; + pBytes += 8; + byteCount -= 8; + } + for (int i = 0; i < byteCount; i++) + { + pString[i] = (char)pBytes[i]; + } + } + + return s; + } + + internal static int GetByteCount(string myString) + { + return myString.Length; + } + internal static unsafe void GetBytes(string myString, int charIndex, int charCount, byte[] bytes, int byteIndex) + { + if (myString.Length == 0) + { + return; + } + fixed (byte* bufferPointer = bytes) + { + byte* newBufferPointer = bufferPointer + byteIndex; + int finalIndex = charIndex + charCount; + while (charIndex < finalIndex) + { + *newBufferPointer++ = (byte)myString[charIndex++]; + } + } + } + internal static unsafe byte[] GetBytes(string myString) + { + byte[] bytes = new byte[myString.Length]; + if (myString.Length != 0) + { + GetBytes(myString, 0, myString.Length, bytes, 0); + } + return bytes; + } + + // The normal client header parser just casts bytes to chars (see GetString). + // Check if those bytes were actually utf-8 instead of ASCII. + // If not, just return the input value. + + internal static string DecodeUtf8FromString(string input) + { + if (string.IsNullOrWhiteSpace(input)) + { + return input; + } + + bool possibleUtf8 = false; + for (int i = 0; i < input.Length; i++) + { + if (input[i] > (char)255) + { + return input; // This couldn't have come from the wire, someone assigned it directly. + } + else if (input[i] > (char)127) + { + possibleUtf8 = true; + break; + } + } + if (possibleUtf8) + { + byte[] rawBytes = new byte[input.Length]; + for (int i = 0; i < input.Length; i++) + { + if (input[i] > (char)255) + { + return input; // This couldn't have come from the wire, someone assigned it directly. + } + rawBytes[i] = (byte)input[i]; + } + try + { + // We don't want '?' replacement characters, just fail. + Encoding decoder = Encoding.GetEncoding("utf-8", EncoderFallback.ExceptionFallback, + DecoderFallback.ExceptionFallback); + return decoder.GetString(rawBytes); + } + catch (ArgumentException) + { + } // Not actually Utf-8 + } + return input; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/HttpKnownHeaderNames.cs b/src/Microsoft.AspNet.Security.Windows/HttpKnownHeaderNames.cs new file mode 100644 index 0000000000..29ab395d03 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/HttpKnownHeaderNames.cs @@ -0,0 +1,75 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + internal static class HttpKnownHeaderNames + { + public const string CacheControl = "Cache-Control"; + public const string Connection = "Connection"; + public const string Date = "Date"; + public const string KeepAlive = "Keep-Alive"; + public const string Pragma = "Pragma"; + public const string ProxyConnection = "Proxy-Connection"; + public const string Trailer = "Trailer"; + public const string TransferEncoding = "Transfer-Encoding"; + public const string Upgrade = "Upgrade"; + public const string Via = "Via"; + public const string Warning = "Warning"; + public const string ContentLength = "Content-Length"; + public const string ContentType = "Content-Type"; + public const string ContentDisposition = "Content-Disposition"; + public const string ContentEncoding = "Content-Encoding"; + public const string ContentLanguage = "Content-Language"; + public const string ContentLocation = "Content-Location"; + public const string ContentRange = "Content-Range"; + public const string Expires = "Expires"; + public const string LastModified = "Last-Modified"; + public const string Age = "Age"; + public const string Location = "Location"; + public const string ProxyAuthenticate = "Proxy-Authenticate"; + public const string RetryAfter = "Retry-After"; + public const string Server = "Server"; + public const string SetCookie = "Set-Cookie"; + public const string SetCookie2 = "Set-Cookie2"; + public const string Vary = "Vary"; + public const string WWWAuthenticate = "WWW-Authenticate"; + public const string Accept = "Accept"; + public const string AcceptCharset = "Accept-Charset"; + public const string AcceptEncoding = "Accept-Encoding"; + public const string AcceptLanguage = "Accept-Language"; + public const string Authorization = "Authorization"; + public const string Cookie = "Cookie"; + public const string Cookie2 = "Cookie2"; + public const string Expect = "Expect"; + public const string From = "From"; + public const string Host = "Host"; + public const string IfMatch = "If-Match"; + public const string IfModifiedSince = "If-Modified-Since"; + public const string IfNoneMatch = "If-None-Match"; + public const string IfRange = "If-Range"; + public const string IfUnmodifiedSince = "If-Unmodified-Since"; + public const string MaxForwards = "Max-Forwards"; + public const string ProxyAuthorization = "Proxy-Authorization"; + public const string Referer = "Referer"; + public const string Range = "Range"; + public const string UserAgent = "User-Agent"; + public const string ContentMD5 = "Content-MD5"; + public const string ETag = "ETag"; + public const string TE = "TE"; + public const string Allow = "Allow"; + public const string AcceptRanges = "Accept-Ranges"; + public const string P3P = "P3P"; + public const string XPoweredBy = "X-Powered-By"; + public const string XAspNetVersion = "X-AspNet-Version"; + public const string SecWebSocketKey = "Sec-WebSocket-Key"; + public const string SecWebSocketExtensions = "Sec-WebSocket-Extensions"; + public const string SecWebSocketAccept = "Sec-WebSocket-Accept"; + public const string Origin = "Origin"; + public const string SecWebSocketProtocol = "Sec-WebSocket-Protocol"; + public const string SecWebSocketVersion = "Sec-WebSocket-Version"; + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/HttpStatusCode.cs b/src/Microsoft.AspNet.Security.Windows/HttpStatusCode.cs new file mode 100644 index 0000000000..d60d640f2d --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/HttpStatusCode.cs @@ -0,0 +1,314 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Security.Windows +{ + // Redirect Status code numbers that need to be defined. + + /// + /// Contains the values of status + /// codes defined for the HTTP protocol. + /// + // UEUE : Any int can be cast to a HttpStatusCode to allow checking for non http1.1 codes. + internal enum HttpStatusCode + { + // Informational 1xx + + /// + /// [To be supplied.] + /// + Continue = 100, + + /// + /// [To be supplied.] + /// + SwitchingProtocols = 101, + + // Successful 2xx + + /// + /// [To be supplied.] + /// + OK = 200, + + /// + /// [To be supplied.] + /// + Created = 201, + + /// + /// [To be supplied.] + /// + Accepted = 202, + + /// + /// [To be supplied.] + /// + NonAuthoritativeInformation = 203, + + /// + /// [To be supplied.] + /// + NoContent = 204, + + /// + /// [To be supplied.] + /// + ResetContent = 205, + + /// + /// [To be supplied.] + /// + PartialContent = 206, + + // Redirection 3xx + + /// + /// [To be supplied.] + /// + MultipleChoices = 300, + + /// + /// [To be supplied.] + /// + Ambiguous = 300, + + /// + /// [To be supplied.] + /// + MovedPermanently = 301, + + /// + /// [To be supplied.] + /// + Moved = 301, + + /// + /// [To be supplied.] + /// + Found = 302, + + /// + /// [To be supplied.] + /// + Redirect = 302, + + /// + /// [To be supplied.] + /// + SeeOther = 303, + + /// + /// [To be supplied.] + /// + RedirectMethod = 303, + + /// + /// [To be supplied.] + /// + NotModified = 304, + + /// + /// [To be supplied.] + /// + UseProxy = 305, + + /// + /// [To be supplied.] + /// + Unused = 306, + + /// + /// [To be supplied.] + /// + TemporaryRedirect = 307, + + /// + /// [To be supplied.] + /// + RedirectKeepVerb = 307, + + // Client Error 4xx + + /// + /// [To be supplied.] + /// + BadRequest = 400, + + /// + /// [To be supplied.] + /// + Unauthorized = 401, + + /// + /// [To be supplied.] + /// + PaymentRequired = 402, + + /// + /// [To be supplied.] + /// + Forbidden = 403, + + /// + /// [To be supplied.] + /// + NotFound = 404, + + /// + /// [To be supplied.] + /// + MethodNotAllowed = 405, + + /// + /// [To be supplied.] + /// + NotAcceptable = 406, + + /// + /// [To be supplied.] + /// + ProxyAuthenticationRequired = 407, + + /// + /// [To be supplied.] + /// + RequestTimeout = 408, + + /// + /// [To be supplied.] + /// + Conflict = 409, + + /// + /// [To be supplied.] + /// + Gone = 410, + + /// + /// [To be supplied.] + /// + LengthRequired = 411, + + /// + /// [To be supplied.] + /// + PreconditionFailed = 412, + + /// + /// [To be supplied.] + /// + RequestEntityTooLarge = 413, + + /// + /// [To be supplied.] + /// + RequestUriTooLong = 414, + + /// + /// [To be supplied.] + /// + UnsupportedMediaType = 415, + + /// + /// [To be supplied.] + /// + RequestedRangeNotSatisfiable = 416, + + /// + /// [To be supplied.] + /// + ExpectationFailed = 417, + + UpgradeRequired = 426, + + // Server Error 5xx + + /// + /// [To be supplied.] + /// + InternalServerError = 500, + + /// + /// [To be supplied.] + /// + NotImplemented = 501, + + /// + /// [To be supplied.] + /// + BadGateway = 502, + + /// + /// [To be supplied.] + /// + ServiceUnavailable = 503, + + /// + /// [To be supplied.] + /// + GatewayTimeout = 504, + + /// + /// [To be supplied.] + /// + HttpVersionNotSupported = 505, + } // enum HttpStatusCode + +/* +Fielding, et al. Standards Track [Page 3] + +RFC 2616 HTTP/1.1 June 1999 + + + 10.1 Informational 1xx ...........................................57 + 10.1.1 100 Continue .............................................58 + 10.1.2 101 Switching Protocols ..................................58 + 10.2 Successful 2xx ..............................................58 + 10.2.1 200 OK ...................................................58 + 10.2.2 201 Created ..............................................59 + 10.2.3 202 Accepted .............................................59 + 10.2.4 203 Non-Authoritative Information ........................59 + 10.2.5 204 No Content ...........................................60 + 10.2.6 205 Reset Content ........................................60 + 10.2.7 206 Partial Content ......................................60 + 10.3 Redirection 3xx .............................................61 + 10.3.1 300 Multiple Choices .....................................61 + 10.3.2 301 Moved Permanently ....................................62 + 10.3.3 302 Found ................................................62 + 10.3.4 303 See Other ............................................63 + 10.3.5 304 Not Modified .........................................63 + 10.3.6 305 Use Proxy ............................................64 + 10.3.7 306 (Unused) .............................................64 + 10.3.8 307 Temporary Redirect ...................................65 + 10.4 Client Error 4xx ............................................65 + 10.4.1 400 Bad Request .........................................65 + 10.4.2 401 Unauthorized ........................................66 + 10.4.3 402 Payment Required ....................................66 + 10.4.4 403 Forbidden ...........................................66 + 10.4.5 404 Not Found ...........................................66 + 10.4.6 405 Method Not Allowed ..................................66 + 10.4.7 406 Not Acceptable ......................................67 + 10.4.8 407 Proxy Authentication Required .......................67 + 10.4.9 408 Request Timeout .....................................67 + 10.4.10 409 Conflict ............................................67 + 10.4.11 410 Gone ................................................68 + 10.4.12 411 Length Required .....................................68 + 10.4.13 412 Precondition Failed .................................68 + 10.4.14 413 Request Entity Too Large ............................69 + 10.4.15 414 Request-URI Too Long ................................69 + 10.4.16 415 Unsupported Media Type ..............................69 + 10.4.17 416 Requested Range Not Satisfiable .....................69 + 10.4.18 417 Expectation Failed ..................................70 + 10.5 Server Error 5xx ............................................70 + 10.5.1 500 Internal Server Error ................................70 + 10.5.2 501 Not Implemented ......................................70 + 10.5.3 502 Bad Gateway ..........................................70 + 10.5.4 503 Service Unavailable ..................................70 + 10.5.5 504 Gateway Timeout ......................................71 + 10.5.6 505 HTTP Version Not Supported ...........................71 +*/ +} // namespace System.Net diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/CaseInsinsitiveAscii.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/CaseInsinsitiveAscii.cs new file mode 100644 index 0000000000..8818694ff8 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/CaseInsinsitiveAscii.cs @@ -0,0 +1,133 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Collections; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class CaseInsensitiveAscii : IEqualityComparer, IComparer + { + // ASCII char ToLower table + internal static readonly byte[] AsciiToLower = new byte[] + { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 97, 98, 99, 100, 101, // 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, + 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, // 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, // 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, + 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, + 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, + 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, + 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, + 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, + 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, + 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, + 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, + 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, + 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, + 250, 251, 252, 253, 254, 255 + }; + + // ASCII string case insensitive hash function + public int GetHashCode(object myObject) + { + string myString = myObject as string; + if (myObject == null) + { + return 0; + } + int myHashCode = myString.Length; + if (myHashCode == 0) + { + return 0; + } + myHashCode ^= AsciiToLower[(byte)myString[0]] << 24 ^ AsciiToLower[(byte)myString[myHashCode - 1]] << 16; + return myHashCode; + } + + // ASCII string case insensitive comparer + public int Compare(object firstObject, object secondObject) + { + string firstString = firstObject as string; + string secondString = secondObject as string; + if (firstString == null) + { + return secondString == null ? 0 : -1; + } + if (secondString == null) + { + return 1; + } + int result = firstString.Length - secondString.Length; + int comparisons = result > 0 ? secondString.Length : firstString.Length; + int difference, index = 0; + while (index < comparisons) + { + difference = (int)(AsciiToLower[firstString[index]] - AsciiToLower[secondString[index]]); + if (difference != 0) + { + result = difference; + break; + } + index++; + } + return result; + } + + // ASCII string case insensitive hash function + private int FastGetHashCode(string myString) + { + int myHashCode = myString.Length; + if (myHashCode != 0) + { + myHashCode ^= AsciiToLower[(byte)myString[0]] << 24 ^ AsciiToLower[(byte)myString[myHashCode - 1]] << 16; + } + return myHashCode; + } + + // ASCII string case insensitive comparer + public new bool Equals(object firstObject, object secondObject) + { + string firstString = firstObject as string; + string secondString = secondObject as string; + if (firstString == null) + { + return secondString == null; + } + if (secondString != null) + { + int index = firstString.Length; + if (index == secondString.Length) + { + if (FastGetHashCode(firstString) == FastGetHashCode(secondString)) + { + int comparisons = firstString.Length; + while (index > 0) + { + index--; + if (AsciiToLower[firstString[index]] != AsciiToLower[secondString[index]]) + { + return false; + } + } + return true; + } + } + } + return false; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/GlobalLog.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/GlobalLog.cs new file mode 100644 index 0000000000..259dc13382 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/GlobalLog.cs @@ -0,0 +1,689 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Configuration; + using System.Diagnostics; + using System.Globalization; + using System.Net; + using System.Runtime.ConstrainedExecution; + using System.Security.Permissions; + using System.Threading; + + /// + /// + /// + internal static class GlobalLog + { + // Logging Initalization - I need to disable Logging code, and limit + // the effect it has when it is dissabled, so I use a bool here. + // + // This can only be set when the logging code is built and enabled. + // By specifing the "CSC_DEFINES=/D:TRAVE" in the build environment, + // this code will be built and then checks against an enviroment variable + // and a BooleanSwitch to see if any of the two have enabled logging. + + private static BaseLoggingObject Logobject = GlobalLog.LoggingInitialize(); +#if TRAVE + internal static LocalDataStoreSlot s_ThreadIdSlot; + internal static bool s_UseThreadId; + internal static bool s_UseTimeSpan; + internal static bool s_DumpWebData; + internal static bool s_UsePerfCounter; + internal static bool s_DebugCallNesting; + internal static bool s_DumpToConsole; + internal static int s_MaxDumpSize; + internal static string s_RootDirectory; + + // + // Logging Config Variables - below are list of consts that can be used to config + // the logging, + // + + // Max number of lines written into a buffer, before a save is invoked + // s_DumpToConsole disables. + public const int MaxLinesBeforeSave = 0; + +#endif + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + private static BaseLoggingObject LoggingInitialize() + { +#if DEBUG + if (GetSwitchValue("SystemNetLogging", "System.Net logging module", false) && + GetSwitchValue("SystemNetLog_ConnectionMonitor", "System.Net connection monitor thread", false)) + { + InitConnectionMonitor(); + } +#endif // DEBUG +#if TRAVE + // by default we'll log to c:\temp\ so that non interactive services (like w3wp.exe) that don't have environment + // variables can easily be debugged, note that the ACLs of the directory might need to be adjusted + if (!GetSwitchValue("SystemNetLog_OverrideDefaults", "System.Net log override default settings", false)) { + s_ThreadIdSlot = Thread.AllocateDataSlot(); + s_UseThreadId = true; + s_UseTimeSpan = true; + s_DumpWebData = true; + s_MaxDumpSize = 256; + s_UsePerfCounter = true; + s_DebugCallNesting = true; + s_DumpToConsole = false; + s_RootDirectory = "C:\\Temp\\"; + return new LoggingObject(); + } + if (GetSwitchValue("SystemNetLogging", "System.Net logging module", false)) { + s_ThreadIdSlot = Thread.AllocateDataSlot(); + s_UseThreadId = GetSwitchValue("SystemNetLog_UseThreadId", "System.Net log display system thread id", false); + s_UseTimeSpan = GetSwitchValue("SystemNetLog_UseTimeSpan", "System.Net log display ticks as TimeSpan", false); + s_DumpWebData = GetSwitchValue("SystemNetLog_DumpWebData", "System.Net log display HTTP send/receive data", false); + s_MaxDumpSize = GetSwitchValue("SystemNetLog_MaxDumpSize", "System.Net log max size of display data", 256); + s_UsePerfCounter = GetSwitchValue("SystemNetLog_UsePerfCounter", "System.Net log use QueryPerformanceCounter() to get ticks ", false); + s_DebugCallNesting = GetSwitchValue("SystemNetLog_DebugCallNesting", "System.Net used to debug call nesting", false); + s_DumpToConsole = GetSwitchValue("SystemNetLog_DumpToConsole", "System.Net log to console", false); + s_RootDirectory = GetSwitchValue("SystemNetLog_RootDirectory", "System.Net root directory of log file", string.Empty); + return new LoggingObject(); + } +#endif // TRAVE + return new BaseLoggingObject(); + } + +#if TRAVE + private static string GetSwitchValue(string switchName, string switchDescription, string defaultValue) { + new EnvironmentPermission(PermissionState.Unrestricted).Assert(); + try { + defaultValue = Environment.GetEnvironmentVariable(switchName); + } + finally { + EnvironmentPermission.RevertAssert(); + } + return defaultValue; + } + + private static int GetSwitchValue(string switchName, string switchDescription, int defaultValue) { + IntegerSwitch theSwitch = new IntegerSwitch(switchName, switchDescription); + if (theSwitch.Enabled) { + return theSwitch.Value; + } + new EnvironmentPermission(PermissionState.Unrestricted).Assert(); + try { + string environmentVar = Environment.GetEnvironmentVariable(switchName); + if (environmentVar!=null) { + defaultValue = Int32.Parse(environmentVar.Trim(), CultureInfo.InvariantCulture); + } + } + finally { + EnvironmentPermission.RevertAssert(); + } + return defaultValue; + } + +#endif + +#if TRAVE || DEBUG + private static bool GetSwitchValue(string switchName, string switchDescription, bool defaultValue) + { + BooleanSwitch theSwitch = new BooleanSwitch(switchName, switchDescription); + new EnvironmentPermission(PermissionState.Unrestricted).Assert(); + try + { + if (theSwitch.Enabled) + { + return true; + } + string environmentVar = Environment.GetEnvironmentVariable(switchName); + defaultValue = environmentVar != null && environmentVar.Trim() == "1"; + } + catch (ConfigurationException) + { + } + finally + { + EnvironmentPermission.RevertAssert(); + } + return defaultValue; + } +#endif // TRAVE || DEBUG + + // Enables thread tracing, detects mis-use of threads. +#if DEBUG + [ThreadStatic] + private static Stack t_ThreadKindStack; + + private static Stack ThreadKindStack + { + get + { + if (t_ThreadKindStack == null) + { + t_ThreadKindStack = new Stack(); + } + return t_ThreadKindStack; + } + } +#endif + + internal static ThreadKinds CurrentThreadKind + { + get + { +#if DEBUG + return ThreadKindStack.Count > 0 ? ThreadKindStack.Peek() : ThreadKinds.Other; +#else + return ThreadKinds.Unknown; +#endif + } + } + + private static bool HasShutdownStarted + { + get + { + return Environment.HasShutdownStarted || AppDomain.CurrentDomain.IsFinalizingForUnload(); + } + } + +#if DEBUG + // ifdef'd instead of conditional since people are forced to handle the return value. + // [Conditional("DEBUG")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + internal static IDisposable SetThreadKind(ThreadKinds kind) + { + if ((kind & ThreadKinds.SourceMask) != ThreadKinds.Unknown) + { + throw new InvalidOperationException(); + } + + // Ignore during shutdown. + if (HasShutdownStarted) + { + return null; + } + + ThreadKinds threadKind = CurrentThreadKind; + ThreadKinds source = threadKind & ThreadKinds.SourceMask; + +#if TRAVE + // Special warnings when doing dangerous things on a thread. + if ((threadKind & ThreadKinds.User) != 0 && (kind & ThreadKinds.System) != 0) + { + Print("WARNING: Thread changed from User to System; user's thread shouldn't be hijacked."); + } + + if ((threadKind & ThreadKinds.Async) != 0 && (kind & ThreadKinds.Sync) != 0) + { + Print("WARNING: Thread changed from Async to Sync, may block an Async thread."); + } + else if ((threadKind & (ThreadKinds.Other | ThreadKinds.CompletionPort)) == 0 && (kind & ThreadKinds.Sync) != 0) + { + Print("WARNING: Thread from a limited resource changed to Sync, may deadlock or bottleneck."); + } +#endif + + ThreadKindStack.Push( + (((kind & ThreadKinds.OwnerMask) == 0 ? threadKind : kind) & ThreadKinds.OwnerMask) | + (((kind & ThreadKinds.SyncMask) == 0 ? threadKind : kind) & ThreadKinds.SyncMask) | + (kind & ~(ThreadKinds.OwnerMask | ThreadKinds.SyncMask)) | + source); + +#if TRAVE + if (CurrentThreadKind != threadKind) + { + Print("Thread becomes:(" + CurrentThreadKind.ToString() + ")"); + } +#endif + + return new ThreadKindFrame(); + } + + private class ThreadKindFrame : IDisposable + { + private int m_FrameNumber; + + internal ThreadKindFrame() + { + m_FrameNumber = ThreadKindStack.Count; + } + + void IDisposable.Dispose() + { + // Ignore during shutdown. + if (GlobalLog.HasShutdownStarted) + { + return; + } + + if (m_FrameNumber != ThreadKindStack.Count) + { + throw new InvalidOperationException(); + } + + ThreadKinds previous = ThreadKindStack.Pop(); + +#if TRAVE + if (CurrentThreadKind != previous) + { + Print("Thread reverts:(" + CurrentThreadKind.ToString() + ")"); + } +#endif + } + } +#endif + + [Conditional("DEBUG")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + internal static void SetThreadSource(ThreadKinds source) + { +#if DEBUG + if ((source & ThreadKinds.SourceMask) != source || source == ThreadKinds.Unknown) + { + throw new ArgumentException("Must specify the thread source.", "source"); + } + + if (ThreadKindStack.Count == 0) + { + ThreadKindStack.Push(source); + return; + } + + if (ThreadKindStack.Count > 1) + { + Print("WARNING: SetThreadSource must be called at the base of the stack, or the stack has been corrupted."); + while (ThreadKindStack.Count > 1) + { + ThreadKindStack.Pop(); + } + } + + if (ThreadKindStack.Peek() != source) + { + // SQL can fail to clean up the stack, leaving the default Other at the bottom. Replace it. + Print("WARNING: The stack has been corrupted."); + ThreadKinds last = ThreadKindStack.Pop() & ThreadKinds.SourceMask; + Assert(last == source || last == ThreadKinds.Other, "Thread source changed.|Was:({0}) Now:({1})", last, source); + ThreadKindStack.Push(source); + } +#endif + } + + [Conditional("DEBUG")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + internal static void ThreadContract(ThreadKinds kind, string errorMsg) + { + ThreadContract(kind, ThreadKinds.SafeSources, errorMsg); + } + + [Conditional("DEBUG")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + internal static void ThreadContract(ThreadKinds kind, ThreadKinds allowedSources, string errorMsg) + { + if ((kind & ThreadKinds.SourceMask) != ThreadKinds.Unknown || (allowedSources & ThreadKinds.SourceMask) != allowedSources) + { + throw new InvalidOperationException(); + } + + ThreadKinds threadKind = CurrentThreadKind; + Assert((threadKind & allowedSources) != 0, errorMsg, "Thread Contract Violation.|Expected source:({0}) Actual source:({1})", allowedSources, threadKind & ThreadKinds.SourceMask); + Assert((threadKind & kind) == kind, errorMsg, "Thread Contract Violation.|Expected kind:({0}) Actual kind:({1})", kind, threadKind & ~ThreadKinds.SourceMask); + } + +#if DEBUG + // Enables auto-hang detection, which will "snap" a log on hang + internal static bool EnableMonitorThread = false; + + // Default value for hang timer +#if FEATURE_PAL // ROTORTODO - after speedups (like real JIT and GC) remove this + public const int DefaultTickValue = 1000*60*5; // 5 minutes +#else + public const int DefaultTickValue = 1000 * 60; // 60 secs +#endif // FEATURE_PAL +#endif // DEBUG + + [System.Diagnostics.Conditional("TRAVE")] + public static void AddToArray(string msg) + { +#if TRAVE + GlobalLog.Logobject.PrintLine(msg); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Ignore(object msg) + { + } + + [System.Diagnostics.Conditional("TRAVE")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + public static void Print(string msg) + { +#if TRAVE + GlobalLog.Logobject.PrintLine(msg); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void PrintHex(string msg, object value) + { +#if TRAVE + GlobalLog.Logobject.PrintLine(msg+TraveHelper.ToHex(value)); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Enter(string func) + { +#if TRAVE + GlobalLog.Logobject.EnterFunc(func + "(*none*)"); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Enter(string func, string parms) + { +#if TRAVE + GlobalLog.Logobject.EnterFunc(func + "(" + parms + ")"); +#endif + } + + [Conditional("DEBUG")] + [Conditional("_FORCE_ASSERTS")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + public static void Assert(bool condition, string messageFormat, params object[] data) + { + if (!condition) + { + string fullMessage = string.Format(CultureInfo.InvariantCulture, messageFormat, data); + int pipeIndex = fullMessage.IndexOf('|'); + if (pipeIndex == -1) + { + Assert(fullMessage); + } + else + { + int detailLength = fullMessage.Length - pipeIndex - 1; + Assert(fullMessage.Substring(0, pipeIndex), detailLength > 0 ? fullMessage.Substring(pipeIndex + 1, detailLength) : null); + } + } + } + + [Conditional("DEBUG")] + [Conditional("_FORCE_ASSERTS")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + public static void Assert(string message) + { + Assert(message, null); + } + + [Conditional("DEBUG")] + [Conditional("_FORCE_ASSERTS")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + public static void Assert(string message, string detailMessage) + { + try + { + Print("Assert: " + message + (!string.IsNullOrEmpty(detailMessage) ? ": " + detailMessage : string.Empty)); + Print("*******"); + Logobject.DumpArray(false); + } + finally + { +#if DEBUG && !STRESS + Debug.Assert(false, message, detailMessage); +#endif + } + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void LeaveException(string func, Exception exception) + { +#if TRAVE + GlobalLog.Logobject.LeaveFunc(func + " exception " + ((exception!=null) ? exception.Message : String.Empty)); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Leave(string func) + { +#if TRAVE + GlobalLog.Logobject.LeaveFunc(func + " returns "); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Leave(string func, string result) + { +#if TRAVE + GlobalLog.Logobject.LeaveFunc(func + " returns " + result); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Leave(string func, int returnval) + { +#if TRAVE + GlobalLog.Logobject.LeaveFunc(func + " returns " + returnval.ToString()); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Leave(string func, bool returnval) + { +#if TRAVE + GlobalLog.Logobject.LeaveFunc(func + " returns " + returnval.ToString()); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void DumpArray() + { +#if TRAVE + GlobalLog.Logobject.DumpArray(true); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Dump(byte[] buffer) + { +#if TRAVE + Logobject.Dump(buffer, 0, buffer!=null ? buffer.Length : -1); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Dump(byte[] buffer, int length) + { +#if TRAVE + Logobject.Dump(buffer, 0, length); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Dump(byte[] buffer, int offset, int length) + { +#if TRAVE + Logobject.Dump(buffer, offset, length); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Dump(IntPtr buffer, int offset, int length) + { +#if TRAVE + Logobject.Dump(buffer, offset, length); +#endif + } + +#if DEBUG + private class HttpWebRequestComparer : IComparer + { + public int Compare( + object x1, + object y1) + { + HttpWebRequest x = (HttpWebRequest)x1; + HttpWebRequest y = (HttpWebRequest)y1; + + if (x.GetHashCode() == y.GetHashCode()) + { + return 0; + } + else if (x.GetHashCode() < y.GetHashCode()) + { + return -1; + } + else if (x.GetHashCode() > y.GetHashCode()) + { + return 1; + } + + return 0; + } + } + /* + private class ConnectionMonitorEntry { + public HttpWebRequest m_Request; + public int m_Flags; + public DateTime m_TimeAdded; + public Connection m_Connection; + + public ConnectionMonitorEntry(HttpWebRequest request, Connection connection, int flags) { + m_Request = request; + m_Connection = connection; + m_Flags = flags; + m_TimeAdded = DateTime.Now; + } + } + */ + private static volatile ManualResetEvent s_ShutdownEvent; + private static volatile SortedList s_RequestList; + + internal const int WaitingForReadDoneFlag = 0x1; +#endif + /* +#if DEBUG + private static void ConnectionMonitor() { + while(! s_ShutdownEvent.WaitOne(DefaultTickValue, false)) { + if (GlobalLog.EnableMonitorThread) { +#if TRAVE + GlobalLog.Logobject.LoggingMonitorTick(); +#endif + } + + int hungCount = 0; + lock (s_RequestList) { + DateTime dateNow = DateTime.Now; + DateTime dateExpired = dateNow.AddSeconds(-DefaultTickValue); + foreach (ConnectionMonitorEntry monitorEntry in s_RequestList.GetValueList() ) { + if (monitorEntry != null && + (dateExpired > monitorEntry.m_TimeAdded)) + { + hungCount++; +#if TRAVE + GlobalLog.Print("delay:" + (dateNow - monitorEntry.m_TimeAdded).TotalSeconds + + " req#" + monitorEntry.m_Request.GetHashCode() + + " cnt#" + monitorEntry.m_Connection.GetHashCode() + + " flags:" + monitorEntry.m_Flags); + +#endif + monitorEntry.m_Connection.Debug(monitorEntry.m_Request.GetHashCode()); + } + } + } + Assert(hungCount == 0, "Warning: Hang Detected on Connection(s) of greater than {0} ms. {1} request(s) hung.|Please Dump System.Net.GlobalLog.s_RequestList for pending requests, make sure your streams are calling Close(), and that your destination server is up.", DefaultTickValue, hungCount); + } + } +#endif // DEBUG + **/ +#if DEBUG + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + internal static void AppDomainUnloadEvent(object sender, EventArgs e) + { + s_ShutdownEvent.Set(); + } +#endif + +#if DEBUG + [System.Diagnostics.Conditional("DEBUG")] + private static void InitConnectionMonitor() + { + s_RequestList = new SortedList(new HttpWebRequestComparer(), 10); + s_ShutdownEvent = new ManualResetEvent(false); + AppDomain.CurrentDomain.DomainUnload += new EventHandler(AppDomainUnloadEvent); + AppDomain.CurrentDomain.ProcessExit += new EventHandler(AppDomainUnloadEvent); + // Thread threadMonitor = new Thread(new ThreadStart(ConnectionMonitor)); + // threadMonitor.IsBackground = true; + // threadMonitor.Start(); + } +#endif + /* + [System.Diagnostics.Conditional("DEBUG")] + internal static void DebugAddRequest(HttpWebRequest request, Connection connection, int flags) { +#if DEBUG + // null if the connection monitor is off + if(s_RequestList == null) + return; + + lock(s_RequestList) { + Assert(!s_RequestList.ContainsKey(request), "s_RequestList.ContainsKey(request)|A HttpWebRequest should not be submitted twice."); + + ConnectionMonitorEntry requestEntry = + new ConnectionMonitorEntry(request, connection, flags); + + try { + s_RequestList.Add(request, requestEntry); + } catch { + } + } +#endif + } +*/ + /* + [System.Diagnostics.Conditional("DEBUG")] + internal static void DebugRemoveRequest(HttpWebRequest request) { + #if DEBUG + // null if the connection monitor is off + if(s_RequestList == null) + return; + + lock(s_RequestList) { + Assert(s_RequestList.ContainsKey(request), "!s_RequestList.ContainsKey(request)|A HttpWebRequest should not be removed twice."); + + try { + s_RequestList.Remove(request); + } catch { + } + } + #endif + } + */ + /* + [System.Diagnostics.Conditional("DEBUG")] + internal static void DebugUpdateRequest(HttpWebRequest request, Connection connection, int flags) { +#if DEBUG + // null if the connection monitor is off + if(s_RequestList == null) + return; + + lock(s_RequestList) { + if(!s_RequestList.ContainsKey(request)) { + return; + } + + ConnectionMonitorEntry requestEntry = + new ConnectionMonitorEntry(request, connection, flags); + + try { + s_RequestList.Remove(request); + s_RequestList.Add(request, requestEntry); + } catch { + } + } +#endif + }*/ + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/HttpListenerContext.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/HttpListenerContext.cs new file mode 100644 index 0000000000..9b9335906f --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/HttpListenerContext.cs @@ -0,0 +1,67 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Security.Principal; + +namespace Microsoft.AspNet.Security.Windows +{ + // TODO: At what point does a user need to be cleaned up? + internal sealed class HttpListenerContext + { + private WindowsAuthMiddleware _winAuth; + private IPrincipal _user = null; + + internal const string NTLM = "NTLM"; + + internal HttpListenerContext(WindowsAuthMiddleware httpListener) + { + _winAuth = httpListener; + } + + internal void Close() + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "Close()", string.Empty); + } + + IDisposable user = _user == null ? null : _user.Identity as IDisposable; + + // TODO: At what point does a user need to be cleaned up? + + // For unsafe connection ntlm auth we dont dispose this identity as yet since its cached + if ((user != null) && + (_user.Identity.AuthenticationType != NTLM) && + (!_winAuth.UnsafeConnectionNtlmAuthentication)) + { + user.Dispose(); + } + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "Close", string.Empty); + } + } + + internal void Abort() + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "Abort", string.Empty); + } + + IDisposable user = _user == null ? null : _user.Identity as IDisposable; + if (user != null) + { + user.Dispose(); + } + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "Abort", string.Empty); + } + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/Internal.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/Internal.cs new file mode 100644 index 0000000000..4d922591dd --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/Internal.cs @@ -0,0 +1,286 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Net.Security; +using System.Runtime.InteropServices; +using System.Runtime.Versioning; +using System.Security.Authentication.ExtendedProtection; +using System.Security.Cryptography.X509Certificates; +using System.Security.Permissions; + +namespace Microsoft.AspNet.Security.Windows +{ + internal enum SecurityStatus + { + // Success / Informational + OK = 0x00000000, + ContinueNeeded = unchecked((int)0x00090312), + CompleteNeeded = unchecked((int)0x00090313), + CompAndContinue = unchecked((int)0x00090314), + ContextExpired = unchecked((int)0x00090317), + CredentialsNeeded = unchecked((int)0x00090320), + Renegotiate = unchecked((int)0x00090321), + + // Errors + OutOfMemory = unchecked((int)0x80090300), + InvalidHandle = unchecked((int)0x80090301), + Unsupported = unchecked((int)0x80090302), + TargetUnknown = unchecked((int)0x80090303), + InternalError = unchecked((int)0x80090304), + PackageNotFound = unchecked((int)0x80090305), + NotOwner = unchecked((int)0x80090306), + CannotInstall = unchecked((int)0x80090307), + InvalidToken = unchecked((int)0x80090308), + CannotPack = unchecked((int)0x80090309), + QopNotSupported = unchecked((int)0x8009030A), + NoImpersonation = unchecked((int)0x8009030B), + LogonDenied = unchecked((int)0x8009030C), + UnknownCredentials = unchecked((int)0x8009030D), + NoCredentials = unchecked((int)0x8009030E), + MessageAltered = unchecked((int)0x8009030F), + OutOfSequence = unchecked((int)0x80090310), + NoAuthenticatingAuthority = unchecked((int)0x80090311), + IncompleteMessage = unchecked((int)0x80090318), + IncompleteCredentials = unchecked((int)0x80090320), + BufferNotEnough = unchecked((int)0x80090321), + WrongPrincipal = unchecked((int)0x80090322), + TimeSkew = unchecked((int)0x80090324), + UntrustedRoot = unchecked((int)0x80090325), + IllegalMessage = unchecked((int)0x80090326), + CertUnknown = unchecked((int)0x80090327), + CertExpired = unchecked((int)0x80090328), + AlgorithmMismatch = unchecked((int)0x80090331), + SecurityQosFailed = unchecked((int)0x80090332), + SmartcardLogonRequired = unchecked((int)0x8009033E), + UnsupportedPreauth = unchecked((int)0x80090343), + BadBinding = unchecked((int)0x80090346) + } + + internal enum ContextAttribute + { + // look into and + Sizes = 0x00, + Names = 0x01, + Lifespan = 0x02, + DceInfo = 0x03, + StreamSizes = 0x04, + // KeyInfo = 0x05, must not be used, see ConnectionInfo instead + Authority = 0x06, + // SECPKG_ATTR_PROTO_INFO = 7, + // SECPKG_ATTR_PASSWORD_EXPIRY = 8, + // SECPKG_ATTR_SESSION_KEY = 9, + PackageInfo = 0x0A, + // SECPKG_ATTR_USER_FLAGS = 11, + NegotiationInfo = 0x0C, + // SECPKG_ATTR_NATIVE_NAMES = 13, + // SECPKG_ATTR_FLAGS = 14, + // SECPKG_ATTR_USE_VALIDATED = 15, + // SECPKG_ATTR_CREDENTIAL_NAME = 16, + // SECPKG_ATTR_TARGET_INFORMATION = 17, + // SECPKG_ATTR_ACCESS_TOKEN = 18, + // SECPKG_ATTR_TARGET = 19, + // SECPKG_ATTR_AUTHENTICATION_ID = 20, + UniqueBindings = 0x19, + EndpointBindings = 0x1A, + ClientSpecifiedSpn = 0x1B, // SECPKG_ATTR_CLIENT_SPECIFIED_TARGET = 27 + RemoteCertificate = 0x53, + LocalCertificate = 0x54, + RootStore = 0x55, + IssuerListInfoEx = 0x59, + ConnectionInfo = 0x5A, + // SECPKG_ATTR_EAP_KEY_BLOCK 0x5b // returns SecPkgContext_EapKeyBlock + // SECPKG_ATTR_MAPPED_CRED_ATTR 0x5c // returns SecPkgContext_MappedCredAttr + // SECPKG_ATTR_SESSION_INFO 0x5d // returns SecPkgContext_SessionInfo + // SECPKG_ATTR_APP_DATA 0x5e // sets/returns SecPkgContext_SessionAppData + // SECPKG_ATTR_REMOTE_CERTIFICATES 0x5F // returns SecPkgContext_Certificates + // SECPKG_ATTR_CLIENT_CERT_POLICY 0x60 // sets SecPkgCred_ClientCertCtlPolicy + // SECPKG_ATTR_CC_POLICY_RESULT 0x61 // returns SecPkgContext_ClientCertPolicyResult + // SECPKG_ATTR_USE_NCRYPT 0x62 // Sets the CRED_FLAG_USE_NCRYPT_PROVIDER FLAG on cred group + // SECPKG_ATTR_LOCAL_CERT_INFO 0x63 // returns SecPkgContext_CertInfo + // SECPKG_ATTR_CIPHER_INFO 0x64 // returns new CNG SecPkgContext_CipherInfo + // SECPKG_ATTR_EAP_PRF_INFO 0x65 // sets SecPkgContext_EapPrfInfo + // SECPKG_ATTR_SUPPORTED_SIGNATURES 0x66 // returns SecPkgContext_SupportedSignatures + // SECPKG_ATTR_REMOTE_CERT_CHAIN 0x67 // returns PCCERT_CONTEXT + UiInfo = 0x68, // sets SEcPkgContext_UiInfo + } + + internal enum Endianness + { + Network = 0x00, + Native = 0x10, + } + + internal enum CredentialUse + { + Inbound = 0x1, + Outbound = 0x2, + Both = 0x3, + } + + internal enum BufferType + { + Empty = 0x00, + Data = 0x01, + Token = 0x02, + Parameters = 0x03, + Missing = 0x04, + Extra = 0x05, + Trailer = 0x06, + Header = 0x07, + Padding = 0x09, // non-data padding + Stream = 0x0A, + ChannelBindings = 0x0E, + TargetHost = 0x10, + ReadOnlyFlag = unchecked((int)0x80000000), + ReadOnlyWithChecksum = 0x10000000 + } + + // SecPkgContext_IssuerListInfoEx + [StructLayout(LayoutKind.Sequential)] + internal struct IssuerListInfoEx + { + public SafeHandle aIssuers; + public uint cIssuers; + + public unsafe IssuerListInfoEx(SafeHandle handle, byte[] nativeBuffer) + { + aIssuers = handle; + fixed (byte* voidPtr = nativeBuffer) + { + // if this breaks on 64 bit, do the sizeof(IntPtr) trick + cIssuers = *((uint*)(voidPtr + IntPtr.Size)); + } + } + } + + [StructLayout(LayoutKind.Sequential)] + internal struct SecureCredential + { + /* + typedef struct _SCHANNEL_CRED + { + DWORD dwVersion; // always SCHANNEL_CRED_VERSION + DWORD cCreds; + PCCERT_CONTEXT *paCred; + HCERTSTORE hRootStore; + + DWORD cMappers; + struct _HMAPPER **aphMappers; + + DWORD cSupportedAlgs; + ALG_ID * palgSupportedAlgs; + + DWORD grbitEnabledProtocols; + DWORD dwMinimumCipherStrength; + DWORD dwMaximumCipherStrength; + DWORD dwSessionLifespan; + DWORD dwFlags; + DWORD reserved; + } SCHANNEL_CRED, *PSCHANNEL_CRED; + */ + + public const int CurrentVersion = 0x4; + + public int version; + public int cCreds; + + // ptr to an array of pointers + // There is a hack done with this field. AcquireCredentialsHandle requires an array of + // certificate handles; we only ever use one. In order to avoid pinning a one element array, + // we copy this value onto the stack, create a pointer on the stack to the copied value, + // and replace this field with the pointer, during the call to AcquireCredentialsHandle. + // Then we fix it up afterwards. Fine as long as all the SSPI credentials are not + // supposed to be threadsafe. + public IntPtr certContextArray; + + private readonly IntPtr rootStore; // == always null, OTHERWISE NOT RELIABLE + public int cMappers; + private readonly IntPtr phMappers; // == always null, OTHERWISE NOT RELIABLE + public int cSupportedAlgs; + private readonly IntPtr palgSupportedAlgs; // == always null, OTHERWISE NOT RELIABLE + public SchProtocols grbitEnabledProtocols; + public int dwMinimumCipherStrength; + public int dwMaximumCipherStrength; + public int dwSessionLifespan; + public SecureCredential.Flags dwFlags; + public int reserved; + + public SecureCredential(int version, X509Certificate certificate, SecureCredential.Flags flags, SchProtocols protocols, EncryptionPolicy policy) + { + // default values required for a struct + rootStore = phMappers = palgSupportedAlgs = certContextArray = IntPtr.Zero; + cCreds = cMappers = cSupportedAlgs = 0; + + if (policy == EncryptionPolicy.RequireEncryption) + { + // Prohibit null encryption cipher + dwMinimumCipherStrength = 0; + dwMaximumCipherStrength = 0; + } + else if (policy == EncryptionPolicy.AllowNoEncryption) + { + // Allow null encryption cipher in addition to other ciphers + dwMinimumCipherStrength = -1; + dwMaximumCipherStrength = 0; + } + else if (policy == EncryptionPolicy.NoEncryption) + { + // Suppress all encryption and require null encryption cipher only + dwMinimumCipherStrength = -1; + dwMaximumCipherStrength = -1; + } + else + { + throw new ArgumentException(SR.GetString(SR.net_invalid_enum, "EncryptionPolicy"), "policy"); + } + + dwSessionLifespan = reserved = 0; + this.version = version; + dwFlags = flags; + grbitEnabledProtocols = protocols; + if (certificate != null) + { + certContextArray = certificate.Handle; + cCreds = 1; + } + } + + [Flags] + public enum Flags + { + Zero = 0, + NoSystemMapper = 0x02, + NoNameCheck = 0x04, + ValidateManual = 0x08, + NoDefaultCred = 0x10, + ValidateAuto = 0x20 + } + } // SecureCredential + + [SuppressMessage("Microsoft.Design", "CA1049:TypesThatOwnNativeResourcesShouldBeDisposable", + Justification = "This structure does not own the native resource.")] + [StructLayout(LayoutKind.Sequential)] + internal unsafe struct SecurityBufferStruct + { + public int count; + public BufferType type; + public IntPtr token; + + public static readonly int Size = sizeof(SecurityBufferStruct); + } + + internal static class IntPtrHelper + { + internal static IntPtr Add(IntPtr a, int b) + { + return (IntPtr)((long)a + (long)b); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/Logging.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/Logging.cs new file mode 100644 index 0000000000..81e66a3c30 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/Logging.cs @@ -0,0 +1,657 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Runtime.InteropServices; +using System.Security; +using System.Threading; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class Logging + { + private const string AttributeNameMaxSize = "maxdatasize"; + private const string AttributeNameTraceMode = "tracemode"; + private const string AttributeValueProtocolOnly = "protocolonly"; + // private const string AttributeValueIncludeHex = "includehex"; + + private const int DefaultMaxDumpSize = 1024; + private const bool DefaultUseProtocolTextOnly = false; + + private const string TraceSourceWebName = "System.Net"; + private const string TraceSourceHttpListenerName = "System.Net.HttpListener"; + + private static readonly string[] SupportedAttributes = new string[] { AttributeNameMaxSize, AttributeNameTraceMode }; + + private static volatile bool s_LoggingEnabled = true; + private static volatile bool s_LoggingInitialized; + private static volatile bool s_AppDomainShutdown; + + private static TraceSource s_WebTraceSource; + private static TraceSource s_HttpListenerTraceSource; + + private static object s_InternalSyncObject; + + private Logging() + { + } + + private static object InternalSyncObject + { + get + { + if (s_InternalSyncObject == null) + { + object o = new Object(); + Interlocked.CompareExchange(ref s_InternalSyncObject, o, null); + } + return s_InternalSyncObject; + } + } + + internal static bool On + { + get + { + if (!s_LoggingInitialized) + { + InitializeLogging(); + } + return s_LoggingEnabled; + } + } + + internal static bool IsVerbose(TraceSource traceSource) + { + return ValidateSettings(traceSource, TraceEventType.Verbose); + } + + internal static TraceSource Web + { + get + { + if (!s_LoggingInitialized) + { + InitializeLogging(); + } + if (!s_LoggingEnabled) + { + return null; + } + return s_WebTraceSource; + } + } + + internal static TraceSource HttpListener + { + get + { + if (!s_LoggingInitialized) + { + InitializeLogging(); + } + if (!s_LoggingEnabled) + { + return null; + } + return s_HttpListenerTraceSource; + } + } + + private static bool GetUseProtocolTextSetting(TraceSource traceSource) + { + bool useProtocolTextOnly = DefaultUseProtocolTextOnly; + if (traceSource.Attributes[AttributeNameTraceMode] == AttributeValueProtocolOnly) + { + useProtocolTextOnly = true; + } + return useProtocolTextOnly; + } + + private static int GetMaxDumpSizeSetting(TraceSource traceSource) + { + int maxDumpSize = DefaultMaxDumpSize; + if (traceSource.Attributes.ContainsKey(AttributeNameMaxSize)) + { + try + { + maxDumpSize = Int32.Parse(traceSource.Attributes[AttributeNameMaxSize], NumberFormatInfo.InvariantInfo); + } + catch (Exception exception) + { + if (exception is ThreadAbortException || exception is StackOverflowException || exception is OutOfMemoryException) + { + throw; + } + traceSource.Attributes[AttributeNameMaxSize] = maxDumpSize.ToString(NumberFormatInfo.InvariantInfo); + } + } + return maxDumpSize; + } + + /// + /// Sets up internal config settings for logging. (MUST be called under critsec) + /// + private static void InitializeLogging() + { + lock (InternalSyncObject) + { + if (!s_LoggingInitialized) + { + bool loggingEnabled = false; + s_WebTraceSource = new NclTraceSource(TraceSourceWebName); + s_HttpListenerTraceSource = new NclTraceSource(TraceSourceHttpListenerName); + + GlobalLog.Print("Initalizating tracing"); + + try + { + loggingEnabled = (s_WebTraceSource.Switch.ShouldTrace(TraceEventType.Critical) || + s_HttpListenerTraceSource.Switch.ShouldTrace(TraceEventType.Critical)); + } + catch (SecurityException) + { + // These may throw if the caller does not have permission to hook up trace listeners. + // We treat this case as though logging were disabled. + Close(); + loggingEnabled = false; + } + if (loggingEnabled) + { + AppDomain currentDomain = AppDomain.CurrentDomain; + currentDomain.UnhandledException += new UnhandledExceptionEventHandler(UnhandledExceptionHandler); + currentDomain.DomainUnload += new EventHandler(AppDomainUnloadEvent); + currentDomain.ProcessExit += new EventHandler(ProcessExitEvent); + } + s_LoggingEnabled = loggingEnabled; + s_LoggingInitialized = true; + } + } + } + + [SuppressMessage("Microsoft.Security", "CA2122:DoNotIndirectlyExposeMethodsWithLinkDemands", Justification = "Logging functions must work in partial trust mode")] + private static void Close() + { + if (s_WebTraceSource != null) + { + s_WebTraceSource.Close(); + } + if (s_HttpListenerTraceSource != null) + { + s_HttpListenerTraceSource.Close(); + } + } + + /// + /// Logs any unhandled exception through this event handler + /// + private static void UnhandledExceptionHandler(object sender, UnhandledExceptionEventArgs args) + { + Exception e = (Exception)args.ExceptionObject; + Exception(Web, sender, "UnhandledExceptionHandler", e); + } + + private static void ProcessExitEvent(object sender, EventArgs e) + { + Close(); + s_AppDomainShutdown = true; + } + + /// + /// Called when the system is shutting down, used to prevent additional logging post-shutdown + /// + private static void AppDomainUnloadEvent(object sender, EventArgs e) + { + Close(); + s_AppDomainShutdown = true; + } + + /// + /// Confirms logging is enabled, given current logging settings + /// + private static bool ValidateSettings(TraceSource traceSource, TraceEventType traceLevel) + { + if (!s_LoggingEnabled) + { + return false; + } + if (!s_LoggingInitialized) + { + InitializeLogging(); + } + if (traceSource == null || !traceSource.Switch.ShouldTrace(traceLevel)) + { + return false; + } + if (s_AppDomainShutdown) + { + return false; + } + return true; + } + + /// + /// Converts an object to a normalized string that can be printed + /// takes System.Net.ObjectNamedFoo and coverts to ObjectNamedFoo, + /// except IPAddress, IPEndPoint, and Uri, which return ToString() + /// + /// + private static string GetObjectName(object obj) + { + if (obj is Uri || obj is System.Net.IPAddress || obj is System.Net.IPEndPoint) + { + return obj.ToString(); + } + else + { + return obj.GetType().Name; + } + } + + internal static uint GetThreadId() + { + uint threadId = UnsafeNclNativeMethods.GetCurrentThreadId(); + if (threadId == 0) + { + threadId = (uint)Thread.CurrentThread.GetHashCode(); + } + return threadId; + } + + internal static void PrintLine(TraceSource traceSource, TraceEventType eventType, int id, string msg) + { + string logHeader = "[" + GetThreadId().ToString("d4", CultureInfo.InvariantCulture) + "] "; + traceSource.TraceEvent(eventType, id, logHeader + msg); + } + + /// + /// Indicates that two objects are getting used with one another + /// + internal static void Associate(TraceSource traceSource, object objA, object objB) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + + string lineA = GetObjectName(objA) + "#" + ValidationHelper.HashString(objA); + string lineB = GetObjectName(objB) + "#" + ValidationHelper.HashString(objB); + + PrintLine(traceSource, TraceEventType.Information, 0, "Associating " + lineA + " with " + lineB); + } + + /// + /// Logs entrance to a function + /// + internal static void Enter(TraceSource traceSource, object obj, string method, string param) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + Enter(traceSource, GetObjectName(obj) + "#" + ValidationHelper.HashString(obj), method, param); + } + + /// + /// Logs entrance to a function + /// + internal static void Enter(TraceSource traceSource, object obj, string method, object paramObject) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + Enter(traceSource, GetObjectName(obj) + "#" + ValidationHelper.HashString(obj), method, paramObject); + } + + /// + /// Logs entrance to a function + /// + internal static void Enter(TraceSource traceSource, string obj, string method, string param) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + Enter(traceSource, obj + "::" + method + "(" + param + ")"); + } + + /// + /// Logs entrance to a function + /// + internal static void Enter(TraceSource traceSource, string obj, string method, object paramObject) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + string paramObjectValue = string.Empty; + if (paramObject != null) + { + paramObjectValue = GetObjectName(paramObject) + "#" + ValidationHelper.HashString(paramObject); + } + Enter(traceSource, obj + "::" + method + "(" + paramObjectValue + ")"); + } + + /// + /// Logs entrance to a function, indents and points that out + /// + internal static void Enter(TraceSource traceSource, string method, string parameters) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + Enter(traceSource, method + "(" + parameters + ")"); + } + + /// + /// Logs entrance to a function, indents and points that out + /// + internal static void Enter(TraceSource traceSource, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + // Trace.CorrelationManager.StartLogicalOperation(); + PrintLine(traceSource, TraceEventType.Verbose, 0, msg); + } + + /// + /// Logs exit from a function + /// + internal static void Exit(TraceSource traceSource, object obj, string method, object retObject) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + string retValue = string.Empty; + if (retObject != null) + { + retValue = GetObjectName(retObject) + "#" + ValidationHelper.HashString(retObject); + } + Exit(traceSource, obj, method, retValue); + } + + /// + /// Logs exit from a function + /// + internal static void Exit(TraceSource traceSource, string obj, string method, object retObject) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + string retValue = string.Empty; + if (retObject != null) + { + retValue = GetObjectName(retObject) + "#" + ValidationHelper.HashString(retObject); + } + Exit(traceSource, obj, method, retValue); + } + + /// + /// Logs exit from a function + /// + internal static void Exit(TraceSource traceSource, object obj, string method, string retValue) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + Exit(traceSource, GetObjectName(obj) + "#" + ValidationHelper.HashString(obj), method, retValue); + } + + /// + /// Logs exit from a function + /// + internal static void Exit(TraceSource traceSource, string obj, string method, string retValue) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + if (!string.IsNullOrEmpty(retValue)) + { + retValue = "\t-> " + retValue; + } + Exit(traceSource, obj + "::" + method + "() " + retValue); + } + + /// + /// Logs exit from a function + /// + internal static void Exit(TraceSource traceSource, string method, string parameters) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + Exit(traceSource, method + "() " + parameters); + } + + /// + /// Logs exit from a function + /// + internal static void Exit(TraceSource traceSource, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + PrintLine(traceSource, TraceEventType.Verbose, 0, "Exiting " + msg); + // Trace.CorrelationManager.StopLogicalOperation(); + } + + /// + /// Logs Exception, restores indenting + /// + internal static void Exception(TraceSource traceSource, object obj, string method, Exception e) + { + if (!ValidateSettings(traceSource, TraceEventType.Error)) + { + return; + } + + string infoLine = SR.GetString(SR.net_log_exception, GetObjectLogHash(obj), method, e.Message); + if (!string.IsNullOrEmpty(e.StackTrace)) + { + infoLine += "\r\n" + e.StackTrace; + } + PrintLine(traceSource, TraceEventType.Error, 0, infoLine); + } + + /// + /// Logs an Info line + /// + internal static void PrintInfo(TraceSource traceSource, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + PrintLine(traceSource, TraceEventType.Information, 0, msg); + } + + /// + /// Logs an Info line + /// + internal static void PrintInfo(TraceSource traceSource, object obj, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + PrintLine(traceSource, TraceEventType.Information, 0, + GetObjectName(obj) + "#" + ValidationHelper.HashString(obj) + + " - " + msg); + } + + /// + /// Logs an Info line + /// + internal static void PrintInfo(TraceSource traceSource, object obj, string method, string param) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + PrintLine(traceSource, TraceEventType.Information, 0, + GetObjectName(obj) + "#" + ValidationHelper.HashString(obj) + + "::" + method + "(" + param + ")"); + } + + /// + /// Logs a Warning line + /// + internal static void PrintWarning(TraceSource traceSource, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Warning)) + { + return; + } + PrintLine(traceSource, TraceEventType.Warning, 0, msg); + } + + /// + /// Logs a Warning line + /// + internal static void PrintWarning(TraceSource traceSource, object obj, string method, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Warning)) + { + return; + } + PrintLine(traceSource, TraceEventType.Warning, 0, + GetObjectName(obj) + "#" + ValidationHelper.HashString(obj) + + "::" + method + "() - " + msg); + } + + /// + /// Logs an Error line + /// + internal static void PrintError(TraceSource traceSource, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Error)) + { + return; + } + PrintLine(traceSource, TraceEventType.Error, 0, msg); + } + + /// + /// Logs an Error line + /// + internal static void PrintError(TraceSource traceSource, object obj, string method, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Error)) + { + return; + } + PrintLine(traceSource, TraceEventType.Error, 0, + GetObjectName(obj) + "#" + ValidationHelper.HashString(obj) + + "::" + method + "() - " + msg); + } + + internal static string GetObjectLogHash(object obj) + { + return GetObjectName(obj) + "#" + ValidationHelper.HashString(obj); + } + + /// + /// Marhsalls a buffer ptr to an array and then dumps the byte array to the log + /// + internal static void Dump(TraceSource traceSource, object obj, string method, IntPtr bufferPtr, int length) + { + if (!ValidateSettings(traceSource, TraceEventType.Verbose) || bufferPtr == IntPtr.Zero || length < 0) + { + return; + } + byte[] buffer = new byte[length]; + Marshal.Copy(bufferPtr, buffer, 0, length); + Dump(traceSource, obj, method, buffer, 0, length); + } + + /// + /// Dumps a byte array to the log + /// + internal static void Dump(TraceSource traceSource, object obj, string method, byte[] buffer, int offset, int length) + { + if (!ValidateSettings(traceSource, TraceEventType.Verbose)) + { + return; + } + if (buffer == null) + { + PrintLine(traceSource, TraceEventType.Verbose, 0, "(null)"); + return; + } + if (offset > buffer.Length) + { + PrintLine(traceSource, TraceEventType.Verbose, 0, "(offset out of range)"); + return; + } + PrintLine(traceSource, TraceEventType.Verbose, 0, "Data from " + GetObjectName(obj) + "#" + ValidationHelper.HashString(obj) + "::" + method); + int maxDumpSize = GetMaxDumpSizeSetting(traceSource); + if (length > maxDumpSize) + { + PrintLine(traceSource, TraceEventType.Verbose, 0, "(printing " + maxDumpSize.ToString(NumberFormatInfo.InvariantInfo) + " out of " + length.ToString(NumberFormatInfo.InvariantInfo) + ")"); + length = maxDumpSize; + } + if ((length < 0) || (length > buffer.Length - offset)) + { + length = buffer.Length - offset; + } + if (GetUseProtocolTextSetting(traceSource)) + { + string output = "<<" + HeaderEncoding.GetString(buffer, offset, length) + ">>"; + PrintLine(traceSource, TraceEventType.Verbose, 0, output); + return; + } + do + { + int n = Math.Min(length, 16); + string disp = String.Format(CultureInfo.CurrentCulture, "{0:X8} : ", offset); + for (int i = 0; i < n; ++i) + { + disp += String.Format(CultureInfo.CurrentCulture, "{0:X2}", buffer[offset + i]) + ((i == 7) ? '-' : ' '); + } + for (int i = n; i < 16; ++i) + { + disp += " "; + } + disp += ": "; + for (int i = 0; i < n; ++i) + { + disp += ((buffer[offset + i] < 0x20) || (buffer[offset + i] > 0x7e)) + ? '.' + : (char)(buffer[offset + i]); + } + PrintLine(traceSource, TraceEventType.Verbose, 0, disp); + offset += n; + length -= n; + } + while (length > 0); + } + + private class NclTraceSource : TraceSource + { + internal NclTraceSource(string name) : base(name) + { + } + /* + protected internal override string[] GetSupportedAttributes() + { + return Logging.SupportedAttributes; + }*/ + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/LoggingObject.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/LoggingObject.cs new file mode 100644 index 0000000000..dbbd6d5415 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/LoggingObject.cs @@ -0,0 +1,580 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +// We have function based stack and thread based logging of basic behavior. We +// also now have the ability to run a "watch thread" which does basic hang detection +// and error-event based logging. The logging code buffers the callstack/picture +// of all COMNET threads, and upon error from an assert or a hang, it will open a file +// and dump the snapsnot. Future work will allow this to be configed by registry and +// to use Runtime based logging. We'd also like to have different levels of logging. + +namespace Microsoft.AspNet.Security.Windows +{ + using System; + + // BaseLoggingObject - used to disable logging, + // this is just a base class that does nothing. + + [Flags] + internal enum ThreadKinds + { + Unknown = 0x0000, + + // Mutually exclusive. + User = 0x0001, // Thread has entered via an API. + System = 0x0002, // Thread has entered via a system callback (e.g. completion port) or is our own thread. + + // Mutually exclusive. + Sync = 0x0004, // Thread should block. + Async = 0x0008, // Thread should not block. + + // Mutually exclusive, not always known for a user thread. Never changes. + Timer = 0x0010, // Thread is the timer thread. (Can't call user code.) + CompletionPort = 0x0020, // Thread is a ThreadPool completion-port thread. + Worker = 0x0040, // Thread is a ThreadPool worker thread. + Finalization = 0x0080, // Thread is the finalization thread. + Other = 0x0100, // Unknown source. + + OwnerMask = User | System, + SyncMask = Sync | Async, + SourceMask = Timer | CompletionPort | Worker | Finalization | Other, + + // Useful "macros" + SafeSources = SourceMask & ~(Timer | Finalization), // Methods that "unsafe" sources can call must be explicitly marked. + ThreadPool = CompletionPort | Worker, // Like Thread.CurrentThread.IsThreadPoolThread + } + + internal class BaseLoggingObject + { + internal BaseLoggingObject() + { + } + + internal virtual void EnterFunc(string funcname) + { + } + + internal virtual void LeaveFunc(string funcname) + { + } + + internal virtual void DumpArrayToConsole() + { + } + + internal virtual void PrintLine(string msg) + { + } + + internal virtual void DumpArray(bool shouldClose) + { + } + + internal virtual void DumpArrayToFile(bool shouldClose) + { + } + + internal virtual void Flush() + { + } + + internal virtual void Flush(bool close) + { + } + + internal virtual void LoggingMonitorTick() + { + } + + internal virtual void Dump(byte[] buffer) + { + } + + internal virtual void Dump(byte[] buffer, int length) + { + } + + internal virtual void Dump(byte[] buffer, int offset, int length) + { + } + + internal virtual void Dump(IntPtr pBuffer, int offset, int length) + { + } + } // class BaseLoggingObject + +#if TRAVE + /// + /// + /// + internal class LoggingObject : BaseLoggingObject { + public ArrayList _Logarray; + private Hashtable _ThreadNesting; + private int _AddCount; + private StreamWriter _Stream; + private int _IamAlive; + private int _LastIamAlive; + private bool _Finalized = false; + private double _NanosecondsPerTick; + private int _StartMilliseconds; + private long _StartTicks; + + internal LoggingObject() : base() { + _Logarray = new ArrayList(); + _ThreadNesting = new Hashtable(); + _AddCount = 0; + _IamAlive = 0; + _LastIamAlive = -1; + + if (GlobalLog.s_UsePerfCounter) { + long ticksPerSecond; + SafeNativeMethods.QueryPerformanceFrequency(out ticksPerSecond); + _NanosecondsPerTick = 10000000.0/(double)ticksPerSecond; + SafeNativeMethods.QueryPerformanceCounter(out _StartTicks); + } else { + _StartMilliseconds = Environment.TickCount; + } + } + + // + // LoggingMonitorTick - this function is run from the monitor thread, + // and used to check to see if there any hangs, ie no logging + // activitity + // + + internal override void LoggingMonitorTick() { + if ( _LastIamAlive == _IamAlive ) { + PrintLine("================= Error TIMEOUT - HANG DETECTED ================="); + DumpArray(true); + } + _LastIamAlive = _IamAlive; + } + + internal override void EnterFunc(string funcname) { + if (_Finalized) { + return; + } + IncNestingCount(); + ValidatePush(funcname); + PrintLine(funcname); + } + + internal override void LeaveFunc(string funcname) { + if (_Finalized) { + return; + } + PrintLine(funcname); + DecNestingCount(); + ValidatePop(funcname); + } + + internal override void DumpArrayToConsole() { + for (int i=0; i < _Logarray.Count; i++) { + Console.WriteLine((string) _Logarray[i]); + } + } + + internal override void PrintLine(string msg) { + if (_Finalized) { + return; + } + string spc = ""; + + _IamAlive++; + + spc = GetNestingString(); + + string tickString = ""; + + if (GlobalLog.s_UsePerfCounter) { + long nowTicks; + SafeNativeMethods.QueryPerformanceCounter(out nowTicks); + if (_StartTicks>nowTicks) { // counter reset, restart from 0 + _StartTicks = nowTicks; + } + nowTicks -= _StartTicks; + if (GlobalLog.s_UseTimeSpan) { + tickString = new TimeSpan((long)(nowTicks*_NanosecondsPerTick)).ToString(); + // note: TimeSpan().ToString() doesn't return the uSec part + // if its 0. .ToString() returns [H*]HH:MM:SS:uuuuuuu, hence 16 + if (tickString.Length < 16) { + tickString += ".0000000"; + } + } + else { + tickString = ((double)nowTicks*_NanosecondsPerTick/10000).ToString("f3"); + } + } + else { + int nowMilliseconds = Environment.TickCount; + if (_StartMilliseconds>nowMilliseconds) { + _StartMilliseconds = nowMilliseconds; + } + nowMilliseconds -= _StartMilliseconds; + if (GlobalLog.s_UseTimeSpan) { + tickString = new TimeSpan(nowMilliseconds*10000).ToString(); + // note: TimeSpan().ToString() doesn't return the uSec part + // if its 0. .ToString() returns [H*]HH:MM:SS:uuuuuuu, hence 16 + if (tickString.Length < 16) { + tickString += ".0000000"; + } + } + else { + tickString = nowMilliseconds.ToString(); + } + } + + uint threadId = 0; + + if (GlobalLog.s_UseThreadId) { + try { + object threadData = Thread.GetData(GlobalLog.s_ThreadIdSlot); + if (threadData!= null) { + threadId = (uint)threadData; + } + + } + catch(Exception exception) { + if (exception is ThreadAbortException || exception is StackOverflowException || exception is OutOfMemoryException) { + throw; + } + } + if (threadId == 0) { + threadId = UnsafeNclNativeMethods.GetCurrentThreadId(); + Thread.SetData(GlobalLog.s_ThreadIdSlot, threadId); + } + } + if (threadId == 0) { + threadId = (uint)Thread.CurrentThread.GetHashCode(); + } + + string str = "[" + threadId.ToString("x8") + "]" + " (" +tickString+ ") " + spc + msg; + + lock(this) { + _AddCount++; + _Logarray.Add(str); + int MaxLines = GlobalLog.s_DumpToConsole ? 0 : GlobalLog.MaxLinesBeforeSave; + if (_AddCount > MaxLines) { + _AddCount = 0; + DumpArray(false); + _Logarray = new ArrayList(); + } + } + } + + internal override void DumpArray(bool shouldClose) { + if ( GlobalLog.s_DumpToConsole ) { + DumpArrayToConsole(); + } else { + DumpArrayToFile(shouldClose); + } + } + + internal unsafe override void Dump(byte[] buffer, int offset, int length) { + //if (!GlobalLog.s_DumpWebData) { + // return; + //} + if (buffer==null) { + PrintLine("(null)"); + return; + } + if (offset > buffer.Length) { + PrintLine("(offset out of range)"); + return; + } + if (length > GlobalLog.s_MaxDumpSize) { + PrintLine("(printing " + GlobalLog.s_MaxDumpSize.ToString() + " out of " + length.ToString() + ")"); + length = GlobalLog.s_MaxDumpSize; + } + if ((length < 0) || (length > buffer.Length - offset)) { + length = buffer.Length - offset; + } + fixed (byte* pBuffer = buffer) { + Dump((IntPtr)pBuffer, offset, length); + } + } + + internal unsafe override void Dump(IntPtr pBuffer, int offset, int length) { + //if (!GlobalLog.s_DumpWebData) { + // return; + //} + if (pBuffer==IntPtr.Zero || length<0) { + PrintLine("(null)"); + return; + } + if (length > GlobalLog.s_MaxDumpSize) { + PrintLine("(printing " + GlobalLog.s_MaxDumpSize.ToString() + " out of " + length.ToString() + ")"); + length = GlobalLog.s_MaxDumpSize; + } + byte* buffer = (byte*)pBuffer + offset; + Dump(buffer, length); + } + + unsafe void Dump(byte* buffer, int length) { + do { + int offset = 0; + int n = Math.Min(length, 16); + string disp = ((IntPtr)buffer).ToString("X8") + " : " + offset.ToString("X8") + " : "; + byte current; + for (int i = 0; i < n; ++i) { + current = *(buffer + i); + disp += current.ToString("X2") + ((i == 7) ? '-' : ' '); + } + for (int i = n; i < 16; ++i) { + disp += " "; + } + disp += ": "; + for (int i = 0; i < n; ++i) { + current = *(buffer + i); + disp += ((current < 0x20) || (current > 0x7e)) ? '.' : (char)current; + } + PrintLine(disp); + offset += n; + buffer += n; + length -= n; + } while (length > 0); + } + + // SECURITY: This is dev-debugging class and we need some permissions + // to use it under trust-restricted environment as well. + [PermissionSet(SecurityAction.Assert, Name="FullTrust")] + internal override void DumpArrayToFile(bool shouldClose) { + lock (this) { + if (!shouldClose) { + if (_Stream==null) { + string mainLogFileRoot = GlobalLog.s_RootDirectory + "System.Net"; + string mainLogFile = mainLogFileRoot; + for (int k=0; k<20; k++) { + if (k>0) { + mainLogFile = mainLogFileRoot + "." + k.ToString(); + } + string fileName = mainLogFile + ".log"; + if (!File.Exists(fileName)) { + try { + _Stream = new StreamWriter(fileName); + break; + } + catch (Exception exception) { + if (exception is ThreadAbortException || exception is StackOverflowException || exception is OutOfMemoryException) { + throw; + } + if (exception is SecurityException || exception is UnauthorizedAccessException) { + // can't be CAS (we assert) this is an ACL issue + break; + } + } + } + } + if (_Stream==null) { + _Stream = StreamWriter.Null; + } + // write a header with information about the Process and the AppDomain + _Stream.Write("# MachineName: " + Environment.MachineName + "\r\n"); + _Stream.Write("# ProcessName: " + Process.GetCurrentProcess().ProcessName + " (pid: " + Process.GetCurrentProcess().Id + ")\r\n"); + _Stream.Write("# AppDomainId: " + AppDomain.CurrentDomain.Id + "\r\n"); + _Stream.Write("# CurrentIdentity: " + WindowsIdentity.GetCurrent().Name + "\r\n"); + _Stream.Write("# CommandLine: " + Environment.CommandLine + "\r\n"); + _Stream.Write("# ClrVersion: " + Environment.Version + "\r\n"); + _Stream.Write("# CreationDate: " + DateTime.Now.ToString("g") + "\r\n"); + } + } + try { + if (_Logarray!=null) { + for (int i=0; i<_Logarray.Count; i++) { + _Stream.Write((string)_Logarray[i]); + _Stream.Write("\r\n"); + } + + if (_Logarray.Count > 0 && _Stream != null) + _Stream.Flush(); + } + } + catch (Exception exception) { + if (exception is ThreadAbortException || exception is StackOverflowException || exception is OutOfMemoryException) { + throw; + } + } + if (shouldClose && _Stream!=null) { + try { + _Stream.Close(); + } + catch (ObjectDisposedException) { } + _Stream = null; + } + } + } + + internal override void Flush() { + Flush(false); + } + + internal override void Flush(bool close) { + lock (this) { + if (!GlobalLog.s_DumpToConsole) { + DumpArrayToFile(close); + _AddCount = 0; + } + } + } + + private class ThreadInfoData { + public ThreadInfoData(string indent) { + Indent = indent; + NestingStack = new Stack(); + } + public string Indent; + public Stack NestingStack; + }; + + string IndentString { + get { + string indent = " "; + Object obj = _ThreadNesting[Thread.CurrentThread.GetHashCode()]; + if (!GlobalLog.s_DebugCallNesting) { + if (obj == null) { + _ThreadNesting[Thread.CurrentThread.GetHashCode()] = indent; + } else { + indent = (String) obj; + } + } else { + ThreadInfoData threadInfo = obj as ThreadInfoData; + if (threadInfo == null) { + threadInfo = new ThreadInfoData(indent); + _ThreadNesting[Thread.CurrentThread.GetHashCode()] = threadInfo; + } + indent = threadInfo.Indent; + } + return indent; + } + set { + Object obj = _ThreadNesting[Thread.CurrentThread.GetHashCode()]; + if (obj == null) { + return; + } + if (!GlobalLog.s_DebugCallNesting) { + _ThreadNesting[Thread.CurrentThread.GetHashCode()] = value; + } else { + ThreadInfoData threadInfo = obj as ThreadInfoData; + if (threadInfo == null) { + threadInfo = new ThreadInfoData(value); + _ThreadNesting[Thread.CurrentThread.GetHashCode()] = threadInfo; + } + threadInfo.Indent = value; + } + } + } + + [System.Diagnostics.Conditional("TRAVE")] + private void IncNestingCount() { + IndentString = IndentString + " "; + } + + [System.Diagnostics.Conditional("TRAVE")] + private void DecNestingCount() { + string indent = IndentString; + if (indent.Length>1) { + try { + indent = indent.Substring(1); + } + catch { + indent = string.Empty; + } + } + if (indent.Length==0) { + indent = "< "; + } + IndentString = indent; + } + + private string GetNestingString() { + return IndentString; + } + + [System.Diagnostics.Conditional("TRAVE")] + private void ValidatePush(string name) { + if (GlobalLog.s_DebugCallNesting) { + Object obj = _ThreadNesting[Thread.CurrentThread.GetHashCode()]; + ThreadInfoData threadInfo = obj as ThreadInfoData; + if (threadInfo == null) { + return; + } + threadInfo.NestingStack.Push(name); + } + } + + [System.Diagnostics.Conditional("TRAVE")] + private void ValidatePop(string name) { + if (GlobalLog.s_DebugCallNesting) { + try { + Object obj = _ThreadNesting[Thread.CurrentThread.GetHashCode()]; + ThreadInfoData threadInfo = obj as ThreadInfoData; + if (threadInfo == null) { + return; + } + if (threadInfo.NestingStack.Count == 0) { + PrintLine("++++====" + "Poped Empty Stack for :"+name); + } + string popedName = (string) threadInfo.NestingStack.Pop(); + string [] parsedList = popedName.Split(new char [] {'(',')',' ','.',':',',','#'}); + foreach (string element in parsedList) { + if (element != null && element.Length > 1 && name.IndexOf(element) != -1) { + return; + } + } + PrintLine("++++====" + "Expected:" + popedName + ": got :" + name + ": StackSize:"+threadInfo.NestingStack.Count); + // relevel the stack + while(threadInfo.NestingStack.Count>0) { + string popedName2 = (string) threadInfo.NestingStack.Pop(); + string [] parsedList2 = popedName2.Split(new char [] {'(',')',' ','.',':',',','#'}); + foreach (string element2 in parsedList2) { + if (element2 != null && element2.Length > 1 && name.IndexOf(element2) != -1) { + return; + } + } + } + } + catch { + PrintLine("++++====" + "ValidatePop failed for: "+name); + } + } + } + + + ~LoggingObject() { + if(!_Finalized) { + _Finalized = true; + lock(this) { + DumpArray(true); + } + } + } + + + } // class LoggingObject + + internal static class TraveHelper { + private static readonly string Hexizer = "0x{0:x}"; + internal static string ToHex(object value) { + return String.Format(Hexizer, value); + } + } +#endif // TRAVE + +#if TRAVE + internal class IntegerSwitch : BooleanSwitch { + public IntegerSwitch(string switchName, string switchDescription) : base(switchName, switchDescription) { + } + public new int Value { + get { + return base.SwitchSetting; + } + } + } + +#endif + + // class GlobalLog +} // namespace System.Net diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/SR.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/SR.cs new file mode 100644 index 0000000000..5486b33602 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/SR.cs @@ -0,0 +1,630 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace System +{ + using System; + using System.Reflection; + using System.Globalization; + using System.Resources; + using System.Text; + using System.Threading; + using System.Security.Permissions; + using System.ComponentModel; + + /// + /// AutoGenerated resource class. Usage: + /// string s = SR.GetString(SR.MyIdenfitier); + /// + internal sealed class SR + { + internal const string security_ExtendedProtection_NoOSSupport = "security_ExtendedProtection_NoOSSupport"; + internal const string net_nonClsCompliantException = "net_nonClsCompliantException"; + internal const string net_illegalConfigWith = "net_illegalConfigWith"; + internal const string net_illegalConfigWithout = "net_illegalConfigWithout"; + internal const string net_baddate = "net_baddate"; + internal const string net_writestarted = "net_writestarted"; + internal const string net_clsmall = "net_clsmall"; + internal const string net_reqsubmitted = "net_reqsubmitted"; + internal const string net_rspsubmitted = "net_rspsubmitted"; + internal const string net_ftp_no_http_cmd = "net_ftp_no_http_cmd"; + internal const string net_ftp_invalid_method_name = "net_ftp_invalid_method_name"; + internal const string net_ftp_invalid_renameto = "net_ftp_invalid_renameto"; + internal const string net_ftp_no_defaultcreds = "net_ftp_no_defaultcreds"; + internal const string net_ftpnoresponse = "net_ftpnoresponse"; + internal const string net_ftp_response_invalid_format = "net_ftp_response_invalid_format"; + internal const string net_ftp_no_offsetforhttp = "net_ftp_no_offsetforhttp"; + internal const string net_ftp_invalid_uri = "net_ftp_invalid_uri"; + internal const string net_ftp_invalid_status_response = "net_ftp_invalid_status_response"; + internal const string net_ftp_server_failed_passive = "net_ftp_server_failed_passive"; + internal const string net_ftp_active_address_different = "net_ftp_active_address_different"; + internal const string net_ftp_proxy_does_not_support_ssl = "net_ftp_proxy_does_not_support_ssl"; + internal const string net_ftp_invalid_response_filename = "net_ftp_invalid_response_filename"; + internal const string net_ftp_unsupported_method = "net_ftp_unsupported_method"; + internal const string net_resubmitcanceled = "net_resubmitcanceled"; + internal const string net_redirect_perm = "net_redirect_perm"; + internal const string net_resubmitprotofailed = "net_resubmitprotofailed"; + internal const string net_needchunked = "net_needchunked"; + internal const string net_nochunked = "net_nochunked"; + internal const string net_nochunkuploadonhttp10 = "net_nochunkuploadonhttp10"; + internal const string net_connarg = "net_connarg"; + internal const string net_no100 = "net_no100"; + internal const string net_fromto = "net_fromto"; + internal const string net_rangetoosmall = "net_rangetoosmall"; + internal const string net_entitytoobig = "net_entitytoobig"; + internal const string net_invalidversion = "net_invalidversion"; + internal const string net_invalidstatus = "net_invalidstatus"; + internal const string net_toosmall = "net_toosmall"; + internal const string net_toolong = "net_toolong"; + internal const string net_connclosed = "net_connclosed"; + internal const string net_noseek = "net_noseek"; + internal const string net_servererror = "net_servererror"; + internal const string net_nouploadonget = "net_nouploadonget"; + internal const string net_mutualauthfailed = "net_mutualauthfailed"; + internal const string net_invasync = "net_invasync"; + internal const string net_inasync = "net_inasync"; + internal const string net_mustbeuri = "net_mustbeuri"; + internal const string net_format_shexp = "net_format_shexp"; + internal const string net_cannot_load_proxy_helper = "net_cannot_load_proxy_helper"; + internal const string net_invalid_host = "net_invalid_host"; + internal const string net_repcall = "net_repcall"; + internal const string net_wrongversion = "net_wrongversion"; + internal const string net_badmethod = "net_badmethod"; + internal const string net_io_notenoughbyteswritten = "net_io_notenoughbyteswritten"; + internal const string net_io_timeout_use_ge_zero = "net_io_timeout_use_ge_zero"; + internal const string net_io_timeout_use_gt_zero = "net_io_timeout_use_gt_zero"; + internal const string net_io_no_0timeouts = "net_io_no_0timeouts"; + internal const string net_requestaborted = "net_requestaborted"; + internal const string net_tooManyRedirections = "net_tooManyRedirections"; + internal const string net_authmodulenotregistered = "net_authmodulenotregistered"; + internal const string net_authschemenotregistered = "net_authschemenotregistered"; + internal const string net_proxyschemenotsupported = "net_proxyschemenotsupported"; + internal const string net_maxsrvpoints = "net_maxsrvpoints"; + internal const string net_unknown_prefix = "net_unknown_prefix"; + internal const string net_notconnected = "net_notconnected"; + internal const string net_notstream = "net_notstream"; + internal const string net_timeout = "net_timeout"; + internal const string net_nocontentlengthonget = "net_nocontentlengthonget"; + internal const string net_contentlengthmissing = "net_contentlengthmissing"; + internal const string net_nonhttpproxynotallowed = "net_nonhttpproxynotallowed"; + internal const string net_nottoken = "net_nottoken"; + internal const string net_rangetype = "net_rangetype"; + internal const string net_need_writebuffering = "net_need_writebuffering"; + internal const string net_securitypackagesupport = "net_securitypackagesupport"; + internal const string net_securityprotocolnotsupported = "net_securityprotocolnotsupported"; + internal const string net_nodefaultcreds = "net_nodefaultcreds"; + internal const string net_stopped = "net_stopped"; + internal const string net_udpconnected = "net_udpconnected"; + internal const string net_readonlystream = "net_readonlystream"; + internal const string net_writeonlystream = "net_writeonlystream"; + internal const string net_no_concurrent_io_allowed = "net_no_concurrent_io_allowed"; + internal const string net_needmorethreads = "net_needmorethreads"; + internal const string net_MethodNotImplementedException = "net_MethodNotImplementedException"; + internal const string net_PropertyNotImplementedException = "net_PropertyNotImplementedException"; + internal const string net_MethodNotSupportedException = "net_MethodNotSupportedException"; + internal const string net_PropertyNotSupportedException = "net_PropertyNotSupportedException"; + internal const string net_ProtocolNotSupportedException = "net_ProtocolNotSupportedException"; + internal const string net_SelectModeNotSupportedException = "net_SelectModeNotSupportedException"; + internal const string net_InvalidSocketHandle = "net_InvalidSocketHandle"; + internal const string net_InvalidAddressFamily = "net_InvalidAddressFamily"; + internal const string net_InvalidEndPointAddressFamily = "net_InvalidEndPointAddressFamily"; + internal const string net_InvalidSocketAddressSize = "net_InvalidSocketAddressSize"; + internal const string net_invalidAddressList = "net_invalidAddressList"; + internal const string net_invalidPingBufferSize = "net_invalidPingBufferSize"; + internal const string net_cant_perform_during_shutdown = "net_cant_perform_during_shutdown"; + internal const string net_cant_create_environment = "net_cant_create_environment"; + internal const string net_completed_result = "net_completed_result"; + internal const string net_protocol_invalid_family = "net_protocol_invalid_family"; + internal const string net_protocol_invalid_multicast_family = "net_protocol_invalid_multicast_family"; + internal const string net_empty_osinstalltype = "net_empty_osinstalltype"; + internal const string net_unknown_osinstalltype = "net_unknown_osinstalltype"; + internal const string net_cant_determine_osinstalltype = "net_cant_determine_osinstalltype"; + internal const string net_osinstalltype = "net_osinstalltype"; + internal const string net_entire_body_not_written = "net_entire_body_not_written"; + internal const string net_must_provide_request_body = "net_must_provide_request_body"; + internal const string net_ssp_dont_support_cbt = "net_ssp_dont_support_cbt"; + internal const string net_sockets_zerolist = "net_sockets_zerolist"; + internal const string net_sockets_blocking = "net_sockets_blocking"; + internal const string net_sockets_useblocking = "net_sockets_useblocking"; + internal const string net_sockets_select = "net_sockets_select"; + internal const string net_sockets_toolarge_select = "net_sockets_toolarge_select"; + internal const string net_sockets_empty_select = "net_sockets_empty_select"; + internal const string net_sockets_mustbind = "net_sockets_mustbind"; + internal const string net_sockets_mustlisten = "net_sockets_mustlisten"; + internal const string net_sockets_mustnotlisten = "net_sockets_mustnotlisten"; + internal const string net_sockets_mustnotbebound = "net_sockets_mustnotbebound"; + internal const string net_sockets_namedmustnotbebound = "net_sockets_namedmustnotbebound"; + internal const string net_sockets_invalid_socketinformation = "net_sockets_invalid_socketinformation"; + internal const string net_sockets_invalid_ipaddress_length = "net_sockets_invalid_ipaddress_length"; + internal const string net_sockets_invalid_optionValue = "net_sockets_invalid_optionValue"; + internal const string net_sockets_invalid_optionValue_all = "net_sockets_invalid_optionValue_all"; + internal const string net_sockets_invalid_dnsendpoint = "net_sockets_invalid_dnsendpoint"; + internal const string net_sockets_disconnectedConnect = "net_sockets_disconnectedConnect"; + internal const string net_sockets_disconnectedAccept = "net_sockets_disconnectedAccept"; + internal const string net_tcplistener_mustbestopped = "net_tcplistener_mustbestopped"; + internal const string net_sockets_no_duplicate_async = "net_sockets_no_duplicate_async"; + internal const string net_socketopinprogress = "net_socketopinprogress"; + internal const string net_buffercounttoosmall = "net_buffercounttoosmall"; + internal const string net_multibuffernotsupported = "net_multibuffernotsupported"; + internal const string net_ambiguousbuffers = "net_ambiguousbuffers"; + internal const string net_sockets_ipv6only = "net_sockets_ipv6only"; + internal const string net_perfcounter_initialized_success = "net_perfcounter_initialized_success"; + internal const string net_perfcounter_initialized_error = "net_perfcounter_initialized_error"; + internal const string net_perfcounter_nocategory = "net_perfcounter_nocategory"; + internal const string net_perfcounter_initialization_started = "net_perfcounter_initialization_started"; + internal const string net_perfcounter_cant_queue_workitem = "net_perfcounter_cant_queue_workitem"; + internal const string net_config_proxy = "net_config_proxy"; + internal const string net_config_proxy_module_not_public = "net_config_proxy_module_not_public"; + internal const string net_config_authenticationmodules = "net_config_authenticationmodules"; + internal const string net_config_webrequestmodules = "net_config_webrequestmodules"; + internal const string net_config_requestcaching = "net_config_requestcaching"; + internal const string net_config_section_permission = "net_config_section_permission"; + internal const string net_config_element_permission = "net_config_element_permission"; + internal const string net_config_property_permission = "net_config_property_permission"; + internal const string net_WebResponseParseError_InvalidHeaderName = "net_WebResponseParseError_InvalidHeaderName"; + internal const string net_WebResponseParseError_InvalidContentLength = "net_WebResponseParseError_InvalidContentLength"; + internal const string net_WebResponseParseError_IncompleteHeaderLine = "net_WebResponseParseError_IncompleteHeaderLine"; + internal const string net_WebResponseParseError_CrLfError = "net_WebResponseParseError_CrLfError"; + internal const string net_WebResponseParseError_InvalidChunkFormat = "net_WebResponseParseError_InvalidChunkFormat"; + internal const string net_WebResponseParseError_UnexpectedServerResponse = "net_WebResponseParseError_UnexpectedServerResponse"; + internal const string net_webstatus_Success = "net_webstatus_Success"; + internal const string net_webstatus_NameResolutionFailure = "net_webstatus_NameResolutionFailure"; + internal const string net_webstatus_ConnectFailure = "net_webstatus_ConnectFailure"; + internal const string net_webstatus_ReceiveFailure = "net_webstatus_ReceiveFailure"; + internal const string net_webstatus_SendFailure = "net_webstatus_SendFailure"; + internal const string net_webstatus_PipelineFailure = "net_webstatus_PipelineFailure"; + internal const string net_webstatus_RequestCanceled = "net_webstatus_RequestCanceled"; + internal const string net_webstatus_ConnectionClosed = "net_webstatus_ConnectionClosed"; + internal const string net_webstatus_TrustFailure = "net_webstatus_TrustFailure"; + internal const string net_webstatus_SecureChannelFailure = "net_webstatus_SecureChannelFailure"; + internal const string net_webstatus_ServerProtocolViolation = "net_webstatus_ServerProtocolViolation"; + internal const string net_webstatus_KeepAliveFailure = "net_webstatus_KeepAliveFailure"; + internal const string net_webstatus_ProxyNameResolutionFailure = "net_webstatus_ProxyNameResolutionFailure"; + internal const string net_webstatus_MessageLengthLimitExceeded = "net_webstatus_MessageLengthLimitExceeded"; + internal const string net_webstatus_CacheEntryNotFound = "net_webstatus_CacheEntryNotFound"; + internal const string net_webstatus_RequestProhibitedByCachePolicy = "net_webstatus_RequestProhibitedByCachePolicy"; + internal const string net_webstatus_Timeout = "net_webstatus_Timeout"; + internal const string net_webstatus_RequestProhibitedByProxy = "net_webstatus_RequestProhibitedByProxy"; + internal const string net_InvalidStatusCode = "net_InvalidStatusCode"; + internal const string net_ftpstatuscode_ServiceNotAvailable = "net_ftpstatuscode_ServiceNotAvailable"; + internal const string net_ftpstatuscode_CantOpenData = "net_ftpstatuscode_CantOpenData"; + internal const string net_ftpstatuscode_ConnectionClosed = "net_ftpstatuscode_ConnectionClosed"; + internal const string net_ftpstatuscode_ActionNotTakenFileUnavailableOrBusy = "net_ftpstatuscode_ActionNotTakenFileUnavailableOrBusy"; + internal const string net_ftpstatuscode_ActionAbortedLocalProcessingError = "net_ftpstatuscode_ActionAbortedLocalProcessingError"; + internal const string net_ftpstatuscode_ActionNotTakenInsufficentSpace = "net_ftpstatuscode_ActionNotTakenInsufficentSpace"; + internal const string net_ftpstatuscode_CommandSyntaxError = "net_ftpstatuscode_CommandSyntaxError"; + internal const string net_ftpstatuscode_ArgumentSyntaxError = "net_ftpstatuscode_ArgumentSyntaxError"; + internal const string net_ftpstatuscode_CommandNotImplemented = "net_ftpstatuscode_CommandNotImplemented"; + internal const string net_ftpstatuscode_BadCommandSequence = "net_ftpstatuscode_BadCommandSequence"; + internal const string net_ftpstatuscode_NotLoggedIn = "net_ftpstatuscode_NotLoggedIn"; + internal const string net_ftpstatuscode_AccountNeeded = "net_ftpstatuscode_AccountNeeded"; + internal const string net_ftpstatuscode_ActionNotTakenFileUnavailable = "net_ftpstatuscode_ActionNotTakenFileUnavailable"; + internal const string net_ftpstatuscode_ActionAbortedUnknownPageType = "net_ftpstatuscode_ActionAbortedUnknownPageType"; + internal const string net_ftpstatuscode_FileActionAborted = "net_ftpstatuscode_FileActionAborted"; + internal const string net_ftpstatuscode_ActionNotTakenFilenameNotAllowed = "net_ftpstatuscode_ActionNotTakenFilenameNotAllowed"; + internal const string net_httpstatuscode_NoContent = "net_httpstatuscode_NoContent"; + internal const string net_httpstatuscode_NonAuthoritativeInformation = "net_httpstatuscode_NonAuthoritativeInformation"; + internal const string net_httpstatuscode_ResetContent = "net_httpstatuscode_ResetContent"; + internal const string net_httpstatuscode_PartialContent = "net_httpstatuscode_PartialContent"; + internal const string net_httpstatuscode_MultipleChoices = "net_httpstatuscode_MultipleChoices"; + internal const string net_httpstatuscode_Ambiguous = "net_httpstatuscode_Ambiguous"; + internal const string net_httpstatuscode_MovedPermanently = "net_httpstatuscode_MovedPermanently"; + internal const string net_httpstatuscode_Moved = "net_httpstatuscode_Moved"; + internal const string net_httpstatuscode_Found = "net_httpstatuscode_Found"; + internal const string net_httpstatuscode_Redirect = "net_httpstatuscode_Redirect"; + internal const string net_httpstatuscode_SeeOther = "net_httpstatuscode_SeeOther"; + internal const string net_httpstatuscode_RedirectMethod = "net_httpstatuscode_RedirectMethod"; + internal const string net_httpstatuscode_NotModified = "net_httpstatuscode_NotModified"; + internal const string net_httpstatuscode_UseProxy = "net_httpstatuscode_UseProxy"; + internal const string net_httpstatuscode_TemporaryRedirect = "net_httpstatuscode_TemporaryRedirect"; + internal const string net_httpstatuscode_RedirectKeepVerb = "net_httpstatuscode_RedirectKeepVerb"; + internal const string net_httpstatuscode_BadRequest = "net_httpstatuscode_BadRequest"; + internal const string net_httpstatuscode_Unauthorized = "net_httpstatuscode_Unauthorized"; + internal const string net_httpstatuscode_PaymentRequired = "net_httpstatuscode_PaymentRequired"; + internal const string net_httpstatuscode_Forbidden = "net_httpstatuscode_Forbidden"; + internal const string net_httpstatuscode_NotFound = "net_httpstatuscode_NotFound"; + internal const string net_httpstatuscode_MethodNotAllowed = "net_httpstatuscode_MethodNotAllowed"; + internal const string net_httpstatuscode_NotAcceptable = "net_httpstatuscode_NotAcceptable"; + internal const string net_httpstatuscode_ProxyAuthenticationRequired = "net_httpstatuscode_ProxyAuthenticationRequired"; + internal const string net_httpstatuscode_RequestTimeout = "net_httpstatuscode_RequestTimeout"; + internal const string net_httpstatuscode_Conflict = "net_httpstatuscode_Conflict"; + internal const string net_httpstatuscode_Gone = "net_httpstatuscode_Gone"; + internal const string net_httpstatuscode_LengthRequired = "net_httpstatuscode_LengthRequired"; + internal const string net_httpstatuscode_InternalServerError = "net_httpstatuscode_InternalServerError"; + internal const string net_httpstatuscode_NotImplemented = "net_httpstatuscode_NotImplemented"; + internal const string net_httpstatuscode_BadGateway = "net_httpstatuscode_BadGateway"; + internal const string net_httpstatuscode_ServiceUnavailable = "net_httpstatuscode_ServiceUnavailable"; + internal const string net_httpstatuscode_GatewayTimeout = "net_httpstatuscode_GatewayTimeout"; + internal const string net_httpstatuscode_HttpVersionNotSupported = "net_httpstatuscode_HttpVersionNotSupported"; + internal const string net_uri_BadScheme = "net_uri_BadScheme"; + internal const string net_uri_BadFormat = "net_uri_BadFormat"; + internal const string net_uri_BadUserPassword = "net_uri_BadUserPassword"; + internal const string net_uri_BadHostName = "net_uri_BadHostName"; + internal const string net_uri_BadAuthority = "net_uri_BadAuthority"; + internal const string net_uri_BadAuthorityTerminator = "net_uri_BadAuthorityTerminator"; + internal const string net_uri_EmptyUri = "net_uri_EmptyUri"; + internal const string net_uri_BadString = "net_uri_BadString"; + internal const string net_uri_MustRootedPath = "net_uri_MustRootedPath"; + internal const string net_uri_BadPort = "net_uri_BadPort"; + internal const string net_uri_SizeLimit = "net_uri_SizeLimit"; + internal const string net_uri_SchemeLimit = "net_uri_SchemeLimit"; + internal const string net_uri_NotAbsolute = "net_uri_NotAbsolute"; + internal const string net_uri_PortOutOfRange = "net_uri_PortOutOfRange"; + internal const string net_uri_UserDrivenParsing = "net_uri_UserDrivenParsing"; + internal const string net_uri_AlreadyRegistered = "net_uri_AlreadyRegistered"; + internal const string net_uri_NeedFreshParser = "net_uri_NeedFreshParser"; + internal const string net_uri_CannotCreateRelative = "net_uri_CannotCreateRelative"; + internal const string net_uri_InvalidUriKind = "net_uri_InvalidUriKind"; + internal const string net_uri_BadUnicodeHostForIdn = "net_uri_BadUnicodeHostForIdn"; + internal const string net_uri_GenericAuthorityNotDnsSafe = "net_uri_GenericAuthorityNotDnsSafe"; + internal const string net_uri_NotJustSerialization = "net_uri_NotJustSerialization"; + internal const string net_emptystringset = "net_emptystringset"; + internal const string net_emptystringcall = "net_emptystringcall"; + internal const string net_headers_req = "net_headers_req"; + internal const string net_headers_rsp = "net_headers_rsp"; + internal const string net_headers_toolong = "net_headers_toolong"; + internal const string net_WebHeaderInvalidControlChars = "net_WebHeaderInvalidControlChars"; + internal const string net_WebHeaderInvalidCRLFChars = "net_WebHeaderInvalidCRLFChars"; + internal const string net_WebHeaderInvalidHeaderChars = "net_WebHeaderInvalidHeaderChars"; + internal const string net_WebHeaderInvalidNonAsciiChars = "net_WebHeaderInvalidNonAsciiChars"; + internal const string net_WebHeaderMissingColon = "net_WebHeaderMissingColon"; + internal const string net_headerrestrict = "net_headerrestrict"; + internal const string net_io_completionportwasbound = "net_io_completionportwasbound"; + internal const string net_io_writefailure = "net_io_writefailure"; + internal const string net_io_readfailure = "net_io_readfailure"; + internal const string net_io_connectionclosed = "net_io_connectionclosed"; + internal const string net_io_transportfailure = "net_io_transportfailure"; + internal const string net_io_internal_bind = "net_io_internal_bind"; + internal const string net_io_invalidasyncresult = "net_io_invalidasyncresult"; + internal const string net_io_invalidnestedcall = "net_io_invalidnestedcall"; + internal const string net_io_invalidendcall = "net_io_invalidendcall"; + internal const string net_io_must_be_rw_stream = "net_io_must_be_rw_stream"; + internal const string net_io_header_id = "net_io_header_id"; + internal const string net_io_out_range = "net_io_out_range"; + internal const string net_io_encrypt = "net_io_encrypt"; + internal const string net_io_decrypt = "net_io_decrypt"; + internal const string net_io_read = "net_io_read"; + internal const string net_io_write = "net_io_write"; + internal const string net_io_eof = "net_io_eof"; + internal const string net_io_async_result = "net_io_async_result"; + internal const string net_listener_mustcall = "net_listener_mustcall"; + internal const string net_listener_mustcompletecall = "net_listener_mustcompletecall"; + internal const string net_listener_callinprogress = "net_listener_callinprogress"; + internal const string net_listener_scheme = "net_listener_scheme"; + internal const string net_listener_host = "net_listener_host"; + internal const string net_listener_slash = "net_listener_slash"; + internal const string net_listener_repcall = "net_listener_repcall"; + internal const string net_listener_invalid_cbt_type = "net_listener_invalid_cbt_type"; + internal const string net_listener_no_spns = "net_listener_no_spns"; + internal const string net_listener_cannot_set_custom_cbt = "net_listener_cannot_set_custom_cbt"; + internal const string net_listener_cbt_not_supported = "net_listener_cbt_not_supported"; + internal const string net_listener_detach_error = "net_listener_detach_error"; + internal const string net_listener_close_urlgroup_error = "net_listener_close_urlgroup_error"; + internal const string net_tls_version = "net_tls_version"; + internal const string net_perm_target = "net_perm_target"; + internal const string net_perm_both_regex = "net_perm_both_regex"; + internal const string net_perm_none = "net_perm_none"; + internal const string net_perm_attrib_count = "net_perm_attrib_count"; + internal const string net_perm_invalid_val = "net_perm_invalid_val"; + internal const string net_perm_attrib_multi = "net_perm_attrib_multi"; + internal const string net_perm_epname = "net_perm_epname"; + internal const string net_perm_invalid_val_in_element = "net_perm_invalid_val_in_element"; + internal const string net_invalid_ip_addr = "net_invalid_ip_addr"; + internal const string dns_bad_ip_address = "dns_bad_ip_address"; + internal const string net_bad_mac_address = "net_bad_mac_address"; + internal const string net_ping = "net_ping"; + internal const string net_bad_ip_address_prefix = "net_bad_ip_address_prefix"; + internal const string net_max_ip_address_list_length_exceeded = "net_max_ip_address_list_length_exceeded"; + internal const string net_ipv4_not_installed = "net_ipv4_not_installed"; + internal const string net_ipv6_not_installed = "net_ipv6_not_installed"; + internal const string net_webclient = "net_webclient"; + internal const string net_webclient_ContentType = "net_webclient_ContentType"; + internal const string net_webclient_Multipart = "net_webclient_Multipart"; + internal const string net_webclient_no_concurrent_io_allowed = "net_webclient_no_concurrent_io_allowed"; + internal const string net_webclient_invalid_baseaddress = "net_webclient_invalid_baseaddress"; + internal const string net_container_add_cookie = "net_container_add_cookie"; + internal const string net_cookie_invalid = "net_cookie_invalid"; + internal const string net_cookie_size = "net_cookie_size"; + internal const string net_cookie_parse_header = "net_cookie_parse_header"; + internal const string net_cookie_attribute = "net_cookie_attribute"; + internal const string net_cookie_format = "net_cookie_format"; + internal const string net_cookie_exists = "net_cookie_exists"; + internal const string net_cookie_capacity_range = "net_cookie_capacity_range"; + internal const string net_set_token = "net_set_token"; + internal const string net_revert_token = "net_revert_token"; + internal const string net_ssl_io_async_context = "net_ssl_io_async_context"; + internal const string net_ssl_io_encrypt = "net_ssl_io_encrypt"; + internal const string net_ssl_io_decrypt = "net_ssl_io_decrypt"; + internal const string net_ssl_io_context_expired = "net_ssl_io_context_expired"; + internal const string net_ssl_io_handshake_start = "net_ssl_io_handshake_start"; + internal const string net_ssl_io_handshake = "net_ssl_io_handshake"; + internal const string net_ssl_io_frame = "net_ssl_io_frame"; + internal const string net_ssl_io_corrupted = "net_ssl_io_corrupted"; + internal const string net_ssl_io_cert_validation = "net_ssl_io_cert_validation"; + internal const string net_ssl_io_invalid_end_call = "net_ssl_io_invalid_end_call"; + internal const string net_ssl_io_invalid_begin_call = "net_ssl_io_invalid_begin_call"; + internal const string net_ssl_io_no_server_cert = "net_ssl_io_no_server_cert"; + internal const string net_auth_bad_client_creds = "net_auth_bad_client_creds"; + internal const string net_auth_bad_client_creds_or_target_mismatch = "net_auth_bad_client_creds_or_target_mismatch"; + internal const string net_auth_context_expectation = "net_auth_context_expectation"; + internal const string net_auth_context_expectation_remote = "net_auth_context_expectation_remote"; + internal const string net_auth_supported_impl_levels = "net_auth_supported_impl_levels"; + internal const string net_auth_no_anonymous_support = "net_auth_no_anonymous_support"; + internal const string net_auth_reauth = "net_auth_reauth"; + internal const string net_auth_noauth = "net_auth_noauth"; + internal const string net_auth_client_server = "net_auth_client_server"; + internal const string net_auth_noencryption = "net_auth_noencryption"; + internal const string net_auth_SSPI = "net_auth_SSPI"; + internal const string net_auth_failure = "net_auth_failure"; + internal const string net_auth_eof = "net_auth_eof"; + internal const string net_auth_alert = "net_auth_alert"; + internal const string net_auth_ignored_reauth = "net_auth_ignored_reauth"; + internal const string net_auth_empty_read = "net_auth_empty_read"; + internal const string net_auth_message_not_encrypted = "net_auth_message_not_encrypted"; + internal const string net_auth_must_specify_extended_protection_scheme = "net_auth_must_specify_extended_protection_scheme"; + internal const string net_frame_size = "net_frame_size"; + internal const string net_frame_read_io = "net_frame_read_io"; + internal const string net_frame_read_size = "net_frame_read_size"; + internal const string net_frame_max_size = "net_frame_max_size"; + internal const string net_jscript_load = "net_jscript_load"; + internal const string net_proxy_not_gmt = "net_proxy_not_gmt"; + internal const string net_proxy_invalid_dayofweek = "net_proxy_invalid_dayofweek"; + internal const string net_proxy_invalid_url_format = "net_proxy_invalid_url_format"; + internal const string net_param_not_string = "net_param_not_string"; + internal const string net_value_cannot_be_negative = "net_value_cannot_be_negative"; + internal const string net_invalid_offset = "net_invalid_offset"; + internal const string net_offset_plus_count = "net_offset_plus_count"; + internal const string net_cannot_be_false = "net_cannot_be_false"; + internal const string net_invalid_enum = "net_invalid_enum"; + internal const string net_listener_already = "net_listener_already"; + internal const string net_cache_shadowstream_not_writable = "net_cache_shadowstream_not_writable"; + internal const string net_cache_validator_fail = "net_cache_validator_fail"; + internal const string net_cache_access_denied = "net_cache_access_denied"; + internal const string net_cache_validator_result = "net_cache_validator_result"; + internal const string net_cache_retrieve_failure = "net_cache_retrieve_failure"; + internal const string net_cache_not_supported_body = "net_cache_not_supported_body"; + internal const string net_cache_not_supported_command = "net_cache_not_supported_command"; + internal const string net_cache_not_accept_response = "net_cache_not_accept_response"; + internal const string net_cache_method_failed = "net_cache_method_failed"; + internal const string net_cache_key_failed = "net_cache_key_failed"; + internal const string net_cache_no_stream = "net_cache_no_stream"; + internal const string net_cache_unsupported_partial_stream = "net_cache_unsupported_partial_stream"; + internal const string net_cache_not_configured = "net_cache_not_configured"; + internal const string net_cache_non_seekable_stream_not_supported = "net_cache_non_seekable_stream_not_supported"; + internal const string net_invalid_cast = "net_invalid_cast"; + internal const string net_collection_readonly = "net_collection_readonly"; + internal const string net_not_ipermission = "net_not_ipermission"; + internal const string net_no_classname = "net_no_classname"; + internal const string net_no_typename = "net_no_typename"; + internal const string net_array_too_small = "net_array_too_small"; + internal const string net_servicePointAddressNotSupportedInHostMode = "net_servicePointAddressNotSupportedInHostMode"; + internal const string net_Websockets_AlreadyOneOutstandingOperation = "net_Websockets_AlreadyOneOutstandingOperation"; + internal const string net_Websockets_WebSocketBaseFaulted = "net_Websockets_WebSocketBaseFaulted"; + internal const string net_WebSockets_NativeSendResponseHeaders = "net_WebSockets_NativeSendResponseHeaders"; + internal const string net_WebSockets_Generic = "net_WebSockets_Generic"; + internal const string net_WebSockets_NotAWebSocket_Generic = "net_WebSockets_NotAWebSocket_Generic"; + internal const string net_WebSockets_UnsupportedWebSocketVersion_Generic = "net_WebSockets_UnsupportedWebSocketVersion_Generic"; + internal const string net_WebSockets_HeaderError_Generic = "net_WebSockets_HeaderError_Generic"; + internal const string net_WebSockets_UnsupportedProtocol_Generic = "net_WebSockets_UnsupportedProtocol_Generic"; + internal const string net_WebSockets_UnsupportedPlatform = "net_WebSockets_UnsupportedPlatform"; + internal const string net_WebSockets_AcceptNotAWebSocket = "net_WebSockets_AcceptNotAWebSocket"; + internal const string net_WebSockets_AcceptUnsupportedWebSocketVersion = "net_WebSockets_AcceptUnsupportedWebSocketVersion"; + internal const string net_WebSockets_AcceptHeaderNotFound = "net_WebSockets_AcceptHeaderNotFound"; + internal const string net_WebSockets_AcceptUnsupportedProtocol = "net_WebSockets_AcceptUnsupportedProtocol"; + internal const string net_WebSockets_ClientAcceptingNoProtocols = "net_WebSockets_ClientAcceptingNoProtocols"; + internal const string net_WebSockets_ClientSecWebSocketProtocolsBlank = "net_WebSockets_ClientSecWebSocketProtocolsBlank"; + internal const string net_WebSockets_ArgumentOutOfRange_TooSmall = "net_WebSockets_ArgumentOutOfRange_TooSmall"; + internal const string net_WebSockets_ArgumentOutOfRange_InternalBuffer = "net_WebSockets_ArgumentOutOfRange_InternalBuffer"; + internal const string net_WebSockets_ArgumentOutOfRange_TooBig = "net_WebSockets_ArgumentOutOfRange_TooBig"; + internal const string net_WebSockets_InvalidState_Generic = "net_WebSockets_InvalidState_Generic"; + internal const string net_WebSockets_InvalidState_ClosedOrAborted = "net_WebSockets_InvalidState_ClosedOrAborted"; + internal const string net_WebSockets_InvalidState = "net_WebSockets_InvalidState"; + internal const string net_WebSockets_ReceiveAsyncDisallowedAfterCloseAsync = "net_WebSockets_ReceiveAsyncDisallowedAfterCloseAsync"; + internal const string net_WebSockets_InvalidMessageType = "net_WebSockets_InvalidMessageType"; + internal const string net_WebSockets_InvalidBufferType = "net_WebSockets_InvalidBufferType"; + internal const string net_WebSockets_InvalidMessageType_Generic = "net_WebSockets_InvalidMessageType_Generic"; + internal const string net_WebSockets_Argument_InvalidMessageType = "net_WebSockets_Argument_InvalidMessageType"; + internal const string net_WebSockets_ConnectionClosedPrematurely_Generic = "net_WebSockets_ConnectionClosedPrematurely_Generic"; + internal const string net_WebSockets_InvalidCharInProtocolString = "net_WebSockets_InvalidCharInProtocolString"; + internal const string net_WebSockets_InvalidEmptySubProtocol = "net_WebSockets_InvalidEmptySubProtocol"; + internal const string net_WebSockets_ReasonNotNull = "net_WebSockets_ReasonNotNull"; + internal const string net_WebSockets_InvalidCloseStatusCode = "net_WebSockets_InvalidCloseStatusCode"; + internal const string net_WebSockets_InvalidCloseStatusDescription = "net_WebSockets_InvalidCloseStatusDescription"; + internal const string net_WebSockets_Scheme = "net_WebSockets_Scheme"; + internal const string net_WebSockets_AlreadyStarted = "net_WebSockets_AlreadyStarted"; + internal const string net_WebSockets_Connect101Expected = "net_WebSockets_Connect101Expected"; + internal const string net_WebSockets_InvalidResponseHeader = "net_WebSockets_InvalidResponseHeader"; + internal const string net_WebSockets_NotConnected = "net_WebSockets_NotConnected"; + internal const string net_WebSockets_InvalidRegistration = "net_WebSockets_InvalidRegistration"; + internal const string net_WebSockets_NoDuplicateProtocol = "net_WebSockets_NoDuplicateProtocol"; + internal const string net_log_exception = "net_log_exception"; + internal const string net_log_listener_delegate_exception = "net_log_listener_delegate_exception"; + internal const string net_log_listener_unsupported_authentication_scheme = "net_log_listener_unsupported_authentication_scheme"; + internal const string net_log_listener_unmatched_authentication_scheme = "net_log_listener_unmatched_authentication_scheme"; + internal const string net_log_listener_create_valid_identity_failed = "net_log_listener_create_valid_identity_failed"; + internal const string net_log_listener_httpsys_registry_null = "net_log_listener_httpsys_registry_null"; + internal const string net_log_listener_httpsys_registry_error = "net_log_listener_httpsys_registry_error"; + internal const string net_log_listener_cant_convert_raw_path = "net_log_listener_cant_convert_raw_path"; + internal const string net_log_listener_cant_convert_percent_value = "net_log_listener_cant_convert_percent_value"; + internal const string net_log_listener_cant_convert_bytes = "net_log_listener_cant_convert_bytes"; + internal const string net_log_listener_cant_convert_to_utf8 = "net_log_listener_cant_convert_to_utf8"; + internal const string net_log_listener_cant_create_uri = "net_log_listener_cant_create_uri"; + internal const string net_log_listener_no_cbt_disabled = "net_log_listener_no_cbt_disabled"; + internal const string net_log_listener_no_cbt_http = "net_log_listener_no_cbt_http"; + internal const string net_log_listener_no_cbt_platform = "net_log_listener_no_cbt_platform"; + internal const string net_log_listener_no_cbt_trustedproxy = "net_log_listener_no_cbt_trustedproxy"; + internal const string net_log_listener_cbt = "net_log_listener_cbt"; + internal const string net_log_listener_no_spn_kerberos = "net_log_listener_no_spn_kerberos"; + internal const string net_log_listener_no_spn_disabled = "net_log_listener_no_spn_disabled"; + internal const string net_log_listener_no_spn_cbt = "net_log_listener_no_spn_cbt"; + internal const string net_log_listener_no_spn_platform = "net_log_listener_no_spn_platform"; + internal const string net_log_listener_no_spn_whensupported = "net_log_listener_no_spn_whensupported"; + internal const string net_log_listener_no_spn_loopback = "net_log_listener_no_spn_loopback"; + internal const string net_log_listener_spn = "net_log_listener_spn"; + internal const string net_log_listener_spn_passed = "net_log_listener_spn_passed"; + internal const string net_log_listener_spn_failed = "net_log_listener_spn_failed"; + internal const string net_log_listener_spn_failed_always = "net_log_listener_spn_failed_always"; + internal const string net_log_listener_spn_failed_empty = "net_log_listener_spn_failed_empty"; + internal const string net_log_listener_spn_failed_dump = "net_log_listener_spn_failed_dump"; + internal const string net_log_listener_spn_add = "net_log_listener_spn_add"; + internal const string net_log_listener_spn_not_add = "net_log_listener_spn_not_add"; + internal const string net_log_listener_spn_remove = "net_log_listener_spn_remove"; + internal const string net_log_listener_spn_not_remove = "net_log_listener_spn_not_remove"; + internal const string net_log_sspi_enumerating_security_packages = "net_log_sspi_enumerating_security_packages"; + internal const string net_log_sspi_security_package_not_found = "net_log_sspi_security_package_not_found"; + internal const string net_log_sspi_security_context_input_buffer = "net_log_sspi_security_context_input_buffer"; + internal const string net_log_sspi_security_context_input_buffers = "net_log_sspi_security_context_input_buffers"; + internal const string net_log_sspi_selected_cipher_suite = "net_log_sspi_selected_cipher_suite"; + internal const string net_log_remote_certificate = "net_log_remote_certificate"; + internal const string net_log_locating_private_key_for_certificate = "net_log_locating_private_key_for_certificate"; + internal const string net_log_cert_is_of_type_2 = "net_log_cert_is_of_type_2"; + internal const string net_log_found_cert_in_store = "net_log_found_cert_in_store"; + internal const string net_log_did_not_find_cert_in_store = "net_log_did_not_find_cert_in_store"; + internal const string net_log_open_store_failed = "net_log_open_store_failed"; + internal const string net_log_got_certificate_from_delegate = "net_log_got_certificate_from_delegate"; + internal const string net_log_no_delegate_and_have_no_client_cert = "net_log_no_delegate_and_have_no_client_cert"; + internal const string net_log_no_delegate_but_have_client_cert = "net_log_no_delegate_but_have_client_cert"; + internal const string net_log_attempting_restart_using_cert = "net_log_attempting_restart_using_cert"; + internal const string net_log_no_issuers_try_all_certs = "net_log_no_issuers_try_all_certs"; + internal const string net_log_server_issuers_look_for_matching_certs = "net_log_server_issuers_look_for_matching_certs"; + internal const string net_log_selected_cert = "net_log_selected_cert"; + internal const string net_log_n_certs_after_filtering = "net_log_n_certs_after_filtering"; + internal const string net_log_finding_matching_certs = "net_log_finding_matching_certs"; + internal const string net_log_using_cached_credential = "net_log_using_cached_credential"; + internal const string net_log_remote_cert_user_declared_valid = "net_log_remote_cert_user_declared_valid"; + internal const string net_log_remote_cert_user_declared_invalid = "net_log_remote_cert_user_declared_invalid"; + internal const string net_log_remote_cert_has_no_errors = "net_log_remote_cert_has_no_errors"; + internal const string net_log_remote_cert_has_errors = "net_log_remote_cert_has_errors"; + internal const string net_log_remote_cert_not_available = "net_log_remote_cert_not_available"; + internal const string net_log_remote_cert_name_mismatch = "net_log_remote_cert_name_mismatch"; + internal const string net_log_proxy_autodetect_script_location_parse_error = "net_log_proxy_autodetect_script_location_parse_error"; + internal const string net_log_proxy_autodetect_failed = "net_log_proxy_autodetect_failed"; + internal const string net_log_proxy_script_execution_error = "net_log_proxy_script_execution_error"; + internal const string net_log_proxy_script_download_compile_error = "net_log_proxy_script_download_compile_error"; + internal const string net_log_proxy_system_setting_update = "net_log_proxy_system_setting_update"; + internal const string net_log_proxy_update_due_to_ip_config_change = "net_log_proxy_update_due_to_ip_config_change"; + internal const string net_log_proxy_called_with_null_parameter = "net_log_proxy_called_with_null_parameter"; + internal const string net_log_proxy_called_with_invalid_parameter = "net_log_proxy_called_with_invalid_parameter"; + internal const string net_log_proxy_ras_supported = "net_log_proxy_ras_supported"; + internal const string net_log_proxy_ras_notsupported_exception = "net_log_proxy_ras_notsupported_exception"; + internal const string net_log_proxy_winhttp_cant_open_session = "net_log_proxy_winhttp_cant_open_session"; + internal const string net_log_proxy_winhttp_getproxy_failed = "net_log_proxy_winhttp_getproxy_failed"; + internal const string net_log_proxy_winhttp_timeout_error = "net_log_proxy_winhttp_timeout_error"; + internal const string net_log_digest_hash_algorithm_not_supported = "net_log_digest_hash_algorithm_not_supported"; + internal const string net_log_digest_qop_not_supported = "net_log_digest_qop_not_supported"; + internal const string net_log_digest_requires_nonce = "net_log_digest_requires_nonce"; + internal const string net_log_auth_invalid_challenge = "net_log_auth_invalid_challenge"; + internal const string net_log_unknown = "net_log_unknown"; + internal const string net_log_operation_returned_something = "net_log_operation_returned_something"; + internal const string net_log_operation_failed_with_error = "net_log_operation_failed_with_error"; + internal const string net_log_buffered_n_bytes = "net_log_buffered_n_bytes"; + internal const string net_log_method_equal = "net_log_method_equal"; + internal const string net_log_releasing_connection = "net_log_releasing_connection"; + internal const string net_log_unexpected_exception = "net_log_unexpected_exception"; + internal const string net_log_server_response_error_code = "net_log_server_response_error_code"; + internal const string net_log_resubmitting_request = "net_log_resubmitting_request"; + internal const string net_log_retrieving_localhost_exception = "net_log_retrieving_localhost_exception"; + internal const string net_log_resolved_servicepoint_may_not_be_remote_server = "net_log_resolved_servicepoint_may_not_be_remote_server"; + internal const string net_log_closed_idle = "net_log_closed_idle"; + internal const string net_log_received_status_line = "net_log_received_status_line"; + internal const string net_log_sending_headers = "net_log_sending_headers"; + internal const string net_log_received_headers = "net_log_received_headers"; + internal const string net_log_shell_expression_pattern_format_warning = "net_log_shell_expression_pattern_format_warning"; + internal const string net_log_exception_in_callback = "net_log_exception_in_callback"; + internal const string net_log_sending_command = "net_log_sending_command"; + internal const string net_log_received_response = "net_log_received_response"; + internal const string net_log_socket_connected = "net_log_socket_connected"; + internal const string net_log_socket_accepted = "net_log_socket_accepted"; + internal const string net_log_socket_not_logged_file = "net_log_socket_not_logged_file"; + internal const string net_log_socket_connect_dnsendpoint = "net_log_socket_connect_dnsendpoint"; + internal const string SSPIInvalidHandleType = "SSPIInvalidHandleType"; + + private static SR loader = null; + private ResourceManager resources; + + internal SR() + { + resources = new System.Resources.ResourceManager("System", this.GetType().Assembly); + } + + private static SR GetLoader() + { + if (loader == null) + { + SR sr = new SR(); + Interlocked.CompareExchange(ref loader, sr, null); + } + return loader; + } + + private static CultureInfo Culture + { + get { return null/*use ResourceManager default, CultureInfo.CurrentUICulture*/; } + } + + public static ResourceManager Resources + { + get + { + return GetLoader().resources; + } + } + + public static string GetString(string name, params object[] args) + { + SR sys = GetLoader(); + if (sys == null) + { + return null; + } + string res = sys.resources.GetString(name, SR.Culture); + + if (args != null && args.Length > 0) + { + for (int i = 0; i < args.Length; i++) + { + String value = args[i] as String; + if (value != null && value.Length > 1024) + { + args[i] = value.Substring(0, 1024 - 3) + "..."; + } + } + return String.Format(CultureInfo.CurrentCulture, res, args); + } + else + { + return res; + } + } + + public static string GetString(string name) + { + SR sys = GetLoader(); + if (sys == null) + { + return null; + } + return sys.resources.GetString(name, SR.Culture); + } + + public static string GetString(string name, out bool usedFallback) + { + // always false for this version of gensr + usedFallback = false; + return GetString(name); + } + + public static object GetObject(string name) + { + SR sys = GetLoader(); + if (sys == null) + { + return null; + } + return sys.resources.GetObject(name, SR.Culture); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/ValidationHelper.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/ValidationHelper.cs new file mode 100644 index 0000000000..ff7a09000b --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/ValidationHelper.cs @@ -0,0 +1,74 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + using System; + using System.Diagnostics; + using System.Diagnostics.CodeAnalysis; + using System.Globalization; + using System.Net.Security; + using System.Runtime.InteropServices; + using System.Runtime.Versioning; + using System.Security.Authentication.ExtendedProtection; + using System.Security.Cryptography.X509Certificates; + using System.Security.Permissions; + + internal static class ValidationHelper + { + public static string ExceptionMessage(Exception exception) + { + if (exception == null) + { + return string.Empty; + } + if (exception.InnerException == null) + { + return exception.Message; + } + return exception.Message + " (" + ExceptionMessage(exception.InnerException) + ")"; + } + + public static string ToString(object objectValue) + { + if (objectValue == null) + { + return "(null)"; + } + else if (objectValue is string && ((string)objectValue).Length == 0) + { + return "(string.empty)"; + } + else if (objectValue is Exception) + { + return ExceptionMessage(objectValue as Exception); + } + else if (objectValue is IntPtr) + { + return "0x" + ((IntPtr)objectValue).ToString("x"); + } + else + { + return objectValue.ToString(); + } + } + public static string HashString(object objectValue) + { + if (objectValue == null) + { + return "(null)"; + } + else if (objectValue is string && ((string)objectValue).Length == 0) + { + return "(string.empty)"; + } + else + { + return objectValue.GetHashCode().ToString(NumberFormatInfo.InvariantInfo); + } + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NTAuthentication.cs b/src/Microsoft.AspNet.Security.Windows/NTAuthentication.cs new file mode 100644 index 0000000000..25cae8a500 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NTAuthentication.cs @@ -0,0 +1,721 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Globalization; +using System.Net; +using System.Security.Authentication.ExtendedProtection; +using System.Security.Permissions; +using System.Security.Principal; +using System.Text; +using System.Threading; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class NTAuthentication + { + private static readonly ContextCallback Callback = new ContextCallback(InitializeCallback); + private static ISSPIInterface SSPIAuth = new SSPIAuthType(); + + private bool _isServer; + + private SafeFreeCredentials _credentialsHandle; + private SafeDeleteContext _securityContext; + private string _spn; + private string _clientSpecifiedSpn; + + private int _tokenSize; + private ContextFlags _requestedContextFlags; + private ContextFlags _contextFlags; + private string _uniqueUserId; + + private bool _isCompleted; + private string _protocolName; + private SecSizes _sizes; + private string _lastProtocolName; + private string _package; + + private ChannelBinding _channelBinding; + + // This overload does not attmept to impersonate because the caller either did it already or the original thread context is still preserved + + internal NTAuthentication(bool isServer, string package, NetworkCredential credential, string spn, ContextFlags requestedContextFlags, ChannelBinding channelBinding) + { + Initialize(isServer, package, credential, spn, requestedContextFlags, channelBinding); + } + + // This overload always uses the default credentials for the process. + + [SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.ControlPrincipal)] + internal NTAuthentication(bool isServer, string package, string spn, ContextFlags requestedContextFlags, ChannelBinding channelBinding) + { + try + { + using (WindowsIdentity.Impersonate(IntPtr.Zero)) + { + Initialize(isServer, package, CredentialCache.DefaultNetworkCredentials, spn, requestedContextFlags, channelBinding); + } + } + catch + { + // Avoid exception filter attacks. + throw; + } + } + + // The semantic of this propoerty is "Don't call me again". + // It can be completed either with success or error + // The latest case is signalled by IsValidContext==false + internal bool IsCompleted + { + get + { + return _isCompleted; + } + } + + internal bool IsValidContext + { + get + { + return !(_securityContext == null || _securityContext.IsInvalid); + } + } + + internal string AssociatedName + { + get + { + if (!(IsValidContext && IsCompleted)) + { + throw new Win32Exception((int)SecurityStatus.InvalidHandle); + } + + string name = SSPIWrapper.QueryContextAttributes(SSPIAuth, _securityContext, ContextAttribute.Names) as string; + GlobalLog.Print("NTAuthentication: The context is associated with [" + name + "]"); + return name; + } + } + + internal bool IsConfidentialityFlag + { + get + { + return (_contextFlags & ContextFlags.Confidentiality) != 0; + } + } + + internal bool IsIntegrityFlag + { + get + { + return (_contextFlags & (_isServer ? ContextFlags.AcceptIntegrity : ContextFlags.InitIntegrity)) != 0; + } + } + + internal bool IsMutualAuthFlag + { + get + { + return (_contextFlags & ContextFlags.MutualAuth) != 0; + } + } + + internal bool IsDelegationFlag + { + get + { + return (_contextFlags & ContextFlags.Delegate) != 0; + } + } + + internal bool IsIdentifyFlag + { + get + { + return (_contextFlags & (_isServer ? ContextFlags.AcceptIdentify : ContextFlags.InitIdentify)) != 0; + } + } + + internal string Spn + { + get + { + return _spn; + } + } + + internal string ClientSpecifiedSpn + { + get + { + if (_clientSpecifiedSpn == null) + { + _clientSpecifiedSpn = GetClientSpecifiedSpn(); + } + return _clientSpecifiedSpn; + } + } + + // True indicates this instance is for Server and will use AcceptSecurityContext SSPI API + + internal bool IsServer + { + get + { + return _isServer; + } + } + + internal bool IsKerberos + { + get + { + if (_lastProtocolName == null) + { + _lastProtocolName = ProtocolName; + } + + return (object)_lastProtocolName == (object)NegotiationInfoClass.Kerberos; + } + } + internal bool IsNTLM + { + get + { + if (_lastProtocolName == null) + { + _lastProtocolName = ProtocolName; + } + + return (object)_lastProtocolName == (object)NegotiationInfoClass.NTLM; + } + } + + internal string Package + { + get + { + return _package; + } + } + + internal string ProtocolName + { + get + { + // NB: May return string.Empty if the auth is not done yet or failed + if (_protocolName == null) + { + NegotiationInfoClass negotiationInfo = null; + + if (IsValidContext) + { + negotiationInfo = SSPIWrapper.QueryContextAttributes(SSPIAuth, _securityContext, ContextAttribute.NegotiationInfo) as NegotiationInfoClass; + if (IsCompleted) + { + if (negotiationInfo != null) + { + // cache it only when it's completed + _protocolName = negotiationInfo.AuthenticationPackage; + } + } + } + return negotiationInfo == null ? string.Empty : negotiationInfo.AuthenticationPackage; + } + return _protocolName; + } + } + + internal SecSizes Sizes + { + get + { + GlobalLog.Assert(IsCompleted && IsValidContext, "NTAuthentication#{0}::MaxDataSize|The context is not completed or invalid.", ValidationHelper.HashString(this)); + if (_sizes == null) + { + _sizes = SSPIWrapper.QueryContextAttributes( + SSPIAuth, + _securityContext, + ContextAttribute.Sizes) as SecSizes; + } + return _sizes; + } + } + + internal ChannelBinding ChannelBinding + { + get { return _channelBinding; } + } + + private static void InitializeCallback(object state) + { + InitializeCallbackContext context = (InitializeCallbackContext)state; + context.ThisPtr.Initialize(context.IsServer, context.Package, context.Credential, context.Spn, context.RequestedContextFlags, context.ChannelBinding); + } + + private void Initialize(bool isServer, string package, NetworkCredential credential, string spn, ContextFlags requestedContextFlags, ChannelBinding channelBinding) + { + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::.ctor() package:" + ValidationHelper.ToString(package) + " spn:" + ValidationHelper.ToString(spn) + " flags :" + requestedContextFlags.ToString()); + _tokenSize = SSPIWrapper.GetVerifyPackageInfo(SSPIAuth, package, true).MaxToken; + _isServer = isServer; + _spn = spn; + _securityContext = null; + _requestedContextFlags = requestedContextFlags; + _package = package; + _channelBinding = channelBinding; + + GlobalLog.Print("Peer SPN-> '" + _spn + "'"); + + // check if we're using DefaultCredentials + + if (credential == CredentialCache.DefaultNetworkCredentials) + { + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::.ctor(): using DefaultCredentials"); + _credentialsHandle = SSPIWrapper.AcquireDefaultCredential( + SSPIAuth, + package, + (_isServer ? CredentialUse.Inbound : CredentialUse.Outbound)); + _uniqueUserId = "/S"; // save off for unique connection marking ONLY used by HTTP client + } + else if (ComNetOS.IsWin7orLater) + { + unsafe + { + SafeSspiAuthDataHandle authData = null; + try + { + SecurityStatus result = UnsafeNclNativeMethods.SspiHelper.SspiEncodeStringsAsAuthIdentity( + credential.UserName/*InternalGetUserName()*/, credential.Domain/*InternalGetDomain()*/, + credential.Password/*InternalGetPassword()*/, out authData); + + if (result != SecurityStatus.OK) + { + if (Logging.On) + { + Logging.PrintError(Logging.Web, SR.GetString(SR.net_log_operation_failed_with_error, "SspiEncodeStringsAsAuthIdentity()", String.Format(CultureInfo.CurrentCulture, "0x{0:X}", (int)result))); + } + throw new Win32Exception((int)result); + } + + _credentialsHandle = SSPIWrapper.AcquireCredentialsHandle(SSPIAuth, + package, (_isServer ? CredentialUse.Inbound : CredentialUse.Outbound), ref authData); + } + finally + { + if (authData != null) + { + authData.Dispose(); + } + } + } + } + else + { + // we're not using DefaultCredentials, we need a + // AuthIdentity struct to contain credentials + // SECREVIEW: + // we'll save username/domain in temp strings, to avoid decrypting multiple times. + // password is only used once + + string username = credential.UserName; // InternalGetUserName(); + + string domain = credential.Domain; // InternalGetDomain(); + // ATTN: + // NetworkCredential class does not differentiate between null and "" but SSPI packages treat these cases differently + // For NTLM we want to keep "" for Wdigest.Dll we should use null. + AuthIdentity authIdentity = new AuthIdentity(username, credential.Password/*InternalGetPassword()*/, (object)package == (object)NegotiationInfoClass.WDigest && (domain == null || domain.Length == 0) ? null : domain); + + _uniqueUserId = domain + "/" + username + "/U"; // save off for unique connection marking ONLY used by HTTP client + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::.ctor(): using authIdentity:" + authIdentity.ToString()); + + _credentialsHandle = SSPIWrapper.AcquireCredentialsHandle( + SSPIAuth, + package, + (_isServer ? CredentialUse.Inbound : CredentialUse.Outbound), + ref authIdentity); + } + } + + // This will return an client token when conducted authentication on server side' + // This token can be used ofr impersanation + // We use it to create a WindowsIdentity and hand it out to the server app. + internal SafeCloseHandle GetContextToken(out SecurityStatus status) + { + GlobalLog.Assert(IsCompleted && IsValidContext, "NTAuthentication#{0}::GetContextToken|Should be called only when completed with success, currently is not!", ValidationHelper.HashString(this)); + GlobalLog.Assert(IsServer, "NTAuthentication#{0}::GetContextToken|The method must not be called by the client side!", ValidationHelper.HashString(this)); + + if (!IsValidContext) + { + throw new Win32Exception((int)SecurityStatus.InvalidHandle); + } + + SafeCloseHandle token = null; + status = (SecurityStatus)SSPIWrapper.QuerySecurityContextToken( + SSPIAuth, + _securityContext, + out token); + + return token; + } + + internal SafeCloseHandle GetContextToken() + { + SecurityStatus status; + SafeCloseHandle token = GetContextToken(out status); + if (status != SecurityStatus.OK) + { + throw new Win32Exception((int)status); + } + return token; + } + + internal void CloseContext() + { + if (_securityContext != null && !_securityContext.IsClosed) + { + _securityContext.Dispose(); + } + } + + // NTAuth::GetOutgoingBlob() + // Created: 12-01-1999: L.M. + // Description: + // Accepts an incoming binary security blob and returns + // an outgoing binary security blob + internal byte[] GetOutgoingBlob(byte[] incomingBlob, bool throwOnError, out SecurityStatus statusCode) + { + GlobalLog.Enter("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob", ((incomingBlob == null) ? "0" : incomingBlob.Length.ToString(NumberFormatInfo.InvariantInfo)) + " bytes"); + + List list = new List(2); + + if (incomingBlob != null) + { + list.Add(new SecurityBuffer(incomingBlob, BufferType.Token)); + } + if (_channelBinding != null) + { + list.Add(new SecurityBuffer(_channelBinding)); + } + + SecurityBuffer[] inSecurityBufferArray = null; + if (list.Count > 0) + { + inSecurityBufferArray = list.ToArray(); + } + + SecurityBuffer outSecurityBuffer = new SecurityBuffer(_tokenSize, BufferType.Token); + + bool firstTime = _securityContext == null; + try + { + if (!_isServer) + { + // client session + statusCode = (SecurityStatus)SSPIWrapper.InitializeSecurityContext( + SSPIAuth, + _credentialsHandle, + ref _securityContext, + _spn, + _requestedContextFlags, + Endianness.Network, + inSecurityBufferArray, + outSecurityBuffer, + ref _contextFlags); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob() SSPIWrapper.InitializeSecurityContext() returns statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + + if (statusCode == SecurityStatus.CompleteNeeded) + { + SecurityBuffer[] inSecurityBuffers = new SecurityBuffer[1]; + inSecurityBuffers[0] = outSecurityBuffer; + + statusCode = (SecurityStatus)SSPIWrapper.CompleteAuthToken( + SSPIAuth, + ref _securityContext, + inSecurityBuffers); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() SSPIWrapper.CompleteAuthToken() returns statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + outSecurityBuffer.token = null; + } + } + else + { + // server session + statusCode = (SecurityStatus)SSPIWrapper.AcceptSecurityContext( + SSPIAuth, + _credentialsHandle, + ref _securityContext, + _requestedContextFlags, + Endianness.Network, + inSecurityBufferArray, + outSecurityBuffer, + ref _contextFlags); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob() SSPIWrapper.AcceptSecurityContext() returns statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + } + } + finally + { + // Assuming the ISC or ASC has referenced the credential on the first successful call, + // we want to decrement the effective ref count by "disposing" it. + // The real dispose will happen when the security context is closed. + // Note if the first call was not successfull the handle is physically destroyed here + + if (firstTime && _credentialsHandle != null) + { + _credentialsHandle.Dispose(); + } + } + + if (((int)statusCode & unchecked((int)0x80000000)) != 0) + { + CloseContext(); + _isCompleted = true; + if (throwOnError) + { + Win32Exception exception = new Win32Exception((int)statusCode); + GlobalLog.Leave("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob", "Win32Exception:" + exception); + throw exception; + } + GlobalLog.Leave("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob", "null statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + return null; + } + else if (firstTime && _credentialsHandle != null) + { + // cache until it is pushed out by newly incoming handles + SSPIHandleCache.CacheCredential(_credentialsHandle); + } + + // the return value from SSPI will tell us correctly if the + // handshake is over or not: http://msdn.microsoft.com/library/psdk/secspi/sspiref_67p0.htm + // we also have to consider the case in which SSPI formed a new context, in this case we're done as well. + if (statusCode == SecurityStatus.OK) + { + // we're sucessfully done + GlobalLog.Assert(statusCode == SecurityStatus.OK, "NTAuthentication#{0}::GetOutgoingBlob()|statusCode:[0x{1:x8}] ({2}) m_SecurityContext#{3}::Handle:[{4}] [STATUS != OK]", ValidationHelper.HashString(this), (int)statusCode, statusCode, ValidationHelper.HashString(_securityContext), ValidationHelper.ToString(_securityContext)); + _isCompleted = true; + } + else + { + // we need to continue + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob() need continue statusCode:[0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + "] (" + statusCode.ToString() + ") m_SecurityContext#" + ValidationHelper.HashString(_securityContext) + "::Handle:" + ValidationHelper.ToString(_securityContext) + "]"); + } + // GlobalLog.Print("out token = " + outSecurityBuffer.ToString()); + // GlobalLog.Dump(outSecurityBuffer.token); + GlobalLog.Leave("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob", "IsCompleted:" + IsCompleted.ToString()); + return outSecurityBuffer.token; + } + + // for Server side (IIS 6.0) see: \\netindex\Sources\inetsrv\iis\iisrearc\iisplus\ulw3\digestprovider.cxx + // for Client side (HTTP.SYS) see: \\netindex\Sources\net\http\sys\ucauth.c + internal string GetOutgoingDigestBlob(string incomingBlob, string requestMethod, string requestedUri, string realm, bool isClientPreAuth, bool throwOnError, out SecurityStatus statusCode) + { + GlobalLog.Enter("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob", incomingBlob); + + // second time call with 3 incoming buffers to select HTTP client. + // we should get back a SecurityStatus.OK and a non null outgoingBlob. + SecurityBuffer[] inSecurityBuffers = null; + SecurityBuffer outSecurityBuffer = new SecurityBuffer(_tokenSize, isClientPreAuth ? BufferType.Parameters : BufferType.Token); + + bool firstTime = _securityContext == null; + try + { + if (!_isServer) + { + // client session + + if (!isClientPreAuth) + { + if (incomingBlob != null) + { + List list = new List(5); + + list.Add(new SecurityBuffer(HeaderEncoding.GetBytes(incomingBlob), BufferType.Token)); + list.Add(new SecurityBuffer(HeaderEncoding.GetBytes(requestMethod), BufferType.Parameters)); + list.Add(new SecurityBuffer(null, BufferType.Parameters)); + list.Add(new SecurityBuffer(Encoding.Unicode.GetBytes(_spn), BufferType.TargetHost)); + + if (_channelBinding != null) + { + list.Add(new SecurityBuffer(_channelBinding)); + } + + inSecurityBuffers = list.ToArray(); + } + + statusCode = (SecurityStatus)SSPIWrapper.InitializeSecurityContext( + SSPIAuth, + _credentialsHandle, + ref _securityContext, + requestedUri, // this must match the Uri in the HTTP status line for the current request + _requestedContextFlags, + Endianness.Network, + inSecurityBuffers, + outSecurityBuffer, + ref _contextFlags); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() SSPIWrapper.InitializeSecurityContext() returns statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + } + else + { +#if WDIGEST_PREAUTH + inSecurityBuffers = new SecurityBuffer[] { + new SecurityBuffer(null, BufferType.Token), + new SecurityBuffer(WebHeaderCollection.HeaderEncoding.GetBytes(requestMethod), BufferType.Parameters), + new SecurityBuffer(WebHeaderCollection.HeaderEncoding.GetBytes(requestedUri), BufferType.Parameters), + new SecurityBuffer(null, BufferType.Parameters), + outSecurityBuffer, + }; + + statusCode = (SecurityStatus) SSPIWrapper.MakeSignature(GlobalSSPI.SSPIAuth, m_SecurityContext, inSecurityBuffers, 0); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() SSPIWrapper.MakeSignature() returns statusCode:0x" + ((int) statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); +#else + statusCode = SecurityStatus.OK; + GlobalLog.Assert("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob()", "Invalid code path."); +#endif + } + } + else + { + // server session + List list = new List(6); + + list.Add(incomingBlob == null ? new SecurityBuffer(0, BufferType.Token) : new SecurityBuffer(HeaderEncoding.GetBytes(incomingBlob), BufferType.Token)); + list.Add(requestMethod == null ? new SecurityBuffer(0, BufferType.Parameters) : new SecurityBuffer(HeaderEncoding.GetBytes(requestMethod), BufferType.Parameters)); + list.Add(requestedUri == null ? new SecurityBuffer(0, BufferType.Parameters) : new SecurityBuffer(HeaderEncoding.GetBytes(requestedUri), BufferType.Parameters)); + list.Add(new SecurityBuffer(0, BufferType.Parameters)); + list.Add(realm == null ? new SecurityBuffer(0, BufferType.Parameters) : new SecurityBuffer(Encoding.Unicode.GetBytes(realm), BufferType.Parameters)); + + if (_channelBinding != null) + { + list.Add(new SecurityBuffer(_channelBinding)); + } + + inSecurityBuffers = list.ToArray(); + + statusCode = (SecurityStatus)SSPIWrapper.AcceptSecurityContext( + SSPIAuth, + _credentialsHandle, + ref _securityContext, + _requestedContextFlags, + Endianness.Network, + inSecurityBuffers, + outSecurityBuffer, + ref _contextFlags); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() SSPIWrapper.AcceptSecurityContext() returns statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + + if (statusCode == SecurityStatus.CompleteNeeded) + { + inSecurityBuffers[4] = outSecurityBuffer; + + statusCode = (SecurityStatus)SSPIWrapper.CompleteAuthToken( + SSPIAuth, + ref _securityContext, + inSecurityBuffers); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() SSPIWrapper.CompleteAuthToken() returns statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + + outSecurityBuffer.token = null; + } + } + } + finally + { + // Assuming the ISC or ASC has referenced the credential on the first successful call, + // we want to decrement the effective ref count by "disposing" it. + // The real dispose will happen when the security context is closed. + // Note if the first call was not successfull the handle is physically destroyed here + + if (firstTime && _credentialsHandle != null) + { + _credentialsHandle.Dispose(); + } + } + + if (((int)statusCode & unchecked((int)0x80000000)) != 0) + { + CloseContext(); + if (throwOnError) + { + Win32Exception exception = new Win32Exception((int)statusCode); + GlobalLog.Leave("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob", "Win32Exception:" + exception); + throw exception; + } + GlobalLog.Leave("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob", "null statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + return null; + } + else if (firstTime && _credentialsHandle != null) + { + // cache until it is pushed out by newly incoming handles + SSPIHandleCache.CacheCredential(_credentialsHandle); + } + + // the return value from SSPI will tell us correctly if the + // handshake is over or not: http://msdn.microsoft.com/library/psdk/secspi/sspiref_67p0.htm + if (statusCode == SecurityStatus.OK) + { + // we're done, cleanup + _isCompleted = true; + } + else + { + // we need to continue + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() need continue statusCode:[0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + "] (" + statusCode.ToString() + ") m_SecurityContext#" + ValidationHelper.HashString(_securityContext) + "::Handle:" + ValidationHelper.ToString(_securityContext) + "]"); + } + GlobalLog.Print("out token = " + outSecurityBuffer.ToString()); + GlobalLog.Dump(outSecurityBuffer.token); + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() IsCompleted:" + IsCompleted.ToString()); + + byte[] decodedOutgoingBlob = outSecurityBuffer.token; + string outgoingBlob = null; + if (decodedOutgoingBlob != null && decodedOutgoingBlob.Length > 0) + { + outgoingBlob = HeaderEncoding.GetString(decodedOutgoingBlob, 0, outSecurityBuffer.size); + } + GlobalLog.Leave("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob", outgoingBlob); + return outgoingBlob; + } + + private string GetClientSpecifiedSpn() + { + GlobalLog.Assert(IsValidContext && IsCompleted, "NTAuthentication: Trying to get the client SPN before handshaking is done!"); + + string spn = SSPIWrapper.QueryContextAttributes(SSPIAuth, _securityContext, + ContextAttribute.ClientSpecifiedSpn) as string; + + GlobalLog.Print("NTAuthentication: The client specified SPN is [" + spn + "]"); + return spn; + } + + private class InitializeCallbackContext + { + internal readonly NTAuthentication ThisPtr; + internal readonly bool IsServer; + internal readonly string Package; + internal readonly NetworkCredential Credential; + internal readonly string Spn; + internal readonly ContextFlags RequestedContextFlags; + internal readonly ChannelBinding ChannelBinding; + + internal InitializeCallbackContext(NTAuthentication thisPtr, bool isServer, string package, NetworkCredential credential, string spn, ContextFlags requestedContextFlags, ChannelBinding channelBinding) + { + this.ThisPtr = thisPtr; + this.IsServer = isServer; + this.Package = package; + this.Credential = credential; + this.Spn = spn; + this.RequestedContextFlags = requestedContextFlags; + this.ChannelBinding = channelBinding; + } + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/AuthIdentity.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/AuthIdentity.cs new file mode 100644 index 0000000000..dcd73516a9 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/AuthIdentity.cs @@ -0,0 +1,40 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Auto)] + internal struct AuthIdentity + { + // see SEC_WINNT_AUTH_IDENTITY_W + internal string UserName; + internal int UserNameLength; + internal string Domain; + internal int DomainLength; + internal string Password; + internal int PasswordLength; + internal int Flags; + + internal AuthIdentity(string userName, string password, string domain) + { + UserName = userName; + UserNameLength = userName == null ? 0 : userName.Length; + Password = password; + PasswordLength = password == null ? 0 : password.Length; + Domain = domain; + DomainLength = domain == null ? 0 : domain.Length; + // Flags are 2 for Unicode and 1 for ANSI. We use 2 on NT and 1 on Win9x. + Flags = 2; + } + + public override string ToString() + { + return ValidationHelper.ToString(Domain) + "\\" + ValidationHelper.ToString(UserName); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/ContextFlags.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/ContextFlags.cs new file mode 100644 index 0000000000..21ef3b735d --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/ContextFlags.cs @@ -0,0 +1,124 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; + +namespace Microsoft.AspNet.Security.Windows +{ + // #define ISC_REQ_DELEGATE 0x00000001 + // #define ISC_REQ_MUTUAL_AUTH 0x00000002 + // #define ISC_REQ_REPLAY_DETECT 0x00000004 + // #define ISC_REQ_SEQUENCE_DETECT 0x00000008 + // #define ISC_REQ_CONFIDENTIALITY 0x00000010 + // #define ISC_REQ_USE_SESSION_KEY 0x00000020 + // #define ISC_REQ_PROMPT_FOR_CREDS 0x00000040 + // #define ISC_REQ_USE_SUPPLIED_CREDS 0x00000080 + // #define ISC_REQ_ALLOCATE_MEMORY 0x00000100 + // #define ISC_REQ_USE_DCE_STYLE 0x00000200 + // #define ISC_REQ_DATAGRAM 0x00000400 + // #define ISC_REQ_CONNECTION 0x00000800 + // #define ISC_REQ_CALL_LEVEL 0x00001000 + // #define ISC_REQ_FRAGMENT_SUPPLIED 0x00002000 + // #define ISC_REQ_EXTENDED_ERROR 0x00004000 + // #define ISC_REQ_STREAM 0x00008000 + // #define ISC_REQ_INTEGRITY 0x00010000 + // #define ISC_REQ_IDENTIFY 0x00020000 + // #define ISC_REQ_NULL_SESSION 0x00040000 + // #define ISC_REQ_MANUAL_CRED_VALIDATION 0x00080000 + // #define ISC_REQ_RESERVED1 0x00100000 + // #define ISC_REQ_FRAGMENT_TO_FIT 0x00200000 + // #define ISC_REQ_HTTP 0x10000000 + // Win7 SP1 + + // #define ISC_REQ_UNVERIFIED_TARGET_NAME 0x20000000 + + // #define ASC_REQ_DELEGATE 0x00000001 + // #define ASC_REQ_MUTUAL_AUTH 0x00000002 + // #define ASC_REQ_REPLAY_DETECT 0x00000004 + // #define ASC_REQ_SEQUENCE_DETECT 0x00000008 + // #define ASC_REQ_CONFIDENTIALITY 0x00000010 + // #define ASC_REQ_USE_SESSION_KEY 0x00000020 + // #define ASC_REQ_ALLOCATE_MEMORY 0x00000100 + // #define ASC_REQ_USE_DCE_STYLE 0x00000200 + // #define ASC_REQ_DATAGRAM 0x00000400 + // #define ASC_REQ_CONNECTION 0x00000800 + // #define ASC_REQ_CALL_LEVEL 0x00001000 + // #define ASC_REQ_EXTENDED_ERROR 0x00008000 + // #define ASC_REQ_STREAM 0x00010000 + // #define ASC_REQ_INTEGRITY 0x00020000 + // #define ASC_REQ_LICENSING 0x00040000 + // #define ASC_REQ_IDENTIFY 0x00080000 + // #define ASC_REQ_ALLOW_NULL_SESSION 0x00100000 + // #define ASC_REQ_ALLOW_NON_USER_LOGONS 0x00200000 + // #define ASC_REQ_ALLOW_CONTEXT_REPLAY 0x00400000 + // #define ASC_REQ_FRAGMENT_TO_FIT 0x00800000 + // #define ASC_REQ_FRAGMENT_SUPPLIED 0x00002000 + // #define ASC_REQ_NO_TOKEN 0x01000000 + // #define ASC_REQ_HTTP 0x10000000 + + [Flags] + internal enum ContextFlags + { + Zero = 0, + // The server in the transport application can + // build new security contexts impersonating the + // client that will be accepted by other servers + // as the client's contexts. + Delegate = 0x00000001, + // The communicating parties must authenticate + // their identities to each other. Without MutualAuth, + // the client authenticates its identity to the server. + // With MutualAuth, the server also must authenticate + // its identity to the client. + MutualAuth = 0x00000002, + // The security package detects replayed packets and + // notifies the caller if a packet has been replayed. + // The use of this flag implies all of the conditions + // specified by the Integrity flag. + ReplayDetect = 0x00000004, + // The context must be allowed to detect out-of-order + // delivery of packets later through the message support + // functions. Use of this flag implies all of the + // conditions specified by the Integrity flag. + SequenceDetect = 0x00000008, + // The context must protect data while in transit. + // Confidentiality is supported for NTLM with Microsoft + // Windows NT version 4.0, SP4 and later and with the + // Kerberos protocol in Microsoft Windows 2000 and later. + Confidentiality = 0x00000010, + UseSessionKey = 0x00000020, + AllocateMemory = 0x00000100, + + // Connection semantics must be used. + Connection = 0x00000800, + + // Client applications requiring extended error messages specify the + // ISC_REQ_EXTENDED_ERROR flag when calling the InitializeSecurityContext + // Server applications requiring extended error messages set + // the ASC_REQ_EXTENDED_ERROR flag when calling AcceptSecurityContext. + InitExtendedError = 0x00004000, + AcceptExtendedError = 0x00008000, + // A transport application requests stream semantics + // by setting the ISC_REQ_STREAM and ASC_REQ_STREAM + // flags in the calls to the InitializeSecurityContext + // and AcceptSecurityContext functions + InitStream = 0x00008000, + AcceptStream = 0x00010000, + // Buffer integrity can be verified; however, replayed + // and out-of-sequence messages will not be detected + InitIntegrity = 0x00010000, // ISC_REQ_INTEGRITY + AcceptIntegrity = 0x00020000, // ASC_REQ_INTEGRITY + + InitManualCredValidation = 0x00080000, // ISC_REQ_MANUAL_CRED_VALIDATION + InitUseSuppliedCreds = 0x00000080, // ISC_REQ_USE_SUPPLIED_CREDS + InitIdentify = 0x00020000, // ISC_REQ_IDENTIFY + AcceptIdentify = 0x00080000, // ASC_REQ_IDENTIFY + + ProxyBindings = 0x04000000, // ASC_REQ_PROXY_BINDINGS + AllowMissingBindings = 0x10000000, // ASC_REQ_ALLOW_MISSING_BINDINGS + + UnverifiedTargetName = 0x20000000, // ISC_REQ_UNVERIFIED_TARGET_NAME + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/NativeSSPI.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/NativeSSPI.cs new file mode 100644 index 0000000000..1ac0a7d055 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/NativeSSPI.cs @@ -0,0 +1,42 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Globalization; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + // used to define the interface for security to use. + internal interface ISSPIInterface + { + SecurityPackageInfoClass[] SecurityPackages { get; set; } + int EnumerateSecurityPackages(out int pkgnum, out SafeFreeContextBuffer pkgArray); + int AcquireCredentialsHandle(string moduleName, CredentialUse usage, ref AuthIdentity authdata, out SafeFreeCredentials outCredential); + int AcquireCredentialsHandle(string moduleName, CredentialUse usage, ref SafeSspiAuthDataHandle authdata, out SafeFreeCredentials outCredential); + int AcquireDefaultCredential(string moduleName, CredentialUse usage, out SafeFreeCredentials outCredential); + int AcquireCredentialsHandle(string moduleName, CredentialUse usage, ref SecureCredential authdata, out SafeFreeCredentials outCredential); + int AcceptSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteContext context, SecurityBuffer inputBuffer, ContextFlags inFlags, + Endianness endianness, SecurityBuffer outputBuffer, ref ContextFlags outFlags); + int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteContext context, SecurityBuffer[] inputBuffers, ContextFlags inFlags, + Endianness endianness, SecurityBuffer outputBuffer, ref ContextFlags outFlags); + int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteContext context, string targetName, ContextFlags inFlags, + Endianness endianness, SecurityBuffer inputBuffer, SecurityBuffer outputBuffer, ref ContextFlags outFlags); + int InitializeSecurityContext(SafeFreeCredentials credential, ref SafeDeleteContext context, string targetName, ContextFlags inFlags, + Endianness endianness, SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer, ref ContextFlags outFlags); + int EncryptMessage(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber); + int DecryptMessage(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber); + int MakeSignature(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber); + int VerifySignature(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber); + + int QueryContextChannelBinding(SafeDeleteContext phContext, ContextAttribute attribute, out SafeFreeContextBufferChannelBinding refHandle); + int QueryContextAttributes(SafeDeleteContext phContext, ContextAttribute attribute, byte[] buffer, Type handleType, out SafeHandle refHandle); + int SetContextAttributes(SafeDeleteContext phContext, ContextAttribute attribute, byte[] buffer); + int QuerySecurityContextToken(SafeDeleteContext phContext, out SafeCloseHandle phToken); + int CompleteAuthToken(ref SafeDeleteContext refContext, SecurityBuffer[] inputBuffers); + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIAuthType.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIAuthType.cs new file mode 100644 index 0000000000..e7951e2969 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIAuthType.cs @@ -0,0 +1,302 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Globalization; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class SSPIAuthType : ISSPIInterface + { + private static volatile SecurityPackageInfoClass[] _securityPackages; + + public SecurityPackageInfoClass[] SecurityPackages + { + get + { + return _securityPackages; + } + set + { + _securityPackages = value; + } + } + + public int EnumerateSecurityPackages(out int pkgnum, out SafeFreeContextBuffer pkgArray) + { + GlobalLog.Print("SSPIAuthType::EnumerateSecurityPackages()"); + return SafeFreeContextBuffer.EnumeratePackages(out pkgnum, out pkgArray); + } + + public int AcquireCredentialsHandle(string moduleName, CredentialUse usage, ref AuthIdentity authdata, out SafeFreeCredentials outCredential) + { + return SafeFreeCredentials.AcquireCredentialsHandle(moduleName, usage, ref authdata, out outCredential); + } + + public int AcquireCredentialsHandle(string moduleName, CredentialUse usage, ref SafeSspiAuthDataHandle authdata, out SafeFreeCredentials outCredential) + { + return SafeFreeCredentials.AcquireCredentialsHandle(moduleName, usage, ref authdata, out outCredential); + } + + public int AcquireDefaultCredential(string moduleName, CredentialUse usage, out SafeFreeCredentials outCredential) + { + return SafeFreeCredentials.AcquireDefaultCredential(moduleName, usage, out outCredential); + } + + public int AcquireCredentialsHandle(string moduleName, CredentialUse usage, ref SecureCredential authdata, out SafeFreeCredentials outCredential) + { + return SafeFreeCredentials.AcquireCredentialsHandle(moduleName, usage, ref authdata, out outCredential); + } + + public int AcceptSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteContext context, SecurityBuffer inputBuffer, ContextFlags inFlags, Endianness endianness, SecurityBuffer outputBuffer, ref ContextFlags outFlags) + { + return SafeDeleteContext.AcceptSecurityContext(ref credential, ref context, inFlags, endianness, inputBuffer, null, outputBuffer, ref outFlags); + } + + public int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteContext context, SecurityBuffer[] inputBuffers, ContextFlags inFlags, Endianness endianness, SecurityBuffer outputBuffer, ref ContextFlags outFlags) + { + return SafeDeleteContext.AcceptSecurityContext(ref credential, ref context, inFlags, endianness, null, inputBuffers, outputBuffer, ref outFlags); + } + + public int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteContext context, string targetName, ContextFlags inFlags, Endianness endianness, SecurityBuffer inputBuffer, SecurityBuffer outputBuffer, ref ContextFlags outFlags) + { + return SafeDeleteContext.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, endianness, inputBuffer, null, outputBuffer, ref outFlags); + } + + public int InitializeSecurityContext(SafeFreeCredentials credential, ref SafeDeleteContext context, string targetName, ContextFlags inFlags, Endianness endianness, SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer, ref ContextFlags outFlags) + { + return SafeDeleteContext.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, endianness, null, inputBuffers, outputBuffer, ref outFlags); + } + + public int EncryptMessage(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + context.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + context.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.NativeNTSSPI.EncryptMessage(ref context._handle, 0, inputOutput, sequenceNumber); + context.DangerousRelease(); + } + } + return status; + } + + public unsafe int DecryptMessage(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + uint qop = 0; + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + context.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + context.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.NativeNTSSPI.DecryptMessage(ref context._handle, inputOutput, sequenceNumber, &qop); + context.DangerousRelease(); + } + } + + const uint SECQOP_WRAP_NO_ENCRYPT = 0x80000001; + if (status == 0 && qop == SECQOP_WRAP_NO_ENCRYPT) + { + GlobalLog.Assert("NativeNTSSPI.DecryptMessage", "Expected qop = 0, returned value = " + qop.ToString("x", CultureInfo.InvariantCulture)); + throw new InvalidOperationException(SR.GetString(SR.net_auth_message_not_encrypted)); + } + + return status; + } + + public int MakeSignature(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + context.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + context.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + const uint SECQOP_WRAP_NO_ENCRYPT = 0x80000001; + status = UnsafeNclNativeMethods.NativeNTSSPI.EncryptMessage(ref context._handle, SECQOP_WRAP_NO_ENCRYPT, inputOutput, sequenceNumber); + context.DangerousRelease(); + } + } + return status; + } + + public unsafe int VerifySignature(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + + uint qop = 0; + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + context.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + context.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.NativeNTSSPI.DecryptMessage(ref context._handle, inputOutput, sequenceNumber, &qop); + context.DangerousRelease(); + } + } + + return status; + } + + public int QueryContextChannelBinding(SafeDeleteContext context, ContextAttribute attribute, out SafeFreeContextBufferChannelBinding binding) + { + // Querying an auth SSP for a CBT doesn't make sense + binding = null; + throw new NotSupportedException(); + } + + public unsafe int QueryContextAttributes(SafeDeleteContext context, ContextAttribute attribute, byte[] buffer, Type handleType, out SafeHandle refHandle) + { + refHandle = null; + if (handleType != null) + { + if (handleType == typeof(SafeFreeContextBuffer)) + { + refHandle = SafeFreeContextBuffer.CreateEmptyHandle(); + } + else if (handleType == typeof(SafeFreeCertContext)) + { + refHandle = new SafeFreeCertContext(); + } + else + { + throw new ArgumentException(SR.GetString(SR.SSPIInvalidHandleType, handleType.FullName), "handleType"); + } + } + + fixed (byte* bufferPtr = buffer) + { + return SafeFreeContextBuffer.QueryContextAttributes(context, attribute, bufferPtr, refHandle); + } + } + + public int SetContextAttributes(SafeDeleteContext context, ContextAttribute attribute, byte[] buffer) + { + throw new NotImplementedException(); + } + + public int QuerySecurityContextToken(SafeDeleteContext phContext, out SafeCloseHandle phToken) + { + return GetSecurityContextToken(phContext, out phToken); + } + + public int CompleteAuthToken(ref SafeDeleteContext refContext, SecurityBuffer[] inputBuffers) + { + return SafeDeleteContext.CompleteAuthToken(ref refContext, inputBuffers); + } + + private static int GetSecurityContextToken(SafeDeleteContext phContext, out SafeCloseHandle safeHandle) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + safeHandle = null; + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + phContext.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + phContext.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.SafeNetHandles.QuerySecurityContextToken(ref phContext._handle, out safeHandle); + phContext.DangerousRelease(); + } + } + + return status; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIHandle.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIHandle.cs new file mode 100644 index 0000000000..52728d3825 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIHandle.cs @@ -0,0 +1,36 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + [StructLayout(LayoutKind.Sequential, Pack = 1)] + internal struct SSPIHandle + { + private IntPtr HandleHi; + private IntPtr HandleLo; + + public bool IsZero + { + get { return HandleHi == IntPtr.Zero && HandleLo == IntPtr.Zero; } + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal void SetToInvalid() + { + HandleHi = IntPtr.Zero; + HandleLo = IntPtr.Zero; + } + + public override string ToString() + { + return HandleHi.ToString("x") + ":" + HandleLo.ToString("x"); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPISessionCache.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPISessionCache.cs new file mode 100644 index 0000000000..e30c1ab434 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPISessionCache.cs @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +/* +Abstract: + The file implements trivial SSPI credential caching mechanism based on lru list +*/ +using System; +using System.Threading; + +namespace Microsoft.AspNet.Security.Windows +{ + // Implements delayed SSPI handle release, like a finalizable object though the handles are kept alive until being pushed out + // by the newly incoming ones. + internal static class SSPIHandleCache + { + private const int MaxCacheSize = 0x1F; // must a (power of 2) - 1 + private static SafeCredentialReference[] _cacheSlots = new SafeCredentialReference[MaxCacheSize + 1]; + private static int _current = -1; + + internal static void CacheCredential(SafeFreeCredentials newHandle) + { + try + { + SafeCredentialReference newRef = SafeCredentialReference.CreateReference(newHandle); + if (newRef == null) + { + return; + } + unchecked + { + int index = Interlocked.Increment(ref _current) & MaxCacheSize; + newRef = Interlocked.Exchange(ref _cacheSlots[index], newRef); + } + if (newRef != null) + { + newRef.Dispose(); + } + } + catch (Exception e) + { + GlobalLog.Assert("SSPIHandlCache", "Attempted to throw: " + e.ToString()); + } + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIWrapper.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIWrapper.cs new file mode 100644 index 0000000000..4c04e704b2 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIWrapper.cs @@ -0,0 +1,462 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.ComponentModel; +using System.Diagnostics; +using System.Globalization; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + // From Schannel.h + + [StructLayout(LayoutKind.Sequential)] + internal struct NegotiationInfo + { + // see SecPkgContext_NegotiationInfoW in + + // [MarshalAs(UnmanagedType.LPStruct)] internal SecurityPackageInfo PackageInfo; + internal IntPtr PackageInfo; + internal uint NegotiationState; + internal static readonly int Size = Marshal.SizeOf(typeof(NegotiationInfo)); + internal static readonly int NegotiationStateOffest = (int)Marshal.OffsetOf(typeof(NegotiationInfo), "NegotiationState"); + } + + // we keep it simple since we use this only to know if NTLM or + // Kerberos are used in the context of a Negotiate handshake + + [StructLayout(LayoutKind.Sequential)] + internal struct SecurityPackageInfo + { + // see SecPkgInfoW in + internal int Capabilities; + internal short Version; + internal short RPCID; + internal int MaxToken; + internal IntPtr Name; + internal IntPtr Comment; + + internal static readonly int Size = Marshal.SizeOf(typeof(SecurityPackageInfo)); + internal static readonly int NameOffest = (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "Name"); + } + + [StructLayout(LayoutKind.Sequential)] + internal struct Bindings + { + // see SecPkgContext_Bindings in + internal int BindingsLength; + internal IntPtr pBindings; + } + + internal static class SSPIWrapper + { + internal static SecurityPackageInfoClass[] EnumerateSecurityPackages(ISSPIInterface secModule) + { + GlobalLog.Enter("EnumerateSecurityPackages"); + if (secModule.SecurityPackages == null) + { + lock (secModule) + { + if (secModule.SecurityPackages == null) + { + int moduleCount = 0; + SafeFreeContextBuffer arrayBaseHandle = null; + try + { + int errorCode = secModule.EnumerateSecurityPackages(out moduleCount, out arrayBaseHandle); + GlobalLog.Print("SSPIWrapper::arrayBase: " + (arrayBaseHandle.DangerousGetHandle().ToString("x"))); + if (errorCode != 0) + { + throw new Win32Exception(errorCode); + } + SecurityPackageInfoClass[] securityPackages = new SecurityPackageInfoClass[moduleCount]; + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, SR.GetString(SR.net_log_sspi_enumerating_security_packages)); + } + int i; + for (i = 0; i < moduleCount; i++) + { + securityPackages[i] = new SecurityPackageInfoClass(arrayBaseHandle, i); + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, " " + securityPackages[i].Name); + } + } + secModule.SecurityPackages = securityPackages; + } + finally + { + if (arrayBaseHandle != null) + { + arrayBaseHandle.Dispose(); + } + } + } + } + } + GlobalLog.Leave("EnumerateSecurityPackages"); + return secModule.SecurityPackages; + } + + internal static SecurityPackageInfoClass GetVerifyPackageInfo(ISSPIInterface secModule, string packageName, bool throwIfMissing) + { + SecurityPackageInfoClass[] supportedSecurityPackages = EnumerateSecurityPackages(secModule); + if (supportedSecurityPackages != null) + { + for (int i = 0; i < supportedSecurityPackages.Length; i++) + { + if (string.Compare(supportedSecurityPackages[i].Name, packageName, StringComparison.OrdinalIgnoreCase) == 0) + { + return supportedSecurityPackages[i]; + } + } + } + + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, SR.GetString(SR.net_log_sspi_security_package_not_found, packageName)); + } + + // error + if (throwIfMissing) + { + throw new NotSupportedException(SR.GetString(SR.net_securitypackagesupport)); + } + + return null; + } + + public static SafeFreeCredentials AcquireDefaultCredential(ISSPIInterface secModule, string package, CredentialUse intent) + { + GlobalLog.Print("SSPIWrapper::AcquireDefaultCredential(): using " + package); + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, + "AcquireDefaultCredential(" + + "package = " + package + ", " + + "intent = " + intent + ")"); + } + + SafeFreeCredentials outCredential = null; + int errorCode = secModule.AcquireDefaultCredential(package, intent, out outCredential); + + if (errorCode != 0) + { +#if TRAVE + GlobalLog.Print("SSPIWrapper::AcquireDefaultCredential(): error " + SecureChannel.MapSecurityStatus((uint)errorCode)); +#endif + if (Logging.On) + { + Logging.PrintError(Logging.Web, SR.GetString(SR.net_log_operation_failed_with_error, "AcquireDefaultCredential()", String.Format(CultureInfo.CurrentCulture, "0X{0:X}", errorCode))); + } + throw new Win32Exception(errorCode); + } + return outCredential; + } + + public static SafeFreeCredentials AcquireCredentialsHandle(ISSPIInterface secModule, string package, CredentialUse intent, ref AuthIdentity authdata) + { + GlobalLog.Print("SSPIWrapper::AcquireCredentialsHandle#2(): using " + package); + + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, + "AcquireCredentialsHandle(" + + "package = " + package + ", " + + "intent = " + intent + ", " + + "authdata = " + authdata + ")"); + } + + SafeFreeCredentials credentialsHandle = null; + int errorCode = secModule.AcquireCredentialsHandle(package, + intent, + ref authdata, + out credentialsHandle); + + if (errorCode != 0) + { + if (Logging.On) + { + Logging.PrintError(Logging.Web, SR.GetString(SR.net_log_operation_failed_with_error, "AcquireCredentialsHandle()", String.Format(CultureInfo.CurrentCulture, "0X{0:X}", errorCode))); + } + throw new Win32Exception(errorCode); + } + return credentialsHandle; + } + + public static SafeFreeCredentials AcquireCredentialsHandle(ISSPIInterface secModule, string package, CredentialUse intent, ref SafeSspiAuthDataHandle authdata) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, + "AcquireCredentialsHandle(" + + "package = " + package + ", " + + "intent = " + intent + ", " + + "authdata = " + authdata + ")"); + } + + SafeFreeCredentials credentialsHandle = null; + int errorCode = secModule.AcquireCredentialsHandle(package, intent, ref authdata, out credentialsHandle); + + if (errorCode != 0) + { + if (Logging.On) + { + Logging.PrintError(Logging.Web, SR.GetString(SR.net_log_operation_failed_with_error, "AcquireCredentialsHandle()", String.Format(CultureInfo.CurrentCulture, "0X{0:X}", errorCode))); + } + throw new Win32Exception(errorCode); + } + return credentialsHandle; + } + + internal static int InitializeSecurityContext(ISSPIInterface secModule, SafeFreeCredentials credential, ref SafeDeleteContext context, string targetName, ContextFlags inFlags, Endianness datarep, SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer, ref ContextFlags outFlags) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, + "InitializeSecurityContext(" + + "credential = " + credential.ToString() + ", " + + "context = " + ValidationHelper.ToString(context) + ", " + + "targetName = " + targetName + ", " + + "inFlags = " + inFlags + ")"); + } + + int errorCode = secModule.InitializeSecurityContext(credential, ref context, targetName, inFlags, datarep, inputBuffers, outputBuffer, ref outFlags); + + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, SR.GetString(SR.net_log_sspi_security_context_input_buffers, "InitializeSecurityContext", (inputBuffers == null ? 0 : inputBuffers.Length), outputBuffer.size, (SecurityStatus)errorCode)); + } + + return errorCode; + } + + internal static int AcceptSecurityContext(ISSPIInterface secModule, SafeFreeCredentials credential, ref SafeDeleteContext context, ContextFlags inFlags, Endianness datarep, SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer, ref ContextFlags outFlags) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, + "AcceptSecurityContext(" + + "credential = " + credential.ToString() + ", " + + "context = " + ValidationHelper.ToString(context) + ", " + + "inFlags = " + inFlags + ")"); + } + + int errorCode = secModule.AcceptSecurityContext(credential, ref context, inputBuffers, inFlags, datarep, outputBuffer, ref outFlags); + + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, SR.GetString(SR.net_log_sspi_security_context_input_buffers, "AcceptSecurityContext", (inputBuffers == null ? 0 : inputBuffers.Length), outputBuffer.size, (SecurityStatus)errorCode)); + } + + return errorCode; + } + + internal static int CompleteAuthToken(ISSPIInterface secModule, ref SafeDeleteContext context, SecurityBuffer[] inputBuffers) + { + int errorCode = secModule.CompleteAuthToken(ref context, inputBuffers); + + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, SR.GetString(SR.net_log_operation_returned_something, "CompleteAuthToken()", (SecurityStatus)errorCode)); + } + + return errorCode; + } + + public static int QuerySecurityContextToken(ISSPIInterface secModule, SafeDeleteContext context, out SafeCloseHandle token) + { + return secModule.QuerySecurityContextToken(context, out token); + } + + public static SafeFreeContextBufferChannelBinding QueryContextChannelBinding(ISSPIInterface secModule, SafeDeleteContext securityContext, ContextAttribute contextAttribute) + { + GlobalLog.Enter("QueryContextChannelBinding", contextAttribute.ToString()); + + SafeFreeContextBufferChannelBinding result; + int errorCode = secModule.QueryContextChannelBinding(securityContext, contextAttribute, out result); + if (errorCode != 0) + { + GlobalLog.Leave("QueryContextChannelBinding", "ERROR = " + ErrorDescription(errorCode)); + return null; + } + + GlobalLog.Leave("QueryContextChannelBinding", ValidationHelper.HashString(result)); + return result; + } + + public static object QueryContextAttributes(ISSPIInterface secModule, SafeDeleteContext securityContext, ContextAttribute contextAttribute) + { + int errorCode; + return QueryContextAttributes(secModule, securityContext, contextAttribute, out errorCode); + } + + public static object QueryContextAttributes(ISSPIInterface secModule, SafeDeleteContext securityContext, ContextAttribute contextAttribute, out int errorCode) + { + GlobalLog.Enter("QueryContextAttributes", contextAttribute.ToString()); + + int nativeBlockSize = IntPtr.Size; + Type handleType = null; + + switch (contextAttribute) + { + case ContextAttribute.Sizes: + nativeBlockSize = SecSizes.SizeOf; + break; + case ContextAttribute.StreamSizes: + nativeBlockSize = StreamSizes.SizeOf; + break; + + case ContextAttribute.Names: + handleType = typeof(SafeFreeContextBuffer); + break; + + case ContextAttribute.PackageInfo: + handleType = typeof(SafeFreeContextBuffer); + break; + + case ContextAttribute.NegotiationInfo: + handleType = typeof(SafeFreeContextBuffer); + nativeBlockSize = Marshal.SizeOf(typeof(NegotiationInfo)); + break; + + case ContextAttribute.ClientSpecifiedSpn: + handleType = typeof(SafeFreeContextBuffer); + break; + + case ContextAttribute.RemoteCertificate: + handleType = typeof(SafeFreeCertContext); + break; + + case ContextAttribute.LocalCertificate: + handleType = typeof(SafeFreeCertContext); + break; + + case ContextAttribute.IssuerListInfoEx: + nativeBlockSize = Marshal.SizeOf(typeof(IssuerListInfoEx)); + handleType = typeof(SafeFreeContextBuffer); + break; + + case ContextAttribute.ConnectionInfo: + nativeBlockSize = Marshal.SizeOf(typeof(SslConnectionInfo)); + break; + + default: + throw new ArgumentException(SR.GetString(SR.net_invalid_enum, "ContextAttribute"), "contextAttribute"); + } + + SafeHandle SspiHandle = null; + object attribute = null; + + try + { + byte[] nativeBuffer = new byte[nativeBlockSize]; + errorCode = secModule.QueryContextAttributes(securityContext, contextAttribute, nativeBuffer, handleType, out SspiHandle); + if (errorCode != 0) + { + GlobalLog.Leave("Win32:QueryContextAttributes", "ERROR = " + ErrorDescription(errorCode)); + return null; + } + + switch (contextAttribute) + { + case ContextAttribute.Sizes: + attribute = new SecSizes(nativeBuffer); + break; + + case ContextAttribute.StreamSizes: + attribute = new StreamSizes(nativeBuffer); + break; + + case ContextAttribute.Names: + attribute = Marshal.PtrToStringUni(SspiHandle.DangerousGetHandle()); + break; + + case ContextAttribute.PackageInfo: + attribute = new SecurityPackageInfoClass(SspiHandle, 0); + break; + + case ContextAttribute.NegotiationInfo: + unsafe + { + fixed (void* ptr = nativeBuffer) + { + attribute = new NegotiationInfoClass(SspiHandle, Marshal.ReadInt32(new IntPtr(ptr), NegotiationInfo.NegotiationStateOffest)); + } + } + break; + + case ContextAttribute.ClientSpecifiedSpn: + attribute = Marshal.PtrToStringUni(SspiHandle.DangerousGetHandle()); + break; + + case ContextAttribute.LocalCertificate: + goto case ContextAttribute.RemoteCertificate; + case ContextAttribute.RemoteCertificate: + attribute = SspiHandle; + SspiHandle = null; + break; + + case ContextAttribute.IssuerListInfoEx: + attribute = new IssuerListInfoEx(SspiHandle, nativeBuffer); + SspiHandle = null; + break; + + case ContextAttribute.ConnectionInfo: + attribute = new SslConnectionInfo(nativeBuffer); + break; + default: + // will return null + break; + } + } + finally + { + if (SspiHandle != null) + { + SspiHandle.Dispose(); + } + } + GlobalLog.Leave("QueryContextAttributes", ValidationHelper.ToString(attribute)); + return attribute; + } + + public static string ErrorDescription(int errorCode) + { + if (errorCode == -1) + { + return "An exception when invoking Win32 API"; + } + switch ((SecurityStatus)errorCode) + { + case SecurityStatus.InvalidHandle: + return "Invalid handle"; + case SecurityStatus.InvalidToken: + return "Invalid token"; + case SecurityStatus.ContinueNeeded: + return "Continue needed"; + case SecurityStatus.IncompleteMessage: + return "Message incomplete"; + case SecurityStatus.WrongPrincipal: + return "Wrong principal"; + case SecurityStatus.TargetUnknown: + return "Target unknown"; + case SecurityStatus.PackageNotFound: + return "Package not found"; + case SecurityStatus.BufferNotEnough: + return "Buffer not enough"; + case SecurityStatus.MessageAltered: + return "Message altered"; + case SecurityStatus.UntrustedRoot: + return "Untrusted root"; + default: + return "0x" + errorCode.ToString("x", NumberFormatInfo.InvariantInfo); + } + } + } // class SSPIWrapper +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCloseHandle.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCloseHandle.cs new file mode 100644 index 0000000000..1b720b1ad6 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCloseHandle.cs @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.ConstrainedExecution; +using System.Security; +using System.Threading; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Security.Windows +{ + [SuppressUnmanagedCodeSecurity] + internal sealed class SafeCloseHandle : CriticalHandleZeroOrMinusOneIsInvalid + { + private int _disposed; + + private SafeCloseHandle() + : base() + { + } + + internal IntPtr DangerousGetHandle() + { + return handle; + } + + protected override bool ReleaseHandle() + { + if (!IsInvalid) + { + if (Interlocked.Increment(ref _disposed) == 1) + { + return UnsafeNclNativeMethods.SafeNetHandles.CloseHandle(handle); + } + } + return true; + } + + // This method will bypass refCount check done by VM + // Means it will force handle release if has a valid value + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal void Abort() + { + ReleaseHandle(); + SetHandleAsInvalid(); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCredentialReference.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCredentialReference.cs new file mode 100644 index 0000000000..a0d29a1b61 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCredentialReference.cs @@ -0,0 +1,69 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.CompilerServices; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Security.Windows +{ + internal sealed class SafeCredentialReference : CriticalHandleMinusOneIsInvalid + { + // Static cache will return the target handle if found the reference in the table. + internal SafeFreeCredentials _Target; + + private SafeCredentialReference(SafeFreeCredentials target) + : base() + { + // Bumps up the refcount on Target to signify that target handle is statically cached so + // its dispose should be postponed + bool b = false; + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + target.DangerousAddRef(ref b); + } + catch + { + if (b) + { + target.DangerousRelease(); + b = false; + } + } + finally + { + if (b) + { + _Target = target; + SetHandle(new IntPtr(0)); // make this handle valid + } + } + } + + internal static SafeCredentialReference CreateReference(SafeFreeCredentials target) + { + SafeCredentialReference result = new SafeCredentialReference(target); + if (result.IsInvalid) + { + return null; + } + + return result; + } + + protected override bool ReleaseHandle() + { + SafeFreeCredentials target = _Target; + if (target != null) + { + target.DangerousRelease(); + } + _Target = null; + return true; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeDeleteContext.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeDeleteContext.cs new file mode 100644 index 0000000000..a356f2b3ef --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeDeleteContext.cs @@ -0,0 +1,707 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + internal sealed class SafeDeleteContext : SafeHandle + { + private const string DummyStr = " "; + private static readonly byte[] DummyBytes = new byte[] { 0 }; + + // ATN: _handle is internal since it is used on PInvokes by other wrapper methods. + // However all such wrappers MUST manually and reliably adjust refCounter of SafeDeleteContext handle. + + internal SSPIHandle _handle; + + private SafeFreeCredentials _effectiveCredential; + + private SafeDeleteContext() + : base(IntPtr.Zero, true) + { + _handle = new SSPIHandle(); + } + + public override bool IsInvalid + { + get + { + return IsClosed || _handle.IsZero; + } + } + + public override string ToString() + { + return _handle.ToString(); + } + + //------------------------------------------------------------------- + internal static unsafe int InitializeSecurityContext(ref SafeFreeCredentials inCredentials, ref SafeDeleteContext refContext, + string targetName, ContextFlags inFlags, Endianness endianness, SecurityBuffer inSecBuffer, SecurityBuffer[] inSecBuffers, + SecurityBuffer outSecBuffer, ref ContextFlags outFlags) + { + GlobalLog.Assert(outSecBuffer != null, "SafeDeleteContext::InitializeSecurityContext()|outSecBuffer != null"); + GlobalLog.Assert(inSecBuffer == null || inSecBuffers == null, "SafeDeleteContext::InitializeSecurityContext()|inSecBuffer == null || inSecBuffers == null"); + + if (inCredentials == null) + { + throw new ArgumentNullException("inCredentials"); + } + + SecurityBufferDescriptor inSecurityBufferDescriptor = null; + if (inSecBuffer != null) + { + inSecurityBufferDescriptor = new SecurityBufferDescriptor(1); + } + else if (inSecBuffers != null) + { + inSecurityBufferDescriptor = new SecurityBufferDescriptor(inSecBuffers.Length); + } + SecurityBufferDescriptor outSecurityBufferDescriptor = new SecurityBufferDescriptor(1); + + // actually this is returned in outFlags + bool isSspiAllocated = (inFlags & ContextFlags.AllocateMemory) != 0 ? true : false; + + int errorCode = -1; + + SSPIHandle contextHandle = new SSPIHandle(); + if (refContext != null) + { + contextHandle = refContext._handle; + } + + // these are pinned user byte arrays passed along with SecurityBuffers + GCHandle[] pinnedInBytes = null; + GCHandle pinnedOutBytes = new GCHandle(); + // optional output buffer that may need to be freed + SafeFreeContextBuffer outFreeContextBuffer = null; + try + { + pinnedOutBytes = GCHandle.Alloc(outSecBuffer.token, GCHandleType.Pinned); + SecurityBufferStruct[] inUnmanagedBuffer = new SecurityBufferStruct[inSecurityBufferDescriptor == null ? 1 : inSecurityBufferDescriptor.Count]; + fixed (void* inUnmanagedBufferPtr = inUnmanagedBuffer) + { + if (inSecurityBufferDescriptor != null) + { + // Fix Descriptor pointer that points to unmanaged SecurityBuffers + inSecurityBufferDescriptor.UnmanagedPointer = inUnmanagedBufferPtr; + pinnedInBytes = new GCHandle[inSecurityBufferDescriptor.Count]; + SecurityBuffer securityBuffer; + for (int index = 0; index < inSecurityBufferDescriptor.Count; ++index) + { + securityBuffer = inSecBuffer != null ? inSecBuffer : inSecBuffers[index]; + if (securityBuffer != null) + { + // Copy the SecurityBuffer content into unmanaged place holder + inUnmanagedBuffer[index].count = securityBuffer.size; + inUnmanagedBuffer[index].type = securityBuffer.type; + + // use the unmanaged token if it's not null; otherwise use the managed buffer + if (securityBuffer.unmanagedToken != null) + { + inUnmanagedBuffer[index].token = securityBuffer.unmanagedToken.DangerousGetHandle(); + } + else if (securityBuffer.token == null || securityBuffer.token.Length == 0) + { + inUnmanagedBuffer[index].token = IntPtr.Zero; + } + else + { + pinnedInBytes[index] = GCHandle.Alloc(securityBuffer.token, GCHandleType.Pinned); + inUnmanagedBuffer[index].token = Marshal.UnsafeAddrOfPinnedArrayElement(securityBuffer.token, securityBuffer.offset); + } + } + } + } + + SecurityBufferStruct[] outUnmanagedBuffer = new SecurityBufferStruct[1]; + fixed (void* outUnmanagedBufferPtr = outUnmanagedBuffer) + { + // Fix Descriptor pointer that points to unmanaged SecurityBuffers + outSecurityBufferDescriptor.UnmanagedPointer = outUnmanagedBufferPtr; + outUnmanagedBuffer[0].count = outSecBuffer.size; + outUnmanagedBuffer[0].type = outSecBuffer.type; + if (outSecBuffer.token == null || outSecBuffer.token.Length == 0) + { + outUnmanagedBuffer[0].token = IntPtr.Zero; + } + else + { + outUnmanagedBuffer[0].token = Marshal.UnsafeAddrOfPinnedArrayElement(outSecBuffer.token, outSecBuffer.offset); + } + if (isSspiAllocated) + { + outFreeContextBuffer = SafeFreeContextBuffer.CreateEmptyHandle(); + } + + if (refContext == null || refContext.IsInvalid) + { + refContext = new SafeDeleteContext(); + } + + if (targetName == null || targetName.Length == 0) + { + targetName = DummyStr; + } + + fixed (char* namePtr = targetName) + { + errorCode = MustRunInitializeSecurityContext( + ref inCredentials, + contextHandle.IsZero ? null : &contextHandle, + (byte*)(((object)targetName == (object)DummyStr) ? null : namePtr), + inFlags, + endianness, + inSecurityBufferDescriptor, + refContext, + outSecurityBufferDescriptor, + ref outFlags, + outFreeContextBuffer); + } + + GlobalLog.Print("SafeDeleteContext:InitializeSecurityContext Marshalling OUT buffer"); + // Get unmanaged buffer with index 0 as the only one passed into PInvoke + outSecBuffer.size = outUnmanagedBuffer[0].count; + outSecBuffer.type = outUnmanagedBuffer[0].type; + if (outSecBuffer.size > 0) + { + outSecBuffer.token = new byte[outSecBuffer.size]; + Marshal.Copy(outUnmanagedBuffer[0].token, outSecBuffer.token, 0, outSecBuffer.size); + } + else + { + outSecBuffer.token = null; + } + } + } + } + finally + { + if (pinnedInBytes != null) + { + for (int index = 0; index < pinnedInBytes.Length; index++) + { + if (pinnedInBytes[index].IsAllocated) + { + pinnedInBytes[index].Free(); + } + } + } + if (pinnedOutBytes.IsAllocated) + { + pinnedOutBytes.Free(); + } + + if (outFreeContextBuffer != null) + { + outFreeContextBuffer.Dispose(); + } + } + + GlobalLog.Leave("SafeDeleteContext::InitializeSecurityContext() unmanaged InitializeSecurityContext()", "errorCode:0x" + errorCode.ToString("x8") + " refContext:" + ValidationHelper.ToString(refContext)); + + return errorCode; + } + + // After PINvoke call the method will fix the handleTemplate.handle with the returned value. + // The caller is responsible for creating a correct SafeFreeContextBuffer_XXX flavour or null can be passed if no handle is returned. + // + // Since it has a CER, this method can't have any references to imports from DLLs that may not exist on the system. + + private static unsafe int MustRunInitializeSecurityContext( + ref SafeFreeCredentials inCredentials, + void* inContextPtr, + byte* targetName, + ContextFlags inFlags, + Endianness endianness, + SecurityBufferDescriptor inputBuffer, + SafeDeleteContext outContext, + SecurityBufferDescriptor outputBuffer, + ref ContextFlags attributes, + SafeFreeContextBuffer handleTemplate) + { + int errorCode = (int)SecurityStatus.InvalidHandle; + bool b1 = false; + bool b2 = false; + + // Run the body of this method as a non-interruptible block. + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + inCredentials.DangerousAddRef(ref b1); + outContext.DangerousAddRef(ref b2); + } + catch (Exception e) + { + if (b1) + { + inCredentials.DangerousRelease(); + b1 = false; + } + if (b2) + { + outContext.DangerousRelease(); + b2 = false; + } + + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + SSPIHandle credentialHandle = inCredentials._handle; + long timeStamp; + + if (!b1) + { + // caller should retry + inCredentials = null; + } + else if (b1 && b2) + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.InitializeSecurityContextW( + ref credentialHandle, + inContextPtr, + targetName, + inFlags, + 0, + endianness, + inputBuffer, + 0, + ref outContext._handle, + outputBuffer, + ref attributes, + out timeStamp); + + // When a credential handle is first associated with the context we keep credential + // ref count bumped up to ensure ordered finalization. + // If the credential handle has been changed we de-ref the old one and associate the + // context with the new cred handle but only if the call was successful. + if (outContext._effectiveCredential != inCredentials && (errorCode & 0x80000000) == 0) + { + // Disassociate the previous credential handle + if (outContext._effectiveCredential != null) + { + outContext._effectiveCredential.DangerousRelease(); + } + outContext._effectiveCredential = inCredentials; + } + else + { + inCredentials.DangerousRelease(); + } + + outContext.DangerousRelease(); + + // The idea is that SSPI has allocated a block and filled up outUnmanagedBuffer+8 slot with the pointer. + if (handleTemplate != null) + { + handleTemplate.Set(((SecurityBufferStruct*)outputBuffer.UnmanagedPointer)->token); // ATTN: on 64 BIT that is still +8 cause of 2* c++ unsigned long == 8 bytes + if (handleTemplate.IsInvalid) + { + handleTemplate.SetHandleAsInvalid(); + } + } + } + + if (inContextPtr == null && (errorCode & 0x80000000) != 0) + { + // an error on the first call, need to set the out handle to invalid value + outContext._handle.SetToInvalid(); + } + } + + return errorCode; + } + + //------------------------------------------------------------------- + internal static unsafe int AcceptSecurityContext(ref SafeFreeCredentials inCredentials, ref SafeDeleteContext refContext, + ContextFlags inFlags, Endianness endianness, SecurityBuffer inSecBuffer, SecurityBuffer[] inSecBuffers, SecurityBuffer outSecBuffer, + ref ContextFlags outFlags) + { + GlobalLog.Assert(outSecBuffer != null, "SafeDeleteContext::AcceptSecurityContext()|outSecBuffer != null"); + GlobalLog.Assert(inSecBuffer == null || inSecBuffers == null, "SafeDeleteContext::AcceptSecurityContext()|inSecBuffer == null || inSecBuffers == null"); + + if (inCredentials == null) + { + throw new ArgumentNullException("inCredentials"); + } + + SecurityBufferDescriptor inSecurityBufferDescriptor = null; + if (inSecBuffer != null) + { + inSecurityBufferDescriptor = new SecurityBufferDescriptor(1); + } + else if (inSecBuffers != null) + { + inSecurityBufferDescriptor = new SecurityBufferDescriptor(inSecBuffers.Length); + } + SecurityBufferDescriptor outSecurityBufferDescriptor = new SecurityBufferDescriptor(1); + + // actually this is returned in outFlags + bool isSspiAllocated = (inFlags & ContextFlags.AllocateMemory) != 0 ? true : false; + + int errorCode = -1; + + SSPIHandle contextHandle = new SSPIHandle(); + if (refContext != null) + { + contextHandle = refContext._handle; + } + + // these are pinned user byte arrays passed along with SecurityBuffers + GCHandle[] pinnedInBytes = null; + GCHandle pinnedOutBytes = new GCHandle(); + // optional output buffer that may need to be freed + SafeFreeContextBuffer outFreeContextBuffer = null; + try + { + pinnedOutBytes = GCHandle.Alloc(outSecBuffer.token, GCHandleType.Pinned); + SecurityBufferStruct[] inUnmanagedBuffer = new SecurityBufferStruct[inSecurityBufferDescriptor == null ? 1 : inSecurityBufferDescriptor.Count]; + fixed (void* inUnmanagedBufferPtr = inUnmanagedBuffer) + { + if (inSecurityBufferDescriptor != null) + { + // Fix Descriptor pointer that points to unmanaged SecurityBuffers + inSecurityBufferDescriptor.UnmanagedPointer = inUnmanagedBufferPtr; + pinnedInBytes = new GCHandle[inSecurityBufferDescriptor.Count]; + SecurityBuffer securityBuffer; + for (int index = 0; index < inSecurityBufferDescriptor.Count; ++index) + { + securityBuffer = inSecBuffer != null ? inSecBuffer : inSecBuffers[index]; + if (securityBuffer != null) + { + // Copy the SecurityBuffer content into unmanaged place holder + inUnmanagedBuffer[index].count = securityBuffer.size; + inUnmanagedBuffer[index].type = securityBuffer.type; + + // use the unmanaged token if it's not null; otherwise use the managed buffer + if (securityBuffer.unmanagedToken != null) + { + inUnmanagedBuffer[index].token = securityBuffer.unmanagedToken.DangerousGetHandle(); + } + else if (securityBuffer.token == null || securityBuffer.token.Length == 0) + { + inUnmanagedBuffer[index].token = IntPtr.Zero; + } + else + { + pinnedInBytes[index] = GCHandle.Alloc(securityBuffer.token, GCHandleType.Pinned); + inUnmanagedBuffer[index].token = Marshal.UnsafeAddrOfPinnedArrayElement(securityBuffer.token, securityBuffer.offset); + } + } + } + } + SecurityBufferStruct[] outUnmanagedBuffer = new SecurityBufferStruct[1]; + fixed (void* outUnmanagedBufferPtr = outUnmanagedBuffer) + { + // Fix Descriptor pointer that points to unmanaged SecurityBuffers + outSecurityBufferDescriptor.UnmanagedPointer = outUnmanagedBufferPtr; + // Copy the SecurityBuffer content into unmanaged place holder + outUnmanagedBuffer[0].count = outSecBuffer.size; + outUnmanagedBuffer[0].type = outSecBuffer.type; + + if (outSecBuffer.token == null || outSecBuffer.token.Length == 0) + { + outUnmanagedBuffer[0].token = IntPtr.Zero; + } + else + { + outUnmanagedBuffer[0].token = Marshal.UnsafeAddrOfPinnedArrayElement(outSecBuffer.token, outSecBuffer.offset); + } + if (isSspiAllocated) + { + outFreeContextBuffer = SafeFreeContextBuffer.CreateEmptyHandle(); + } + + if (refContext == null || refContext.IsInvalid) + { + refContext = new SafeDeleteContext(); + } + + errorCode = MustRunAcceptSecurityContext( + ref inCredentials, + contextHandle.IsZero ? null : &contextHandle, + inSecurityBufferDescriptor, + inFlags, + endianness, + refContext, + outSecurityBufferDescriptor, + ref outFlags, + outFreeContextBuffer); + + GlobalLog.Print("SafeDeleteContext:AcceptSecurityContext Marshalling OUT buffer"); + // Get unmanaged buffer with index 0 as the only one passed into PInvoke + outSecBuffer.size = outUnmanagedBuffer[0].count; + outSecBuffer.type = outUnmanagedBuffer[0].type; + if (outSecBuffer.size > 0) + { + outSecBuffer.token = new byte[outSecBuffer.size]; + Marshal.Copy(outUnmanagedBuffer[0].token, outSecBuffer.token, 0, outSecBuffer.size); + } + else + { + outSecBuffer.token = null; + } + } + } + } + finally + { + if (pinnedInBytes != null) + { + for (int index = 0; index < pinnedInBytes.Length; index++) + { + if (pinnedInBytes[index].IsAllocated) + { + pinnedInBytes[index].Free(); + } + } + } + + if (pinnedOutBytes.IsAllocated) + { + pinnedOutBytes.Free(); + } + + if (outFreeContextBuffer != null) + { + outFreeContextBuffer.Dispose(); + } + } + + GlobalLog.Leave("SafeDeleteContext::AcceptSecurityContex() unmanaged AcceptSecurityContex()", "errorCode:0x" + errorCode.ToString("x8") + " refContext:" + ValidationHelper.ToString(refContext)); + + return errorCode; + } + + // After PINvoke call the method will fix the handleTemplate.handle with the returned value. + // The caller is responsible for creating a correct SafeFreeContextBuffer_XXX flavour or null can be passed if no handle is returned. + // + // Since it has a CER, this method can't have any references to imports from DLLs that may not exist on the system. + + private static unsafe int MustRunAcceptSecurityContext( + ref SafeFreeCredentials inCredentials, + void* inContextPtr, + SecurityBufferDescriptor inputBuffer, + ContextFlags inFlags, + Endianness endianness, + SafeDeleteContext outContext, + SecurityBufferDescriptor outputBuffer, + ref ContextFlags outFlags, + SafeFreeContextBuffer handleTemplate) + { + int errorCode = (int)SecurityStatus.InvalidHandle; + bool b1 = false; + bool b2 = false; + + // Run the body of this method as a non-interruptible block. + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + inCredentials.DangerousAddRef(ref b1); + outContext.DangerousAddRef(ref b2); + } + catch (Exception e) + { + if (b1) + { + inCredentials.DangerousRelease(); + b1 = false; + } + if (b2) + { + outContext.DangerousRelease(); + b2 = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + SSPIHandle credentialHandle = inCredentials._handle; + long timeStamp; + + if (!b1) + { + // caller should retry + inCredentials = null; + } + else if (b1 && b2) + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.AcceptSecurityContext( + ref credentialHandle, + inContextPtr, + inputBuffer, + inFlags, + endianness, + ref outContext._handle, + outputBuffer, + ref outFlags, + out timeStamp); + + // When a credential handle is first associated with the context we keep credential + // ref count bumped up to ensure ordered finalization. + // If the credential handle has been changed we de-ref the old one and associate the + // context with the new cred handle but only if the call was successful. + if (outContext._effectiveCredential != inCredentials && (errorCode & 0x80000000) == 0) + { + // Disassociate the previous credential handle + if (outContext._effectiveCredential != null) + { + outContext._effectiveCredential.DangerousRelease(); + } + outContext._effectiveCredential = inCredentials; + } + else + { + inCredentials.DangerousRelease(); + } + + outContext.DangerousRelease(); + + // The idea is that SSPI has allocated a block and filled up outUnmanagedBuffer+8 slot with the pointer. + if (handleTemplate != null) + { + handleTemplate.Set(((SecurityBufferStruct*)outputBuffer.UnmanagedPointer)->token); // ATTN: on 64 BIT that is still +8 cause of 2* c++ unsigned long == 8 bytes + if (handleTemplate.IsInvalid) + { + handleTemplate.SetHandleAsInvalid(); + } + } + } + + if (inContextPtr == null && (errorCode & 0x80000000) != 0) + { + // an error on the first call, need to set the out handle to invalid value + outContext._handle.SetToInvalid(); + } + } + + return errorCode; + } + + internal static unsafe int CompleteAuthToken(ref SafeDeleteContext refContext, SecurityBuffer[] inSecBuffers) + { + GlobalLog.Enter("SafeDeleteContext::CompleteAuthToken"); + GlobalLog.Print(" refContext = " + ValidationHelper.ToString(refContext)); + GlobalLog.Assert(inSecBuffers != null, "SafeDeleteContext::CompleteAuthToken()|inSecBuffers == null"); + SecurityBufferDescriptor inSecurityBufferDescriptor = new SecurityBufferDescriptor(inSecBuffers.Length); + + int errorCode = (int)SecurityStatus.InvalidHandle; + + // these are pinned user byte arrays passed along with SecurityBuffers + GCHandle[] pinnedInBytes = null; + + SecurityBufferStruct[] inUnmanagedBuffer = new SecurityBufferStruct[inSecurityBufferDescriptor.Count]; + fixed (void* inUnmanagedBufferPtr = inUnmanagedBuffer) + { + // Fix Descriptor pointer that points to unmanaged SecurityBuffers + inSecurityBufferDescriptor.UnmanagedPointer = inUnmanagedBufferPtr; + pinnedInBytes = new GCHandle[inSecurityBufferDescriptor.Count]; + SecurityBuffer securityBuffer; + for (int index = 0; index < inSecurityBufferDescriptor.Count; ++index) + { + securityBuffer = inSecBuffers[index]; + if (securityBuffer != null) + { + inUnmanagedBuffer[index].count = securityBuffer.size; + inUnmanagedBuffer[index].type = securityBuffer.type; + + // use the unmanaged token if it's not null; otherwise use the managed buffer + if (securityBuffer.unmanagedToken != null) + { + inUnmanagedBuffer[index].token = securityBuffer.unmanagedToken.DangerousGetHandle(); + } + else if (securityBuffer.token == null || securityBuffer.token.Length == 0) + { + inUnmanagedBuffer[index].token = IntPtr.Zero; + } + else + { + pinnedInBytes[index] = GCHandle.Alloc(securityBuffer.token, GCHandleType.Pinned); + inUnmanagedBuffer[index].token = Marshal.UnsafeAddrOfPinnedArrayElement(securityBuffer.token, securityBuffer.offset); + } + } + } + + SSPIHandle contextHandle = new SSPIHandle(); + if (refContext != null) + { + contextHandle = refContext._handle; + } + try + { + if (refContext == null || refContext.IsInvalid) + { + refContext = new SafeDeleteContext(); + } + + bool b = false; + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + refContext.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + refContext.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.CompleteAuthToken(contextHandle.IsZero ? null : &contextHandle, inSecurityBufferDescriptor); + refContext.DangerousRelease(); + } + } + } + finally + { + if (pinnedInBytes != null) + { + for (int index = 0; index < pinnedInBytes.Length; index++) + { + if (pinnedInBytes[index].IsAllocated) + { + pinnedInBytes[index].Free(); + } + } + } + } + } + + GlobalLog.Leave("SafeDeleteContext::CompleteAuthToken() unmanaged CompleteAuthToken()", "errorCode:0x" + errorCode.ToString("x8") + " refContext:" + ValidationHelper.ToString(refContext)); + + return errorCode; + } + + protected override bool ReleaseHandle() + { + if (this._effectiveCredential != null) + { + this._effectiveCredential.DangerousRelease(); + } + + return UnsafeNclNativeMethods.SafeNetHandles.DeleteSecurityContext(ref _handle) == 0; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCertContext.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCertContext.cs new file mode 100644 index 0000000000..20f85f5010 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCertContext.cs @@ -0,0 +1,35 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.ConstrainedExecution; +using System.Security; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Security.Windows +{ + [SuppressUnmanagedCodeSecurity] + internal sealed class SafeFreeCertContext : SafeHandleZeroOrMinusOneIsInvalid + { + internal SafeFreeCertContext() + : base(true) + { + } + + // This must be ONLY called from this file within a CER. + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal unsafe void Set(IntPtr value) + { + this.handle = value; + } + + protected override bool ReleaseHandle() + { + UnsafeNclNativeMethods.SafeNetHandles.CertFreeCertificateContext(handle); + return true; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBuffer.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBuffer.cs new file mode 100644 index 0000000000..db31e3018a --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBuffer.cs @@ -0,0 +1,148 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Security; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Security.Windows +{ + [SuppressUnmanagedCodeSecurity] + internal sealed class SafeFreeContextBuffer : SafeHandleZeroOrMinusOneIsInvalid + { + private SafeFreeContextBuffer() + : base(true) + { + } + + // This must be ONLY called from this file and in the context of a CER + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal unsafe void Set(IntPtr value) + { + this.handle = value; + } + + internal static int EnumeratePackages(out int pkgnum, out SafeFreeContextBuffer pkgArray) + { + int res = -1; + res = UnsafeNclNativeMethods.SafeNetHandles.EnumerateSecurityPackagesW(out pkgnum, out pkgArray); + if (res != 0 && pkgArray != null) + { + pkgArray.SetHandleAsInvalid(); + } + return res; + } + + internal static SafeFreeContextBuffer CreateEmptyHandle() + { + return new SafeFreeContextBuffer(); + } + + // After PINvoke call the method will fix the refHandle.handle with the returned value. + // The caller is responsible for creating a correct SafeHandle template or null can be passed if no handle is returned. + // + // This method switches between three non-interruptible helper methods. (This method can't be both non-interruptible and + // reference imports from all three DLLs - doing so would cause all three DLLs to try to be bound to.) + + public static unsafe int QueryContextAttributes(SafeDeleteContext phContext, ContextAttribute contextAttribute, byte* buffer, SafeHandle refHandle) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + + // We don't want to be interrupted by thread abort exceptions or unexpected out-of-memory errors failing to jit + // one of the following methods. So run within a CER non-interruptible block. + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + phContext.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + phContext.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.SafeNetHandles.QueryContextAttributesW(ref phContext._handle, contextAttribute, buffer); + phContext.DangerousRelease(); + } + + if (status == 0 && refHandle != null) + { + if (refHandle is SafeFreeContextBuffer) + { + ((SafeFreeContextBuffer)refHandle).Set(*(IntPtr*)buffer); + } + else + { + ((SafeFreeCertContext)refHandle).Set(*(IntPtr*)buffer); + } + } + + if (status != 0 && refHandle != null) + { + refHandle.SetHandleAsInvalid(); + } + } + + return status; + } + + public static int SetContextAttributes(SafeDeleteContext phContext, ContextAttribute contextAttribute, byte[] buffer) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + + // We don't want to be interrupted by thread abort exceptions or unexpected out-of-memory errors failing + // to jit one of the following methods. So run within a CER non-interruptible block. + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + phContext.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + phContext.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.SafeNetHandles.SetContextAttributesW( + ref phContext._handle, contextAttribute, buffer, buffer.Length); + phContext.DangerousRelease(); + } + } + + return status; + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.FreeContextBuffer(handle) == 0; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBufferChannelBinding.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBufferChannelBinding.cs new file mode 100644 index 0000000000..344ef38a59 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBufferChannelBinding.cs @@ -0,0 +1,89 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.ConstrainedExecution; +using System.Security; +using System.Security.Authentication.ExtendedProtection; + +namespace Microsoft.AspNet.Security.Windows +{ + [SuppressUnmanagedCodeSecurity] + internal sealed class SafeFreeContextBufferChannelBinding : ChannelBinding + { + private int size; + + public override int Size + { + get { return size; } + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal unsafe void Set(IntPtr value) + { + this.handle = value; + } + + internal static SafeFreeContextBufferChannelBinding CreateEmptyHandle() + { + return new SafeFreeContextBufferChannelBinding(); + } + + public static unsafe int QueryContextChannelBinding(SafeDeleteContext phContext, ContextAttribute contextAttribute, Bindings* buffer, + SafeFreeContextBufferChannelBinding refHandle) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + + // We don't want to be interrupted by thread abort exceptions or unexpected out-of-memory errors failing to jit + // one of the following methods. So run within a CER non-interruptible block. + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + phContext.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + phContext.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.SafeNetHandles.QueryContextAttributesW(ref phContext._handle, contextAttribute, buffer); + phContext.DangerousRelease(); + } + + if (status == 0 && refHandle != null) + { + refHandle.Set((*buffer).pBindings); + refHandle.size = (*buffer).BindingsLength; + } + + if (status != 0 && refHandle != null) + { + refHandle.SetHandleAsInvalid(); + } + } + + return status; + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.FreeContextBuffer(handle) == 0; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCredentials.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCredentials.cs new file mode 100644 index 0000000000..0b42913028 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCredentials.cs @@ -0,0 +1,199 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + internal sealed class SafeFreeCredentials : SafeHandle + { + internal SSPIHandle _handle; // should be always used as by ref in PINvokes parameters + + private SafeFreeCredentials() + : base(IntPtr.Zero, true) + { + _handle = new SSPIHandle(); + } + + public override bool IsInvalid + { + get { return IsClosed || _handle.IsZero; } + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.FreeCredentialsHandle(ref _handle) == 0; + } + + public static unsafe int AcquireCredentialsHandle(string package, CredentialUse intent, ref AuthIdentity authdata, + out SafeFreeCredentials outCredential) + { + GlobalLog.Print("SafeFreeCredentials::AcquireCredentialsHandle#1(" + + package + ", " + + intent + ", " + + authdata + ")"); + + int errorCode = -1; + long timeStamp; + + outCredential = new SafeFreeCredentials(); + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + } + finally + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.AcquireCredentialsHandleW( + null, + package, + (int)intent, + null, + ref authdata, + null, + null, + ref outCredential._handle, + out timeStamp); + } + + if (errorCode != 0) + { + outCredential.SetHandleAsInvalid(); + } + return errorCode; + } + + public static unsafe int AcquireDefaultCredential(string package, CredentialUse intent, out SafeFreeCredentials outCredential) + { + GlobalLog.Print("SafeFreeCredentials::AcquireDefaultCredential(" + + package + ", " + + intent + ")"); + + int errorCode = -1; + long timeStamp; + + outCredential = new SafeFreeCredentials(); + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + } + finally + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.AcquireCredentialsHandleW( + null, + package, + (int)intent, + null, + IntPtr.Zero, + null, + null, + ref outCredential._handle, + out timeStamp); + } + + if (errorCode != 0) + { + outCredential.SetHandleAsInvalid(); + } + return errorCode; + } + + // This overload is only called on Win7+ where SspiEncodeStringsAsAuthIdentity() was used to + // create the authData blob. + public static unsafe int AcquireCredentialsHandle( + string package, + CredentialUse intent, + ref SafeSspiAuthDataHandle authdata, + out SafeFreeCredentials outCredential) + { + int errorCode = -1; + long timeStamp; + + outCredential = new SafeFreeCredentials(); + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + } + finally + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.AcquireCredentialsHandleW( + null, + package, + (int)intent, + null, + authdata, + null, + null, + ref outCredential._handle, + out timeStamp); + } + + if (errorCode != 0) + { + outCredential.SetHandleAsInvalid(); + } + return errorCode; + } + + public static unsafe int AcquireCredentialsHandle(string package, CredentialUse intent, ref SecureCredential authdata, + out SafeFreeCredentials outCredential) + { + GlobalLog.Print("SafeFreeCredentials::AcquireCredentialsHandle#2(" + + package + ", " + + intent + ", " + + authdata + ")"); + + int errorCode = -1; + long timeStamp; + + // If there is a certificate, wrap it into an array. + // Not threadsafe. + IntPtr copiedPtr = authdata.certContextArray; + try + { + IntPtr certArrayPtr = new IntPtr(&copiedPtr); + if (copiedPtr != IntPtr.Zero) + { + authdata.certContextArray = certArrayPtr; + } + + outCredential = new SafeFreeCredentials(); + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + } + finally + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.AcquireCredentialsHandleW( + null, + package, + (int)intent, + null, + ref authdata, + null, + null, + ref outCredential._handle, + out timeStamp); + } + } + finally + { + authdata.certContextArray = copiedPtr; + } + + if (errorCode != 0) + { + outCredential.SetHandleAsInvalid(); + } + return errorCode; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeLocalFree.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeLocalFree.cs new file mode 100644 index 0000000000..30437854d7 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeLocalFree.cs @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Security; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Security.Windows +{ + [SuppressUnmanagedCodeSecurity] + internal sealed class SafeLocalFree : SafeHandleZeroOrMinusOneIsInvalid + { + private const int LmemFixed = 0; + private const int NULL = 0; + + // This returned handle cannot be modified by the application. + public static SafeLocalFree Zero = new SafeLocalFree(false); + + private SafeLocalFree() + : base(true) + { + } + + private SafeLocalFree(bool ownsHandle) + : base(ownsHandle) + { + } + + public static SafeLocalFree LocalAlloc(int cb) + { + SafeLocalFree result = UnsafeNclNativeMethods.SafeNetHandles.LocalAlloc(LmemFixed, (UIntPtr)cb); + if (result.IsInvalid) + { + result.SetHandleAsInvalid(); + throw new OutOfMemoryException(); + } + return result; + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.LocalFree(handle) == IntPtr.Zero; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeSspiAuthDataHandle.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeSspiAuthDataHandle.cs new file mode 100644 index 0000000000..e02c75aaba --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeSspiAuthDataHandle.cs @@ -0,0 +1,32 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + using System; + using System.Runtime.CompilerServices; + using System.Runtime.ConstrainedExecution; + using System.Runtime.InteropServices; + using System.Security; + using System.Security.Authentication.ExtendedProtection; + using System.Threading; + using Microsoft.Win32.SafeHandles; + + [SuppressUnmanagedCodeSecurity] + internal sealed class SafeSspiAuthDataHandle : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeSspiAuthDataHandle() + : base(true) + { + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SspiHelper.SspiFreeAuthIdentity(handle) == SecurityStatus.OK; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SchProtocols.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SchProtocols.cs new file mode 100644 index 0000000000..eb48aa1d3f --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SchProtocols.cs @@ -0,0 +1,41 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; + +namespace Microsoft.AspNet.Security.Windows +{ + // From Schannel.h + [Flags] + internal enum SchProtocols + { + Zero = 0, + PctClient = 0x00000002, + PctServer = 0x00000001, + Pct = (PctClient | PctServer), + Ssl2Client = 0x00000008, + Ssl2Server = 0x00000004, + Ssl2 = (Ssl2Client | Ssl2Server), + Ssl3Client = 0x00000020, + Ssl3Server = 0x00000010, + Ssl3 = (Ssl3Client | Ssl3Server), + Tls10Client = 0x00000080, + Tls10Server = 0x00000040, + Tls10 = (Tls10Client | Tls10Server), + Tls11Client = 0x00000200, + Tls11Server = 0x00000100, + Tls11 = (Tls11Client | Tls11Server), + Tls12Client = 0x00000800, + Tls12Server = 0x00000400, + Tls12 = (Tls12Client | Tls12Server), + Ssl3Tls = (Ssl3 | Tls10), + UniClient = unchecked((int)0x80000000), + UniServer = 0x40000000, + Unified = (UniClient | UniServer), + ClientMask = (PctClient | Ssl2Client | Ssl3Client | Tls10Client | Tls11Client | Tls12Client | UniClient), + ServerMask = (PctServer | Ssl2Server | Ssl3Server | Tls10Server | Tls11Server | Tls12Server | UniServer) + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecSizes.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecSizes.cs new file mode 100644 index 0000000000..0f1bb3e7a3 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecSizes.cs @@ -0,0 +1,44 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + using System; + using System.ComponentModel; + using System.Diagnostics; + using System.Globalization; + using System.Runtime.InteropServices; + + [StructLayout(LayoutKind.Sequential)] + internal class SecSizes + { + public readonly int MaxToken; + public readonly int MaxSignature; + public readonly int BlockSize; + public readonly int SecurityTrailer; + + internal unsafe SecSizes(byte[] memory) + { + fixed (void* voidPtr = memory) + { + IntPtr unmanagedAddress = new IntPtr(voidPtr); + try + { + MaxToken = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress)); + MaxSignature = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 4)); + BlockSize = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 8)); + SecurityTrailer = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 12)); + } + catch (OverflowException) + { + GlobalLog.Assert(false, "SecSizes::.ctor", "Negative size."); + throw; + } + } + } + public static readonly int SizeOf = Marshal.SizeOf(typeof(SecSizes)); + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBuffer.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBuffer.cs new file mode 100644 index 0000000000..069223dd4f --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBuffer.cs @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Globalization; +using System.Runtime.InteropServices; +using System.Security.Authentication.ExtendedProtection; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class SecurityBuffer + { + public int size; + public BufferType type; + public byte[] token; + public SafeHandle unmanagedToken; + public int offset; + + public SecurityBuffer(byte[] data, int offset, int size, BufferType tokentype) + { + GlobalLog.Assert(offset >= 0 && offset <= (data == null ? 0 : data.Length), "SecurityBuffer::.ctor", "'offset' out of range. [" + offset + "]"); + GlobalLog.Assert(size >= 0 && size <= (data == null ? 0 : data.Length - offset), "SecurityBuffer::.ctor", "'size' out of range. [" + size + "]"); + + this.offset = data == null || offset < 0 ? 0 : Math.Min(offset, data.Length); + this.size = data == null || size < 0 ? 0 : Math.Min(size, data.Length - this.offset); + this.type = tokentype; + this.token = size == 0 ? null : data; + } + + public SecurityBuffer(byte[] data, BufferType tokentype) + { + this.size = data == null ? 0 : data.Length; + this.type = tokentype; + this.token = size == 0 ? null : data; + } + + public SecurityBuffer(int size, BufferType tokentype) + { + GlobalLog.Assert(size >= 0, "SecurityBuffer::.ctor", "'size' out of range. [" + size.ToString(NumberFormatInfo.InvariantInfo) + "]"); + + this.size = size; + this.type = tokentype; + this.token = size == 0 ? null : new byte[size]; + } + + public SecurityBuffer(ChannelBinding binding) + { + this.size = (binding == null ? 0 : binding.Size); + this.type = BufferType.ChannelBindings; + this.unmanagedToken = binding; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBufferDescriptor.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBufferDescriptor.cs new file mode 100644 index 0000000000..138f12b2bd --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBufferDescriptor.cs @@ -0,0 +1,33 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + [StructLayout(LayoutKind.Sequential)] + internal unsafe class SecurityBufferDescriptor + { + /* + typedef struct _SecBufferDesc { + ULONG ulVersion; + ULONG cBuffers; + PSecBuffer pBuffers; + } SecBufferDesc, * PSecBufferDesc; + */ + public readonly int Version; + public readonly int Count; + public void* UnmanagedPointer; + + public SecurityBufferDescriptor(int count) + { + Version = 0; + Count = count; + UnmanagedPointer = null; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityPackageInfoClass.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityPackageInfoClass.cs new file mode 100644 index 0000000000..df19cf7e37 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityPackageInfoClass.cs @@ -0,0 +1,80 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Globalization; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class SecurityPackageInfoClass + { + private int _capabilities = 0; + private short _version = 0; + private short _rpcid = 0; + private int _maxToken = 0; + private string _name = null; + private string _comment = null; + + /* + * This is to support SSL under semi trusted environment. + * Note that it is only for SSL with no client cert + */ + internal SecurityPackageInfoClass(SafeHandle safeHandle, int index) + { + if (safeHandle.IsInvalid) + { + GlobalLog.Print("SecurityPackageInfoClass::.ctor() the pointer is invalid: " + (safeHandle.DangerousGetHandle()).ToString("x")); + return; + } + IntPtr unmanagedAddress = IntPtrHelper.Add(safeHandle.DangerousGetHandle(), SecurityPackageInfo.Size * index); + GlobalLog.Print("SecurityPackageInfoClass::.ctor() unmanagedPointer: " + ((long)unmanagedAddress).ToString("x")); + + _capabilities = Marshal.ReadInt32(unmanagedAddress, (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "Capabilities")); + _version = Marshal.ReadInt16(unmanagedAddress, (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "Version")); + _rpcid = Marshal.ReadInt16(unmanagedAddress, (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "RPCID")); + _maxToken = Marshal.ReadInt32(unmanagedAddress, (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "MaxToken")); + + IntPtr unmanagedString; + + unmanagedString = Marshal.ReadIntPtr(unmanagedAddress, (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "Name")); + if (unmanagedString != IntPtr.Zero) + { + _name = Marshal.PtrToStringUni(unmanagedString); + GlobalLog.Print("Name: " + Name); + } + + unmanagedString = Marshal.ReadIntPtr(unmanagedAddress, (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "Comment")); + if (unmanagedString != IntPtr.Zero) + { + _comment = Marshal.PtrToStringUni(unmanagedString); + GlobalLog.Print("Comment: " + _comment); + } + + GlobalLog.Print("SecurityPackageInfoClass::.ctor(): " + ToString()); + } + + internal int MaxToken + { + get { return _maxToken; } + } + + internal string Name + { + get { return _name; } + } + + public override string ToString() + { + return "Capabilities:" + String.Format(CultureInfo.InvariantCulture, "0x{0:x}", _capabilities) + + " Version:" + _version.ToString(NumberFormatInfo.InvariantInfo) + + " RPCID:" + _rpcid.ToString(NumberFormatInfo.InvariantInfo) + + " MaxToken:" + MaxToken.ToString(NumberFormatInfo.InvariantInfo) + + " Name:" + ((Name == null) ? "(null)" : Name) + + " Comment:" + ((_comment == null) ? "(null)" : _comment); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SslConnectionInfo.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SslConnectionInfo.cs new file mode 100644 index 0000000000..2430b73b8d --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SslConnectionInfo.cs @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + [StructLayout(LayoutKind.Sequential)] + internal class SslConnectionInfo + { + public readonly int Protocol; + public readonly int DataCipherAlg; + public readonly int DataKeySize; + public readonly int DataHashAlg; + public readonly int DataHashKeySize; + public readonly int KeyExchangeAlg; + public readonly int KeyExchKeySize; + + internal unsafe SslConnectionInfo(byte[] nativeBuffer) + { + fixed (void* voidPtr = nativeBuffer) + { + IntPtr unmanagedAddress = new IntPtr(voidPtr); + Protocol = Marshal.ReadInt32(unmanagedAddress); + DataCipherAlg = Marshal.ReadInt32(unmanagedAddress, 4); + DataKeySize = Marshal.ReadInt32(unmanagedAddress, 8); + DataHashAlg = Marshal.ReadInt32(unmanagedAddress, 12); + DataHashKeySize = Marshal.ReadInt32(unmanagedAddress, 16); + KeyExchangeAlg = Marshal.ReadInt32(unmanagedAddress, 20); + KeyExchKeySize = Marshal.ReadInt32(unmanagedAddress, 24); + } + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/StreamSizes.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/StreamSizes.cs new file mode 100644 index 0000000000..6be68aa96b --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/StreamSizes.cs @@ -0,0 +1,43 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + [StructLayout(LayoutKind.Sequential)] + internal class StreamSizes + { + public int header; + public int trailer; + public int maximumMessage; + public int buffersCount; + public int blockSize; + + internal unsafe StreamSizes(byte[] memory) + { + fixed (void* voidPtr = memory) + { + IntPtr unmanagedAddress = new IntPtr(voidPtr); + try + { + header = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress)); + trailer = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 4)); + maximumMessage = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 8)); + buffersCount = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 12)); + blockSize = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 16)); + } + catch (OverflowException) + { + GlobalLog.Assert(false, "StreamSizes::.ctor", "Negative size."); + throw; + } + } + } + public static readonly int SizeOf = Marshal.SizeOf(typeof(StreamSizes)); + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/UnsafeNativeMethods.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/UnsafeNativeMethods.cs new file mode 100644 index 0000000000..3e8a40d344 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/UnsafeNativeMethods.cs @@ -0,0 +1,222 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Security; + +namespace Microsoft.AspNet.Security.Windows +{ + [System.Security.SuppressUnmanagedCodeSecurityAttribute] + internal static class UnsafeNclNativeMethods + { + private const string KERNEL32 = "kernel32.dll"; + private const string SECUR32 = "secur32.dll"; + private const string CRYPT32 = "crypt32.dll"; + + [DllImport(KERNEL32, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint GetCurrentThreadId(); + + [System.Security.SuppressUnmanagedCodeSecurityAttribute] + internal static class SafeNetHandles + { + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern int FreeContextBuffer( + [In] IntPtr contextBuffer); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern int FreeCredentialsHandle( + ref SSPIHandle handlePtr); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern int DeleteSecurityContext( + ref SSPIHandle handlePtr); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + internal static unsafe extern int AcceptSecurityContext( + ref SSPIHandle credentialHandle, + [In] void* inContextPtr, + [In] SecurityBufferDescriptor inputBuffer, + [In] ContextFlags inFlags, + [In] Endianness endianness, + ref SSPIHandle outContextPtr, + [In, Out] SecurityBufferDescriptor outputBuffer, + [In, Out] ref ContextFlags attributes, + out long timeStamp); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + internal static unsafe extern int QueryContextAttributesW( + ref SSPIHandle contextHandle, + [In] ContextAttribute attribute, + [In] void* buffer); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + internal static unsafe extern int SetContextAttributesW( + ref SSPIHandle contextHandle, + [In] ContextAttribute attribute, + [In] byte[] buffer, + [In] int bufferSize); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + internal static extern int EnumerateSecurityPackagesW( + [Out] out int pkgnum, + [Out] out SafeFreeContextBuffer handle); + + [DllImport(SECUR32, ExactSpelling = true, CharSet = CharSet.Unicode, SetLastError = true)] + internal static unsafe extern int AcquireCredentialsHandleW( + [In] string principal, + [In] string moduleName, + [In] int usage, + [In] void* logonID, + [In] ref AuthIdentity authdata, + [In] void* keyCallback, + [In] void* keyArgument, + ref SSPIHandle handlePtr, + [Out] out long timeStamp); + + [DllImport(SECUR32, ExactSpelling = true, CharSet = CharSet.Unicode, SetLastError = true)] + internal static unsafe extern int AcquireCredentialsHandleW( + [In] string principal, + [In] string moduleName, + [In] int usage, + [In] void* logonID, + [In] IntPtr zero, + [In] void* keyCallback, + [In] void* keyArgument, + ref SSPIHandle handlePtr, + [Out] out long timeStamp); + + // Win7+ + [DllImport(SECUR32, ExactSpelling = true, CharSet = CharSet.Unicode, SetLastError = true)] + internal static unsafe extern int AcquireCredentialsHandleW( + [In] string principal, + [In] string moduleName, + [In] int usage, + [In] void* logonID, + [In] SafeSspiAuthDataHandle authdata, + [In] void* keyCallback, + [In] void* keyArgument, + ref SSPIHandle handlePtr, + [Out] out long timeStamp); + + [DllImport(SECUR32, ExactSpelling = true, CharSet = CharSet.Unicode, SetLastError = true)] + internal static unsafe extern int AcquireCredentialsHandleW( + [In] string principal, + [In] string moduleName, + [In] int usage, + [In] void* logonID, + [In] ref SecureCredential authData, + [In] void* keyCallback, + [In] void* keyArgument, + ref SSPIHandle handlePtr, + [Out] out long timeStamp); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + internal static unsafe extern int InitializeSecurityContextW( + ref SSPIHandle credentialHandle, + [In] void* inContextPtr, + [In] byte* targetName, + [In] ContextFlags inFlags, + [In] int reservedI, + [In] Endianness endianness, + [In] SecurityBufferDescriptor inputBuffer, + [In] int reservedII, + ref SSPIHandle outContextPtr, + [In, Out] SecurityBufferDescriptor outputBuffer, + [In, Out] ref ContextFlags attributes, + out long timeStamp); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + internal static unsafe extern int CompleteAuthToken( + [In] void* inContextPtr, + [In, Out] SecurityBufferDescriptor inputBuffers); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + internal static extern int QuerySecurityContextToken(ref SSPIHandle phContext, [Out] out SafeCloseHandle handle); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern bool CloseHandle(IntPtr handle); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + internal static extern SafeLocalFree LocalAlloc(int uFlags, UIntPtr sizetdwBytes); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern IntPtr LocalFree(IntPtr handle); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern unsafe bool FreeLibrary([In] IntPtr hModule); + + [DllImport(CRYPT32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern void CertFreeCertificateChain( + [In] IntPtr pChainContext); + + [DllImport(CRYPT32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern void CertFreeCertificateChainList( + [In] IntPtr ppChainContext); + + [DllImport(CRYPT32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern bool CertFreeCertificateContext( // Suppressing returned status check, it's always==TRUE, + [In] IntPtr certContext); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern IntPtr GlobalFree(IntPtr handle); + } + + [System.Security.SuppressUnmanagedCodeSecurityAttribute] + internal static class NativeNTSSPI + { + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + internal static extern int EncryptMessage( + ref SSPIHandle contextHandle, + [In] uint qualityOfProtection, + [In, Out] SecurityBufferDescriptor inputOutput, + [In] uint sequenceNumber); + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + internal static unsafe extern int DecryptMessage( + [In] ref SSPIHandle contextHandle, + [In, Out] SecurityBufferDescriptor inputOutput, + [In] uint sequenceNumber, + uint* qualityOfProtection); + } // class UnsafeNclNativeMethods.NativeNTSSPI + + [SuppressUnmanagedCodeSecurityAttribute] + internal static class SspiHelper + { + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + internal static unsafe extern SecurityStatus SspiFreeAuthIdentity( + [In] IntPtr authData); + + [SuppressMessage("Microsoft.Security", "CA2118:ReviewSuppressUnmanagedCodeSecurityUsage", Justification = "Implementation requires unmanaged code usage")] + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true, CharSet = CharSet.Unicode)] + internal static unsafe extern SecurityStatus SspiEncodeStringsAsAuthIdentity( + [In] string userName, + [In] string domainName, + [In] string password, + [Out] out SafeSspiAuthDataHandle authData); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NegotiationInfoClass.cs b/src/Microsoft.AspNet.Security.Windows/NegotiationInfoClass.cs new file mode 100644 index 0000000000..ab7367c456 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NegotiationInfoClass.cs @@ -0,0 +1,70 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + using System; + using System.ComponentModel; + using System.Diagnostics; + using System.Globalization; + using System.Runtime.InteropServices; + + internal class NegotiationInfoClass + { + internal const string NTLM = "NTLM"; + internal const string Kerberos = "Kerberos"; + internal const string WDigest = "WDigest"; + internal const string Digest = "Digest"; + internal const string Negotiate = "Negotiate"; + internal string AuthenticationPackage; + + internal NegotiationInfoClass(SafeHandle safeHandle, int negotiationState) + { + if (safeHandle.IsInvalid) + { + GlobalLog.Print("NegotiationInfoClass::.ctor() the handle is invalid:" + (safeHandle.DangerousGetHandle()).ToString("x")); + return; + } + IntPtr packageInfo = safeHandle.DangerousGetHandle(); + GlobalLog.Print("NegotiationInfoClass::.ctor() packageInfo:" + packageInfo.ToString("x8") + " negotiationState:" + negotiationState.ToString("x8")); + + const int SECPKG_NEGOTIATION_COMPLETE = 0; + const int SECPKG_NEGOTIATION_OPTIMISTIC = 1; + // const int SECPKG_NEGOTIATION_IN_PROGRESS = 2; + // const int SECPKG_NEGOTIATION_DIRECT = 3; + // const int SECPKG_NEGOTIATION_TRY_MULTICRED = 4; + + if (negotiationState == SECPKG_NEGOTIATION_COMPLETE || negotiationState == SECPKG_NEGOTIATION_OPTIMISTIC) + { + IntPtr unmanagedString = Marshal.ReadIntPtr(packageInfo, SecurityPackageInfo.NameOffest); + string name = null; + if (unmanagedString != IntPtr.Zero) + { + name = Marshal.PtrToStringUni(unmanagedString); + } + GlobalLog.Print("NegotiationInfoClass::.ctor() packageInfo:" + packageInfo.ToString("x8") + " negotiationState:" + negotiationState.ToString("x8") + " name:" + ValidationHelper.ToString(name)); + + // an optimization for future string comparisons + if (string.Compare(name, Kerberos, StringComparison.OrdinalIgnoreCase) == 0) + { + AuthenticationPackage = Kerberos; + } + else if (string.Compare(name, NTLM, StringComparison.OrdinalIgnoreCase) == 0) + { + AuthenticationPackage = NTLM; + } + else if (string.Compare(name, WDigest, StringComparison.OrdinalIgnoreCase) == 0) + { + AuthenticationPackage = WDigest; + } + else + { + AuthenticationPackage = name; + } + } + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/PrefixCollection.cs b/src/Microsoft.AspNet.Security.Windows/PrefixCollection.cs new file mode 100644 index 0000000000..37a2caa049 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/PrefixCollection.cs @@ -0,0 +1,109 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class PrefixCollection : ICollection + { + private WindowsAuthMiddleware _winAuth; + + internal PrefixCollection(WindowsAuthMiddleware winAuth) + { + _winAuth = winAuth; + } + + public int Count + { + get + { + return _winAuth._uriPrefixes.Count; + } + } + + public bool IsSynchronized + { + get + { + return false; + } + } + + public bool IsReadOnly + { + get + { + return false; + } + } + + public void CopyTo(Array array, int offset) + { + if (Count > array.Length) + { + throw new ArgumentOutOfRangeException("array", SR.GetString(SR.net_array_too_small)); + } + if (offset + Count > array.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + int index = 0; + foreach (string uriPrefix in _winAuth._uriPrefixes.Keys) + { + array.SetValue(uriPrefix, offset + index++); + } + } + + public void CopyTo(string[] array, int offset) + { + if (Count > array.Length) + { + throw new ArgumentOutOfRangeException("array", SR.GetString(SR.net_array_too_small)); + } + if (offset + Count > array.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + int index = 0; + foreach (string uriPrefix in _winAuth._uriPrefixes.Keys) + { + array[offset + index++] = uriPrefix; + } + } + + public void Add(string uriPrefix) + { + _winAuth.AddPrefix(uriPrefix); + } + + public bool Contains(string uriPrefix) + { + return _winAuth._uriPrefixes.Contains(uriPrefix); + } + + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return new PrefixEnumerator(_winAuth._uriPrefixes.Keys.GetEnumerator()); + } + + public bool Remove(string uriPrefix) + { + return _winAuth.RemovePrefix(uriPrefix); + } + + public void Clear() + { + _winAuth.RemoveAll(true); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/PrefixEnumerator.cs b/src/Microsoft.AspNet.Security.Windows/PrefixEnumerator.cs new file mode 100644 index 0000000000..40de3afd6d --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/PrefixEnumerator.cs @@ -0,0 +1,51 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System.Collections; +using System.Collections.Generic; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class PrefixEnumerator : IEnumerator + { + private IEnumerator enumerator; + + internal PrefixEnumerator(IEnumerator enumerator) + { + this.enumerator = enumerator; + } + + public string Current + { + get + { + return (string)enumerator.Current; + } + } + + object System.Collections.IEnumerator.Current + { + get + { + return enumerator.Current; + } + } + + public bool MoveNext() + { + return enumerator.MoveNext(); + } + + public void Dispose() + { + } + + void System.Collections.IEnumerator.Reset() + { + enumerator.Reset(); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Properties/AssemblyInfo.cs b/src/Microsoft.AspNet.Security.Windows/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..501727a48a --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Properties/AssemblyInfo.cs @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("Microsoft.AspNet.Security.Windows")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.AspNet.Security.Windows")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("1f471909-581f-4060-a147-430891e9c3c1")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] diff --git a/src/Microsoft.AspNet.Security.Windows/ServiceNameStore.cs b/src/Microsoft.AspNet.Security.Windows/ServiceNameStore.cs new file mode 100644 index 0000000000..7d2d33ff00 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/ServiceNameStore.cs @@ -0,0 +1,368 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Net; +using System.Security.Authentication.ExtendedProtection; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class ServiceNameStore + { + private List serviceNames; + private ServiceNameCollection serviceNameCollection; + + public ServiceNameStore() + { + serviceNames = new List(); + serviceNameCollection = null; // set only when needed (due to expensive item-by-item copy) + } + + public ServiceNameCollection ServiceNames + { + get + { + if (serviceNameCollection == null) + { + serviceNameCollection = new ServiceNameCollection(serviceNames); + } + return serviceNameCollection; + } + } + + private bool AddSingleServiceName(string spn) + { + spn = NormalizeServiceName(spn); + if (Contains(spn)) + { + return false; + } + else + { + serviceNames.Add(spn); + return true; + } + } + + public bool Add(string uriPrefix) + { + Debug.Assert(!String.IsNullOrEmpty(uriPrefix)); + + string[] newServiceNames = BuildServiceNames(uriPrefix); + + bool addedAny = false; + foreach (string spn in newServiceNames) + { + if (AddSingleServiceName(spn)) + { + addedAny = true; + + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, "ServiceNameStore#" + + ValidationHelper.HashString(this) + "::Add() " + + SR.GetString(SR.net_log_listener_spn_add, spn, uriPrefix)); + } + } + } + + if (addedAny) + { + serviceNameCollection = null; + } + else if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, "ServiceNameStore#" + + ValidationHelper.HashString(this) + "::Add() " + + SR.GetString(SR.net_log_listener_spn_not_add, uriPrefix)); + } + + return addedAny; + } + + public bool Remove(string uriPrefix) + { + Debug.Assert(!String.IsNullOrEmpty(uriPrefix)); + + string newServiceName = BuildSimpleServiceName(uriPrefix); + newServiceName = NormalizeServiceName(newServiceName); + bool needToRemove = Contains(newServiceName); + + if (needToRemove) + { + serviceNames.Remove(newServiceName); + serviceNameCollection = null; // invalidate (readonly) ServiceNameCollection + } + + if (Logging.On) + { + if (needToRemove) + { + Logging.PrintInfo(Logging.HttpListener, "ServiceNameStore#" + + ValidationHelper.HashString(this) + "::Remove() " + + SR.GetString(SR.net_log_listener_spn_remove, newServiceName, uriPrefix)); + } + else + { + Logging.PrintInfo(Logging.HttpListener, "ServiceNameStore#" + + ValidationHelper.HashString(this) + "::Remove() " + + SR.GetString(SR.net_log_listener_spn_not_remove, uriPrefix)); + } + } + + return needToRemove; + } + + // Assumes already normalized + private bool Contains(string newServiceName) + { + if (newServiceName == null) + { + return false; + } + + return Contains(newServiceName, serviceNames); + } + + // Assumes searchServiceName and serviceNames have already been normalized + internal static bool Contains(string searchServiceName, ICollection serviceNames) + { + Debug.Assert(serviceNames != null); + Debug.Assert(!String.IsNullOrEmpty(searchServiceName)); + + foreach (string serviceName in serviceNames) + { + if (Match(serviceName, searchServiceName)) + { + return true; + } + } + + return false; + } + + // Assumes already normalized + internal static bool Match(string serviceName1, string serviceName2) + { + return (String.Compare(serviceName1, serviceName2, StringComparison.OrdinalIgnoreCase) == 0); + } + + public void Clear() + { + serviceNames.Clear(); + serviceNameCollection = null; // invalidate (readonly) ServiceNameCollection + } + + private string ExtractHostname(string uriPrefix, bool allowInvalidUriStrings) + { + if (Uri.IsWellFormedUriString(uriPrefix, UriKind.Absolute)) + { + Uri hostUri = new Uri(uriPrefix); + return hostUri.Host; + } + else if (allowInvalidUriStrings) + { + int i = uriPrefix.IndexOf("://") + 3; + int j = i; + + bool inSquareBrackets = false; + while (j < uriPrefix.Length && uriPrefix[j] != '/' && (uriPrefix[j] != ':' || inSquareBrackets)) + { + if (uriPrefix[j] == '[') + { + if (inSquareBrackets) + { + j = i; + break; + } + inSquareBrackets = true; + } + if (inSquareBrackets && uriPrefix[j] == ']') + { + inSquareBrackets = false; + } + j++; + } + + return uriPrefix.Substring(i, j - i); + } + + return null; + } + + public string BuildSimpleServiceName(string uriPrefix) + { + string hostname = ExtractHostname(uriPrefix, false); + + if (hostname != null) + { + return "HTTP/" + hostname; + } + else + { + return null; + } + } + + public string[] BuildServiceNames(string uriPrefix) + { + string hostname = ExtractHostname(uriPrefix, true); + + IPAddress ipAddress = null; + if (String.Compare(hostname, "*", StringComparison.InvariantCultureIgnoreCase) == 0 || + String.Compare(hostname, "+", StringComparison.InvariantCultureIgnoreCase) == 0 || + IPAddress.TryParse(hostname, out ipAddress)) + { + // for a wildcard, register the machine name. If the caller doesn't have DNS permission + // or the query fails for some reason, don't add an SPN. + try + { + string machineName = Dns.GetHostEntry(String.Empty).HostName; + return new string[] { "HTTP/" + machineName }; + } + catch (System.Net.Sockets.SocketException) + { + return new string[0]; + } + catch (System.Security.SecurityException) + { + return new string[0]; + } + } + else if (!hostname.Contains(".")) + { + // for a dotless name, try to resolve the FQDN. If the caller doesn't have DNS permission + // or the query fails for some reason, add only the dotless name. + try + { + string fqdn = Dns.GetHostEntry(hostname).HostName; + return new string[] { "HTTP/" + hostname, "HTTP/" + fqdn }; + } + catch (System.Net.Sockets.SocketException) + { + return new string[] { "HTTP/" + hostname }; + } + catch (System.Security.SecurityException) + { + return new string[] { "HTTP/" + hostname }; + } + } + else + { + return new string[] { "HTTP/" + hostname }; + } + } + + // Normalizes any punycode to unicode in an Service Name (SPN) host. + // If the algorithm fails at any point then the original input is returned. + // ServiceName is in one of the following forms: + // prefix/host + // prefix/host:port + // prefix/host/DistinguishedName + // prefix/host:port/DistinguishedName + internal static string NormalizeServiceName(string inputServiceName) + { + if (string.IsNullOrWhiteSpace(inputServiceName)) + { + return inputServiceName; + } + + // Separate out the prefix + int shashIndex = inputServiceName.IndexOf('/'); + if (shashIndex < 0) + { + return inputServiceName; + } + string prefix = inputServiceName.Substring(0, shashIndex + 1); // Includes slash + string hostPortAndDistinguisher = inputServiceName.Substring(shashIndex + 1); // Excludes slash + + if (string.IsNullOrWhiteSpace(hostPortAndDistinguisher)) + { + return inputServiceName; + } + + string host = hostPortAndDistinguisher; + string port = string.Empty; + string distinguisher = string.Empty; + + // Check for the absence of a port or distinguisher. + UriHostNameType hostType = Uri.CheckHostName(hostPortAndDistinguisher); + if (hostType == UriHostNameType.Unknown) + { + string hostAndPort = hostPortAndDistinguisher; + + // Check for distinguisher + int nextSlashIndex = hostPortAndDistinguisher.IndexOf('/'); + if (nextSlashIndex >= 0) + { + // host:port/distinguisher or host/distinguisher + hostAndPort = hostPortAndDistinguisher.Substring(0, nextSlashIndex); // Excludes Slash + distinguisher = hostPortAndDistinguisher.Substring(nextSlashIndex); // Includes Slash + host = hostAndPort; // We don't know if there is a port yet. + + // No need to validate the distinguisher + } + + // Check for port + int colonIndex = hostAndPort.LastIndexOf(':'); // Allow IPv6 addresses + if (colonIndex >= 0) + { + // host:port + host = hostAndPort.Substring(0, colonIndex); // Excludes colon + port = hostAndPort.Substring(colonIndex + 1); // Excludes colon + + // Loosely validate the port just to make sure it was a port and not something else + UInt16 portValue; + if (!UInt16.TryParse(port, NumberStyles.Integer, CultureInfo.InvariantCulture, out portValue)) + { + return inputServiceName; + } + + // Re-include the colon for the final output. Do not change the port format. + port = hostAndPort.Substring(colonIndex); + } + + hostType = Uri.CheckHostName(host); // Revaidate the host + } + + if (hostType != UriHostNameType.Dns) + { + // UriHostNameType.IPv4, UriHostNameType.IPv6: Do not normalize IPv4/6 hosts. + // UriHostNameType.Basic: This is never returned by CheckHostName today + // UriHostNameType.Unknown: Nothing recognizable to normalize + // default Some new UriHostNameType? + return inputServiceName; + } + + // Now we have a valid DNS host, normalize it. + + Uri constructedUri; + // This shouldn't fail, but we need to avoid any unexpected exceptions on this code path. + if (!Uri.TryCreate(Uri.UriSchemeHttp + Uri.SchemeDelimiter + host, UriKind.Absolute, out constructedUri)) + { + return inputServiceName; + } + + string normalizedHost = constructedUri.GetComponents( + UriComponents.NormalizedHost, UriFormat.SafeUnescaped); + + string normalizedServiceName = string.Format(CultureInfo.InvariantCulture, + "{0}{1}{2}{3}", prefix, normalizedHost, port, distinguisher); + + // Don't return the new one unless we absolutely have to. It may have only changed casing. + if (String.Compare(inputServiceName, normalizedServiceName, StringComparison.OrdinalIgnoreCase) == 0) + { + return inputServiceName; + } + + return normalizedServiceName; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/WindowsAuthMiddleware.cs b/src/Microsoft.AspNet.Security.Windows/WindowsAuthMiddleware.cs new file mode 100644 index 0000000000..85f696728f --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/WindowsAuthMiddleware.cs @@ -0,0 +1,1281 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.Contracts; +using System.Net; +using System.Security.Authentication.ExtendedProtection; +using System.Security.Permissions; +using System.Security.Principal; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Security.Windows +{ + using AppFunc = Func, Task>; + + /// + /// A middleware that performs Windows Authentication of the specified types. + /// + public sealed class WindowsAuthMiddleware + { + private Func, AuthTypes> _authenticationDelegate; + private AuthTypes _authenticationScheme = AuthTypes.Negotiate | AuthTypes.Ntlm | AuthTypes.Digest; + private string _realm; + private PrefixCollection _prefixes; + private bool _unsafeConnectionNtlmAuthentication; + private Func, ExtendedProtectionPolicy> _extendedProtectionSelectorDelegate; + private ExtendedProtectionPolicy _extendedProtectionPolicy; + private ServiceNameStore _defaultServiceNames; + + private Hashtable _disconnectResults; // ulong -> DisconnectAsyncResult + private object _internalLock; + + internal Hashtable _uriPrefixes; + private DigestCache _digestCache; + + private AppFunc _nextApp; + + // TODO: Support proxy auth + // private bool _doProxyAuth; + + /// + /// + /// + /// + public WindowsAuthMiddleware(AppFunc nextApp) + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "WindowsAuth", string.Empty); + } + + _internalLock = new object(); + _defaultServiceNames = new ServiceNameStore(); + + // default: no CBT checks on any platform (appcompat reasons); applies also to PolicyEnforcement + // config element + _extendedProtectionPolicy = new ExtendedProtectionPolicy(PolicyEnforcement.Never); + _uriPrefixes = new Hashtable(); + _digestCache = new DigestCache(); + + _nextApp = nextApp; + + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "WindowsAuth", string.Empty); + } + } + + /// + /// Dynamically select the type of authentication to apply per request. + /// + public Func, AuthTypes> AuthenticationSchemeSelectorDelegate + { + get + { + return _authenticationDelegate; + } + set + { + _authenticationDelegate = value; + } + } + + /// + /// Dynamically select the type of extended protection to apply per request. + /// + public Func, ExtendedProtectionPolicy> ExtendedProtectionSelectorDelegate + { + get + { + return _extendedProtectionSelectorDelegate; + } + set + { + if (value == null) + { + throw new ArgumentNullException(); + } + + if (!ExtendedProtectionPolicy.OSSupportsExtendedProtection) + { + throw new PlatformNotSupportedException(SR.GetString(SR.security_ExtendedProtection_NoOSSupport)); + } + + _extendedProtectionSelectorDelegate = value; + } + } + + /// + /// Specifies which types of Windows authentication are enabled. + /// + public AuthTypes AuthenticationSchemes + { + get + { + return _authenticationScheme; + } + set + { + _authenticationScheme = value; + } + } + + /// + /// Configures extended protection. + /// + public ExtendedProtectionPolicy ExtendedProtectionPolicy + { + get + { + return _extendedProtectionPolicy; + } + set + { + if (value == null) + { + throw new ArgumentNullException("value"); + } + if (!ExtendedProtectionPolicy.OSSupportsExtendedProtection && value.PolicyEnforcement == PolicyEnforcement.Always) + { + throw new PlatformNotSupportedException(SR.GetString(SR.security_ExtendedProtection_NoOSSupport)); + } + if (value.CustomChannelBinding != null) + { + throw new ArgumentException(SR.GetString(SR.net_listener_cannot_set_custom_cbt), "CustomChannelBinding"); + } + + _extendedProtectionPolicy = value; + } + } + + /// + /// Configures the service names for extended protection. + /// + public ServiceNameCollection DefaultServiceNames + { + get + { + return _defaultServiceNames.ServiceNames; + } + } + + /// + /// The Realm for use in digest authentication. + /// + public string Realm + { + get + { + return _realm; + } + set + { + _realm = value; + } + } + + /// + /// Enables authenticated connection sharing with NTLM. + /// + public bool UnsafeConnectionNtlmAuthentication + { + get + { + return _unsafeConnectionNtlmAuthentication; + } + + set + { + if (_unsafeConnectionNtlmAuthentication == value) + { + return; + } + lock (DisconnectResults.SyncRoot) + { + if (_unsafeConnectionNtlmAuthentication == value) + { + return; + } + _unsafeConnectionNtlmAuthentication = value; + if (!value) + { + foreach (DisconnectAsyncResult result in DisconnectResults.Values) + { + result.AuthenticatedUser = null; + } + } + } + } + } + + internal Hashtable DisconnectResults + { + get + { + if (_disconnectResults == null) + { + Interlocked.CompareExchange(ref _disconnectResults, Hashtable.Synchronized(new Hashtable()), null); + } + return _disconnectResults; + } + } + + internal unsafe void AddPrefix(string uriPrefix) + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "AddPrefix", "uriPrefix:" + uriPrefix); + } + string registeredPrefix = null; + try + { + if (uriPrefix == null) + { + throw new ArgumentNullException("uriPrefix"); + } + (new WebPermission(NetworkAccess.Accept, uriPrefix)).Demand(); + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::AddPrefix() uriPrefix:" + uriPrefix); + int i; + if (string.Compare(uriPrefix, 0, "http://", 0, 7, StringComparison.OrdinalIgnoreCase) == 0) + { + i = 7; + } + else if (string.Compare(uriPrefix, 0, "https://", 0, 8, StringComparison.OrdinalIgnoreCase) == 0) + { + i = 8; + } + else + { + throw new ArgumentException(SR.GetString(SR.net_listener_scheme), "uriPrefix"); + } + bool inSquareBrakets = false; + int j = i; + while (j < uriPrefix.Length && uriPrefix[j] != '/' && (uriPrefix[j] != ':' || inSquareBrakets)) + { + if (uriPrefix[j] == '[') + { + if (inSquareBrakets) + { + j = i; + break; + } + inSquareBrakets = true; + } + if (inSquareBrakets && uriPrefix[j] == ']') + { + inSquareBrakets = false; + } + j++; + } + if (i == j) + { + throw new ArgumentException(SR.GetString(SR.net_listener_host), "uriPrefix"); + } + if (uriPrefix[uriPrefix.Length - 1] != '/') + { + throw new ArgumentException(SR.GetString(SR.net_listener_slash), "uriPrefix"); + } + registeredPrefix = uriPrefix[j] == ':' ? String.Copy(uriPrefix) : uriPrefix.Substring(0, j) + (i == 7 ? ":80" : ":443") + uriPrefix.Substring(j); + fixed (char* pChar = registeredPrefix) + { + i = 0; + while (pChar[i] != ':') + { + pChar[i] = (char)CaseInsensitiveAscii.AsciiToLower[(byte)pChar[i]]; + i++; + } + } + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::AddPrefix() mapped uriPrefix:" + uriPrefix + " to registeredPrefix:" + registeredPrefix); + + _uriPrefixes[uriPrefix] = registeredPrefix; + _defaultServiceNames.Add(uriPrefix); + } + catch (Exception exception) + { + if (Logging.On) + { + Logging.Exception(Logging.HttpListener, this, "AddPrefix", exception); + } + throw; + } + finally + { + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "AddPrefix", "prefix:" + registeredPrefix); + } + } + } + + internal PrefixCollection Prefixes + { + get + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "Prefixes_get", string.Empty); + } + if (_prefixes == null) + { + _prefixes = new PrefixCollection(this); + } + return _prefixes; + } + } + + internal bool RemovePrefix(string uriPrefix) + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "RemovePrefix", "uriPrefix:" + uriPrefix); + } + try + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::RemovePrefix() uriPrefix:" + uriPrefix); + if (uriPrefix == null) + { + throw new ArgumentNullException("uriPrefix"); + } + + if (!_uriPrefixes.Contains(uriPrefix)) + { + return false; + } + + _uriPrefixes.Remove(uriPrefix); + _defaultServiceNames.Remove(uriPrefix); + } + catch (Exception exception) + { + if (Logging.On) + { + Logging.Exception(Logging.HttpListener, this, "RemovePrefix", exception); + } + throw; + } + finally + { + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "RemovePrefix", "uriPrefix:" + uriPrefix); + } + } + return true; + } + + internal void RemoveAll(bool clear) + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "RemoveAll", string.Empty); + } + try + { + // go through the uri list and unregister for each one of them + if (_uriPrefixes.Count > 0) + { + if (clear) + { + _uriPrefixes.Clear(); + _defaultServiceNames.Clear(); + } + } + } + finally + { + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "RemoveAll", string.Empty); + } + } + } + + // old API, now private, and helper methods + private void Dispose(bool disposing) + { + GlobalLog.Assert(disposing, "Dispose(bool) does nothing if called from the finalizer."); + + if (!disposing) + { + return; + } + + try + { + _digestCache.Dispose(); + } + finally + { + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "Dispose", string.Empty); + } + } + } + + /// + /// + /// + /// + /// + public Task Invoke(IDictionary env) + { + // Process the auth header, if any + if (!TryHandleAuthentication(env)) + { + // If failed and a 400/401/500 was sent. + return Task.FromResult(null); + } + + // If passing through, register for OnSendingHeaders. Add an auth header challenge on 401. + var registerOnSendingHeaders = env.Get, object>>(Constants.ServerOnSendingHeadersKey); + if (registerOnSendingHeaders == null) + { + // This module requires OnSendingHeaders support. + throw new PlatformNotSupportedException(); + } + registerOnSendingHeaders(Set401Challenges, env); + + // Invoke the next item in the app chain + return _nextApp(env); + } + + // Returns true if auth completed successfully (or anonymous), false if there was an auth header + // but processing it failed. + private bool TryHandleAuthentication(IDictionary env) + { + DisconnectAsyncResult disconnectResult; + object connectionId = env.Get(Constants.ServerConnectionIdKey, -1); + string authorizationHeader = null; + if (!TryGetIncomingAuthHeader(env, out authorizationHeader)) + { + if (UnsafeConnectionNtlmAuthentication) + { + disconnectResult = (DisconnectAsyncResult)DisconnectResults[connectionId]; + if (disconnectResult != null) + { + WindowsPrincipal principal = disconnectResult.AuthenticatedUser; + if (principal != null) + { + // This connection has already been authenticated; + SetIdentity(env, principal, null); + } + } + } + + return true; // Anonymous or UnsafeConnectionNtlmAuthentication + } + + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() authorizationHeader:" + ValidationHelper.ToString(authorizationHeader)); + + if (UnsafeConnectionNtlmAuthentication) + { + disconnectResult = (DisconnectAsyncResult)DisconnectResults[connectionId]; + // They sent an authorization header - destroy their previous credentials. + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() clearing principal cache"); + if (disconnectResult != null) + { + disconnectResult.AuthenticatedUser = null; + } + } + + try + { + AuthTypes headerScheme; + string inBlob; + if (!TryGetRecognizedAuthScheme(authorizationHeader, out headerScheme, out inBlob)) + { + return true; // Anonymous / pass through + } + Contract.Assert(headerScheme != AuthTypes.None); + Contract.Assert(inBlob != null); + + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() Performing Authentication headerScheme:" + ValidationHelper.ToString(headerScheme)); + switch (headerScheme) + { + case AuthTypes.Digest: + return TryAuthenticateWithDigest(env, inBlob); + + case AuthTypes.Negotiate: + case AuthTypes.Ntlm: + string package = headerScheme == AuthTypes.Ntlm ? NegotiationInfoClass.NTLM : NegotiationInfoClass.Negotiate; + return TryAuthenticateWithNegotiate(env, package, inBlob); + + default: + throw new NotImplementedException(headerScheme.ToString()); + } + } + catch (Exception) + { + SendError(env, HttpStatusCode.InternalServerError, null); + return false; + } + } + + // TODO: Support proxy auth + private bool TryGetIncomingAuthHeader(IDictionary env, out string authorizationHeader) + { + IDictionary headers = env.Get>(Constants.RequestHeadersKey); + authorizationHeader = headers.Get("Authorization"); + return !string.IsNullOrWhiteSpace(authorizationHeader); + } + + private bool TryGetRecognizedAuthScheme(string authorizationHeader, out AuthTypes headerScheme, out string inBlob) + { + headerScheme = AuthTypes.None; + + int index; + // Find the end of the scheme name. Trust that HTTP.SYS parsed out just our header ok. + for (index = 0; index < authorizationHeader.Length; index++) + { + if (authorizationHeader[index] == ' ' || authorizationHeader[index] == '\t' || + authorizationHeader[index] == '\r' || authorizationHeader[index] == '\n') + { + break; + } + } + + // Currently only allow one Authorization scheme/header per request. + if (index < authorizationHeader.Length) + { + if ((AuthenticationSchemes & AuthTypes.Negotiate) != AuthTypes.None && + string.Compare(authorizationHeader, 0, NegotiationInfoClass.Negotiate, 0, index, StringComparison.OrdinalIgnoreCase) == 0) + { + headerScheme = AuthTypes.Negotiate; + } + else if ((AuthenticationSchemes & AuthTypes.Ntlm) != AuthTypes.None && + string.Compare(authorizationHeader, 0, NegotiationInfoClass.NTLM, 0, index, StringComparison.OrdinalIgnoreCase) == 0) + { + headerScheme = AuthTypes.Ntlm; + } + else if ((AuthenticationSchemes & AuthTypes.Digest) != AuthTypes.None && + string.Compare(authorizationHeader, 0, NegotiationInfoClass.Digest, 0, index, StringComparison.OrdinalIgnoreCase) == 0) + { + headerScheme = AuthTypes.Digest; + } + } + + // Find the beginning of the blob. Trust that HTTP.SYS parsed out just our header ok. + for (index++; index < authorizationHeader.Length; index++) + { + if (authorizationHeader[index] != ' ' && authorizationHeader[index] != '\t' && + authorizationHeader[index] != '\r' && authorizationHeader[index] != '\n') + { + break; + } + } + inBlob = index < authorizationHeader.Length ? authorizationHeader.Substring(index) : string.Empty; + + return headerScheme != AuthTypes.None; + } + + // Returns true if successfully authenticated via Digest. Returns false if a 401 was sent. + private bool TryAuthenticateWithDigest(IDictionary env, string inBlob) + { + NTAuthentication context = null; + IPrincipal principal = null; + SecurityStatus statusCodeNew; + ChannelBinding binding; + string outBlob; + HttpStatusCode httpError = HttpStatusCode.OK; + string verb = env.Get(Constants.RequestMethodKey); + bool isSecureConnection = IsSecureConnection(env); + // GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() package:WDigest headerScheme:" + headerScheme); + + // WDigest had some weird behavior. This is what I have discovered: + // Local accounts don't work, only domain accounts. The domain (i.e. REDMOND) is implied. Not sure how it is chosen. + // If the domain is specified and the credentials are correct, it works. If they're not (domain, username or password): + // AcceptSecurityContext (GetOutgoingDigestBlob) returns success but with a bogus 4k challenge, and + // QuerySecurityContextToken (GetContextToken) fails with NoImpersonation. + // If the domain isn't specified, AcceptSecurityContext returns NoAuthenticatingAuthority for a bad username, + // and LogonDenied for a bad password. + + // Also interesting is that WDigest requires us to keep a reference to the previous context, but fails if we + // actually pass it in! (It't ok to pass it in for the first request, but not if nc > 1.) For Whidbey, + // we create a new context and associate it with the connection, just like NTLM, but instead of using it for + // the next request on the connection, we always create a new context and swap the old one out. As long + // as we keep the old one around until after we authenticate with the new one, it works. For this reason, + // we also keep these contexts around past the lifetime of the connection, so that KeepAlive=false works. + binding = GetChannelBinding(env, isSecureConnection, ExtendedProtectionPolicy); + + context = new NTAuthentication(true, NegotiationInfoClass.WDigest, null, + GetContextFlags(ExtendedProtectionPolicy, isSecureConnection), binding); + + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() verb:" + verb + " context.IsValidContext:" + context.IsValidContext.ToString()); + + outBlob = context.GetOutgoingDigestBlob(inBlob, verb, null, Realm, false, false, out statusCodeNew); + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() GetOutgoingDigestBlob() returned IsCompleted:" + context.IsCompleted + " statusCodeNew:" + statusCodeNew + " outBlob:[" + outBlob + "]"); + + // WDigest bug: sometimes when AcceptSecurityContext returns success, it provides a bogus, empty 4k buffer. + // Ignore it. (Should find out what's going on here from WDigest people.) + if (statusCodeNew == SecurityStatus.OK) + { + outBlob = null; + } + + IList challenges = null; + if (outBlob != null) + { + string challenge = NegotiationInfoClass.Digest + " " + outBlob; + AddChallenge(ref challenges, challenge); + } + + if (context.IsValidContext) + { + SafeCloseHandle userContext = null; + try + { + if (!CheckSpn(context, isSecureConnection, ExtendedProtectionPolicy)) + { + httpError = HttpStatusCode.Unauthorized; + } + else + { + SetServiceName(env, context.ClientSpecifiedSpn); + + userContext = context.GetContextToken(out statusCodeNew); + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() GetContextToken() returned:" + statusCodeNew.ToString()); + if (statusCodeNew != SecurityStatus.OK) + { + httpError = HttpStatusFromSecurityStatus(statusCodeNew); + } + else if (userContext == null) + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() error: GetContextToken() returned:null statusCodeNew:" + statusCodeNew.ToString()); + httpError = HttpStatusCode.Unauthorized; + } + else + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() creating new WindowsIdentity() from userContext:" + userContext.DangerousGetHandle().ToString("x8")); + principal = new WindowsPrincipal(CreateWindowsIdentity(userContext.DangerousGetHandle(), "Digest"/*DigestClient.AuthType*/, WindowsAccountType.Normal, true)); + SetIdentity(env, principal, null); + _digestCache.SaveDigestContext(context); + } + } + } + finally + { + if (userContext != null) + { + userContext.Dispose(); + } + } + } + else + { + httpError = HttpStatusFromSecurityStatus(statusCodeNew); + } + + if (httpError != HttpStatusCode.OK) + { + SendError(env, httpError, challenges); + return false; + } + return true; + } + + // Negotiate or NTLM + private bool TryAuthenticateWithNegotiate(IDictionary env, string package, string inBlob) + { + object connectionId = env.Get(Constants.ServerConnectionIdKey, null); + if (connectionId == null) + { + // We need a connection ID from the server to correctly track in-progress auth. + throw new PlatformNotSupportedException(); + } + + NTAuthentication oldContext = null, context; + DisconnectAsyncResult disconnectResult = (DisconnectAsyncResult)DisconnectResults[connectionId]; + if (disconnectResult != null) + { + oldContext = disconnectResult.Session; + } + ChannelBinding binding; + bool isSecureConnection = IsSecureConnection(env); + byte[] bytes = null; + HttpStatusCode httpError = HttpStatusCode.OK; + bool error = false; + string outBlob = null; + + if (oldContext != null && oldContext.Package == package) + { + context = oldContext; + } + else + { + binding = GetChannelBinding(env, isSecureConnection, ExtendedProtectionPolicy); + + context = new NTAuthentication(true, package, null, + GetContextFlags(ExtendedProtectionPolicy, isSecureConnection), binding); + + // Clean up old context + if (oldContext != null) + { + oldContext.CloseContext(); + } + } + + try + { + bytes = Convert.FromBase64String(inBlob); + } + catch (FormatException) + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() FromBase64String threw a FormatException."); + httpError = HttpStatusCode.BadRequest; + error = true; + } + + byte[] decodedOutgoingBlob = null; + SecurityStatus statusCodeNew; + if (!error) + { + decodedOutgoingBlob = context.GetOutgoingBlob(bytes, false, out statusCodeNew); + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() GetOutgoingBlob() returned IsCompleted:" + context.IsCompleted + " statusCodeNew:" + statusCodeNew); + error = !context.IsValidContext; + if (error) + { + // Bug #474228: SSPI Workaround + // If a client sends up a blob on the initial request, Negotiate returns SEC_E_INVALID_HANDLE + // when it should return SEC_E_INVALID_TOKEN. + if (statusCodeNew == SecurityStatus.InvalidHandle && oldContext == null && bytes != null && bytes.Length > 0) + { + statusCodeNew = SecurityStatus.InvalidToken; + } + + httpError = HttpStatusFromSecurityStatus(statusCodeNew); + } + } + + if (decodedOutgoingBlob != null) + { + outBlob = Convert.ToBase64String(decodedOutgoingBlob); + } + + if (!error) + { + if (context.IsCompleted) + { + SafeCloseHandle userContext = null; + try + { + if (!CheckSpn(context, isSecureConnection, ExtendedProtectionPolicy)) + { + httpError = HttpStatusCode.Unauthorized; + } + else + { + SetServiceName(env, context.ClientSpecifiedSpn); + + userContext = context.GetContextToken(out statusCodeNew); + if (statusCodeNew != SecurityStatus.OK) + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() GetContextToken() failed with statusCodeNew:" + statusCodeNew.ToString()); + httpError = HttpStatusFromSecurityStatus(statusCodeNew); + } + else + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() creating new WindowsIdentity() from userContext:" + userContext.DangerousGetHandle().ToString("x8")); + WindowsPrincipal windowsPrincipal = new WindowsPrincipal(CreateWindowsIdentity(userContext.DangerousGetHandle(), context.ProtocolName, WindowsAccountType.Normal, true)); + SetIdentity(env, windowsPrincipal, outBlob); + + // if appropriate, cache this credential on this connection + if (UnsafeConnectionNtlmAuthentication + && context.ProtocolName.Equals(NegotiationInfoClass.NTLM, StringComparison.OrdinalIgnoreCase)) + { + // We may need to call WaitForDisconnect. + if (disconnectResult == null) + { + RegisterForDisconnectNotification(env, out disconnectResult); + } + + if (disconnectResult != null) + { + lock (DisconnectResults.SyncRoot) + { + if (UnsafeConnectionNtlmAuthentication) + { + disconnectResult.AuthenticatedUser = windowsPrincipal; + } + } + } + } + } + } + } + finally + { + if (userContext != null) + { + userContext.Dispose(); + } + } + return true; + } + else + { + // auth incomplete + if (disconnectResult == null) + { + RegisterForDisconnectNotification(env, out disconnectResult); + + // Failed - send 500. + if (disconnectResult == null) + { + context.CloseContext(); + SendError(env, HttpStatusCode.InternalServerError, null); + return false; + } + } + + disconnectResult.Session = context; + + string challenge = package; + if (!String.IsNullOrEmpty(outBlob)) + { + challenge += " " + outBlob; + } + IList challenges = null; + AddChallenge(ref challenges, challenge); + SendError(env, HttpStatusCode.Unauthorized, challenges); + return false; + } + } + + SendError(env, httpError, null); + return false; + } + + private void SetIdentity(IDictionary env, IPrincipal principal, string mutualAuth) + { + env[Constants.ServerUserKey] = principal; + if (!string.IsNullOrWhiteSpace(mutualAuth)) + { + var responseHeaders = env.Get>(Constants.ResponseHeadersKey); + responseHeaders.Append(HttpKnownHeaderNames.WWWAuthenticate, mutualAuth); + } + } + + // For user info only + private void SetServiceName(IDictionary env, string serviceName) + { + if (!string.IsNullOrWhiteSpace(serviceName)) + { + env[Constants.SslSpnKey] = serviceName; + } + } + + [SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.UnmanagedCode)] + [SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.ControlPrincipal)] + internal static WindowsIdentity CreateWindowsIdentity(IntPtr userToken, string type, WindowsAccountType acctType, bool isAuthenticated) + { + return new WindowsIdentity(userToken, type, acctType, isAuthenticated); + } + + // On a 401 response, set any appropriate challenges + private void Set401Challenges(object state) + { + var env = (IDictionary)state; + var responseHeaders = env.Get>(Constants.ResponseHeadersKey); + + // We use the cached results from the delegates so that we don't have to call them again here. + NTAuthentication newContext; + IList challenges = BuildChallenge(env, AuthenticationSchemes, out newContext, ExtendedProtectionPolicy); + + // null == Anonymous + if (challenges != null) + { + // Digest challenge, keep it alive for 10s - 5min. + if (newContext != null) + { + _digestCache.SaveDigestContext(newContext); + } + + responseHeaders.Append(HttpKnownHeaderNames.WWWAuthenticate, challenges); + } + } + + private static bool IsSecureConnection(IDictionary env) + { + return "https".Equals(env.Get(Constants.RequestSchemeKey, "http"), StringComparison.OrdinalIgnoreCase); + } + + private static bool ScenarioChecksChannelBinding(bool isSecureConnection, ProtectionScenario scenario) + { + return (isSecureConnection && scenario == ProtectionScenario.TransportSelected); + } + + private ChannelBinding GetChannelBinding(IDictionary env, bool isSecureConnection, ExtendedProtectionPolicy policy) + { + if (policy.PolicyEnforcement == PolicyEnforcement.Never) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_cbt_disabled)); + } + return null; + } + + if (!isSecureConnection) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_cbt_http)); + } + return null; + } + + if (!ExtendedProtectionPolicy.OSSupportsExtendedProtection) + { + GlobalLog.Assert(policy.PolicyEnforcement != PolicyEnforcement.Always, "User managed to set PolicyEnforcement.Always when the OS does not support extended protection!"); + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_cbt_platform)); + } + return null; + } + + if (policy.ProtectionScenario == ProtectionScenario.TrustedProxy) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_cbt_trustedproxy)); + } + return null; + } + + ChannelBinding result = env.Get(Constants.SslChannelBindingKey); + if (result == null) + { + // A channel binding object is required. + throw new InvalidOperationException(); + } + + GlobalLog.Assert(result != null, "GetChannelBindingFromTls returned null even though OS supposedly supports Extended Protection"); + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_cbt)); + } + return result; + } + + private bool CheckSpn(NTAuthentication context, bool isSecureConnection, ExtendedProtectionPolicy policy) + { + // Kerberos does SPN check already in ASC + if (context.IsKerberos) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_spn_kerberos)); + } + return true; + } + + // Don't check the SPN if Extended Protection is off or we already checked the CBT + if (policy.PolicyEnforcement == PolicyEnforcement.Never) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_spn_disabled)); + } + return true; + } + + if (ScenarioChecksChannelBinding(isSecureConnection, policy.ProtectionScenario)) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_spn_cbt)); + } + return true; + } + + if (!ExtendedProtectionPolicy.OSSupportsExtendedProtection) + { + GlobalLog.Assert(policy.PolicyEnforcement != PolicyEnforcement.Always, "User managed to set PolicyEnforcement.Always when the OS does not support extended protection!"); + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_spn_platform)); + } + return true; + } + + string clientSpn = context.ClientSpecifiedSpn; + + // An empty SPN is only allowed in the WhenSupported case + if (String.IsNullOrEmpty(clientSpn)) + { + if (policy.PolicyEnforcement == PolicyEnforcement.WhenSupported) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, + SR.GetString(SR.net_log_listener_no_spn_whensupported)); + } + return true; + } + else + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, + SR.GetString(SR.net_log_listener_spn_failed_always)); + } + return false; + } + } + else if (String.Compare(clientSpn, "http/localhost", StringComparison.OrdinalIgnoreCase) == 0) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_spn_loopback)); + } + + return true; + } + else + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_spn, clientSpn)); + } + + ServiceNameCollection serviceNames = GetServiceNames(policy); + + bool found = serviceNames.Contains(clientSpn); + + if (Logging.On) + { + if (found) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_spn_passed)); + } + else + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_spn_failed)); + + if (serviceNames.Count == 0) + { + Logging.PrintWarning(Logging.HttpListener, this, "CheckSpn", + SR.GetString(SR.net_log_listener_spn_failed_empty)); + } + else + { + Logging.PrintInfo(Logging.HttpListener, this, + SR.GetString(SR.net_log_listener_spn_failed_dump)); + + foreach (string serviceName in serviceNames) + { + Logging.PrintInfo(Logging.HttpListener, this, "\t" + serviceName); + } + } + } + } + + return found; + } + } + + private ServiceNameCollection GetServiceNames(ExtendedProtectionPolicy policy) + { + ServiceNameCollection serviceNames; + + if (policy.CustomServiceNames == null) + { + if (_defaultServiceNames.ServiceNames.Count == 0) + { + throw new InvalidOperationException(SR.GetString(SR.net_listener_no_spns)); + } + serviceNames = _defaultServiceNames.ServiceNames; + } + else + { + serviceNames = policy.CustomServiceNames; + } + return serviceNames; + } + + private ContextFlags GetContextFlags(ExtendedProtectionPolicy policy, bool isSecureConnection) + { + ContextFlags result = ContextFlags.Connection; + + if (policy.PolicyEnforcement != PolicyEnforcement.Never) + { + if (policy.PolicyEnforcement == PolicyEnforcement.WhenSupported) + { + result |= ContextFlags.AllowMissingBindings; + } + + if (policy.ProtectionScenario == ProtectionScenario.TrustedProxy) + { + result |= ContextFlags.ProxyBindings; + } + } + + return result; + } + + private static void AddChallenge(ref IList challenges, string challenge) + { + if (challenge != null) + { + challenge = challenge.Trim(); + if (challenge.Length > 0) + { + GlobalLog.Print("HttpListener:AddChallenge() challenge:" + challenge); + if (challenges == null) + { + challenges = new List(4); + } + challenges.Add(challenge); + } + } + } + + private IList BuildChallenge(IDictionary env, AuthTypes authenticationScheme, out NTAuthentication digestContext, + ExtendedProtectionPolicy policy) + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::BuildChallenge() authenticationScheme:" + authenticationScheme.ToString()); + IList challenges = null; + digestContext = null; + + if ((authenticationScheme & AuthTypes.Negotiate) != 0) + { + AddChallenge(ref challenges, NegotiationInfoClass.Negotiate); + } + + if ((authenticationScheme & AuthTypes.Ntlm) != 0) + { + AddChallenge(ref challenges, NegotiationInfoClass.NTLM); + } + + if ((authenticationScheme & AuthTypes.Digest) != 0) + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::BuildChallenge() package:WDigest"); + + NTAuthentication context = null; + try + { + bool isSecureConnection = IsSecureConnection(env); + string outBlob = null; + ChannelBinding binding = GetChannelBinding(env, isSecureConnection, policy); + + context = new NTAuthentication(true, NegotiationInfoClass.WDigest, null, + GetContextFlags(policy, isSecureConnection), binding); + + SecurityStatus statusCode; + outBlob = context.GetOutgoingDigestBlob(null, null, null, Realm, false, false, out statusCode); + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::BuildChallenge() GetOutgoingDigestBlob() returned IsCompleted:" + context.IsCompleted + " statusCode:" + statusCode + " outBlob:[" + outBlob + "]"); + + if (context.IsValidContext) + { + digestContext = context; + _digestCache.SaveDigestContext(digestContext); + } + + AddChallenge(ref challenges, NegotiationInfoClass.Digest + (string.IsNullOrEmpty(outBlob) ? string.Empty : " " + outBlob)); + } + catch (InvalidOperationException) + { + // No CBT available, therefore no digest challenge can be issued. + } + finally + { + if (context != null && digestContext != context) + { + context.CloseContext(); + } + } + } + + return challenges; + } + + private void RegisterForDisconnectNotification(IDictionary env, out DisconnectAsyncResult disconnectResult) + { + object connectionId = env[Constants.ServerConnectionIdKey]; + CancellationToken connectionDisconnect = env.Get(Constants.ServerConnectionDisconnectKey); + if (!connectionDisconnect.CanBeCanceled || connectionDisconnect.IsCancellationRequested) + { + disconnectResult = null; + return; + } + try + { + disconnectResult = new DisconnectAsyncResult(this, connectionId, connectionDisconnect); + } + catch (ObjectDisposedException) + { + // Just disconnected + disconnectResult = null; + return; + } + } + + private void SendError(IDictionary env, HttpStatusCode httpStatusCode, IList challenges) + { + // Send an OWIN HTTP response with the given error status code. + env[Constants.ResponseStatusCodeKey] = (int)httpStatusCode; + + if (challenges != null) + { + var responseHeaders = env.Get>(Constants.ResponseHeadersKey); + responseHeaders.Append(HttpKnownHeaderNames.WWWAuthenticate, challenges); + } + } + + // This only works for context-destroying errors. + private HttpStatusCode HttpStatusFromSecurityStatus(SecurityStatus status) + { + if (IsCredentialFailure(status)) + { + return HttpStatusCode.Unauthorized; + } + if (IsClientFault(status)) + { + return HttpStatusCode.BadRequest; + } + return HttpStatusCode.InternalServerError; + } + + // This only works for context-destroying errors. + private static bool IsCredentialFailure(SecurityStatus error) + { + return error == SecurityStatus.LogonDenied || + error == SecurityStatus.UnknownCredentials || + error == SecurityStatus.NoImpersonation || + error == SecurityStatus.NoAuthenticatingAuthority || + error == SecurityStatus.UntrustedRoot || + error == SecurityStatus.CertExpired || + error == SecurityStatus.SmartcardLogonRequired || + error == SecurityStatus.BadBinding; + } + + // This only works for context-destroying errors. + private static bool IsClientFault(SecurityStatus error) + { + return error == SecurityStatus.InvalidToken || + error == SecurityStatus.CannotPack || + error == SecurityStatus.QopNotSupported || + error == SecurityStatus.NoCredentials || + error == SecurityStatus.MessageAltered || + error == SecurityStatus.OutOfSequence || + error == SecurityStatus.IncompleteMessage || + error == SecurityStatus.IncompleteCredentials || + error == SecurityStatus.WrongPrincipal || + error == SecurityStatus.TimeSkew || + error == SecurityStatus.IllegalMessage || + error == SecurityStatus.CertUnknown || + error == SecurityStatus.AlgorithmMismatch || + error == SecurityStatus.SecurityQosFailed || + error == SecurityStatus.UnsupportedPreauth; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/packages.config b/src/Microsoft.AspNet.Security.Windows/packages.config new file mode 100644 index 0000000000..7432196421 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/packages.config @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/src/Microsoft.AspNet.Security.Windows/project.json b/src/Microsoft.AspNet.Security.Windows/project.json new file mode 100644 index 0000000000..3fe6a340df --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/project.json @@ -0,0 +1,14 @@ +{ + "version": "0.1-alpha-*", + "dependencies": { + }, + "compilationOptions" : { "allowUnsafe": true }, + "configurations": + { + "net45" : { + "dependencies": { + "Owin": "1.0" + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/AsyncAcceptContext.cs b/src/Microsoft.AspNet.Server.WebListener/AsyncAcceptContext.cs new file mode 100644 index 0000000000..a929f8a854 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/AsyncAcceptContext.cs @@ -0,0 +1,224 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal unsafe class AsyncAcceptContext : IAsyncResult, IDisposable + { + internal static readonly IOCompletionCallback IOCallback = new IOCompletionCallback(IOWaitCallback); + + private TaskCompletionSource _tcs; + private OwinWebListener _server; + private NativeRequestContext _nativeRequestContext; + + internal AsyncAcceptContext(OwinWebListener server) + { + _server = server; + _tcs = new TaskCompletionSource(); + _nativeRequestContext = new NativeRequestContext(this); + } + + internal Task Task + { + get + { + return _tcs.Task; + } + } + + private TaskCompletionSource Tcs + { + get + { + return _tcs; + } + } + + private OwinWebListener Server + { + get + { + return _server; + } + } + + [SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Redirecting to callback")] + [SuppressMessage("Microsoft.Reliability", "CA2000:Dispose objects before losing scope", Justification = "Disposed by callback")] + private static void IOCompleted(AsyncAcceptContext asyncResult, uint errorCode, uint numBytes) + { + bool complete = false; + try + { + if (errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_MORE_DATA) + { + asyncResult.Tcs.TrySetException(new WebListenerException((int)errorCode)); + complete = true; + } + else + { + OwinWebListener server = asyncResult.Server; + if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + // at this point we have received an unmanaged HTTP_REQUEST and memoryBlob + // points to it we need to hook up our authentication handling code here. + bool stoleBlob = false; + try + { + if (server.ValidateRequest(asyncResult._nativeRequestContext)) + { + stoleBlob = true; + RequestContext requestContext = new RequestContext(server, asyncResult._nativeRequestContext); + asyncResult.Tcs.TrySetResult(requestContext); + complete = true; + } + } + finally + { + if (stoleBlob) + { + // The request has been handed to the user, which means this code can't reuse the blob. Reset it here. + asyncResult._nativeRequestContext = complete ? null : new NativeRequestContext(asyncResult); + } + else + { + asyncResult._nativeRequestContext.Reset(0, 0); + } + } + } + else + { + asyncResult._nativeRequestContext.Reset(asyncResult._nativeRequestContext.RequestBlob->RequestId, numBytes); + } + + // We need to issue a new request, either because auth failed, or because our buffer was too small the first time. + if (!complete) + { + uint statusCode = asyncResult.QueueBeginGetContext(); + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + // someother bad error, possible(?) return values are: + // ERROR_INVALID_HANDLE, ERROR_INSUFFICIENT_BUFFER, ERROR_OPERATION_ABORTED + asyncResult.Tcs.TrySetException(new WebListenerException((int)statusCode)); + complete = true; + } + } + if (!complete) + { + return; + } + } + + if (complete) + { + asyncResult.Dispose(); + } + } + catch (Exception exception) + { + // Logged by caller + asyncResult.Tcs.TrySetException(exception); + asyncResult.Dispose(); + } + } + + private static unsafe void IOWaitCallback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped) + { + // take the ListenerAsyncResult object from the state + Overlapped callbackOverlapped = Overlapped.Unpack(nativeOverlapped); + AsyncAcceptContext asyncResult = (AsyncAcceptContext)callbackOverlapped.AsyncResult; + + IOCompleted(asyncResult, errorCode, numBytes); + } + + internal uint QueueBeginGetContext() + { + uint statusCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS; + bool retry; + do + { + retry = false; + uint bytesTransferred = 0; + statusCode = UnsafeNclNativeMethods.HttpApi.HttpReceiveHttpRequest( + Server.RequestQueueHandle, + _nativeRequestContext.RequestBlob->RequestId, + (uint)UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, + _nativeRequestContext.RequestBlob, + _nativeRequestContext.Size, + &bytesTransferred, + _nativeRequestContext.NativeOverlapped); + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_INVALID_PARAMETER && _nativeRequestContext.RequestBlob->RequestId != 0) + { + // we might get this if somebody stole our RequestId, + // set RequestId to 0 and start all over again with the buffer we just allocated + // BUGBUG: how can someone steal our request ID? seems really bad and in need of fix. + _nativeRequestContext.RequestBlob->RequestId = 0; + retry = true; + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_MORE_DATA) + { + // the buffer was not big enough to fit the headers, we need + // to read the RequestId returned, allocate a new buffer of the required size + _nativeRequestContext.Reset(_nativeRequestContext.RequestBlob->RequestId, bytesTransferred); + retry = true; + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS + && OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + IOCompleted(this, statusCode, bytesTransferred); + } + } + while (retry); + return statusCode; + } + + public object AsyncState + { + get { return _tcs.Task.AsyncState; } + } + + public WaitHandle AsyncWaitHandle + { + get { return ((IAsyncResult)_tcs.Task).AsyncWaitHandle; } + } + + public bool CompletedSynchronously + { + get { return ((IAsyncResult)_tcs.Task).CompletedSynchronously; } + } + + public bool IsCompleted + { + get { return _tcs.Task.IsCompleted; } + } + + public void Dispose() + { + Dispose(true); + } + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + if (_nativeRequestContext != null) + { + _nativeRequestContext.ReleasePins(); + _nativeRequestContext.Dispose(); + } + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/AuthenticationManager.cs b/src/Microsoft.AspNet.Server.WebListener/AuthenticationManager.cs new file mode 100644 index 0000000000..9fa737049a --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/AuthenticationManager.cs @@ -0,0 +1,129 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Server.WebListener +{ + // See the native HTTP_SERVER_AUTHENTICATION_INFO structure documentation for additional information. + // http://msdn.microsoft.com/en-us/library/windows/desktop/aa364638(v=vs.85).aspx + + /// + /// Exposes the Http.Sys authentication configurations. + /// + public sealed class AuthenticationManager + { +#if NET45 + private static readonly int AuthInfoSize = + Marshal.SizeOf(typeof(UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_AUTHENTICATION_INFO)); +#else + private static readonly int AuthInfoSize = + Marshal.SizeOf(); +#endif + + private OwinWebListener _server; + AuthenticationType _authTypes; + + internal AuthenticationManager(OwinWebListener context) + { + _server = context; + _authTypes = AuthenticationType.None; + } + + #region Properties + + public AuthenticationType AuthenticationTypes + { + get + { + return _authTypes; + } + set + { + _authTypes = value; + SetServerSecurity(); + } + } + + #endregion Properties + + private unsafe void SetServerSecurity() + { + UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_AUTHENTICATION_INFO authInfo = + new UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_AUTHENTICATION_INFO(); + + authInfo.Flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_PROPERTY_FLAG_PRESENT; + authInfo.AuthSchemes = (UnsafeNclNativeMethods.HttpApi.HTTP_AUTH_TYPES)_authTypes; + + // TODO: + // NTLM auth sharing (on by default?) DisableNTLMCredentialCaching + // Kerberos auth sharing (off by default?) HTTP_AUTH_EX_FLAG_ENABLE_KERBEROS_CREDENTIAL_CACHING + // Mutual Auth - ReceiveMutualAuth + // Digest domain and realm - HTTP_SERVER_AUTHENTICATION_DIGEST_PARAMS + // Basic realm - HTTP_SERVER_AUTHENTICATION_BASIC_PARAMS + + IntPtr infoptr = new IntPtr(&authInfo); + + _server.SetUrlGroupProperty( + UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_PROPERTY.HttpServerAuthenticationProperty, + infoptr, (uint)AuthInfoSize); + } + + internal void SetAuthenticationChallenge(Response response) + { + if (_authTypes == AuthenticationType.None) + { + return; + } + + IList challenges = new List(); + + // Order by strength. + if ((_authTypes & AuthenticationType.Kerberos) == AuthenticationType.Kerberos) + { + challenges.Add("Kerberos"); + } + if ((_authTypes & AuthenticationType.Negotiate) == AuthenticationType.Negotiate) + { + challenges.Add("Negotiate"); + } + if ((_authTypes & AuthenticationType.Ntlm) == AuthenticationType.Ntlm) + { + challenges.Add("NTLM"); + } + if ((_authTypes & AuthenticationType.Digest) == AuthenticationType.Digest) + { + // TODO: + throw new NotImplementedException("Digest challenge generation has not been implemented."); + // challenges.Add("Digest"); + } + if ((_authTypes & AuthenticationType.Basic) == AuthenticationType.Basic) + { + // TODO: Realm + challenges.Add("Basic"); + } + + // Append to the existing header, if any. Some clients (IE, Chrome) require each challenges to be sent on their own line/header. + string[] oldValues; + string[] newValues; + if (response.Headers.TryGetValue(HttpKnownHeaderNames.WWWAuthenticate, out oldValues)) + { + newValues = new string[oldValues.Length + challenges.Count]; + Array.Copy(oldValues, newValues, oldValues.Length); + challenges.CopyTo(newValues, oldValues.Length); + } + else + { + newValues = new string[challenges.Count]; + challenges.CopyTo(newValues, 0); + } + response.Headers[HttpKnownHeaderNames.WWWAuthenticate] = newValues; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/AuthenticationTypes.cs b/src/Microsoft.AspNet.Server.WebListener/AuthenticationTypes.cs new file mode 100644 index 0000000000..2e47212706 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/AuthenticationTypes.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + [Flags] + public enum AuthenticationType + { + None = 0x0, + Basic = 0x1, + Digest = 0x2, + Ntlm = 0x4, + Negotiate = 0x8, + Kerberos = 0x10, + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/Constants.cs b/src/Microsoft.AspNet.Server.WebListener/Constants.cs new file mode 100644 index 0000000000..53e3325471 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/Constants.cs @@ -0,0 +1,63 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class Constants + { + internal const string VersionKey = "owin.Version"; + internal const string OwinVersion = "1.0"; + internal const string CallCancelledKey = "owin.CallCancelled"; + + internal const string ServerCapabilitiesKey = "server.Capabilities"; + + internal const string RequestBodyKey = "owin.RequestBody"; + internal const string RequestHeadersKey = "owin.RequestHeaders"; + internal const string RequestSchemeKey = "owin.RequestScheme"; + internal const string RequestMethodKey = "owin.RequestMethod"; + internal const string RequestPathBaseKey = "owin.RequestPathBase"; + internal const string RequestPathKey = "owin.RequestPath"; + internal const string RequestQueryStringKey = "owin.RequestQueryString"; + internal const string HttpRequestProtocolKey = "owin.RequestProtocol"; + + internal const string HttpResponseProtocolKey = "owin.ResponseProtocol"; + internal const string ResponseStatusCodeKey = "owin.ResponseStatusCode"; + internal const string ResponseReasonPhraseKey = "owin.ResponseReasonPhrase"; + internal const string ResponseHeadersKey = "owin.ResponseHeaders"; + internal const string ResponseBodyKey = "owin.ResponseBody"; + + internal const string ClientCertifiateKey = "ssl.ClientCertificate"; + + internal const string RemoteIpAddressKey = "server.RemoteIpAddress"; + internal const string RemotePortKey = "server.RemotePort"; + internal const string LocalIpAddressKey = "server.LocalIpAddress"; + internal const string LocalPortKey = "server.LocalPort"; + internal const string IsLocalKey = "server.IsLocal"; + internal const string ServerOnSendingHeadersKey = "server.OnSendingHeaders"; + internal const string ServerLoggerFactoryKey = "server.LoggerFactory"; + + internal const string OpaqueVersionKey = "opaque.Version"; + internal const string OpaqueVersion = "1.0"; + internal const string OpaqueFuncKey = "opaque.Upgrade"; + internal const string OpaqueStreamKey = "opaque.Stream"; + internal const string OpaqueCallCancelledKey = "opaque.CallCancelled"; + + internal const string SendFileVersionKey = "sendfile.Version"; + internal const string SendFileVersion = "1.0"; + internal const string SendFileSupportKey = "sendfile.Support"; + internal const string SendFileConcurrencyKey = "sendfile.Concurrency"; + internal const string Overlapped = "Overlapped"; + + internal const string HttpScheme = "http"; + internal const string HttpsScheme = "https"; + internal const string SchemeDelimiter = "://"; + + internal static Version V1_0 = new Version(1, 0); + internal static Version V1_1 = new Version(1, 1); + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/CustomDictionary.xml b/src/Microsoft.AspNet.Server.WebListener/CustomDictionary.xml new file mode 100644 index 0000000000..78a76142f7 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/CustomDictionary.xml @@ -0,0 +1,10 @@ + + + + + Owin + + + + + diff --git a/src/Microsoft.AspNet.Server.WebListener/DictionaryExtensions.cs b/src/Microsoft.AspNet.Server.WebListener/DictionaryExtensions.cs new file mode 100644 index 0000000000..a1a6e3577a --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/DictionaryExtensions.cs @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Linq; +using System.Text; + +namespace System.Collections.Generic +{ + internal static class DictionaryExtensions + { + internal static void Append(this IDictionary dictionary, string key, string value) + { + string[] orriginalValues; + if (dictionary.TryGetValue(key, out orriginalValues)) + { + string[] newValues = new string[orriginalValues.Length + 1]; + orriginalValues.CopyTo(newValues, 0); + newValues[newValues.Length - 1] = value; + dictionary[key] = newValues; + } + else + { + dictionary[key] = new string[] { value }; + } + } + + internal static string Get(this IDictionary dictionary, string key) + { + string[] values; + if (dictionary.TryGetValue(key, out values)) + { + return string.Join(", ", values); + } + return null; + } + + internal static T Get(this IDictionary dictionary, string key, T fallback = default(T)) + { + object values; + if (dictionary.TryGetValue(key, out values)) + { + return (T)values; + } + return fallback; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/GlobalSuppressions.cs b/src/Microsoft.AspNet.Server.WebListener/GlobalSuppressions.cs new file mode 100644 index 0000000000..2b0ddbdbf4 Binary files /dev/null and b/src/Microsoft.AspNet.Server.WebListener/GlobalSuppressions.cs differ diff --git a/src/Microsoft.AspNet.Server.WebListener/Helpers.cs b/src/Microsoft.AspNet.Server.WebListener/Helpers.cs new file mode 100644 index 0000000000..ea0ffd4f04 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/Helpers.cs @@ -0,0 +1,29 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Runtime.CompilerServices; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class Helpers + { + internal static Task CompletedTask() + { + return Task.FromResult(null); + } + + internal static ConfiguredTaskAwaitable SupressContext(this Task task) + { + return task.ConfigureAwait(continueOnCapturedContext: false); + } + + internal static ConfiguredTaskAwaitable SupressContext(this Task task) + { + return task.ConfigureAwait(continueOnCapturedContext: false); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/LogHelper.cs b/src/Microsoft.AspNet.Server.WebListener/LogHelper.cs new file mode 100644 index 0000000000..a6513df571 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/LogHelper.cs @@ -0,0 +1,82 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics; +using System.Globalization; + +namespace Microsoft.AspNet.Server.WebListener +{ + using LoggerFactoryFunc = Func, bool>>; + using LoggerFunc = Func, bool>; + + internal static class LogHelper + { + private static readonly Func LogState = + (state, error) => Convert.ToString(state, CultureInfo.CurrentCulture); + + private static readonly Func LogStateAndError = + (state, error) => string.Format(CultureInfo.CurrentCulture, "{0}\r\n{1}", state, error); + + internal static LoggerFunc CreateLogger(LoggerFactoryFunc factory, Type type) + { + if (factory == null) + { + return null; + } + + return factory(type.FullName); + } + + internal static void LogInfo(LoggerFunc logger, string data) + { + if (logger == null) + { + Debug.WriteLine(data); + } + else + { + logger(TraceEventType.Information, 0, data, null, LogState); + } + } + + internal static void LogVerbose(LoggerFunc logger, string data) + { + if (logger == null) + { + Debug.WriteLine(data); + } + else + { + logger(TraceEventType.Verbose, 0, data, null, LogState); + } + } + + internal static void LogException(LoggerFunc logger, string location, Exception exception) + { + if (logger == null) + { + Debug.WriteLine(exception); + } + else + { + logger(TraceEventType.Error, 0, location, exception, LogStateAndError); + } + } + + internal static void LogError(LoggerFunc logger, string location, string message) + { + if (logger == null) + { + Debug.WriteLine(message); + } + else + { + logger(TraceEventType.Error, 0, location + ": " + message, null, LogState); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/AddressFamily.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/AddressFamily.cs new file mode 100644 index 0000000000..5f78516900 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/AddressFamily.cs @@ -0,0 +1,172 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener +{ + /// + /// + /// Specifies the address families that an instance of the + /// class can use. + /// + /// + internal enum AddressFamily + { + /// + /// [To be supplied.] + /// + Unknown = -1, // Unknown + + /// + /// [To be supplied.] + /// + Unspecified = 0, // unspecified + + /// + /// [To be supplied.] + /// + Unix = 1, // local to host (pipes, portals) + + /// + /// [To be supplied.] + /// + InterNetwork = 2, // internetwork: UDP, TCP, etc. + + /// + /// [To be supplied.] + /// + ImpLink = 3, // arpanet imp addresses + + /// + /// [To be supplied.] + /// + Pup = 4, // pup protocols: e.g. BSP + + /// + /// [To be supplied.] + /// + Chaos = 5, // mit CHAOS protocols + + /// + /// [To be supplied.] + /// + NS = 6, // XEROX NS protocols + + /// + /// [To be supplied.] + /// + Ipx = NS, // IPX and SPX + + /// + /// [To be supplied.] + /// + Iso = 7, // ISO protocols + + /// + /// [To be supplied.] + /// + Osi = Iso, // OSI is ISO + + /// + /// [To be supplied.] + /// + Ecma = 8, // european computer manufacturers + + /// + /// [To be supplied.] + /// + DataKit = 9, // datakit protocols + + /// + /// [To be supplied.] + /// + Ccitt = 10, // CCITT protocols, X.25 etc + + /// + /// [To be supplied.] + /// + Sna = 11, // IBM SNA + + /// + /// [To be supplied.] + /// + DecNet = 12, // DECnet + + /// + /// [To be supplied.] + /// + DataLink = 13, // Direct data link interface + + /// + /// [To be supplied.] + /// + Lat = 14, // LAT + + /// + /// [To be supplied.] + /// + HyperChannel = 15, // NSC Hyperchannel + + /// + /// [To be supplied.] + /// + AppleTalk = 16, // AppleTalk + + /// + /// [To be supplied.] + /// + NetBios = 17, // NetBios-style addresses + + /// + /// [To be supplied.] + /// + VoiceView = 18, // VoiceView + + /// + /// [To be supplied.] + /// + FireFox = 19, // FireFox + + /// + /// [To be supplied.] + /// + Banyan = 21, // Banyan + + /// + /// [To be supplied.] + /// + Atm = 22, // Native ATM Services + + /// + /// [To be supplied.] + /// + InterNetworkV6 = 23, // Internetwork Version 6 + + /// + /// [To be supplied.] + /// + Cluster = 24, // Microsoft Wolfpack + + /// + /// [To be supplied.] + /// + Ieee12844 = 25, // IEEE 1284.4 WG AF + + /// + /// [To be supplied.] + /// + Irda = 26, // IrDA + + /// + /// [To be supplied.] + /// + NetworkDesigners = 28, // Network Designers OSI & gateway enabled protocols + + /// + /// [To be supplied.] + /// + Max = 29, // Max + }; // enum AddressFamily +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/ComNetOS.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/ComNetOS.cs new file mode 100644 index 0000000000..80e2719065 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/ComNetOS.cs @@ -0,0 +1,26 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class ComNetOS + { + // Minimum support for Windows 7 is assumed. + internal static readonly bool IsWin8orLater; + + static ComNetOS() + { +#if NET45 + var win8Version = new Version(6, 2); + IsWin8orLater = (Environment.OSVersion.Version >= win8Version); +#else + IsWin8orLater = true; +#endif + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/ContextAttribute.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/ContextAttribute.cs new file mode 100644 index 0000000000..a3440bc7a3 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/ContextAttribute.cs @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum ContextAttribute + { + // look into and + Sizes = 0x00, + Names = 0x01, + Lifespan = 0x02, + DceInfo = 0x03, + StreamSizes = 0x04, + // KeyInfo = 0x05, must not be used, see ConnectionInfo instead + Authority = 0x06, + // SECPKG_ATTR_PROTO_INFO = 7, + // SECPKG_ATTR_PASSWORD_EXPIRY = 8, + // SECPKG_ATTR_SESSION_KEY = 9, + PackageInfo = 0x0A, + // SECPKG_ATTR_USER_FLAGS = 11, + NegotiationInfo = 0x0C, + // SECPKG_ATTR_NATIVE_NAMES = 13, + // SECPKG_ATTR_FLAGS = 14, + // SECPKG_ATTR_USE_VALIDATED = 15, + // SECPKG_ATTR_CREDENTIAL_NAME = 16, + // SECPKG_ATTR_TARGET_INFORMATION = 17, + // SECPKG_ATTR_ACCESS_TOKEN = 18, + // SECPKG_ATTR_TARGET = 19, + // SECPKG_ATTR_AUTHENTICATION_ID = 20, + UniqueBindings = 0x19, + EndpointBindings = 0x1A, + ClientSpecifiedSpn = 0x1B, // SECPKG_ATTR_CLIENT_SPECIFIED_TARGET = 27 + RemoteCertificate = 0x53, + LocalCertificate = 0x54, + RootStore = 0x55, + IssuerListInfoEx = 0x59, + ConnectionInfo = 0x5A, + // SECPKG_ATTR_EAP_KEY_BLOCK 0x5b // returns SecPkgContext_EapKeyBlock + // SECPKG_ATTR_MAPPED_CRED_ATTR 0x5c // returns SecPkgContext_MappedCredAttr + // SECPKG_ATTR_SESSION_INFO 0x5d // returns SecPkgContext_SessionInfo + // SECPKG_ATTR_APP_DATA 0x5e // sets/returns SecPkgContext_SessionAppData + // SECPKG_ATTR_REMOTE_CERTIFICATES 0x5F // returns SecPkgContext_Certificates + // SECPKG_ATTR_CLIENT_CERT_POLICY 0x60 // sets SecPkgCred_ClientCertCtlPolicy + // SECPKG_ATTR_CC_POLICY_RESULT 0x61 // returns SecPkgContext_ClientCertPolicyResult + // SECPKG_ATTR_USE_NCRYPT 0x62 // Sets the CRED_FLAG_USE_NCRYPT_PROVIDER FLAG on cred group + // SECPKG_ATTR_LOCAL_CERT_INFO 0x63 // returns SecPkgContext_CertInfo + // SECPKG_ATTR_CIPHER_INFO 0x64 // returns new CNG SecPkgContext_CipherInfo + // SECPKG_ATTR_EAP_PRF_INFO 0x65 // sets SecPkgContext_EapPrfInfo + // SECPKG_ATTR_SUPPORTED_SIGNATURES 0x66 // returns SecPkgContext_SupportedSignatures + // SECPKG_ATTR_REMOTE_CERT_CHAIN 0x67 // returns PCCERT_CONTEXT + UiInfo = 0x68, // sets SEcPkgContext_UiInfo + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpRequestQueueV2Handle.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpRequestQueueV2Handle.cs new file mode 100644 index 0000000000..ad9ae1d391 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpRequestQueueV2Handle.cs @@ -0,0 +1,25 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Server.WebListener +{ + // This class is a wrapper for Http.sys V2 request queue handle. + internal sealed class HttpRequestQueueV2Handle : SafeHandleZeroOrMinusOneIsInvalid + { + private HttpRequestQueueV2Handle() + : base(true) + { + } + + protected override bool ReleaseHandle() + { + return (UnsafeNclNativeMethods.SafeNetHandles.HttpCloseRequestQueue(handle) == + UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpServerSessionHandle.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpServerSessionHandle.cs new file mode 100644 index 0000000000..65c0506f5c --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpServerSessionHandle.cs @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Threading; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal sealed class HttpServerSessionHandle : CriticalHandleZeroOrMinusOneIsInvalid + { + private int disposed; + private ulong serverSessionId; + + internal HttpServerSessionHandle(ulong id) + : base() + { + serverSessionId = id; + + // This class uses no real handle so we need to set a dummy handle. Otherwise, IsInvalid always remains + // true. + + SetHandle(new IntPtr(1)); + } + + internal ulong DangerousGetServerSessionId() + { + return serverSessionId; + } + + protected override bool ReleaseHandle() + { + if (!IsInvalid) + { + if (Interlocked.Increment(ref disposed) == 1) + { + // Closing server session also closes all open url groups under that server session. + return (UnsafeNclNativeMethods.HttpApi.HttpCloseServerSession(serverSessionId) == + UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS); + } + } + return true; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysRequestHeader.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysRequestHeader.cs new file mode 100644 index 0000000000..1202ab2c80 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysRequestHeader.cs @@ -0,0 +1,54 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum HttpSysRequestHeader + { + CacheControl = 0, // general-header [section 4.5] + Connection = 1, // general-header [section 4.5] + Date = 2, // general-header [section 4.5] + KeepAlive = 3, // general-header [not in rfc] + Pragma = 4, // general-header [section 4.5] + Trailer = 5, // general-header [section 4.5] + TransferEncoding = 6, // general-header [section 4.5] + Upgrade = 7, // general-header [section 4.5] + Via = 8, // general-header [section 4.5] + Warning = 9, // general-header [section 4.5] + Allow = 10, // entity-header [section 7.1] + ContentLength = 11, // entity-header [section 7.1] + ContentType = 12, // entity-header [section 7.1] + ContentEncoding = 13, // entity-header [section 7.1] + ContentLanguage = 14, // entity-header [section 7.1] + ContentLocation = 15, // entity-header [section 7.1] + ContentMd5 = 16, // entity-header [section 7.1] + ContentRange = 17, // entity-header [section 7.1] + Expires = 18, // entity-header [section 7.1] + LastModified = 19, // entity-header [section 7.1] + + Accept = 20, // request-header [section 5.3] + AcceptCharset = 21, // request-header [section 5.3] + AcceptEncoding = 22, // request-header [section 5.3] + AcceptLanguage = 23, // request-header [section 5.3] + Authorization = 24, // request-header [section 5.3] + Cookie = 25, // request-header [not in rfc] + Expect = 26, // request-header [section 5.3] + From = 27, // request-header [section 5.3] + Host = 28, // request-header [section 5.3] + IfMatch = 29, // request-header [section 5.3] + IfModifiedSince = 30, // request-header [section 5.3] + IfNoneMatch = 31, // request-header [section 5.3] + IfRange = 32, // request-header [section 5.3] + IfUnmodifiedSince = 33, // request-header [section 5.3] + MaxForwards = 34, // request-header [section 5.3] + ProxyAuthorization = 35, // request-header [section 5.3] + Referer = 36, // request-header [section 5.3] + Range = 37, // request-header [section 5.3] + Te = 38, // request-header [section 5.3] + Translate = 39, // request-header [webDAV, not in rfc 2518] + UserAgent = 40, // request-header [section 5.3] + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysResponseHeader.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysResponseHeader.cs new file mode 100644 index 0000000000..0c7707e943 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysResponseHeader.cs @@ -0,0 +1,43 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum HttpSysResponseHeader + { + CacheControl = 0, // general-header [section 4.5] + Connection = 1, // general-header [section 4.5] + Date = 2, // general-header [section 4.5] + KeepAlive = 3, // general-header [not in rfc] + Pragma = 4, // general-header [section 4.5] + Trailer = 5, // general-header [section 4.5] + TransferEncoding = 6, // general-header [section 4.5] + Upgrade = 7, // general-header [section 4.5] + Via = 8, // general-header [section 4.5] + Warning = 9, // general-header [section 4.5] + Allow = 10, // entity-header [section 7.1] + ContentLength = 11, // entity-header [section 7.1] + ContentType = 12, // entity-header [section 7.1] + ContentEncoding = 13, // entity-header [section 7.1] + ContentLanguage = 14, // entity-header [section 7.1] + ContentLocation = 15, // entity-header [section 7.1] + ContentMd5 = 16, // entity-header [section 7.1] + ContentRange = 17, // entity-header [section 7.1] + Expires = 18, // entity-header [section 7.1] + LastModified = 19, // entity-header [section 7.1] + + AcceptRanges = 20, // response-header [section 6.2] + Age = 21, // response-header [section 6.2] + ETag = 22, // response-header [section 6.2] + Location = 23, // response-header [section 6.2] + ProxyAuthenticate = 24, // response-header [section 6.2] + RetryAfter = 25, // response-header [section 6.2] + Server = 26, // response-header [section 6.2] + SetCookie = 27, // response-header [not in rfc] + Vary = 28, // response-header [section 6.2] + WwwAuthenticate = 29, // response-header [section 6.2] + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysSettings.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysSettings.cs new file mode 100644 index 0000000000..98067e6fd6 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysSettings.cs @@ -0,0 +1,125 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.Security; +#if NET45 +using Microsoft.Win32; +#endif + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class HttpSysSettings + { +#if NET45 + private const string HttpSysParametersKey = @"System\CurrentControlSet\Services\HTTP\Parameters"; +#endif + private const bool EnableNonUtf8Default = true; + private const bool FavorUtf8Default = true; + private const string EnableNonUtf8Name = "EnableNonUtf8"; + private const string FavorUtf8Name = "FavorUtf8"; + + private static volatile bool enableNonUtf8 = EnableNonUtf8Default; + private static volatile bool favorUtf8 = FavorUtf8Default; + + static HttpSysSettings() + { + ReadHttpSysRegistrySettings(); + } + + internal static bool EnableNonUtf8 + { + get { return enableNonUtf8; } + } + + internal static bool FavorUtf8 + { + get { return favorUtf8; } + } + + private static void ReadHttpSysRegistrySettings() +#if !NET45 + { + } +#else + { + try + { + RegistryKey httpSysParameters = Registry.LocalMachine.OpenSubKey(HttpSysParametersKey); + + if (httpSysParameters == null) + { + LogWarning("ReadHttpSysRegistrySettings", "The Http.Sys registry key is null.", + HttpSysParametersKey); + } + else + { + using (httpSysParameters) + { + enableNonUtf8 = ReadRegistryValue(httpSysParameters, EnableNonUtf8Name, EnableNonUtf8Default); + favorUtf8 = ReadRegistryValue(httpSysParameters, FavorUtf8Name, FavorUtf8Default); + } + } + } + catch (SecurityException e) + { + LogRegistryException("ReadHttpSysRegistrySettings", e); + } + catch (ObjectDisposedException e) + { + LogRegistryException("ReadHttpSysRegistrySettings", e); + } + } + + private static bool ReadRegistryValue(RegistryKey key, string valueName, bool defaultValue) + { + Debug.Assert(key != null, "'key' must not be null"); + + try + { + if (key.GetValue(valueName) != null && key.GetValueKind(valueName) == RegistryValueKind.DWord) + { + // At this point we know the Registry value exists and it must be valid (any DWORD value + // can be converted to a bool). + return Convert.ToBoolean(key.GetValue(valueName), CultureInfo.InvariantCulture); + } + } + catch (UnauthorizedAccessException e) + { + LogRegistryException("ReadRegistryValue", e); + } + catch (IOException e) + { + LogRegistryException("ReadRegistryValue", e); + } + catch (SecurityException e) + { + LogRegistryException("ReadRegistryValue", e); + } + catch (ObjectDisposedException e) + { + LogRegistryException("ReadRegistryValue", e); + } + + return defaultValue; + } + + private static void LogRegistryException(string methodName, Exception e) + { + LogWarning(methodName, "Unable to access the Http.Sys registry value.", HttpSysParametersKey, e); + } + + private static void LogWarning(string methodName, string message, params object[] args) + { + // TODO: log + // Logging.PrintWarning(Logging.HttpListener, typeof(HttpSysSettings), methodName, SR.GetString(message, args)); + } +#endif + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/IntPtrHelper.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/IntPtrHelper.cs new file mode 100644 index 0000000000..87ee72ce34 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/IntPtrHelper.cs @@ -0,0 +1,23 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class IntPtrHelper + { + internal static IntPtr Add(IntPtr a, int b) + { + return (IntPtr)((long)a + (long)b); + } + + internal static long Subtract(IntPtr a, IntPtr b) + { + return ((long)a - (long)b); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/NclUtilities.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/NclUtilities.cs new file mode 100644 index 0000000000..4e74a95560 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/NclUtilities.cs @@ -0,0 +1,25 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class NclUtilities + { + internal static bool HasShutdownStarted + { + get + { + return Environment.HasShutdownStarted +#if NET45 + || AppDomain.CurrentDomain.IsFinalizingForUnload() +#endif + ; + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SSPIHandle.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SSPIHandle.cs new file mode 100644 index 0000000000..6089860776 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SSPIHandle.cs @@ -0,0 +1,34 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Server.WebListener +{ + [StructLayout(LayoutKind.Sequential, Pack = 1)] + internal struct SSPIHandle + { + private IntPtr handleHi; + private IntPtr handleLo; + + public bool IsZero + { + get { return handleHi == IntPtr.Zero && handleLo == IntPtr.Zero; } + } + + internal void SetToInvalid() + { + handleHi = IntPtr.Zero; + handleLo = IntPtr.Zero; + } + + public override string ToString() + { + return handleHi.ToString("x") + ":" + handleLo.ToString("x"); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLoadLibrary.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLoadLibrary.cs new file mode 100644 index 0000000000..998fbe014a --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLoadLibrary.cs @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal sealed class SafeLoadLibrary : SafeHandleZeroOrMinusOneIsInvalid + { + private const string KERNEL32 = "kernel32.dll"; + + public static readonly SafeLoadLibrary Zero = new SafeLoadLibrary(false); + + private SafeLoadLibrary() + : base(true) + { + } + + private SafeLoadLibrary(bool ownsHandle) + : base(ownsHandle) + { + } + + public static unsafe SafeLoadLibrary LoadLibraryEx(string library) + { + SafeLoadLibrary result = UnsafeNclNativeMethods.SafeNetHandles.LoadLibraryExW(library, null, 0); + if (result.IsInvalid) + { + result.SetHandleAsInvalid(); + } + return result; + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.FreeLibrary(handle); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFree.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFree.cs new file mode 100644 index 0000000000..ad35b83ff0 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFree.cs @@ -0,0 +1,46 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal sealed class SafeLocalFree : SafeHandleZeroOrMinusOneIsInvalid + { + private const int LMEM_FIXED = 0; + private const int NULL = 0; + + // This returned handle cannot be modified by the application. + public static SafeLocalFree Zero = new SafeLocalFree(false); + + private SafeLocalFree() + : base(true) + { + } + + private SafeLocalFree(bool ownsHandle) + : base(ownsHandle) + { + } + + public static SafeLocalFree LocalAlloc(int cb) + { + SafeLocalFree result = UnsafeNclNativeMethods.SafeNetHandles.LocalAlloc(LMEM_FIXED, (UIntPtr)cb); + if (result.IsInvalid) + { + result.SetHandleAsInvalid(); + throw new OutOfMemoryException(); + } + return result; + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.LocalFree(handle) == IntPtr.Zero; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFreeChannelBinding.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFreeChannelBinding.cs new file mode 100644 index 0000000000..66124161ca --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFreeChannelBinding.cs @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Security.Authentication.ExtendedProtection; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class SafeLocalFreeChannelBinding : ChannelBinding + { + private const int LMEM_FIXED = 0; + private int size; + + public override int Size + { + get { return size; } + } + + public static SafeLocalFreeChannelBinding LocalAlloc(int cb) + { + SafeLocalFreeChannelBinding result; + + result = UnsafeNclNativeMethods.SafeNetHandles.LocalAllocChannelBinding(LMEM_FIXED, (UIntPtr)cb); + if (result.IsInvalid) + { + result.SetHandleAsInvalid(); + throw new OutOfMemoryException(); + } + + result.size = cb; + return result; + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.LocalFree(handle) == IntPtr.Zero; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalMemHandle.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalMemHandle.cs new file mode 100644 index 0000000000..6bc76f512f --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalMemHandle.cs @@ -0,0 +1,30 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal sealed class SafeLocalMemHandle : SafeHandleZeroOrMinusOneIsInvalid + { + internal SafeLocalMemHandle() + : base(true) + { + } + + internal SafeLocalMemHandle(IntPtr existingHandle, bool ownsHandle) + : base(ownsHandle) + { + SetHandle(existingHandle); + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.LocalFree(handle) == IntPtr.Zero; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeNativeOverlapped.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeNativeOverlapped.cs new file mode 100644 index 0000000000..c36466ec84 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeNativeOverlapped.cs @@ -0,0 +1,68 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class SafeNativeOverlapped : SafeHandle + { + internal static readonly SafeNativeOverlapped Zero = new SafeNativeOverlapped(); + + internal SafeNativeOverlapped() + : this(IntPtr.Zero) + { + } + + internal unsafe SafeNativeOverlapped(NativeOverlapped* handle) + : this((IntPtr)handle) + { + } + + internal SafeNativeOverlapped(IntPtr handle) + : base(IntPtr.Zero, true) + { + SetHandle(handle); + } + + public override bool IsInvalid + { + get { return handle == IntPtr.Zero; } + } + + public void ReinitializeNativeOverlapped() + { + IntPtr handleSnapshot = handle; + + if (handleSnapshot != IntPtr.Zero) + { + unsafe + { + ((NativeOverlapped*)handleSnapshot)->InternalHigh = IntPtr.Zero; + ((NativeOverlapped*)handleSnapshot)->InternalLow = IntPtr.Zero; + ((NativeOverlapped*)handleSnapshot)->EventHandle = IntPtr.Zero; + } + } + } + + protected override bool ReleaseHandle() + { + IntPtr oldHandle = Interlocked.Exchange(ref handle, IntPtr.Zero); + // Do not call free durring AppDomain shutdown, there may be an outstanding operation. + // Overlapped will take care calling free when the native callback completes. + if (oldHandle != IntPtr.Zero && !NclUtilities.HasShutdownStarted) + { + unsafe + { + Overlapped.Free((NativeOverlapped*)oldHandle); + } + } + return true; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SchProtocols.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SchProtocols.cs new file mode 100644 index 0000000000..be1d240cd2 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SchProtocols.cs @@ -0,0 +1,51 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener +{ + using System; + using System.Globalization; + using System.Runtime.InteropServices; + + // From Schannel.h + [Flags] + internal enum SchProtocols + { + Zero = 0, + PctClient = 0x00000002, + PctServer = 0x00000001, + Pct = (PctClient | PctServer), + Ssl2Client = 0x00000008, + Ssl2Server = 0x00000004, + Ssl2 = (Ssl2Client | Ssl2Server), + Ssl3Client = 0x00000020, + Ssl3Server = 0x00000010, + Ssl3 = (Ssl3Client | Ssl3Server), + Tls10Client = 0x00000080, + Tls10Server = 0x00000040, + Tls10 = (Tls10Client | Tls10Server), + Tls11Client = 0x00000200, + Tls11Server = 0x00000100, + Tls11 = (Tls11Client | Tls11Server), + Tls12Client = 0x00000800, + Tls12Server = 0x00000400, + Tls12 = (Tls12Client | Tls12Server), + Ssl3Tls = (Ssl3 | Tls10), + UniClient = unchecked((int)0x80000000), + UniServer = 0x40000000, + Unified = (UniClient | UniServer), + ClientMask = (PctClient | Ssl2Client | Ssl3Client | Tls10Client | Tls11Client | Tls12Client | UniClient), + ServerMask = (PctServer | Ssl2Server | Ssl3Server | Tls10Server | Tls11Server | Tls12Server | UniServer) + } + + [StructLayout(LayoutKind.Sequential)] + internal struct Bindings + { + // see SecPkgContext_Bindings in + internal int BindingsLength; + internal IntPtr bindings; + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SecurityStatus.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SecurityStatus.cs new file mode 100644 index 0000000000..294fa1d863 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SecurityStatus.cs @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum SecurityStatus + { + // Success / Informational + OK = 0x00000000, + ContinueNeeded = unchecked((int)0x00090312), + CompleteNeeded = unchecked((int)0x00090313), + CompAndContinue = unchecked((int)0x00090314), + ContextExpired = unchecked((int)0x00090317), + CredentialsNeeded = unchecked((int)0x00090320), + Renegotiate = unchecked((int)0x00090321), + + // Errors + OutOfMemory = unchecked((int)0x80090300), + InvalidHandle = unchecked((int)0x80090301), + Unsupported = unchecked((int)0x80090302), + TargetUnknown = unchecked((int)0x80090303), + InternalError = unchecked((int)0x80090304), + PackageNotFound = unchecked((int)0x80090305), + NotOwner = unchecked((int)0x80090306), + CannotInstall = unchecked((int)0x80090307), + InvalidToken = unchecked((int)0x80090308), + CannotPack = unchecked((int)0x80090309), + QopNotSupported = unchecked((int)0x8009030A), + NoImpersonation = unchecked((int)0x8009030B), + LogonDenied = unchecked((int)0x8009030C), + UnknownCredentials = unchecked((int)0x8009030D), + NoCredentials = unchecked((int)0x8009030E), + MessageAltered = unchecked((int)0x8009030F), + OutOfSequence = unchecked((int)0x80090310), + NoAuthenticatingAuthority = unchecked((int)0x80090311), + IncompleteMessage = unchecked((int)0x80090318), + IncompleteCredentials = unchecked((int)0x80090320), + BufferNotEnough = unchecked((int)0x80090321), + WrongPrincipal = unchecked((int)0x80090322), + TimeSkew = unchecked((int)0x80090324), + UntrustedRoot = unchecked((int)0x80090325), + IllegalMessage = unchecked((int)0x80090326), + CertUnknown = unchecked((int)0x80090327), + CertExpired = unchecked((int)0x80090328), + AlgorithmMismatch = unchecked((int)0x80090331), + SecurityQosFailed = unchecked((int)0x80090332), + SmartcardLogonRequired = unchecked((int)0x8009033E), + UnsupportedPreauth = unchecked((int)0x80090343), + BadBinding = unchecked((int)0x80090346) + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SocketAddress.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SocketAddress.cs new file mode 100644 index 0000000000..9dfdb69ac3 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SocketAddress.cs @@ -0,0 +1,342 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.Diagnostics.Contracts; +using System.Globalization; +using System.Text; + +namespace Microsoft.AspNet.Server.WebListener +{ + // a little perf app measured these times when comparing the internal + // buffer implemented as a managed byte[] or unmanaged memory IntPtr + // that's why we use byte[] + // byte[] total ms:19656 + // IntPtr total ms:25671 + + /// + /// + /// This class is used when subclassing EndPoint, and provides indication + /// on how to format the memory buffers that winsock uses for network addresses. + /// + /// + internal class SocketAddress + { + private const int NumberOfIPv6Labels = 8; + // Lower case hex, no leading zeros + private const string IPv6NumberFormat = "{0:x}"; + private const string IPv6StringSeparator = ":"; + private const string IPv4StringFormat = "{0:d}.{1:d}.{2:d}.{3:d}"; + + internal const int IPv6AddressSize = 28; + internal const int IPv4AddressSize = 16; + + private const int WriteableOffset = 2; + + private int _size; + private byte[] _buffer; + private int _hash; + + /// + /// [To be supplied.] + /// + public SocketAddress(AddressFamily family, int size) + { + if (size < WriteableOffset) + { + // it doesn't make sense to create a socket address with less tha + // 2 bytes, that's where we store the address family. + + throw new ArgumentOutOfRangeException("size"); + } + _size = size; + _buffer = new byte[((size / IntPtr.Size) + 2) * IntPtr.Size]; // sizeof DWORD + +#if BIGENDIAN + m_Buffer[0] = unchecked((byte)((int)family>>8)); + m_Buffer[1] = unchecked((byte)((int)family )); +#else + _buffer[0] = unchecked((byte)((int)family)); + _buffer[1] = unchecked((byte)((int)family >> 8)); +#endif + } + + internal byte[] Buffer + { + get { return _buffer; } + } + + internal AddressFamily Family + { + get + { + int family; +#if BIGENDIAN + family = ((int)m_Buffer[0]<<8) | m_Buffer[1]; +#else + family = _buffer[0] | ((int)_buffer[1] << 8); +#endif + return (AddressFamily)family; + } + } + + internal int Size + { + get + { + return _size; + } + } + + // access to unmanaged serialized data. this doesn't + // allow access to the first 2 bytes of unmanaged memory + // that are supposed to contain the address family which + // is readonly. + // + // you can still use negative offsets as a back door in case + // winsock changes the way it uses SOCKADDR. maybe we want to prohibit it? + // maybe we should make the class sealed to avoid potentially dangerous calls + // into winsock with unproperly formatted data? + + /// + /// [To be supplied.] + /// + private byte this[int offset] + { + get + { + // access + if (offset < 0 || offset >= Size) + { + throw new ArgumentOutOfRangeException("offset"); + } + return _buffer[offset]; + } + } + + internal int GetPort() + { + return (int)((_buffer[2] << 8 & 0xFF00) | (_buffer[3])); + } + + public override bool Equals(object comparand) + { + SocketAddress castedComparand = comparand as SocketAddress; + if (castedComparand == null || this.Size != castedComparand.Size) + { + return false; + } + for (int i = 0; i < this.Size; i++) + { + if (this[i] != castedComparand[i]) + { + return false; + } + } + return true; + } + + public override int GetHashCode() + { + if (_hash == 0) + { + int i; + int size = Size & ~3; + + for (i = 0; i < size; i += 4) + { + _hash ^= (int)_buffer[i] + | ((int)_buffer[i + 1] << 8) + | ((int)_buffer[i + 2] << 16) + | ((int)_buffer[i + 3] << 24); + } + if ((Size & 3) != 0) + { + int remnant = 0; + int shift = 0; + + for (; i < Size; ++i) + { + remnant |= ((int)_buffer[i]) << shift; + shift += 8; + } + _hash ^= remnant; + } + } + return _hash; + } + + public override string ToString() + { + StringBuilder bytes = new StringBuilder(); + for (int i = WriteableOffset; i < this.Size; i++) + { + if (i > WriteableOffset) + { + bytes.Append(","); + } + bytes.Append(this[i].ToString(NumberFormatInfo.InvariantInfo)); + } + return Family.ToString() + ":" + Size.ToString(NumberFormatInfo.InvariantInfo) + ":{" + bytes.ToString() + "}"; + } + + internal string GetIPAddressString() + { + if (Family == AddressFamily.InterNetworkV6) + { + return GetIpv6AddressString(); + } + else if (Family == AddressFamily.InterNetwork) + { + return GetIPv4AddressString(); + } + else + { + return null; + } + } + + private string GetIPv4AddressString() + { + Contract.Assert(Size >= IPv4AddressSize); + + return string.Format(CultureInfo.InvariantCulture, IPv4StringFormat, + _buffer[4], _buffer[5], _buffer[6], _buffer[7]); + } + + // TODO: Does scope ID ever matter? + private unsafe string GetIpv6AddressString() + { + Contract.Assert(Size >= IPv6AddressSize); + + fixed (byte* rawBytes = _buffer) + { + // Convert from bytes to shorts. + ushort* rawShorts = stackalloc ushort[NumberOfIPv6Labels]; + int numbersOffset = 0; + // The address doesn't start at the beginning of the buffer. + for (int i = 8; i < ((NumberOfIPv6Labels * 2) + 8); i += 2) + { + rawShorts[numbersOffset++] = (ushort)(rawBytes[i] << 8 | rawBytes[i + 1]); + } + return GetIPv6AddressString(rawShorts); + } + } + + private static unsafe string GetIPv6AddressString(ushort* numbers) + { + // RFC 5952 Sections 4 & 5 - Compressed, lower case, with possible embedded IPv4 addresses. + + // Start to finish, inclusive. <-1, -1> for no compression + KeyValuePair range = FindCompressionRange(numbers); + bool ipv4Embedded = ShouldHaveIpv4Embedded(numbers); + + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < NumberOfIPv6Labels; i++) + { + if (ipv4Embedded && i == (NumberOfIPv6Labels - 2)) + { + // Write the remaining digits as an IPv4 address + builder.Append(IPv6StringSeparator); + builder.Append(string.Format(CultureInfo.InvariantCulture, IPv4StringFormat, + numbers[i] >> 8, numbers[i] & 0xFF, numbers[i + 1] >> 8, numbers[i + 1] & 0xFF)); + break; + } + + // Compression; 1::1, ::1, 1:: + if (range.Key == i) + { + // Start compression, add : + builder.Append(IPv6StringSeparator); + } + if (range.Key <= i && range.Value == (NumberOfIPv6Labels - 1)) + { + // Remainder compressed; 1:: + builder.Append(IPv6StringSeparator); + break; + } + if (range.Key <= i && i <= range.Value) + { + continue; // Compressed + } + + if (i != 0) + { + builder.Append(IPv6StringSeparator); + } + builder.Append(string.Format(CultureInfo.InvariantCulture, IPv6NumberFormat, numbers[i])); + } + + return builder.ToString(); + } + + // RFC 5952 Section 4.2.3 + // Longest consecutive sequence of zero segments, minimum 2. + // On equal, first sequence wins. + // <-1, -1> for no compression. + private static unsafe KeyValuePair FindCompressionRange(ushort* numbers) + { + int longestSequenceLength = 0; + int longestSequenceStart = -1; + + int currentSequenceLength = 0; + for (int i = 0; i < NumberOfIPv6Labels; i++) + { + if (numbers[i] == 0) + { + // In a sequence + currentSequenceLength++; + if (currentSequenceLength > longestSequenceLength) + { + longestSequenceLength = currentSequenceLength; + longestSequenceStart = i - currentSequenceLength + 1; + } + } + else + { + currentSequenceLength = 0; + } + } + + if (longestSequenceLength >= 2) + { + return new KeyValuePair(longestSequenceStart, + longestSequenceStart + longestSequenceLength - 1); + } + + return new KeyValuePair(-1, -1); // No compression + } + + // Returns true if the IPv6 address should be formated with an embedded IPv4 address: + // ::192.168.1.1 + private static unsafe bool ShouldHaveIpv4Embedded(ushort* numbers) + { + // 0:0 : 0:0 : x:x : x.x.x.x + if (numbers[0] == 0 && numbers[1] == 0 && numbers[2] == 0 && numbers[3] == 0 && numbers[6] != 0) + { + // RFC 5952 Section 5 - 0:0 : 0:0 : 0:[0 | FFFF] : x.x.x.x + if (numbers[4] == 0 && (numbers[5] == 0 || numbers[5] == 0xFFFF)) + { + return true; + + // SIIT - 0:0 : 0:0 : FFFF:0 : x.x.x.x + } + else if (numbers[4] == 0xFFFF && numbers[5] == 0) + { + return true; + } + } + // ISATAP + if (numbers[4] == 0 && numbers[5] == 0x5EFE) + { + return true; + } + + return false; + } + } // class SocketAddress +} // namespace System.Net diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/UnsafeNativeMethods.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/UnsafeNativeMethods.cs new file mode 100644 index 0000000000..0eb890cf8b --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/UnsafeNativeMethods.cs @@ -0,0 +1,1129 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class UnsafeNclNativeMethods + { + private const string KERNEL32 = "kernel32.dll"; + private const string SECUR32 = "secur32.dll"; + private const string HTTPAPI = "httpapi.dll"; + + // CONSIDER: Make this an enum, requires changing a lot of types from uint to ErrorCodes. + internal static class ErrorCodes + { + internal const uint ERROR_SUCCESS = 0; + internal const uint ERROR_HANDLE_EOF = 38; + internal const uint ERROR_NOT_SUPPORTED = 50; + internal const uint ERROR_INVALID_PARAMETER = 87; + internal const uint ERROR_ALREADY_EXISTS = 183; + internal const uint ERROR_MORE_DATA = 234; + internal const uint ERROR_OPERATION_ABORTED = 995; + internal const uint ERROR_IO_PENDING = 997; + internal const uint ERROR_NOT_FOUND = 1168; + internal const uint ERROR_CONNECTION_INVALID = 1229; + } + + [DllImport(KERNEL32, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint GetCurrentThreadId(); + + [DllImport(KERNEL32, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static unsafe extern uint CancelIoEx(SafeHandle handle, SafeNativeOverlapped overlapped); + + [DllImport(KERNEL32, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static unsafe extern bool SetFileCompletionNotificationModes(SafeHandle handle, FileCompletionNotificationModes modes); + + [Flags] + internal enum FileCompletionNotificationModes : byte + { + None = 0, + SkipCompletionPortOnSuccess = 1, + SkipSetEventOnHandle = 2 + } + + internal static class SafeNetHandles + { + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + internal static extern int FreeContextBuffer( + [In] IntPtr contextBuffer); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + internal static unsafe extern int QueryContextAttributesW( + ref SSPIHandle contextHandle, + [In] ContextAttribute attribute, + [In] void* buffer); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern unsafe uint HttpCreateRequestQueue(HttpApi.HTTPAPI_VERSION version, string pName, + Microsoft.AspNet.Server.WebListener.UnsafeNclNativeMethods.SECURITY_ATTRIBUTES pSecurityAttributes, uint flags, out HttpRequestQueueV2Handle pReqQueueHandle); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern unsafe uint HttpCloseRequestQueue(IntPtr pReqQueueHandle); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + internal static extern bool CloseHandle(IntPtr handle); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + internal static extern SafeLocalFree LocalAlloc(int uFlags, UIntPtr sizetdwBytes); + + [DllImport(KERNEL32, EntryPoint = "LocalAlloc", SetLastError = true)] + internal static extern SafeLocalFreeChannelBinding LocalAllocChannelBinding(int uFlags, UIntPtr sizetdwBytes); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + internal static extern IntPtr LocalFree(IntPtr handle); + + [DllImport(KERNEL32, ExactSpelling = true, CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern unsafe SafeLoadLibrary LoadLibraryExW([In] string lpwLibFileName, [In] void* hFile, [In] uint dwFlags); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + internal static extern unsafe bool FreeLibrary([In] IntPtr hModule); + } + + internal static unsafe class HttpApi + { + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpInitialize(HTTPAPI_VERSION version, uint flags, void* pReserved); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpReceiveRequestEntityBody(SafeHandle requestQueueHandle, ulong requestId, uint flags, IntPtr pEntityBuffer, uint entityBufferLength, out uint bytesReturned, SafeNativeOverlapped pOverlapped); + + [DllImport(HTTPAPI, EntryPoint = "HttpReceiveRequestEntityBody", ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpReceiveRequestEntityBody2(SafeHandle requestQueueHandle, ulong requestId, uint flags, void* pEntityBuffer, uint entityBufferLength, out uint bytesReturned, [In] SafeHandle pOverlapped); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpReceiveClientCertificate(SafeHandle requestQueueHandle, ulong connectionId, uint flags, HTTP_SSL_CLIENT_CERT_INFO* pSslClientCertInfo, uint sslClientCertInfoSize, uint* pBytesReceived, SafeNativeOverlapped pOverlapped); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpReceiveClientCertificate(SafeHandle requestQueueHandle, ulong connectionId, uint flags, byte* pSslClientCertInfo, uint sslClientCertInfoSize, uint* pBytesReceived, SafeNativeOverlapped pOverlapped); + + [SuppressMessage("Microsoft.Interoperability", "CA1415:DeclarePInvokesCorrectly", Justification = "NativeOverlapped is now wrapped by SafeNativeOverlapped")] + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpReceiveHttpRequest(SafeHandle requestQueueHandle, ulong requestId, uint flags, HTTP_REQUEST* pRequestBuffer, uint requestBufferLength, uint* pBytesReturned, SafeNativeOverlapped pOverlapped); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpSendHttpResponse(SafeHandle requestQueueHandle, ulong requestId, uint flags, HTTP_RESPONSE* pHttpResponse, void* pCachePolicy, uint* pBytesSent, SafeLocalFree pRequestBuffer, uint requestBufferLength, SafeNativeOverlapped pOverlapped, IntPtr pLogData); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpSendResponseEntityBody(SafeHandle requestQueueHandle, ulong requestId, uint flags, ushort entityChunkCount, HTTP_DATA_CHUNK* pEntityChunks, uint* pBytesSent, SafeLocalFree pRequestBuffer, uint requestBufferLength, SafeNativeOverlapped pOverlapped, IntPtr pLogData); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpCancelHttpRequest(SafeHandle requestQueueHandle, ulong requestId, IntPtr pOverlapped); + + [DllImport(HTTPAPI, EntryPoint = "HttpSendResponseEntityBody", ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpSendResponseEntityBody2(SafeHandle requestQueueHandle, ulong requestId, uint flags, ushort entityChunkCount, IntPtr pEntityChunks, out uint pBytesSent, SafeLocalFree pRequestBuffer, uint requestBufferLength, SafeHandle pOverlapped, IntPtr pLogData); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpWaitForDisconnect(SafeHandle requestQueueHandle, ulong connectionId, SafeNativeOverlapped pOverlapped); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpCreateServerSession(HTTPAPI_VERSION version, ulong* serverSessionId, uint reserved); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpCreateUrlGroup(ulong serverSessionId, ulong* urlGroupId, uint reserved); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern uint HttpAddUrlToUrlGroup(ulong urlGroupId, string pFullyQualifiedUrl, ulong context, uint pReserved); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpSetUrlGroupProperty(ulong urlGroupId, HTTP_SERVER_PROPERTY serverProperty, IntPtr pPropertyInfo, uint propertyInfoLength); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern uint HttpRemoveUrlFromUrlGroup(ulong urlGroupId, string pFullyQualifiedUrl, uint flags); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpCloseServerSession(ulong serverSessionId); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpCloseUrlGroup(ulong urlGroupId); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpSetRequestQueueProperty(SafeHandle requestQueueHandle, HTTP_SERVER_PROPERTY serverProperty, IntPtr pPropertyInfo, uint propertyInfoLength, uint reserved, IntPtr pReserved); + + internal enum HTTP_API_VERSION + { + Invalid, + Version10, + Version20, + } + + // see http.w for definitions + internal enum HTTP_SERVER_PROPERTY + { + HttpServerAuthenticationProperty, + HttpServerLoggingProperty, + HttpServerQosProperty, + HttpServerTimeoutsProperty, + HttpServerQueueLengthProperty, + HttpServerStateProperty, + HttpServer503VerbosityProperty, + HttpServerBindingProperty, + HttpServerExtendedAuthenticationProperty, + HttpServerListenEndpointProperty, + HttpServerChannelBindProperty, + HttpServerProtectionLevelProperty, + } + + // Currently only one request info type is supported but the enum is for future extensibility. + + internal enum HTTP_REQUEST_INFO_TYPE + { + HttpRequestInfoTypeAuth, + } + + internal enum HTTP_RESPONSE_INFO_TYPE + { + HttpResponseInfoTypeMultipleKnownHeaders, + HttpResponseInfoTypeAuthenticationProperty, + HttpResponseInfoTypeQosProperty, + } + + internal enum HTTP_TIMEOUT_TYPE + { + EntityBody, + DrainEntityBody, + RequestQueue, + IdleConnection, + HeaderWait, + MinSendRate, + } + + internal const int MaxTimeout = 6; + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_VERSION + { + internal ushort MajorVersion; + internal ushort MinorVersion; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_KNOWN_HEADER + { + internal ushort RawValueLength; + internal sbyte* pRawValue; + } + + [StructLayout(LayoutKind.Explicit)] + internal struct HTTP_DATA_CHUNK + { + [SuppressMessage("Microsoft.Performance", "CA1823:AvoidUnusedPrivateFields", Justification = "Used natively")] + [FieldOffset(0)] + internal HTTP_DATA_CHUNK_TYPE DataChunkType; + + [FieldOffset(8)] + internal FromMemory fromMemory; + + [FieldOffset(8)] + internal FromFileHandle fromFile; + } + + [SuppressMessage("Microsoft.Design", "CA1049:TypesThatOwnNativeResourcesShouldBeDisposable", + Justification = "This type does not own the native buffer")] + [StructLayout(LayoutKind.Sequential)] + internal struct FromMemory + { + // 4 bytes for 32bit, 8 bytes for 64bit + internal IntPtr pBuffer; + internal uint BufferLength; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct FromFileHandle + { + internal ulong offset; + internal ulong count; + internal IntPtr fileHandle; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTPAPI_VERSION + { + internal ushort HttpApiMajorVersion; + internal ushort HttpApiMinorVersion; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_COOKED_URL + { + internal ushort FullUrlLength; + internal ushort HostLength; + internal ushort AbsPathLength; + internal ushort QueryStringLength; + internal ushort* pFullUrl; + internal ushort* pHost; + internal ushort* pAbsPath; + internal ushort* pQueryString; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct SOCKADDR + { + internal ushort sa_family; + internal byte sa_data; + internal byte sa_data_02; + internal byte sa_data_03; + internal byte sa_data_04; + internal byte sa_data_05; + internal byte sa_data_06; + internal byte sa_data_07; + internal byte sa_data_08; + internal byte sa_data_09; + internal byte sa_data_10; + internal byte sa_data_11; + internal byte sa_data_12; + internal byte sa_data_13; + internal byte sa_data_14; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_TRANSPORT_ADDRESS + { + internal SOCKADDR* pRemoteAddress; + internal SOCKADDR* pLocalAddress; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_SSL_CLIENT_CERT_INFO + { + internal uint CertFlags; + internal uint CertEncodedSize; + internal byte* pCertEncoded; + internal void* Token; + internal byte CertDeniedByMapper; + } + + internal enum HTTP_SERVICE_BINDING_TYPE : uint + { + HttpServiceBindingTypeNone = 0, + HttpServiceBindingTypeW, + HttpServiceBindingTypeA + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_SERVICE_BINDING_BASE + { + internal HTTP_SERVICE_BINDING_TYPE Type; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_REQUEST_CHANNEL_BIND_STATUS + { + internal IntPtr ServiceName; + internal IntPtr ChannelToken; + internal uint ChannelTokenSize; + internal uint Flags; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_UNKNOWN_HEADER + { + internal ushort NameLength; + internal ushort RawValueLength; + internal sbyte* pName; + internal sbyte* pRawValue; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_SSL_INFO + { + internal ushort ServerCertKeySize; + internal ushort ConnectionKeySize; + internal uint ServerCertIssuerSize; + internal uint ServerCertSubjectSize; + internal sbyte* pServerCertIssuer; + internal sbyte* pServerCertSubject; + internal HTTP_SSL_CLIENT_CERT_INFO* pClientCertInfo; + internal uint SslClientCertNegotiated; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_RESPONSE_HEADERS + { + internal ushort UnknownHeaderCount; + internal HTTP_UNKNOWN_HEADER* pUnknownHeaders; + internal ushort TrailerCount; + internal HTTP_UNKNOWN_HEADER* pTrailers; + internal HTTP_KNOWN_HEADER KnownHeaders; + internal HTTP_KNOWN_HEADER KnownHeaders_02; + internal HTTP_KNOWN_HEADER KnownHeaders_03; + internal HTTP_KNOWN_HEADER KnownHeaders_04; + internal HTTP_KNOWN_HEADER KnownHeaders_05; + internal HTTP_KNOWN_HEADER KnownHeaders_06; + internal HTTP_KNOWN_HEADER KnownHeaders_07; + internal HTTP_KNOWN_HEADER KnownHeaders_08; + internal HTTP_KNOWN_HEADER KnownHeaders_09; + internal HTTP_KNOWN_HEADER KnownHeaders_10; + internal HTTP_KNOWN_HEADER KnownHeaders_11; + internal HTTP_KNOWN_HEADER KnownHeaders_12; + internal HTTP_KNOWN_HEADER KnownHeaders_13; + internal HTTP_KNOWN_HEADER KnownHeaders_14; + internal HTTP_KNOWN_HEADER KnownHeaders_15; + internal HTTP_KNOWN_HEADER KnownHeaders_16; + internal HTTP_KNOWN_HEADER KnownHeaders_17; + internal HTTP_KNOWN_HEADER KnownHeaders_18; + internal HTTP_KNOWN_HEADER KnownHeaders_19; + internal HTTP_KNOWN_HEADER KnownHeaders_20; + internal HTTP_KNOWN_HEADER KnownHeaders_21; + internal HTTP_KNOWN_HEADER KnownHeaders_22; + internal HTTP_KNOWN_HEADER KnownHeaders_23; + internal HTTP_KNOWN_HEADER KnownHeaders_24; + internal HTTP_KNOWN_HEADER KnownHeaders_25; + internal HTTP_KNOWN_HEADER KnownHeaders_26; + internal HTTP_KNOWN_HEADER KnownHeaders_27; + internal HTTP_KNOWN_HEADER KnownHeaders_28; + internal HTTP_KNOWN_HEADER KnownHeaders_29; + internal HTTP_KNOWN_HEADER KnownHeaders_30; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_REQUEST_HEADERS + { + internal ushort UnknownHeaderCount; + internal HTTP_UNKNOWN_HEADER* pUnknownHeaders; + internal ushort TrailerCount; + internal HTTP_UNKNOWN_HEADER* pTrailers; + internal HTTP_KNOWN_HEADER KnownHeaders; + internal HTTP_KNOWN_HEADER KnownHeaders_02; + internal HTTP_KNOWN_HEADER KnownHeaders_03; + internal HTTP_KNOWN_HEADER KnownHeaders_04; + internal HTTP_KNOWN_HEADER KnownHeaders_05; + internal HTTP_KNOWN_HEADER KnownHeaders_06; + internal HTTP_KNOWN_HEADER KnownHeaders_07; + internal HTTP_KNOWN_HEADER KnownHeaders_08; + internal HTTP_KNOWN_HEADER KnownHeaders_09; + internal HTTP_KNOWN_HEADER KnownHeaders_10; + internal HTTP_KNOWN_HEADER KnownHeaders_11; + internal HTTP_KNOWN_HEADER KnownHeaders_12; + internal HTTP_KNOWN_HEADER KnownHeaders_13; + internal HTTP_KNOWN_HEADER KnownHeaders_14; + internal HTTP_KNOWN_HEADER KnownHeaders_15; + internal HTTP_KNOWN_HEADER KnownHeaders_16; + internal HTTP_KNOWN_HEADER KnownHeaders_17; + internal HTTP_KNOWN_HEADER KnownHeaders_18; + internal HTTP_KNOWN_HEADER KnownHeaders_19; + internal HTTP_KNOWN_HEADER KnownHeaders_20; + internal HTTP_KNOWN_HEADER KnownHeaders_21; + internal HTTP_KNOWN_HEADER KnownHeaders_22; + internal HTTP_KNOWN_HEADER KnownHeaders_23; + internal HTTP_KNOWN_HEADER KnownHeaders_24; + internal HTTP_KNOWN_HEADER KnownHeaders_25; + internal HTTP_KNOWN_HEADER KnownHeaders_26; + internal HTTP_KNOWN_HEADER KnownHeaders_27; + internal HTTP_KNOWN_HEADER KnownHeaders_28; + internal HTTP_KNOWN_HEADER KnownHeaders_29; + internal HTTP_KNOWN_HEADER KnownHeaders_30; + internal HTTP_KNOWN_HEADER KnownHeaders_31; + internal HTTP_KNOWN_HEADER KnownHeaders_32; + internal HTTP_KNOWN_HEADER KnownHeaders_33; + internal HTTP_KNOWN_HEADER KnownHeaders_34; + internal HTTP_KNOWN_HEADER KnownHeaders_35; + internal HTTP_KNOWN_HEADER KnownHeaders_36; + internal HTTP_KNOWN_HEADER KnownHeaders_37; + internal HTTP_KNOWN_HEADER KnownHeaders_38; + internal HTTP_KNOWN_HEADER KnownHeaders_39; + internal HTTP_KNOWN_HEADER KnownHeaders_40; + internal HTTP_KNOWN_HEADER KnownHeaders_41; + } + + internal enum HTTP_VERB : int + { + HttpVerbUnparsed = 0, + HttpVerbUnknown = 1, + HttpVerbInvalid = 2, + HttpVerbOPTIONS = 3, + HttpVerbGET = 4, + HttpVerbHEAD = 5, + HttpVerbPOST = 6, + HttpVerbPUT = 7, + HttpVerbDELETE = 8, + HttpVerbTRACE = 9, + HttpVerbCONNECT = 10, + HttpVerbTRACK = 11, + HttpVerbMOVE = 12, + HttpVerbCOPY = 13, + HttpVerbPROPFIND = 14, + HttpVerbPROPPATCH = 15, + HttpVerbMKCOL = 16, + HttpVerbLOCK = 17, + HttpVerbUNLOCK = 18, + HttpVerbSEARCH = 19, + HttpVerbMaximum = 20, + } + + internal static readonly string[] HttpVerbs = new string[] + { + null, + "Unknown", + "Invalid", + "OPTIONS", + "GET", + "HEAD", + "POST", + "PUT", + "DELETE", + "TRACE", + "CONNECT", + "TRACK", + "MOVE", + "COPY", + "PROPFIND", + "PROPPATCH", + "MKCOL", + "LOCK", + "UNLOCK", + "SEARCH", + }; + + internal enum HTTP_DATA_CHUNK_TYPE : int + { + HttpDataChunkFromMemory = 0, + HttpDataChunkFromFileHandle = 1, + HttpDataChunkFromFragmentCache = 2, + HttpDataChunkMaximum = 3, + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_RESPONSE_INFO + { + internal HTTP_RESPONSE_INFO_TYPE Type; + internal uint Length; + internal void* pInfo; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_RESPONSE + { + internal uint Flags; + internal HTTP_VERSION Version; + internal ushort StatusCode; + internal ushort ReasonLength; + internal sbyte* pReason; + internal HTTP_RESPONSE_HEADERS Headers; + internal ushort EntityChunkCount; + internal HTTP_DATA_CHUNK* pEntityChunks; + internal ushort ResponseInfoCount; + internal HTTP_RESPONSE_INFO* pResponseInfo; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_REQUEST_AUTH_INFO + { + internal HTTP_AUTH_STATUS AuthStatus; + internal uint SecStatus; + internal uint Flags; + internal HTTP_REQUEST_AUTH_TYPE AuthType; + internal IntPtr AccessToken; + internal uint ContextAttributes; + internal uint PackedContextLength; + internal uint PackedContextType; + internal IntPtr PackedContext; + internal uint MutualAuthDataLength; + internal char* pMutualAuthData; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_REQUEST_INFO + { + internal HTTP_REQUEST_INFO_TYPE InfoType; + internal uint InfoLength; + internal HTTP_REQUEST_AUTH_INFO* pInfo; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_REQUEST + { + internal uint Flags; + internal ulong ConnectionId; + internal ulong RequestId; + internal ulong UrlContext; + internal HTTP_VERSION Version; + internal HTTP_VERB Verb; + internal ushort UnknownVerbLength; + internal ushort RawUrlLength; + internal sbyte* pUnknownVerb; + internal sbyte* pRawUrl; + internal HTTP_COOKED_URL CookedUrl; + internal HTTP_TRANSPORT_ADDRESS Address; + internal HTTP_REQUEST_HEADERS Headers; + internal ulong BytesReceived; + internal ushort EntityChunkCount; + internal HTTP_DATA_CHUNK* pEntityChunks; + internal ulong RawConnectionId; + internal HTTP_SSL_INFO* pSslInfo; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_REQUEST_V2 + { + internal HTTP_REQUEST Request; + internal ushort RequestInfoCount; + internal HTTP_REQUEST_INFO* pRequestInfo; + } + + internal enum HTTP_AUTH_STATUS + { + HttpAuthStatusSuccess, + HttpAuthStatusNotAuthenticated, + HttpAuthStatusFailure, + } + + internal enum HTTP_REQUEST_AUTH_TYPE + { + HttpRequestAuthTypeNone = 0, + HttpRequestAuthTypeBasic, + HttpRequestAuthTypeDigest, + HttpRequestAuthTypeNTLM, + HttpRequestAuthTypeNegotiate, + HttpRequestAuthTypeKerberos + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_SERVER_AUTHENTICATION_INFO + { + internal HTTP_FLAGS Flags; + internal HTTP_AUTH_TYPES AuthSchemes; + internal bool ReceiveMutualAuth; + internal bool ReceiveContextHandle; + internal bool DisableNTLMCredentialCaching; + internal ulong ExFlags; + HTTP_SERVER_AUTHENTICATION_DIGEST_PARAMS DigestParams; + HTTP_SERVER_AUTHENTICATION_BASIC_PARAMS BasicParams; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_SERVER_AUTHENTICATION_DIGEST_PARAMS + { + internal ushort DomainNameLength; + internal char* DomainName; + internal ushort RealmLength; + internal char* Realm; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_SERVER_AUTHENTICATION_BASIC_PARAMS + { + ushort RealmLength; + char* Realm; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_TIMEOUT_LIMIT_INFO + { + internal HTTP_FLAGS Flags; + internal ushort EntityBody; + internal ushort DrainEntityBody; + internal ushort RequestQueue; + internal ushort IdleConnection; + internal ushort HeaderWait; + internal uint MinSendRate; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_BINDING_INFO + { + internal HTTP_FLAGS Flags; + internal IntPtr RequestQueueHandle; + } + + // see http.w for definitions + [Flags] + internal enum HTTP_FLAGS : uint + { + NONE = 0x00000000, + HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY = 0x00000001, + HTTP_RECEIVE_SECURE_CHANNEL_TOKEN = 0x00000001, + HTTP_SEND_RESPONSE_FLAG_DISCONNECT = 0x00000001, + HTTP_SEND_RESPONSE_FLAG_MORE_DATA = 0x00000002, + HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA = 0x00000004, + HTTP_SEND_RESPONSE_FLAG_RAW_HEADER = 0x00000004, + HTTP_SEND_REQUEST_FLAG_MORE_DATA = 0x00000001, + HTTP_PROPERTY_FLAG_PRESENT = 0x00000001, + HTTP_INITIALIZE_SERVER = 0x00000001, + HTTP_INITIALIZE_CBT = 0x00000004, + HTTP_SEND_RESPONSE_FLAG_OPAQUE = 0x00000040, + } + + [Flags] + internal enum HTTP_AUTH_TYPES : uint + { + NONE = 0x00000000, + HTTP_AUTH_ENABLE_BASIC = 0x00000001, + HTTP_AUTH_ENABLE_DIGEST = 0x00000002, + HTTP_AUTH_ENABLE_NTLM = 0x00000004, + HTTP_AUTH_ENABLE_NEGOTIATE = 0x00000008, + HTTP_AUTH_ENABLE_KERBEROS = 0x00000010, + } + + private const int HttpHeaderRequestMaximum = (int)HttpSysRequestHeader.UserAgent + 1; + private const int HttpHeaderResponseMaximum = (int)HttpSysResponseHeader.WwwAuthenticate + 1; + + internal static class HTTP_REQUEST_HEADER_ID + { + internal static string ToString(int position) + { + return _strings[position]; + } + + private static string[] _strings = + { + "Cache-Control", + "Connection", + "Date", + "Keep-Alive", + "Pragma", + "Trailer", + "Transfer-Encoding", + "Upgrade", + "Via", + "Warning", + + "Allow", + "Content-Length", + "Content-Type", + "Content-Encoding", + "Content-Language", + "Content-Location", + "Content-MD5", + "Content-Range", + "Expires", + "Last-Modified", + + "Accept", + "Accept-Charset", + "Accept-Encoding", + "Accept-Language", + "Authorization", + "Cookie", + "Expect", + "From", + "Host", + "If-Match", + + "If-Modified-Since", + "If-None-Match", + "If-Range", + "If-Unmodified-Since", + "Max-Forwards", + "Proxy-Authorization", + "Referer", + "Range", + "Te", + "Translate", + "User-Agent", + }; + } + + internal static class HTTP_RESPONSE_HEADER_ID + { + private static string[] _strings = + { + "Cache-Control", + "Connection", + "Date", + "Keep-Alive", + "Pragma", + "Trailer", + "Transfer-Encoding", + "Upgrade", + "Via", + "Warning", + + "Allow", + "Content-Length", + "Content-Type", + "Content-Encoding", + "Content-Language", + "Content-Location", + "Content-MD5", + "Content-Range", + "Expires", + "Last-Modified", + + "Accept-Ranges", + "Age", + "ETag", + "Location", + "Proxy-Authenticate", + "Retry-After", + "Server", + "Set-Cookie", + "Vary", + "WWW-Authenticate", + }; + + private static Dictionary _lookupTable = CreateLookupTable(); + + private static Dictionary CreateLookupTable() + { + Dictionary lookupTable = new Dictionary((int)Enum.HttpHeaderResponseMaximum, StringComparer.OrdinalIgnoreCase); + for (int i = 0; i < (int)Enum.HttpHeaderResponseMaximum; i++) + { + lookupTable.Add(_strings[i], i); + } + return lookupTable; + } + + internal static int IndexOfKnownHeader(string HeaderName) + { + int index; + return _lookupTable.TryGetValue(HeaderName, out index) ? index : -1; + } + + internal static string ToString(int position) + { + return _strings[position]; + } + + internal enum Enum + { + HttpHeaderCacheControl = 0, // general-header [section 4.5] + HttpHeaderConnection = 1, // general-header [section 4.5] + HttpHeaderDate = 2, // general-header [section 4.5] + HttpHeaderKeepAlive = 3, // general-header [not in rfc] + HttpHeaderPragma = 4, // general-header [section 4.5] + HttpHeaderTrailer = 5, // general-header [section 4.5] + HttpHeaderTransferEncoding = 6, // general-header [section 4.5] + HttpHeaderUpgrade = 7, // general-header [section 4.5] + HttpHeaderVia = 8, // general-header [section 4.5] + HttpHeaderWarning = 9, // general-header [section 4.5] + + HttpHeaderAllow = 10, // entity-header [section 7.1] + HttpHeaderContentLength = 11, // entity-header [section 7.1] + HttpHeaderContentType = 12, // entity-header [section 7.1] + HttpHeaderContentEncoding = 13, // entity-header [section 7.1] + HttpHeaderContentLanguage = 14, // entity-header [section 7.1] + HttpHeaderContentLocation = 15, // entity-header [section 7.1] + HttpHeaderContentMd5 = 16, // entity-header [section 7.1] + HttpHeaderContentRange = 17, // entity-header [section 7.1] + HttpHeaderExpires = 18, // entity-header [section 7.1] + HttpHeaderLastModified = 19, // entity-header [section 7.1] + + // Response Headers + + HttpHeaderAcceptRanges = 20, // response-header [section 6.2] + HttpHeaderAge = 21, // response-header [section 6.2] + HttpHeaderEtag = 22, // response-header [section 6.2] + HttpHeaderLocation = 23, // response-header [section 6.2] + HttpHeaderProxyAuthenticate = 24, // response-header [section 6.2] + HttpHeaderRetryAfter = 25, // response-header [section 6.2] + HttpHeaderServer = 26, // response-header [section 6.2] + HttpHeaderSetCookie = 27, // response-header [not in rfc] + HttpHeaderVary = 28, // response-header [section 6.2] + HttpHeaderWwwAuthenticate = 29, // response-header [section 6.2] + + HttpHeaderResponseMaximum = 30, + + HttpHeaderMaximum = 41 + } + } + + private static HTTPAPI_VERSION version; + + // This property is used by HttpListener to pass the version structure to the native layer in API + // calls. + + internal static HTTPAPI_VERSION Version + { + get + { + return version; + } + } + + // This property is used by HttpListener to get the Api version in use so that it uses appropriate + // Http APIs. + + internal static HTTP_API_VERSION ApiVersion + { + get + { + if (version.HttpApiMajorVersion == 2 && version.HttpApiMinorVersion == 0) + { + return HTTP_API_VERSION.Version20; + } + else if (version.HttpApiMajorVersion == 1 && version.HttpApiMinorVersion == 0) + { + return HTTP_API_VERSION.Version10; + } + else + { + return HTTP_API_VERSION.Invalid; + } + } + } + + static HttpApi() + { + InitHttpApi(2, 0); + } + + private static void InitHttpApi(ushort majorVersion, ushort minorVersion) + { + version.HttpApiMajorVersion = majorVersion; + version.HttpApiMinorVersion = minorVersion; + + // For pre-Win7 OS versions, we need to check whether http.sys contains the CBT patch. + // We do so by passing HTTP_INITIALIZE_CBT flag to HttpInitialize. If the flag is not + // supported, http.sys is not patched. Note that http.sys will return invalid parameter + // also on Win7, even though it shipped with CBT support. Therefore we must not pass + // the flag on Win7 and later. + uint statusCode = ErrorCodes.ERROR_SUCCESS; + + // on Win7 and later, we don't pass the CBT flag. CBT is always supported. + statusCode = HttpApi.HttpInitialize(version, (uint)HTTP_FLAGS.HTTP_INITIALIZE_SERVER, null); + + supported = statusCode == ErrorCodes.ERROR_SUCCESS; + } + + private static volatile bool supported; + internal static bool Supported + { + get + { + return supported; + } + } + + // Server API + + internal static void GetUnknownHeaders(IDictionary unknownHeaders, byte[] memoryBlob, IntPtr originalAddress) + { + // Return value. + fixed (byte* pMemoryBlob = memoryBlob) + { + HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob; + long fixup = pMemoryBlob - (byte*)originalAddress; + int index; + + // unknown headers + if (request->Headers.UnknownHeaderCount != 0) + { + HTTP_UNKNOWN_HEADER* pUnknownHeader = (HTTP_UNKNOWN_HEADER*)(fixup + (byte*)request->Headers.pUnknownHeaders); + for (index = 0; index < request->Headers.UnknownHeaderCount; index++) + { + // For unknown headers, when header value is empty, RawValueLength will be 0 and + // pRawValue will be null. + if (pUnknownHeader->pName != null && pUnknownHeader->NameLength > 0) + { + string headerName = HeaderEncoding.GetString(pUnknownHeader->pName + fixup, pUnknownHeader->NameLength); + string headerValue; + if (pUnknownHeader->pRawValue != null && pUnknownHeader->RawValueLength > 0) + { + headerValue = HeaderEncoding.GetString(pUnknownHeader->pRawValue + fixup, pUnknownHeader->RawValueLength); + } + else + { + headerValue = string.Empty; + } + // Note that Http.Sys currently collapses all headers of the same name to a single string, so + // append will just set. + unknownHeaders.Append(headerName, headerValue); + } + pUnknownHeader++; + } + } + } + } + + private static string GetKnownHeader(HTTP_REQUEST* request, long fixup, int headerIndex) + { + string header = null; + + HTTP_KNOWN_HEADER* pKnownHeader = (&request->Headers.KnownHeaders) + headerIndex; + // For known headers, when header value is empty, RawValueLength will be 0 and + // pRawValue will point to empty string ("\0") + if (pKnownHeader->pRawValue != null) + { + header = HeaderEncoding.GetString(pKnownHeader->pRawValue + fixup, pKnownHeader->RawValueLength); + } + + return header; + } + + internal static string GetKnownHeader(byte[] memoryBlob, IntPtr originalAddress, int headerIndex) + { + fixed (byte* pMemoryBlob = memoryBlob) + { + return GetKnownHeader((HTTP_REQUEST*)pMemoryBlob, pMemoryBlob - (byte*)originalAddress, headerIndex); + } + } + + private static unsafe string GetVerb(HTTP_REQUEST* request, long fixup) + { + string verb = null; + + if ((int)request->Verb > (int)HTTP_VERB.HttpVerbUnknown && (int)request->Verb < (int)HTTP_VERB.HttpVerbMaximum) + { + verb = HttpVerbs[(int)request->Verb]; + } + else if (request->Verb == HTTP_VERB.HttpVerbUnknown && request->pUnknownVerb != null) + { + verb = HeaderEncoding.GetString(request->pUnknownVerb + fixup, request->UnknownVerbLength); + } + + return verb; + } + + internal static unsafe string GetVerb(byte[] memoryBlob, IntPtr originalAddress) + { + fixed (byte* pMemoryBlob = memoryBlob) + { + return GetVerb((HTTP_REQUEST*)pMemoryBlob, pMemoryBlob - (byte*)originalAddress); + } + } + + internal static HTTP_VERB GetKnownVerb(byte[] memoryBlob, IntPtr originalAddress) + { + // Return value. + HTTP_VERB verb = HTTP_VERB.HttpVerbUnknown; + fixed (byte* pMemoryBlob = memoryBlob) + { + HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob; + if ((int)request->Verb > (int)HTTP_VERB.HttpVerbUnparsed && (int)request->Verb < (int)HTTP_VERB.HttpVerbMaximum) + { + verb = request->Verb; + } + } + + return verb; + } + + internal static uint GetChunks(byte[] memoryBlob, IntPtr originalAddress, ref int dataChunkIndex, ref uint dataChunkOffset, byte[] buffer, int offset, int size) + { + // Return value. + uint dataRead = 0; + fixed (byte* pMemoryBlob = memoryBlob) + { + HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob; + long fixup = pMemoryBlob - (byte*)originalAddress; + + if (request->EntityChunkCount > 0 && dataChunkIndex < request->EntityChunkCount && dataChunkIndex != -1) + { + HTTP_DATA_CHUNK* pDataChunk = (HTTP_DATA_CHUNK*)(fixup + (byte*)&request->pEntityChunks[dataChunkIndex]); + + fixed (byte* pReadBuffer = buffer) + { + byte* pTo = &pReadBuffer[offset]; + + while (dataChunkIndex < request->EntityChunkCount && dataRead < size) + { + if (dataChunkOffset >= pDataChunk->fromMemory.BufferLength) + { + dataChunkOffset = 0; + dataChunkIndex++; + pDataChunk++; + } + else + { + byte* pFrom = (byte*)pDataChunk->fromMemory.pBuffer + dataChunkOffset + fixup; + + uint bytesToRead = pDataChunk->fromMemory.BufferLength - (uint)dataChunkOffset; + if (bytesToRead > (uint)size) + { + bytesToRead = (uint)size; + } + for (uint i = 0; i < bytesToRead; i++) + { + *(pTo++) = *(pFrom++); + } + dataRead += bytesToRead; + dataChunkOffset += bytesToRead; + } + } + } + } + // we're finished. + if (dataChunkIndex == request->EntityChunkCount) + { + dataChunkIndex = -1; + } + } + + return dataRead; + } + + internal static SocketAddress GetRemoteEndPoint(byte[] memoryBlob, IntPtr originalAddress) + { + fixed (byte* pMemoryBlob = memoryBlob) + { + HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob; + return GetEndPoint(memoryBlob, originalAddress, (byte*)request->Address.pRemoteAddress); + } + } + + internal static SocketAddress GetLocalEndPoint(byte[] memoryBlob, IntPtr originalAddress) + { + fixed (byte* pMemoryBlob = memoryBlob) + { + HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob; + return GetEndPoint(memoryBlob, originalAddress, (byte*)request->Address.pLocalAddress); + } + } + + internal static SocketAddress GetEndPoint(byte[] memoryBlob, IntPtr originalAddress, byte* source) + { + fixed (byte* pMemoryBlob = memoryBlob) + { + IntPtr address = source != null ? + (IntPtr)(pMemoryBlob - (byte*)originalAddress + source) : IntPtr.Zero; + return CopyOutAddress(address); + } + } + + private static SocketAddress CopyOutAddress(IntPtr address) + { + if (address != IntPtr.Zero) + { + ushort addressFamily = *((ushort*)address); + if (addressFamily == (ushort)AddressFamily.InterNetwork) + { + SocketAddress v4address = new SocketAddress(AddressFamily.InterNetwork, SocketAddress.IPv4AddressSize); + fixed (byte* pBuffer = v4address.Buffer) + { + for (int index = 2; index < SocketAddress.IPv4AddressSize; index++) + { + pBuffer[index] = ((byte*)address)[index]; + } + } + return v4address; + } + if (addressFamily == (ushort)AddressFamily.InterNetworkV6) + { + SocketAddress v6address = new SocketAddress(AddressFamily.InterNetworkV6, SocketAddress.IPv6AddressSize); + fixed (byte* pBuffer = v6address.Buffer) + { + for (int index = 2; index < SocketAddress.IPv6AddressSize; index++) + { + pBuffer[index] = ((byte*)address)[index]; + } + } + return v6address; + } + } + + return null; + } + } + + // DACL related stuff + + [SuppressMessage("Microsoft.Performance", "CA1812:AvoidUninstantiatedInternalClasses", Justification = "Instantiated natively")] + [SuppressMessage("Microsoft.Design", "CA1001:TypesThatOwnDisposableFieldsShouldBeDisposable", + Justification = "Does not own the resource.")] + [StructLayout(LayoutKind.Sequential)] + internal class SECURITY_ATTRIBUTES + { + public int nLength = 12; + public SafeLocalMemHandle lpSecurityDescriptor = new SafeLocalMemHandle(IntPtr.Zero, false); + public bool bInheritHandle = false; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/OwinServerFactory.cs b/src/Microsoft.AspNet.Server.WebListener/OwinServerFactory.cs new file mode 100644 index 0000000000..6b0b205db7 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/OwinServerFactory.cs @@ -0,0 +1,108 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- +// Copyright 2011-2012 Katana contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + using AppFunc = Func, Task>; + using LoggerFactoryFunc = Func, bool>>; + + /// + /// Implements the Katana setup pattern for this server. + /// + public static class OwinServerFactory + { + /// + /// Populates the server capabilities. + /// Also included is a configurable instance of the server. + /// + /// + [SuppressMessage("Microsoft.Reliability", "CA2000:Dispose objects before losing scope", Justification = "Disposed by caller")] + public static void Initialize(IDictionary properties) + { + if (properties == null) + { + throw new ArgumentNullException("properties"); + } + + properties[Constants.VersionKey] = Constants.OwinVersion; + + IDictionary capabilities = + properties.Get>(Constants.ServerCapabilitiesKey) + ?? new Dictionary(); + properties[Constants.ServerCapabilitiesKey] = capabilities; + + // SendFile + capabilities[Constants.SendFileVersionKey] = Constants.SendFileVersion; + IDictionary sendfileSupport = new Dictionary(); + sendfileSupport[Constants.SendFileConcurrencyKey] = Constants.Overlapped; + capabilities[Constants.SendFileSupportKey] = sendfileSupport; + + // Opaque + if (ComNetOS.IsWin8orLater) + { + capabilities[Constants.OpaqueVersionKey] = Constants.OpaqueVersion; + } + + // Directly expose the server for advanced configuration. + properties[typeof(OwinWebListener).FullName] = new OwinWebListener(); + } + + /// + /// Creates a server and starts listening on the given addresses. + /// + /// The application entry point. + /// The configuration. + /// The server. Invoke Dispose to shut down. + [SuppressMessage("Microsoft.Design", "CA1006:DoNotNestGenericTypesInMemberSignatures", Justification = "By design")] + [SuppressMessage("Microsoft.Reliability", "CA2000:Dispose objects before losing scope", Justification = "Disposed by caller")] + public static IDisposable Create(AppFunc app, IDictionary properties) + { + if (app == null) + { + throw new ArgumentNullException("app"); + } + if (properties == null) + { + throw new ArgumentNullException("properties"); + } + + var addresses = properties.Get>>("host.Addresses") + ?? new List>(); + + OwinWebListener server = properties.Get(typeof(OwinWebListener).FullName) + ?? new OwinWebListener(); + + var capabilities = + properties.Get>(Constants.ServerCapabilitiesKey) + ?? new Dictionary(); + + var loggerFactory = properties.Get(Constants.ServerLoggerFactoryKey); + + server.Start(app, addresses, capabilities, loggerFactory); + return server; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/OwinWebListener.cs b/src/Microsoft.AspNet.Server.WebListener/OwinWebListener.cs new file mode 100644 index 0000000000..d6c1cd916a --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/OwinWebListener.cs @@ -0,0 +1,1107 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.Runtime.InteropServices; +using System.Security.Authentication.ExtendedProtection; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + using AppFunc = Func, Task>; + using LoggerFactoryFunc = Func, bool>>; + using LoggerFunc = Func, bool>; + + /// + /// An HTTP server wrapping the Http.Sys APIs that accepts requests and passes them on to the given OWIN application. + /// + public sealed class OwinWebListener : IDisposable + { + private const long DefaultRequestQueueLength = 1000; // Http.sys default. +#if NET45 + private static readonly Type ChannelBindingStatusType = typeof(UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_CHANNEL_BIND_STATUS); + private static readonly int RequestChannelBindStatusSize = + Marshal.SizeOf(typeof(UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_CHANNEL_BIND_STATUS)); + private static readonly int BindingInfoSize = + Marshal.SizeOf(typeof(UnsafeNclNativeMethods.HttpApi.HTTP_BINDING_INFO)); +#else + private static readonly int RequestChannelBindStatusSize = + Marshal.SizeOf(); + private static readonly int BindingInfoSize = + Marshal.SizeOf(); +#endif + private static readonly int DefaultMaxAccepts = 5 * Environment.ProcessorCount; + private static readonly int DefaultMaxRequests = Int32.MaxValue; + + // Win8# 559317 fixed a bug in Http.sys's HttpReceiveClientCertificate method. + // Without this fix IOCP callbacks were not being called although ERROR_IO_PENDING was + // returned from HttpReceiveClientCertificate when using the + // FileCompletionNotificationModes.SkipCompletionPortOnSuccess flag. + // This bug was only hit when the buffer passed into HttpReceiveClientCertificate + // (1500 bytes initially) is tool small for the certificate. + // Due to this bug in downlevel operating systems the FileCompletionNotificationModes.SkipCompletionPortOnSuccess + // flag is only used on Win8 and later. + internal static readonly bool SkipIOCPCallbackOnSuccess = ComNetOS.IsWin8orLater; + + // Mitigate potential DOS attacks by limiting the number of unknown headers we accept. Numerous header names + // with hash collisions will cause the server to consume excess CPU. 1000 headers limits CPU time to under + // 0.5 seconds per request. Respond with a 400 Bad Request. + private const int UnknownHeaderLimit = 1000; + + private readonly ConcurrentDictionary _connectionCancellationTokens; + + private AppFunc _appFunc; + private IDictionary _capabilities; + private LoggerFunc _logger; + + private SafeHandle _requestQueueHandle; + private volatile State _state; // m_State is set only within lock blocks, but often read outside locks. + private bool _ignoreWriteExceptions; + private HttpServerSessionHandle _serverSessionHandle; + private ulong _urlGroupId; + private TimeoutManager _timeoutManager; + private AuthenticationManager _authManager; + private bool _v2Initialized; + + private object _internalLock; + + private List _uriPrefixes = new List(); + + private PumpLimits _pumpLimits; + private int _currentOutstandingAccepts; + private int _currentOutstandingRequests; + private Action _offloadListenForNextRequest; + + // The native request queue + private long? _requestQueueLength; + + internal OwinWebListener() + { + if (!UnsafeNclNativeMethods.HttpApi.Supported) + { + throw new PlatformNotSupportedException(); + } + + Debug.Assert(UnsafeNclNativeMethods.HttpApi.ApiVersion == + UnsafeNclNativeMethods.HttpApi.HTTP_API_VERSION.Version20, "Invalid Http api version"); + + _state = State.Stopped; + _internalLock = new object(); + + _timeoutManager = new TimeoutManager(this); + _authManager = new AuthenticationManager(this); + _connectionCancellationTokens = new ConcurrentDictionary(); + + _offloadListenForNextRequest = new Action(ListenForNextRequestAsync); + + _pumpLimits = new PumpLimits(DefaultMaxAccepts, DefaultMaxRequests); + } + + internal enum State + { + Stopped, + Started, + Disposed, + } + + internal LoggerFunc Logger + { + get { return _logger; } + } + + internal List UriPrefixes + { + get { return _uriPrefixes; } + } + + internal IDictionary Capabilities + { + get { return _capabilities; } + } + + internal SafeHandle RequestQueueHandle + { + get + { + return _requestQueueHandle; + } + } + + /// + /// Exposes the Http.Sys timeout configurations. These may also be configured in the registry. + /// + public TimeoutManager TimeoutManager + { + get + { + ValidateV2Property(); + Debug.Assert(_timeoutManager != null, "Timeout manager is not assigned"); + return _timeoutManager; + } + } + + public AuthenticationManager AuthenticationManager + { + get + { + ValidateV2Property(); + Debug.Assert(_authManager != null, "Auth manager is not assigned"); + return _authManager; + } + } + + internal static bool IsSupported + { + get + { + return UnsafeNclNativeMethods.HttpApi.Supported; + } + } + + internal bool IsListening + { + get + { + return _state == State.Started; + } + } + + internal bool IgnoreWriteExceptions + { + get + { + return _ignoreWriteExceptions; + } + set + { + CheckDisposed(); + _ignoreWriteExceptions = value; + } + } + + private bool CanAcceptMoreRequests + { + get + { + PumpLimits limits = _pumpLimits; + return (_currentOutstandingAccepts < limits.MaxOutstandingAccepts + && _currentOutstandingRequests < limits.MaxOutstandingRequests - _currentOutstandingAccepts); + } + } + + /// + /// These are merged as one operation because they should be swapped out atomically. + /// This controls how many requests the server attempts to process concurrently. + /// + /// The maximum number of pending accepts. + /// The maximum number of outstanding requests. + public void SetRequestProcessingLimits(int maxAccepts, int maxRequests) + { + _pumpLimits = new PumpLimits(maxAccepts, maxRequests); + + // Kick the pump in case we went from zero to non-zero limits. + OffloadListenForNextRequestAsync(); + } + + /// + /// Gets the request processing limits. + /// + /// The maximum number of pending accepts. + /// The maximum number of outstanding requests. + [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "0#", Justification = "By design")] + [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "1#", Justification = "By design")] + public void GetRequestProcessingLimits(out int maxAccepts, out int maxRequests) + { + PumpLimits limits = _pumpLimits; + maxAccepts = limits.MaxOutstandingAccepts; + maxRequests = limits.MaxOutstandingRequests; + } + + /// + /// Sets the maximum number of requests that will be queued up in Http.Sys. + /// + /// + public void SetRequestQueueLimit(long limit) + { + if (limit <= 0) + { + throw new ArgumentOutOfRangeException("limit", limit, string.Empty); + } + if ((!_requestQueueLength.HasValue && limit == DefaultRequestQueueLength) + || (_requestQueueLength.HasValue && limit == _requestQueueLength.Value)) + { + return; + } + + _requestQueueLength = limit; + + SetRequestQueueLimit(); + } + + private unsafe void SetRequestQueueLimit() + { + // The listener must be active for this to work. Call from Start after activating. + if (!IsListening || !_requestQueueLength.HasValue) + { + return; + } + + long length = _requestQueueLength.Value; + uint result = UnsafeNclNativeMethods.HttpApi.HttpSetRequestQueueProperty(_requestQueueHandle, + UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_PROPERTY.HttpServerQueueLengthProperty, + new IntPtr((void*)&length), (uint)Marshal.SizeOf(length), 0, IntPtr.Zero); + + if (result != 0) + { + throw new WebListenerException((int)result); + } + } + + private void ValidateV2Property() + { + // Make sure that calling CheckDisposed and SetupV2Config is an atomic operation. This + // avoids race conditions if the listener is aborted/closed after CheckDisposed(), but + // before SetupV2Config(). + lock (_internalLock) + { + CheckDisposed(); + SetupV2Config(); + } + } + + internal void SetUrlGroupProperty(UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_PROPERTY property, IntPtr info, uint infosize) + { + ValidateV2Property(); + uint statusCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS; + + Debug.Assert(_urlGroupId != 0, "SetUrlGroupProperty called with invalid url group id"); + Debug.Assert(info != IntPtr.Zero, "SetUrlGroupProperty called with invalid pointer"); + + // Set the url group property using Http Api. + + statusCode = UnsafeNclNativeMethods.HttpApi.HttpSetUrlGroupProperty( + _urlGroupId, property, info, infosize); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + WebListenerException exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_logger, "SetUrlGroupProperty", exception); + throw exception; + } + } + + internal void RemoveAll(bool clear) + { + CheckDisposed(); + // go through the uri list and unregister for each one of them + if (_uriPrefixes.Count > 0) + { + LogHelper.LogInfo(_logger, "RemoveAll"); + if (_state == State.Started) + { + foreach (Prefix registeredPrefix in _uriPrefixes) + { + // ignore possible failures + InternalRemovePrefix(registeredPrefix.Whole); + } + } + + if (clear) + { + _uriPrefixes.Clear(); + } + } + } + + private IntPtr DangerousGetHandle() + { + return _requestQueueHandle.DangerousGetHandle(); + } + + private unsafe void SetupV2Config() + { + uint statusCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS; + ulong id = 0; + + // If we have already initialized V2 config, then nothing to do. + + if (_v2Initialized) + { + return; + } + + // V2 initialization sequence: + // 1. Create server session + // 2. Create url group + // 3. Create request queue - Done in Start() + // 4. Add urls to url group - Done in Start() + // 5. Attach request queue to url group - Done in Start() + + try + { + statusCode = UnsafeNclNativeMethods.HttpApi.HttpCreateServerSession( + UnsafeNclNativeMethods.HttpApi.Version, &id, 0); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + throw new WebListenerException((int)statusCode); + } + + Debug.Assert(id != 0, "Invalid id returned by HttpCreateServerSession"); + + _serverSessionHandle = new HttpServerSessionHandle(id); + + id = 0; + statusCode = UnsafeNclNativeMethods.HttpApi.HttpCreateUrlGroup( + _serverSessionHandle.DangerousGetServerSessionId(), &id, 0); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + throw new WebListenerException((int)statusCode); + } + + Debug.Assert(id != 0, "Invalid id returned by HttpCreateUrlGroup"); + _urlGroupId = id; + + _v2Initialized = true; + } + catch (Exception exception) + { + // If V2 initialization fails, we mark object as unusable. + _state = State.Disposed; + // If Url group or request queue creation failed, close server session before throwing. + if (_serverSessionHandle != null) + { + _serverSessionHandle.Dispose(); + } + LogHelper.LogException(_logger, "SetupV2Config", exception); + throw; + } + } + + internal void Start(AppFunc app, IList> addresses, IDictionary capabilities, LoggerFactoryFunc loggerFactory) + { + CheckDisposed(); + // Can't call Start twice + Contract.Assert(_appFunc == null); + + Contract.Assert(app != null); + Contract.Assert(addresses != null); + Contract.Assert(capabilities != null); + + _appFunc = app; + _capabilities = capabilities; + _logger = LogHelper.CreateLogger(loggerFactory, typeof(OwinWebListener)); + LogHelper.LogInfo(_logger, "Start"); + + foreach (var address in addresses) + { + // Build addresses from parts + var scheme = address.Get("scheme") ?? Constants.HttpScheme; + var host = address.Get("host") ?? "localhost"; + var port = address.Get("port") ?? "5000"; + var path = address.Get("path") ?? string.Empty; + + Prefix prefix = Prefix.Create(scheme, host, port, path); + _uriPrefixes.Add(prefix); + } + + // Make sure there are no race conditions between Start/Stop/Abort/Close/Dispose and + // calls to SetupV2Config: Start needs to setup all resources (esp. in V2 where besides + // the request handle, there is also a server session and a Url group. Abort/Stop must + // not interfere while Start is allocating those resources. The lock also makes sure + // all methods changing state can read and change the state in an atomic way. + lock (_internalLock) + { + try + { + CheckDisposed(); + if (_state == State.Started) + { + return; + } + + // SetupV2Config() is not called in the ctor, because it may throw. This would + // be a regression since in v1 the ctor never threw. Besides, ctors should do + // minimal work according to the framework design guidelines. + SetupV2Config(); + CreateRequestQueueHandle(); + AttachRequestQueueToUrlGroup(); + + // All resources are set up correctly. Now add all prefixes. + try + { + AddAllPrefixes(); + } + catch (WebListenerException) + { + // If an error occurred while adding prefixes, free all resources allocated by previous steps. + DetachRequestQueueFromUrlGroup(); + throw; + } + + _state = State.Started; + + SetRequestQueueLimit(); + + OffloadListenForNextRequestAsync(); + } + catch (Exception exception) + { + // Make sure the HttpListener instance can't be used if Start() failed. + _state = State.Disposed; + CloseRequestQueueHandle(); + CleanupV2Config(); + LogHelper.LogException(_logger, "Start", exception); + throw; + } + } + } + + // Make sure the next request is processed on another thread as to not recursively + // block the request we just received. + private void OffloadListenForNextRequestAsync() + { + if (IsListening && CanAcceptMoreRequests) + { + Task offloadTask = Task.Run(_offloadListenForNextRequest); + } + } + + // The message pump. + // When we start listening for the next request on one thread, we may need to be sure that the + // completion continues on another thread as to not block the current request processing. + // The awaits will manage stack depth for us. + private async void ListenForNextRequestAsync() + { + while (IsListening && CanAcceptMoreRequests) + { + // Receive a request + RequestContext requestContext; + Interlocked.Increment(ref _currentOutstandingAccepts); + try + { + requestContext = await GetContextAsync().SupressContext(); + Interlocked.Decrement(ref _currentOutstandingAccepts); + } + catch (Exception exception) + { + LogHelper.LogException(_logger, "ListenForNextRequestAsync", exception); + // Assume the server has stopped. + Interlocked.Decrement(ref _currentOutstandingAccepts); + Contract.Assert(!IsListening); + return; + } + + Interlocked.Increment(ref _currentOutstandingRequests); + OffloadListenForNextRequestAsync(); + await ProcessRequestAsync(requestContext).SupressContext(); + Interlocked.Decrement(ref _currentOutstandingRequests); + } + } + + private async Task ProcessRequestAsync(RequestContext requestContext) + { + try + { + try + { + // TODO: Make disconnect registration lazy + RegisterForDisconnectNotification(requestContext); + await _appFunc(requestContext.Environment).SupressContext(); + await requestContext.ProcessResponseAsync().SupressContext(); + } + catch (Exception ex) + { + LogHelper.LogException(_logger, "ProcessRequestAsync", ex); + if (requestContext.Response.SentHeaders) + { + requestContext.Abort(); + } + else + { + // We haven't sent a response yet, try to send a 500 Internal Server Error + requestContext.SetFatalResponse(); + } + } + requestContext.Dispose(); + } + catch (Exception ex) + { + LogHelper.LogException(_logger, "ProcessRequestAsync", ex); + requestContext.Abort(); + requestContext.Dispose(); + } + } + + private void CleanupV2Config() + { + // If we never setup V2, just return. + if (!_v2Initialized) + { + return; + } + + // V2 stopping sequence: + // 1. Detach request queue from url group - Done in Stop()/Abort() + // 2. Remove urls from url group - Done in Stop() + // 3. Close request queue - Done in Stop()/Abort() + // 4. Close Url group. + // 5. Close server session. + + Debug.Assert(_urlGroupId != 0, "HttpCloseUrlGroup called with invalid url group id"); + + uint statusCode = UnsafeNclNativeMethods.HttpApi.HttpCloseUrlGroup(_urlGroupId); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + LogHelper.LogError(_logger, "CleanupV2Config", "Result: " + statusCode); + } + _urlGroupId = 0; + + Debug.Assert(_serverSessionHandle != null, "ServerSessionHandle is null in CloseV2Config"); + Debug.Assert(!_serverSessionHandle.IsInvalid, "ServerSessionHandle is invalid in CloseV2Config"); + + _serverSessionHandle.Dispose(); + } + + private unsafe void AttachRequestQueueToUrlGroup() + { + // Set the association between request queue and url group. After this, requests for registered urls will + // get delivered to this request queue. + + UnsafeNclNativeMethods.HttpApi.HTTP_BINDING_INFO info = new UnsafeNclNativeMethods.HttpApi.HTTP_BINDING_INFO(); + info.Flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_PROPERTY_FLAG_PRESENT; + info.RequestQueueHandle = DangerousGetHandle(); + + IntPtr infoptr = new IntPtr(&info); + + SetUrlGroupProperty(UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_PROPERTY.HttpServerBindingProperty, + infoptr, (uint)BindingInfoSize); + } + + private unsafe void DetachRequestQueueFromUrlGroup() + { + Debug.Assert(_urlGroupId != 0, "DetachRequestQueueFromUrlGroup can't detach using Url group id 0."); + + // Break the association between request queue and url group. After this, requests for registered urls + // will get 503s. + // Note that this method may be called multiple times (Stop() and then Abort()). This + // is fine since http.sys allows to set HttpServerBindingProperty multiple times for valid + // Url groups. + + UnsafeNclNativeMethods.HttpApi.HTTP_BINDING_INFO info = new UnsafeNclNativeMethods.HttpApi.HTTP_BINDING_INFO(); + info.Flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + info.RequestQueueHandle = IntPtr.Zero; + + IntPtr infoptr = new IntPtr(&info); + + uint statusCode = UnsafeNclNativeMethods.HttpApi.HttpSetUrlGroupProperty(_urlGroupId, + UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_PROPERTY.HttpServerBindingProperty, + infoptr, (uint)BindingInfoSize); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + LogHelper.LogError(_logger, "DetachRequestQueueFromUrlGroup", "Result: " + statusCode); + } + } + + internal void Stop() + { + try + { + lock (_internalLock) + { + CheckDisposed(); + if (_state == State.Stopped) + { + return; + } + LogHelper.LogInfo(_logger, "Stop"); + + RemoveAll(false); + + _state = State.Stopped; + + DetachRequestQueueFromUrlGroup(); + + // Even though it would be enough to just detach the request queue in v2, in order to + // keep app compat with earlier versions of the framework, we need to close the request queue. + // This will make sure that pending GetContext() calls will complete and throw an exception. Just + // detaching the url group from the request queue would not cause GetContext() to return. + CloseRequestQueueHandle(); + } + } + catch (Exception exception) + { + LogHelper.LogException(_logger, "Stop", exception); + throw; + } + } + + private unsafe void CreateRequestQueueHandle() + { + uint statusCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS; + + HttpRequestQueueV2Handle requestQueueHandle = null; + statusCode = + UnsafeNclNativeMethods.SafeNetHandles.HttpCreateRequestQueue( + UnsafeNclNativeMethods.HttpApi.Version, null, null, 0, out requestQueueHandle); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + throw new WebListenerException((int)statusCode); + } + + // Disabling callbacks when IO operation completes synchronously (returns ErrorCodes.ERROR_SUCCESS) + if (SkipIOCPCallbackOnSuccess && + !UnsafeNclNativeMethods.SetFileCompletionNotificationModes( + requestQueueHandle, + UnsafeNclNativeMethods.FileCompletionNotificationModes.SkipCompletionPortOnSuccess | + UnsafeNclNativeMethods.FileCompletionNotificationModes.SkipSetEventOnHandle)) + { + throw new WebListenerException(Marshal.GetLastWin32Error()); + } + + _requestQueueHandle = requestQueueHandle; + ThreadPool.BindHandle(_requestQueueHandle); + } + + private unsafe void CloseRequestQueueHandle() + { + if ((_requestQueueHandle != null) && (!_requestQueueHandle.IsInvalid)) + { + _requestQueueHandle.Dispose(); + } + } + + /// + /// Stop the server and clean up. + /// + public void Dispose() + { + Dispose(true); + } + + // old API, now private, and helper methods + private void Dispose(bool disposing) + { + if (!disposing) + { + return; + } + + lock (_internalLock) + { + try + { + if (_state == State.Disposed) + { + return; + } + LogHelper.LogInfo(_logger, "Dispose"); + + Stop(); + CleanupV2Config(); + } + catch (Exception exception) + { + LogHelper.LogException(_logger, "Dispose", exception); + throw; + } + finally + { + _state = State.Disposed; + } + } + } + + private uint InternalAddPrefix(string uriPrefix, int contextId) + { + uint statusCode = 0; + + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpAddUrlToUrlGroup( + _urlGroupId, + uriPrefix, + (ulong)contextId, + 0); + + return statusCode; + } + + private bool InternalRemovePrefix(string uriPrefix) + { + uint statusCode = 0; + + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpRemoveUrlFromUrlGroup( + _urlGroupId, + uriPrefix, + 0); + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_NOT_FOUND) + { + return false; + } + return true; + } + + private void AddAllPrefixes() + { + // go through the uri list and register for each one of them + if (_uriPrefixes.Count > 0) + { + for (int i = 0; i < _uriPrefixes.Count; i++) + { + // We'll get this index back on each request and use it to look up the prefix to calculate PathBase. + Prefix registeredPrefix = _uriPrefixes[i]; + uint statusCode = InternalAddPrefix(registeredPrefix.Whole, i); + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_ALREADY_EXISTS) + { + throw new WebListenerException((int)statusCode, String.Format(Resources.Exception_PrefixAlreadyRegistered, registeredPrefix.Whole)); + } + else + { + throw new WebListenerException((int)statusCode); + } + } + } + } + } + + internal unsafe bool ValidateRequest(NativeRequestContext requestMemory) + { + // Block potential DOS attacks + if (requestMemory.RequestBlob->Headers.UnknownHeaderCount > UnknownHeaderLimit) + { + SendError(requestMemory.RequestBlob->RequestId, HttpStatusCode.BadRequest); + return false; + } + return true; + } + + [SuppressMessage("Microsoft.Reliability", "CA2000:Dispose objects before losing scope", Justification = "Disposed by callback")] + internal Task GetContextAsync() + { + AsyncAcceptContext asyncResult = null; + try + { + CheckDisposed(); + Debug.Assert(_state != State.Stopped, "Listener has been stopped."); + // prepare the ListenerAsyncResult object (this will have it's own + // event that the user can wait on for IO completion - which means we + // need to signal it when IO completes) + asyncResult = new AsyncAcceptContext(this); + uint statusCode = asyncResult.QueueBeginGetContext(); + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + // someother bad error, possible(?) return values are: + // ERROR_INVALID_HANDLE, ERROR_INSUFFICIENT_BUFFER, ERROR_OPERATION_ABORTED + asyncResult.Dispose(); + throw new WebListenerException((int)statusCode); + } + } + catch (Exception exception) + { + LogHelper.LogException(_logger, "GetContextAsync", exception); + throw; + } + + return asyncResult.Task; + } + + private void RegisterForDisconnectNotification(RequestContext requestContext) + { + try + { + // Create exactly one CancellationToken per connection. + ulong connectionId = requestContext.Request.ConnectionId; + CancellationToken ct = GetConnectionCancellation(connectionId); + requestContext.Request.RegisterForDisconnect(ct); + requestContext.Environment.ConnectionDisconnect = ct; + } + catch (Win32Exception exception) + { + LogHelper.LogException(_logger, "RegisterForDisconnectNotification", exception); + } + } + + private CancellationToken GetConnectionCancellation(ulong connectionId) + { + // Read case is performance senstive + ConnectionCancellation cancellation; + if (!_connectionCancellationTokens.TryGetValue(connectionId, out cancellation)) + { + cancellation = GetCreatedConnectionCancellation(connectionId); + } + return cancellation.GetCancellationToken(connectionId); + } + + private ConnectionCancellation GetCreatedConnectionCancellation(ulong connectionId) + { + // Race condition on creation has no side effects + ConnectionCancellation cancellation = new ConnectionCancellation(this); + return _connectionCancellationTokens.GetOrAdd(connectionId, cancellation); + } + + private unsafe CancellationToken CreateDisconnectToken(ulong connectionId) + { + // Debug.WriteLine("Server: Registering connection for disconnect for connection ID: " + connectionId); + + // Create a nativeOverlapped callback so we can register for disconnect callback + var overlapped = new Overlapped(); + var cts = new CancellationTokenSource(); + + SafeNativeOverlapped nativeOverlapped = null; + nativeOverlapped = new SafeNativeOverlapped(overlapped.UnsafePack( + (errorCode, numBytes, overlappedPtr) => + { + // Debug.WriteLine("Server: http.sys disconnect callback fired for connection ID: " + connectionId); + + // Free the overlapped + nativeOverlapped.Dispose(); + + // Pull the token out of the list and Cancel it. + ConnectionCancellation token; + _connectionCancellationTokens.TryRemove(connectionId, out token); + try + { + cts.Cancel(); + } + catch (AggregateException exception) + { + LogHelper.LogException(_logger, "CreateDisconnectToken::Disconnected", exception); + } + + cts.Dispose(); + }, + null)); + + uint statusCode; + try + { + statusCode = UnsafeNclNativeMethods.HttpApi.HttpWaitForDisconnect(_requestQueueHandle, connectionId, nativeOverlapped); + } + catch (Win32Exception exception) + { + statusCode = (uint)exception.NativeErrorCode; + LogHelper.LogException(_logger, "CreateDisconnectToken", exception); + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + // We got an unknown result so return a None + // TODO: return a canceled token? + return CancellationToken.None; + } + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + // TODO: return a canceled token? + return CancellationToken.None; + } + + return cts.Token; + } + + private unsafe void SendError(ulong requestId, HttpStatusCode httpStatusCode) + { + UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE httpResponse = new UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE(); + httpResponse.Version = new UnsafeNclNativeMethods.HttpApi.HTTP_VERSION(); + httpResponse.Version.MajorVersion = (ushort)1; + httpResponse.Version.MinorVersion = (ushort)1; + httpResponse.StatusCode = (ushort)httpStatusCode; + string statusDescription = HttpReasonPhrase.Get(httpStatusCode); + uint dataWritten = 0; + uint statusCode; + byte[] byteReason = HeaderEncoding.GetBytes(statusDescription); + fixed (byte* pReason = byteReason) + { + httpResponse.pReason = (sbyte*)pReason; + httpResponse.ReasonLength = (ushort)byteReason.Length; + + byte[] byteContentLength = new byte[] { (byte)'0' }; + fixed (byte* pContentLength = byteContentLength) + { + (&httpResponse.Headers.KnownHeaders)[(int)HttpSysResponseHeader.ContentLength].pRawValue = (sbyte*)pContentLength; + (&httpResponse.Headers.KnownHeaders)[(int)HttpSysResponseHeader.ContentLength].RawValueLength = (ushort)byteContentLength.Length; + httpResponse.Headers.UnknownHeaderCount = 0; + + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendHttpResponse( + _requestQueueHandle, + requestId, + 0, + &httpResponse, + null, + &dataWritten, + SafeLocalFree.Zero, + 0, + SafeNativeOverlapped.Zero, + IntPtr.Zero); + } + } + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + // if we fail to send a 401 something's seriously wrong, abort the request + RequestContext.CancelRequest(_requestQueueHandle, requestId); + } + } + + private static int GetTokenOffsetFromBlob(IntPtr blob) + { + Debug.Assert(blob != IntPtr.Zero); +#if NET45 + IntPtr tokenPointer = Marshal.ReadIntPtr(blob, (int)Marshal.OffsetOf(ChannelBindingStatusType, "ChannelToken")); +#else + IntPtr tokenPointer = Marshal.ReadIntPtr(blob, (int)Marshal.OffsetOf("ChannelToken")); +#endif + Debug.Assert(tokenPointer != IntPtr.Zero); + return (int)IntPtrHelper.Subtract(tokenPointer, blob); + } + + private static int GetTokenSizeFromBlob(IntPtr blob) + { + Debug.Assert(blob != IntPtr.Zero); +#if NET45 + return Marshal.ReadInt32(blob, (int)Marshal.OffsetOf(ChannelBindingStatusType, "ChannelTokenSize")); +#else + return Marshal.ReadInt32(blob, (int)Marshal.OffsetOf("ChannelTokenSize")); +#endif + } + + internal ChannelBinding GetChannelBinding(ulong connectionId, bool isSecureConnection) + { + if (!isSecureConnection) + { + LogHelper.LogInfo(_logger, "Channel binding is not supported for HTTP."); + return null; + } + + ChannelBinding result = GetChannelBindingFromTls(connectionId); + + Debug.Assert(result != null, "GetChannelBindingFromTls returned null even though OS supposedly supports Extended Protection"); + LogHelper.LogInfo(_logger, "Channel binding retrieved."); + return result; + } + + private unsafe ChannelBinding GetChannelBindingFromTls(ulong connectionId) + { + // +128 since a CBT is usually <128 thus we need to call HRCC just once. If the CBT + // is >128 we will get ERROR_MORE_DATA and call again + int size = RequestChannelBindStatusSize + 128; + + Debug.Assert(size >= 0); + + byte[] blob = null; + SafeLocalFreeChannelBinding token = null; + + uint bytesReceived = 0; + uint statusCode; + + do + { + blob = new byte[size]; + fixed (byte* blobPtr = blob) + { + // Http.sys team: ServiceName will always be null if + // HTTP_RECEIVE_SECURE_CHANNEL_TOKEN flag is set. + statusCode = UnsafeNclNativeMethods.HttpApi.HttpReceiveClientCertificate( + RequestQueueHandle, + connectionId, + (uint)UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_RECEIVE_SECURE_CHANNEL_TOKEN, + blobPtr, + (uint)size, + &bytesReceived, + SafeNativeOverlapped.Zero); + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + int tokenOffset = GetTokenOffsetFromBlob((IntPtr)blobPtr); + int tokenSize = GetTokenSizeFromBlob((IntPtr)blobPtr); + Debug.Assert(tokenSize < Int32.MaxValue); + + token = SafeLocalFreeChannelBinding.LocalAlloc(tokenSize); + + Marshal.Copy(blob, tokenOffset, token.DangerousGetHandle(), tokenSize); + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_MORE_DATA) + { + int tokenSize = GetTokenSizeFromBlob((IntPtr)blobPtr); + Debug.Assert(tokenSize < Int32.MaxValue); + + size = RequestChannelBindStatusSize + tokenSize; + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_INVALID_PARAMETER) + { + LogHelper.LogError(_logger, "GetChannelBindingFromTls", "Channel binding is not supported."); + return null; // old schannel library which doesn't support CBT + } + else + { + // It's up to the consumer to fail if the missing ChannelBinding matters to them. + LogHelper.LogException(_logger, "GetChannelBindingFromTls", new WebListenerException((int)statusCode)); + break; + } + } + } + while (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS); + + return token; + } + + internal void CheckDisposed() + { + if (_state == State.Disposed) + { + throw new ObjectDisposedException(this.GetType().FullName); + } + } + + private class ConnectionCancellation + { + private readonly OwinWebListener _parent; + private volatile bool _initialized; // Must be volatile because initialization is synchronized + private CancellationToken _cancellationToken; + + public ConnectionCancellation(OwinWebListener parent) + { + _parent = parent; + } + + internal CancellationToken GetCancellationToken(ulong connectionId) + { + // Initialized case is performance sensitive + if (_initialized) + { + return _cancellationToken; + } + return InitializeCancellationToken(connectionId); + } + + private CancellationToken InitializeCancellationToken(ulong connectionId) + { + object syncObject = this; +#pragma warning disable 420 // Disable warning about volatile by reference since EnsureInitialized does volatile operations + return LazyInitializer.EnsureInitialized(ref _cancellationToken, ref _initialized, ref syncObject, () => _parent.CreateDisconnectToken(connectionId)); +#pragma warning restore 420 + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/Prefix.cs b/src/Microsoft.AspNet.Server.WebListener/Prefix.cs new file mode 100644 index 0000000000..08a0fe61c5 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/Prefix.cs @@ -0,0 +1,102 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Globalization; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class Prefix + { + private Prefix(bool isHttps, string scheme, string host, string port, int portValue, string path) + { + IsHttps = isHttps; + Scheme = scheme; + Host = host; + Port = port; + PortValue = portValue; + Path = path; + Whole = string.Format(CultureInfo.InvariantCulture, "{0}://{1}:{2}{3}", Scheme, Host, Port, Path); + } + + /// + /// http://msdn.microsoft.com/en-us/library/windows/desktop/aa364698(v=vs.85).aspx + /// + /// http or https. Will be normalized to lower case. + /// +, *, IPv4, [IPv6], or a dns name. Http.Sys does not permit punycode (xn--), use Unicode instead. + /// If empty, the default port for the given scheme will be used (80 or 443). + /// Should start and end with a '/', though a missing trailing slash will be added. This value must be un-escaped. + public static Prefix Create(string scheme, string host, string port, string path) + { + bool isHttps; + if (string.Equals(Constants.HttpScheme, scheme, StringComparison.OrdinalIgnoreCase)) + { + scheme = Constants.HttpScheme; // Always use a lower case scheme + isHttps = false; + } + else if (string.Equals(Constants.HttpsScheme, scheme, StringComparison.OrdinalIgnoreCase)) + { + scheme = Constants.HttpsScheme; // Always use a lower case scheme + isHttps = true; + } + else + { + throw new ArgumentOutOfRangeException("scheme", scheme, Resources.Exception_UnsupportedScheme); + } + + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentNullException("host"); + } + + int portValue; + if (string.IsNullOrWhiteSpace(port)) + { + port = isHttps ? "443" : "80"; + portValue = isHttps ? 443 : 80; + } + else + { + portValue = int.Parse(port, NumberStyles.None, CultureInfo.InvariantCulture); + } + + // Http.Sys requires the path end with a slash. + if (string.IsNullOrWhiteSpace(path)) + { + path = "/"; + } + else if (!path.EndsWith("/", StringComparison.Ordinal)) + { + path += "/"; + } + + return new Prefix(isHttps, scheme, host, port, portValue, path); + } + + public bool IsHttps { get; private set; } + public string Scheme { get; private set; } + public string Host { get; private set; } + public string Port { get; private set; } + public int PortValue { get; private set; } + public string Path { get; private set; } + public string Whole { get; private set; } + + public override bool Equals(object obj) + { + return string.Equals(Whole, obj as string, StringComparison.OrdinalIgnoreCase); + } + + public override int GetHashCode() + { + return StringComparer.OrdinalIgnoreCase.GetHashCode(Whole); + } + + public override string ToString() + { + return Whole; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/Properties/AssemblyInfo.cs b/src/Microsoft.AspNet.Server.WebListener/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..ed361b4f1c --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/Properties/AssemblyInfo.cs @@ -0,0 +1,44 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("Microsoft.AspNet.Server.WebListener")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.AspNet.Server.WebListener")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("1f471909-581f-4060-a147-430891e9c3c1")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] +[assembly: CLSCompliant(true)] diff --git a/src/Microsoft.AspNet.Server.WebListener/PumpLimits.cs b/src/Microsoft.AspNet.Server.WebListener/PumpLimits.cs new file mode 100644 index 0000000000..28b99c11cc --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/PumpLimits.cs @@ -0,0 +1,31 @@ +// +// Copyright 2011-2012 Katana contributors +// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class PumpLimits + { + internal PumpLimits(int maxAccepts, int maxRequests) + { + MaxOutstandingAccepts = maxAccepts; + MaxOutstandingRequests = maxRequests; + } + + internal int MaxOutstandingAccepts { get; private set; } + + internal int MaxOutstandingRequests { get; private set; } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/BoundaryType.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/BoundaryType.cs new file mode 100644 index 0000000000..6d9ca0b4e2 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/BoundaryType.cs @@ -0,0 +1,18 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum BoundaryType + { + ContentLength = 0, // Content-Length: XXX + Chunked = 1, // Transfer-Encoding: chunked + Raw = 2, // the app is responsible for sending the correct headers and body encoding + Multipart = 3, + None = 4, + Invalid = 5, + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.cs new file mode 100644 index 0000000000..d0fe8d05ea --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.cs @@ -0,0 +1,1913 @@ +//----------------------------------------------------------------------- +// +// Copyright (c) Katana Contributors. All rights reserved. +// +//----------------------------------------------------------------------- +// + +using System; +using System.CodeDom.Compiler; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Security.Authentication.ExtendedProtection; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + using OpaqueUpgrade = Action, Func, Task>>; + + [GeneratedCode("TextTemplatingFileGenerator", "")] + internal partial class CallEnvironment + { + // Mark all fields with delay initialization support as set. + private UInt32 _flag0 = 0x43e00200u; + private UInt32 _flag1 = 0x1u; + // Mark all fields with delay initialization support as requiring initialization. + private UInt32 _initFlag0 = 0x43e00200u; + private UInt32 _initFlag1 = 0x1u; + + internal interface IPropertySource + { + Stream GetRequestBody(); + string GetRemoteIpAddress(); + string GetRemotePort(); + string GetLocalIpAddress(); + string GetLocalPort(); + bool GetIsLocal(); + bool TryGetChannelBinding(ref ChannelBinding value); + bool TryGetOpaqueUpgrade(ref OpaqueUpgrade value); + } + + private string _OwinVersion; + private CancellationToken _CallCancelled; + private string _RequestProtocol; + private string _RequestMethod; + private string _RequestScheme; + private string _RequestPathBase; + private string _RequestPath; + private string _RequestQueryString; + private IDictionary _RequestHeaders; + private Stream _RequestBody; + private IDictionary _ResponseHeaders; + private Stream _ResponseBody; + private int? _ResponseStatusCode; + private string _ResponseReasonPhrase; + private TextWriter _HostTraceOutput; + private string _HostAppName; + private string _HostAppMode; + private CancellationToken _OnAppDisposing; + private System.Security.Principal.IPrincipal _User; + private Action, object> _OnSendingHeaders; + private IDictionary _ServerCapabilities; + private string _RemoteIpAddress; + private string _RemotePort; + private string _LocalIpAddress; + private string _LocalPort; + private bool _IsLocal; + private object _ConnectionId; + private CancellationToken _ConnectionDisconnect; + private object _ClientCert; + private Func _LoadClientCert; + private ChannelBinding _ChannelBinding; + private Func _SendFileAsync; + private OpaqueUpgrade _OpaqueUpgrade; + private OwinWebListener _Listener; + + bool InitPropertyChannelBinding() + { + if (!_propertySource.TryGetChannelBinding(ref _ChannelBinding)) + { + _flag0 &= ~0x40000000u; + _initFlag0 &= ~0x40000000u; + return false; + } + _initFlag0 &= ~0x40000000u; + return true; + } + + bool InitPropertyOpaqueUpgrade() + { + if (!_propertySource.TryGetOpaqueUpgrade(ref _OpaqueUpgrade)) + { + _flag1 &= ~0x1u; + _initFlag1 &= ~0x1u; + return false; + } + _initFlag1 &= ~0x1u; + return true; + } + + internal string OwinVersion + { + get + { + return _OwinVersion; + } + set + { + _flag0 |= 0x1u; + _OwinVersion = value; + } + } + + internal CancellationToken CallCancelled + { + get + { + return _CallCancelled; + } + set + { + _flag0 |= 0x2u; + _CallCancelled = value; + } + } + + internal string RequestProtocol + { + get + { + return _RequestProtocol; + } + set + { + _flag0 |= 0x4u; + _RequestProtocol = value; + } + } + + internal string RequestMethod + { + get + { + return _RequestMethod; + } + set + { + _flag0 |= 0x8u; + _RequestMethod = value; + } + } + + internal string RequestScheme + { + get + { + return _RequestScheme; + } + set + { + _flag0 |= 0x10u; + _RequestScheme = value; + } + } + + internal string RequestPathBase + { + get + { + return _RequestPathBase; + } + set + { + _flag0 |= 0x20u; + _RequestPathBase = value; + } + } + + internal string RequestPath + { + get + { + return _RequestPath; + } + set + { + _flag0 |= 0x40u; + _RequestPath = value; + } + } + + internal string RequestQueryString + { + get + { + return _RequestQueryString; + } + set + { + _flag0 |= 0x80u; + _RequestQueryString = value; + } + } + + internal IDictionary RequestHeaders + { + get + { + return _RequestHeaders; + } + set + { + _flag0 |= 0x100u; + _RequestHeaders = value; + } + } + + internal Stream RequestBody + { + get + { + if (((_initFlag0 & 0x200u) != 0)) + { + _RequestBody = _propertySource.GetRequestBody(); + _initFlag0 &= ~0x200u; + } + return _RequestBody; + } + set + { + _initFlag0 &= ~0x200u; + _flag0 |= 0x200u; + _RequestBody = value; + } + } + + internal IDictionary ResponseHeaders + { + get + { + return _ResponseHeaders; + } + set + { + _flag0 |= 0x400u; + _ResponseHeaders = value; + } + } + + internal Stream ResponseBody + { + get + { + return _ResponseBody; + } + set + { + _flag0 |= 0x800u; + _ResponseBody = value; + } + } + + internal int? ResponseStatusCode + { + get + { + return _ResponseStatusCode; + } + set + { + _flag0 |= 0x1000u; + _ResponseStatusCode = value; + } + } + + internal string ResponseReasonPhrase + { + get + { + return _ResponseReasonPhrase; + } + set + { + _flag0 |= 0x2000u; + _ResponseReasonPhrase = value; + } + } + + internal TextWriter HostTraceOutput + { + get + { + return _HostTraceOutput; + } + set + { + _flag0 |= 0x4000u; + _HostTraceOutput = value; + } + } + + internal string HostAppName + { + get + { + return _HostAppName; + } + set + { + _flag0 |= 0x8000u; + _HostAppName = value; + } + } + + internal string HostAppMode + { + get + { + return _HostAppMode; + } + set + { + _flag0 |= 0x10000u; + _HostAppMode = value; + } + } + + internal CancellationToken OnAppDisposing + { + get + { + return _OnAppDisposing; + } + set + { + _flag0 |= 0x20000u; + _OnAppDisposing = value; + } + } + + internal System.Security.Principal.IPrincipal User + { + get + { + return _User; + } + set + { + _flag0 |= 0x40000u; + _User = value; + } + } + + internal Action, object> OnSendingHeaders + { + get + { + return _OnSendingHeaders; + } + set + { + _flag0 |= 0x80000u; + _OnSendingHeaders = value; + } + } + + internal IDictionary ServerCapabilities + { + get + { + return _ServerCapabilities; + } + set + { + _flag0 |= 0x100000u; + _ServerCapabilities = value; + } + } + + internal string RemoteIpAddress + { + get + { + if (((_initFlag0 & 0x200000u) != 0)) + { + _RemoteIpAddress = _propertySource.GetRemoteIpAddress(); + _initFlag0 &= ~0x200000u; + } + return _RemoteIpAddress; + } + set + { + _initFlag0 &= ~0x200000u; + _flag0 |= 0x200000u; + _RemoteIpAddress = value; + } + } + + internal string RemotePort + { + get + { + if (((_initFlag0 & 0x400000u) != 0)) + { + _RemotePort = _propertySource.GetRemotePort(); + _initFlag0 &= ~0x400000u; + } + return _RemotePort; + } + set + { + _initFlag0 &= ~0x400000u; + _flag0 |= 0x400000u; + _RemotePort = value; + } + } + + internal string LocalIpAddress + { + get + { + if (((_initFlag0 & 0x800000u) != 0)) + { + _LocalIpAddress = _propertySource.GetLocalIpAddress(); + _initFlag0 &= ~0x800000u; + } + return _LocalIpAddress; + } + set + { + _initFlag0 &= ~0x800000u; + _flag0 |= 0x800000u; + _LocalIpAddress = value; + } + } + + internal string LocalPort + { + get + { + if (((_initFlag0 & 0x1000000u) != 0)) + { + _LocalPort = _propertySource.GetLocalPort(); + _initFlag0 &= ~0x1000000u; + } + return _LocalPort; + } + set + { + _initFlag0 &= ~0x1000000u; + _flag0 |= 0x1000000u; + _LocalPort = value; + } + } + + internal bool IsLocal + { + get + { + if (((_initFlag0 & 0x2000000u) != 0)) + { + _IsLocal = _propertySource.GetIsLocal(); + _initFlag0 &= ~0x2000000u; + } + return _IsLocal; + } + set + { + _initFlag0 &= ~0x2000000u; + _flag0 |= 0x2000000u; + _IsLocal = value; + } + } + + internal object ConnectionId + { + get + { + return _ConnectionId; + } + set + { + _flag0 |= 0x4000000u; + _ConnectionId = value; + } + } + + internal CancellationToken ConnectionDisconnect + { + get + { + return _ConnectionDisconnect; + } + set + { + _flag0 |= 0x8000000u; + _ConnectionDisconnect = value; + } + } + + internal object ClientCert + { + get + { + return _ClientCert; + } + set + { + _flag0 |= 0x10000000u; + _ClientCert = value; + } + } + + internal Func LoadClientCert + { + get + { + return _LoadClientCert; + } + set + { + _flag0 |= 0x20000000u; + _LoadClientCert = value; + } + } + + internal ChannelBinding ChannelBinding + { + get + { + if (((_initFlag0 & 0x40000000u) != 0)) + { + InitPropertyChannelBinding(); + } + return _ChannelBinding; + } + set + { + _initFlag0 &= ~0x40000000u; + _flag0 |= 0x40000000u; + _ChannelBinding = value; + } + } + + internal Func SendFileAsync + { + get + { + return _SendFileAsync; + } + set + { + _flag0 |= 0x80000000u; + _SendFileAsync = value; + } + } + + internal OpaqueUpgrade OpaqueUpgrade + { + get + { + if (((_initFlag1 & 0x1u) != 0)) + { + InitPropertyOpaqueUpgrade(); + } + return _OpaqueUpgrade; + } + set + { + _initFlag1 &= ~0x1u; + _flag1 |= 0x1u; + _OpaqueUpgrade = value; + } + } + + internal OwinWebListener Listener + { + get + { + return _Listener; + } + set + { + _flag1 |= 0x2u; + _Listener = value; + } + } + + private bool PropertiesContainsKey(string key) + { + switch (key.Length) + { + case 11: + if (((_flag0 & 0x40000u) != 0) && string.Equals(key, "server.User", StringComparison.Ordinal)) + { + return true; + } + break; + case 12: + if (((_flag0 & 0x1u) != 0) && string.Equals(key, "owin.Version", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x8000u) != 0) && string.Equals(key, "host.AppName", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x10000u) != 0) && string.Equals(key, "host.AppMode", StringComparison.Ordinal)) + { + return true; + } + break; + case 14: + if (((_flag0 & 0x2000000u) != 0) && string.Equals(key, "server.IsLocal", StringComparison.Ordinal)) + { + return true; + } + if (((_flag1 & 0x1u) != 0) && string.Equals(key, "opaque.Upgrade", StringComparison.Ordinal)) + { + if (((_initFlag1 & 0x1u) == 0) || InitPropertyOpaqueUpgrade()) + { + return true; + } + } + break; + case 16: + if (((_flag0 & 0x40u) != 0) && string.Equals(key, "owin.RequestPath", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x200u) != 0) && string.Equals(key, "owin.RequestBody", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x4000u) != 0) && string.Equals(key, "host.TraceOutput", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x1000000u) != 0) && string.Equals(key, "server.LocalPort", StringComparison.Ordinal)) + { + return true; + } + break; + case 17: + if (((_flag0 & 0x800u) != 0) && string.Equals(key, "owin.ResponseBody", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x400000u) != 0) && string.Equals(key, "server.RemotePort", StringComparison.Ordinal)) + { + return true; + } + break; + case 18: + if (((_flag0 & 0x2u) != 0) && string.Equals(key, "owin.CallCancelled", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x8u) != 0) && string.Equals(key, "owin.RequestMethod", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x10u) != 0) && string.Equals(key, "owin.RequestScheme", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x40000000u) != 0) && string.Equals(key, "ssl.ChannelBinding", StringComparison.Ordinal)) + { + if (((_initFlag0 & 0x40000000u) == 0) || InitPropertyChannelBinding()) + { + return true; + } + } + if (((_flag0 & 0x80000000u) != 0) && string.Equals(key, "sendfile.SendAsync", StringComparison.Ordinal)) + { + return true; + } + break; + case 19: + if (((_flag0 & 0x100u) != 0) && string.Equals(key, "owin.RequestHeaders", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x20000u) != 0) && string.Equals(key, "host.OnAppDisposing", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x100000u) != 0) && string.Equals(key, "server.Capabilities", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x4000000u) != 0) && string.Equals(key, "server.ConnectionId", StringComparison.Ordinal)) + { + return true; + } + break; + case 20: + if (((_flag0 & 0x4u) != 0) && string.Equals(key, "owin.RequestProtocol", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x20u) != 0) && string.Equals(key, "owin.RequestPathBase", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x400u) != 0) && string.Equals(key, "owin.ResponseHeaders", StringComparison.Ordinal)) + { + return true; + } + break; + case 21: + if (((_flag0 & 0x800000u) != 0) && string.Equals(key, "server.LocalIpAddress", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x10000000u) != 0) && string.Equals(key, "ssl.ClientCertificate", StringComparison.Ordinal)) + { + return true; + } + break; + case 22: + if (((_flag0 & 0x200000u) != 0) && string.Equals(key, "server.RemoteIpAddress", StringComparison.Ordinal)) + { + return true; + } + break; + case 23: + if (((_flag0 & 0x80u) != 0) && string.Equals(key, "owin.RequestQueryString", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x1000u) != 0) && string.Equals(key, "owin.ResponseStatusCode", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x80000u) != 0) && string.Equals(key, "server.OnSendingHeaders", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x20000000u) != 0) && string.Equals(key, "ssl.LoadClientCertAsync", StringComparison.Ordinal)) + { + return true; + } + break; + case 25: + if (((_flag0 & 0x2000u) != 0) && string.Equals(key, "owin.ResponseReasonPhrase", StringComparison.Ordinal)) + { + return true; + } + break; + case 27: + if (((_flag0 & 0x8000000u) != 0) && string.Equals(key, "server.ConnectionDisconnect", StringComparison.Ordinal)) + { + return true; + } + break; + case 51: + if (((_flag1 & 0x2u) != 0) && string.Equals(key, "Microsoft.AspNet.Server.WebListener.OwinWebListener", StringComparison.Ordinal)) + { + return true; + } + break; + } + return false; + } + + private bool PropertiesTryGetValue(string key, out object value) + { + switch (key.Length) + { + case 11: + if (((_flag0 & 0x40000u) != 0) && string.Equals(key, "server.User", StringComparison.Ordinal)) + { + value = User; + return true; + } + break; + case 12: + if (((_flag0 & 0x1u) != 0) && string.Equals(key, "owin.Version", StringComparison.Ordinal)) + { + value = OwinVersion; + return true; + } + if (((_flag0 & 0x8000u) != 0) && string.Equals(key, "host.AppName", StringComparison.Ordinal)) + { + value = HostAppName; + return true; + } + if (((_flag0 & 0x10000u) != 0) && string.Equals(key, "host.AppMode", StringComparison.Ordinal)) + { + value = HostAppMode; + return true; + } + break; + case 14: + if (((_flag0 & 0x2000000u) != 0) && string.Equals(key, "server.IsLocal", StringComparison.Ordinal)) + { + value = IsLocal; + return true; + } + if (((_flag1 & 0x1u) != 0) && string.Equals(key, "opaque.Upgrade", StringComparison.Ordinal)) + { + value = OpaqueUpgrade; + // Delayed initialization in the property getter may determine that the element is not actually present + if (!((_flag1 & 0x1u) != 0)) + { + value = default(OpaqueUpgrade); + return false; + } + return true; + } + break; + case 16: + if (((_flag0 & 0x40u) != 0) && string.Equals(key, "owin.RequestPath", StringComparison.Ordinal)) + { + value = RequestPath; + return true; + } + if (((_flag0 & 0x200u) != 0) && string.Equals(key, "owin.RequestBody", StringComparison.Ordinal)) + { + value = RequestBody; + return true; + } + if (((_flag0 & 0x4000u) != 0) && string.Equals(key, "host.TraceOutput", StringComparison.Ordinal)) + { + value = HostTraceOutput; + return true; + } + if (((_flag0 & 0x1000000u) != 0) && string.Equals(key, "server.LocalPort", StringComparison.Ordinal)) + { + value = LocalPort; + return true; + } + break; + case 17: + if (((_flag0 & 0x800u) != 0) && string.Equals(key, "owin.ResponseBody", StringComparison.Ordinal)) + { + value = ResponseBody; + return true; + } + if (((_flag0 & 0x400000u) != 0) && string.Equals(key, "server.RemotePort", StringComparison.Ordinal)) + { + value = RemotePort; + return true; + } + break; + case 18: + if (((_flag0 & 0x2u) != 0) && string.Equals(key, "owin.CallCancelled", StringComparison.Ordinal)) + { + value = CallCancelled; + return true; + } + if (((_flag0 & 0x8u) != 0) && string.Equals(key, "owin.RequestMethod", StringComparison.Ordinal)) + { + value = RequestMethod; + return true; + } + if (((_flag0 & 0x10u) != 0) && string.Equals(key, "owin.RequestScheme", StringComparison.Ordinal)) + { + value = RequestScheme; + return true; + } + if (((_flag0 & 0x40000000u) != 0) && string.Equals(key, "ssl.ChannelBinding", StringComparison.Ordinal)) + { + value = ChannelBinding; + // Delayed initialization in the property getter may determine that the element is not actually present + if (!((_flag0 & 0x40000000u) != 0)) + { + value = default(ChannelBinding); + return false; + } + return true; + } + if (((_flag0 & 0x80000000u) != 0) && string.Equals(key, "sendfile.SendAsync", StringComparison.Ordinal)) + { + value = SendFileAsync; + return true; + } + break; + case 19: + if (((_flag0 & 0x100u) != 0) && string.Equals(key, "owin.RequestHeaders", StringComparison.Ordinal)) + { + value = RequestHeaders; + return true; + } + if (((_flag0 & 0x20000u) != 0) && string.Equals(key, "host.OnAppDisposing", StringComparison.Ordinal)) + { + value = OnAppDisposing; + return true; + } + if (((_flag0 & 0x100000u) != 0) && string.Equals(key, "server.Capabilities", StringComparison.Ordinal)) + { + value = ServerCapabilities; + return true; + } + if (((_flag0 & 0x4000000u) != 0) && string.Equals(key, "server.ConnectionId", StringComparison.Ordinal)) + { + value = ConnectionId; + return true; + } + break; + case 20: + if (((_flag0 & 0x4u) != 0) && string.Equals(key, "owin.RequestProtocol", StringComparison.Ordinal)) + { + value = RequestProtocol; + return true; + } + if (((_flag0 & 0x20u) != 0) && string.Equals(key, "owin.RequestPathBase", StringComparison.Ordinal)) + { + value = RequestPathBase; + return true; + } + if (((_flag0 & 0x400u) != 0) && string.Equals(key, "owin.ResponseHeaders", StringComparison.Ordinal)) + { + value = ResponseHeaders; + return true; + } + break; + case 21: + if (((_flag0 & 0x800000u) != 0) && string.Equals(key, "server.LocalIpAddress", StringComparison.Ordinal)) + { + value = LocalIpAddress; + return true; + } + if (((_flag0 & 0x10000000u) != 0) && string.Equals(key, "ssl.ClientCertificate", StringComparison.Ordinal)) + { + value = ClientCert; + return true; + } + break; + case 22: + if (((_flag0 & 0x200000u) != 0) && string.Equals(key, "server.RemoteIpAddress", StringComparison.Ordinal)) + { + value = RemoteIpAddress; + return true; + } + break; + case 23: + if (((_flag0 & 0x80u) != 0) && string.Equals(key, "owin.RequestQueryString", StringComparison.Ordinal)) + { + value = RequestQueryString; + return true; + } + if (((_flag0 & 0x1000u) != 0) && string.Equals(key, "owin.ResponseStatusCode", StringComparison.Ordinal)) + { + value = ResponseStatusCode; + return true; + } + if (((_flag0 & 0x80000u) != 0) && string.Equals(key, "server.OnSendingHeaders", StringComparison.Ordinal)) + { + value = OnSendingHeaders; + return true; + } + if (((_flag0 & 0x20000000u) != 0) && string.Equals(key, "ssl.LoadClientCertAsync", StringComparison.Ordinal)) + { + value = LoadClientCert; + return true; + } + break; + case 25: + if (((_flag0 & 0x2000u) != 0) && string.Equals(key, "owin.ResponseReasonPhrase", StringComparison.Ordinal)) + { + value = ResponseReasonPhrase; + return true; + } + break; + case 27: + if (((_flag0 & 0x8000000u) != 0) && string.Equals(key, "server.ConnectionDisconnect", StringComparison.Ordinal)) + { + value = ConnectionDisconnect; + return true; + } + break; + case 51: + if (((_flag1 & 0x2u) != 0) && string.Equals(key, "Microsoft.AspNet.Server.WebListener.OwinWebListener", StringComparison.Ordinal)) + { + value = Listener; + return true; + } + break; + } + value = null; + return false; + } + + private bool PropertiesTrySetValue(string key, object value) + { + switch (key.Length) + { + case 11: + if (string.Equals(key, "server.User", StringComparison.Ordinal)) + { + User = (System.Security.Principal.IPrincipal)value; + return true; + } + break; + case 12: + if (string.Equals(key, "owin.Version", StringComparison.Ordinal)) + { + OwinVersion = (string)value; + return true; + } + if (string.Equals(key, "host.AppName", StringComparison.Ordinal)) + { + HostAppName = (string)value; + return true; + } + if (string.Equals(key, "host.AppMode", StringComparison.Ordinal)) + { + HostAppMode = (string)value; + return true; + } + break; + case 14: + if (string.Equals(key, "server.IsLocal", StringComparison.Ordinal)) + { + IsLocal = (bool)value; + return true; + } + if (string.Equals(key, "opaque.Upgrade", StringComparison.Ordinal)) + { + OpaqueUpgrade = (OpaqueUpgrade)value; + return true; + } + break; + case 16: + if (string.Equals(key, "owin.RequestPath", StringComparison.Ordinal)) + { + RequestPath = (string)value; + return true; + } + if (string.Equals(key, "owin.RequestBody", StringComparison.Ordinal)) + { + RequestBody = (Stream)value; + return true; + } + if (string.Equals(key, "host.TraceOutput", StringComparison.Ordinal)) + { + HostTraceOutput = (TextWriter)value; + return true; + } + if (string.Equals(key, "server.LocalPort", StringComparison.Ordinal)) + { + LocalPort = (string)value; + return true; + } + break; + case 17: + if (string.Equals(key, "owin.ResponseBody", StringComparison.Ordinal)) + { + ResponseBody = (Stream)value; + return true; + } + if (string.Equals(key, "server.RemotePort", StringComparison.Ordinal)) + { + RemotePort = (string)value; + return true; + } + break; + case 18: + if (string.Equals(key, "owin.CallCancelled", StringComparison.Ordinal)) + { + CallCancelled = (CancellationToken)value; + return true; + } + if (string.Equals(key, "owin.RequestMethod", StringComparison.Ordinal)) + { + RequestMethod = (string)value; + return true; + } + if (string.Equals(key, "owin.RequestScheme", StringComparison.Ordinal)) + { + RequestScheme = (string)value; + return true; + } + if (string.Equals(key, "ssl.ChannelBinding", StringComparison.Ordinal)) + { + ChannelBinding = (ChannelBinding)value; + return true; + } + if (string.Equals(key, "sendfile.SendAsync", StringComparison.Ordinal)) + { + SendFileAsync = (Func)value; + return true; + } + break; + case 19: + if (string.Equals(key, "owin.RequestHeaders", StringComparison.Ordinal)) + { + RequestHeaders = (IDictionary)value; + return true; + } + if (string.Equals(key, "host.OnAppDisposing", StringComparison.Ordinal)) + { + OnAppDisposing = (CancellationToken)value; + return true; + } + if (string.Equals(key, "server.Capabilities", StringComparison.Ordinal)) + { + ServerCapabilities = (IDictionary)value; + return true; + } + if (string.Equals(key, "server.ConnectionId", StringComparison.Ordinal)) + { + ConnectionId = (object)value; + return true; + } + break; + case 20: + if (string.Equals(key, "owin.RequestProtocol", StringComparison.Ordinal)) + { + RequestProtocol = (string)value; + return true; + } + if (string.Equals(key, "owin.RequestPathBase", StringComparison.Ordinal)) + { + RequestPathBase = (string)value; + return true; + } + if (string.Equals(key, "owin.ResponseHeaders", StringComparison.Ordinal)) + { + ResponseHeaders = (IDictionary)value; + return true; + } + break; + case 21: + if (string.Equals(key, "server.LocalIpAddress", StringComparison.Ordinal)) + { + LocalIpAddress = (string)value; + return true; + } + if (string.Equals(key, "ssl.ClientCertificate", StringComparison.Ordinal)) + { + ClientCert = (object)value; + return true; + } + break; + case 22: + if (string.Equals(key, "server.RemoteIpAddress", StringComparison.Ordinal)) + { + RemoteIpAddress = (string)value; + return true; + } + break; + case 23: + if (string.Equals(key, "owin.RequestQueryString", StringComparison.Ordinal)) + { + RequestQueryString = (string)value; + return true; + } + if (string.Equals(key, "owin.ResponseStatusCode", StringComparison.Ordinal)) + { + ResponseStatusCode = (int?)value; + return true; + } + if (string.Equals(key, "server.OnSendingHeaders", StringComparison.Ordinal)) + { + OnSendingHeaders = (Action, object>)value; + return true; + } + if (string.Equals(key, "ssl.LoadClientCertAsync", StringComparison.Ordinal)) + { + LoadClientCert = (Func)value; + return true; + } + break; + case 25: + if (string.Equals(key, "owin.ResponseReasonPhrase", StringComparison.Ordinal)) + { + ResponseReasonPhrase = (string)value; + return true; + } + break; + case 27: + if (string.Equals(key, "server.ConnectionDisconnect", StringComparison.Ordinal)) + { + ConnectionDisconnect = (CancellationToken)value; + return true; + } + break; + case 51: + if (string.Equals(key, "Microsoft.AspNet.Server.WebListener.OwinWebListener", StringComparison.Ordinal)) + { + Listener = (OwinWebListener)value; + return true; + } + break; + } + return false; + } + + private bool PropertiesTryRemove(string key) + { + switch (key.Length) + { + case 11: + if (((_flag0 & 0x40000u) != 0) && string.Equals(key, "server.User", StringComparison.Ordinal)) + { + _flag0 &= ~0x40000u; + _User = default(System.Security.Principal.IPrincipal); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 12: + if (((_flag0 & 0x1u) != 0) && string.Equals(key, "owin.Version", StringComparison.Ordinal)) + { + _flag0 &= ~0x1u; + _OwinVersion = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x8000u) != 0) && string.Equals(key, "host.AppName", StringComparison.Ordinal)) + { + _flag0 &= ~0x8000u; + _HostAppName = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x10000u) != 0) && string.Equals(key, "host.AppMode", StringComparison.Ordinal)) + { + _flag0 &= ~0x10000u; + _HostAppMode = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 14: + if (((_flag0 & 0x2000000u) != 0) && string.Equals(key, "server.IsLocal", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x2000000u; + _flag0 &= ~0x2000000u; + _IsLocal = default(bool); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag1 & 0x1u) != 0) && string.Equals(key, "opaque.Upgrade", StringComparison.Ordinal)) + { + _initFlag1 &= ~0x1u; + _flag1 &= ~0x1u; + _OpaqueUpgrade = default(OpaqueUpgrade); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 16: + if (((_flag0 & 0x40u) != 0) && string.Equals(key, "owin.RequestPath", StringComparison.Ordinal)) + { + _flag0 &= ~0x40u; + _RequestPath = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x200u) != 0) && string.Equals(key, "owin.RequestBody", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x200u; + _flag0 &= ~0x200u; + _RequestBody = default(Stream); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x4000u) != 0) && string.Equals(key, "host.TraceOutput", StringComparison.Ordinal)) + { + _flag0 &= ~0x4000u; + _HostTraceOutput = default(TextWriter); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x1000000u) != 0) && string.Equals(key, "server.LocalPort", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x1000000u; + _flag0 &= ~0x1000000u; + _LocalPort = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 17: + if (((_flag0 & 0x800u) != 0) && string.Equals(key, "owin.ResponseBody", StringComparison.Ordinal)) + { + _flag0 &= ~0x800u; + _ResponseBody = default(Stream); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x400000u) != 0) && string.Equals(key, "server.RemotePort", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x400000u; + _flag0 &= ~0x400000u; + _RemotePort = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 18: + if (((_flag0 & 0x2u) != 0) && string.Equals(key, "owin.CallCancelled", StringComparison.Ordinal)) + { + _flag0 &= ~0x2u; + _CallCancelled = default(CancellationToken); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x8u) != 0) && string.Equals(key, "owin.RequestMethod", StringComparison.Ordinal)) + { + _flag0 &= ~0x8u; + _RequestMethod = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x10u) != 0) && string.Equals(key, "owin.RequestScheme", StringComparison.Ordinal)) + { + _flag0 &= ~0x10u; + _RequestScheme = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x40000000u) != 0) && string.Equals(key, "ssl.ChannelBinding", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x40000000u; + _flag0 &= ~0x40000000u; + _ChannelBinding = default(ChannelBinding); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x80000000u) != 0) && string.Equals(key, "sendfile.SendAsync", StringComparison.Ordinal)) + { + _flag0 &= ~0x80000000u; + _SendFileAsync = default(Func); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 19: + if (((_flag0 & 0x100u) != 0) && string.Equals(key, "owin.RequestHeaders", StringComparison.Ordinal)) + { + _flag0 &= ~0x100u; + _RequestHeaders = default(IDictionary); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x20000u) != 0) && string.Equals(key, "host.OnAppDisposing", StringComparison.Ordinal)) + { + _flag0 &= ~0x20000u; + _OnAppDisposing = default(CancellationToken); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x100000u) != 0) && string.Equals(key, "server.Capabilities", StringComparison.Ordinal)) + { + _flag0 &= ~0x100000u; + _ServerCapabilities = default(IDictionary); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x4000000u) != 0) && string.Equals(key, "server.ConnectionId", StringComparison.Ordinal)) + { + _flag0 &= ~0x4000000u; + _ConnectionId = default(object); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 20: + if (((_flag0 & 0x4u) != 0) && string.Equals(key, "owin.RequestProtocol", StringComparison.Ordinal)) + { + _flag0 &= ~0x4u; + _RequestProtocol = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x20u) != 0) && string.Equals(key, "owin.RequestPathBase", StringComparison.Ordinal)) + { + _flag0 &= ~0x20u; + _RequestPathBase = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x400u) != 0) && string.Equals(key, "owin.ResponseHeaders", StringComparison.Ordinal)) + { + _flag0 &= ~0x400u; + _ResponseHeaders = default(IDictionary); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 21: + if (((_flag0 & 0x800000u) != 0) && string.Equals(key, "server.LocalIpAddress", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x800000u; + _flag0 &= ~0x800000u; + _LocalIpAddress = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x10000000u) != 0) && string.Equals(key, "ssl.ClientCertificate", StringComparison.Ordinal)) + { + _flag0 &= ~0x10000000u; + _ClientCert = default(object); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 22: + if (((_flag0 & 0x200000u) != 0) && string.Equals(key, "server.RemoteIpAddress", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x200000u; + _flag0 &= ~0x200000u; + _RemoteIpAddress = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 23: + if (((_flag0 & 0x80u) != 0) && string.Equals(key, "owin.RequestQueryString", StringComparison.Ordinal)) + { + _flag0 &= ~0x80u; + _RequestQueryString = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x1000u) != 0) && string.Equals(key, "owin.ResponseStatusCode", StringComparison.Ordinal)) + { + _flag0 &= ~0x1000u; + _ResponseStatusCode = default(int?); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x80000u) != 0) && string.Equals(key, "server.OnSendingHeaders", StringComparison.Ordinal)) + { + _flag0 &= ~0x80000u; + _OnSendingHeaders = default(Action, object>); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x20000000u) != 0) && string.Equals(key, "ssl.LoadClientCertAsync", StringComparison.Ordinal)) + { + _flag0 &= ~0x20000000u; + _LoadClientCert = default(Func); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 25: + if (((_flag0 & 0x2000u) != 0) && string.Equals(key, "owin.ResponseReasonPhrase", StringComparison.Ordinal)) + { + _flag0 &= ~0x2000u; + _ResponseReasonPhrase = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 27: + if (((_flag0 & 0x8000000u) != 0) && string.Equals(key, "server.ConnectionDisconnect", StringComparison.Ordinal)) + { + _flag0 &= ~0x8000000u; + _ConnectionDisconnect = default(CancellationToken); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 51: + if (((_flag1 & 0x2u) != 0) && string.Equals(key, "Microsoft.AspNet.Server.WebListener.OwinWebListener", StringComparison.Ordinal)) + { + _flag1 &= ~0x2u; + _Listener = default(OwinWebListener); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + } + return false; + } + + private IEnumerable PropertiesKeys() + { + if (((_flag0 & 0x1u) != 0)) + { + yield return "owin.Version"; + } + if (((_flag0 & 0x2u) != 0)) + { + yield return "owin.CallCancelled"; + } + if (((_flag0 & 0x4u) != 0)) + { + yield return "owin.RequestProtocol"; + } + if (((_flag0 & 0x8u) != 0)) + { + yield return "owin.RequestMethod"; + } + if (((_flag0 & 0x10u) != 0)) + { + yield return "owin.RequestScheme"; + } + if (((_flag0 & 0x20u) != 0)) + { + yield return "owin.RequestPathBase"; + } + if (((_flag0 & 0x40u) != 0)) + { + yield return "owin.RequestPath"; + } + if (((_flag0 & 0x80u) != 0)) + { + yield return "owin.RequestQueryString"; + } + if (((_flag0 & 0x100u) != 0)) + { + yield return "owin.RequestHeaders"; + } + if (((_flag0 & 0x200u) != 0)) + { + yield return "owin.RequestBody"; + } + if (((_flag0 & 0x400u) != 0)) + { + yield return "owin.ResponseHeaders"; + } + if (((_flag0 & 0x800u) != 0)) + { + yield return "owin.ResponseBody"; + } + if (((_flag0 & 0x1000u) != 0)) + { + yield return "owin.ResponseStatusCode"; + } + if (((_flag0 & 0x2000u) != 0)) + { + yield return "owin.ResponseReasonPhrase"; + } + if (((_flag0 & 0x4000u) != 0)) + { + yield return "host.TraceOutput"; + } + if (((_flag0 & 0x8000u) != 0)) + { + yield return "host.AppName"; + } + if (((_flag0 & 0x10000u) != 0)) + { + yield return "host.AppMode"; + } + if (((_flag0 & 0x20000u) != 0)) + { + yield return "host.OnAppDisposing"; + } + if (((_flag0 & 0x40000u) != 0)) + { + yield return "server.User"; + } + if (((_flag0 & 0x80000u) != 0)) + { + yield return "server.OnSendingHeaders"; + } + if (((_flag0 & 0x100000u) != 0)) + { + yield return "server.Capabilities"; + } + if (((_flag0 & 0x200000u) != 0)) + { + yield return "server.RemoteIpAddress"; + } + if (((_flag0 & 0x400000u) != 0)) + { + yield return "server.RemotePort"; + } + if (((_flag0 & 0x800000u) != 0)) + { + yield return "server.LocalIpAddress"; + } + if (((_flag0 & 0x1000000u) != 0)) + { + yield return "server.LocalPort"; + } + if (((_flag0 & 0x2000000u) != 0)) + { + yield return "server.IsLocal"; + } + if (((_flag0 & 0x4000000u) != 0)) + { + yield return "server.ConnectionId"; + } + if (((_flag0 & 0x8000000u) != 0)) + { + yield return "server.ConnectionDisconnect"; + } + if (((_flag0 & 0x10000000u) != 0)) + { + yield return "ssl.ClientCertificate"; + } + if (((_flag0 & 0x20000000u) != 0)) + { + yield return "ssl.LoadClientCertAsync"; + } + if (((_flag0 & 0x40000000u) != 0)) + { + if (((_initFlag0 & 0x40000000u) == 0) || InitPropertyChannelBinding()) + { + yield return "ssl.ChannelBinding"; + } + } + if (((_flag0 & 0x80000000u) != 0)) + { + yield return "sendfile.SendAsync"; + } + if (((_flag1 & 0x1u) != 0)) + { + if (((_initFlag1 & 0x1u) == 0) || InitPropertyOpaqueUpgrade()) + { + yield return "opaque.Upgrade"; + } + } + if (((_flag1 & 0x2u) != 0)) + { + yield return "Microsoft.AspNet.Server.WebListener.OwinWebListener"; + } + } + + private IEnumerable PropertiesValues() + { + if (((_flag0 & 0x1u) != 0)) + { + yield return OwinVersion; + } + if (((_flag0 & 0x2u) != 0)) + { + yield return CallCancelled; + } + if (((_flag0 & 0x4u) != 0)) + { + yield return RequestProtocol; + } + if (((_flag0 & 0x8u) != 0)) + { + yield return RequestMethod; + } + if (((_flag0 & 0x10u) != 0)) + { + yield return RequestScheme; + } + if (((_flag0 & 0x20u) != 0)) + { + yield return RequestPathBase; + } + if (((_flag0 & 0x40u) != 0)) + { + yield return RequestPath; + } + if (((_flag0 & 0x80u) != 0)) + { + yield return RequestQueryString; + } + if (((_flag0 & 0x100u) != 0)) + { + yield return RequestHeaders; + } + if (((_flag0 & 0x200u) != 0)) + { + yield return RequestBody; + } + if (((_flag0 & 0x400u) != 0)) + { + yield return ResponseHeaders; + } + if (((_flag0 & 0x800u) != 0)) + { + yield return ResponseBody; + } + if (((_flag0 & 0x1000u) != 0)) + { + yield return ResponseStatusCode; + } + if (((_flag0 & 0x2000u) != 0)) + { + yield return ResponseReasonPhrase; + } + if (((_flag0 & 0x4000u) != 0)) + { + yield return HostTraceOutput; + } + if (((_flag0 & 0x8000u) != 0)) + { + yield return HostAppName; + } + if (((_flag0 & 0x10000u) != 0)) + { + yield return HostAppMode; + } + if (((_flag0 & 0x20000u) != 0)) + { + yield return OnAppDisposing; + } + if (((_flag0 & 0x40000u) != 0)) + { + yield return User; + } + if (((_flag0 & 0x80000u) != 0)) + { + yield return OnSendingHeaders; + } + if (((_flag0 & 0x100000u) != 0)) + { + yield return ServerCapabilities; + } + if (((_flag0 & 0x200000u) != 0)) + { + yield return RemoteIpAddress; + } + if (((_flag0 & 0x400000u) != 0)) + { + yield return RemotePort; + } + if (((_flag0 & 0x800000u) != 0)) + { + yield return LocalIpAddress; + } + if (((_flag0 & 0x1000000u) != 0)) + { + yield return LocalPort; + } + if (((_flag0 & 0x2000000u) != 0)) + { + yield return IsLocal; + } + if (((_flag0 & 0x4000000u) != 0)) + { + yield return ConnectionId; + } + if (((_flag0 & 0x8000000u) != 0)) + { + yield return ConnectionDisconnect; + } + if (((_flag0 & 0x10000000u) != 0)) + { + yield return ClientCert; + } + if (((_flag0 & 0x20000000u) != 0)) + { + yield return LoadClientCert; + } + if (((_flag0 & 0x40000000u) != 0)) + { + if (((_initFlag0 & 0x40000000u) == 0) || InitPropertyChannelBinding()) + { + yield return ChannelBinding; + } + } + if (((_flag0 & 0x80000000u) != 0)) + { + yield return SendFileAsync; + } + if (((_flag1 & 0x1u) != 0)) + { + if (((_initFlag1 & 0x1u) == 0) || InitPropertyOpaqueUpgrade()) + { + yield return OpaqueUpgrade; + } + } + if (((_flag1 & 0x2u) != 0)) + { + yield return Listener; + } + } + + private IEnumerable> PropertiesEnumerable() + { + if (((_flag0 & 0x1u) != 0)) + { + yield return new KeyValuePair("owin.Version", OwinVersion); + } + if (((_flag0 & 0x2u) != 0)) + { + yield return new KeyValuePair("owin.CallCancelled", CallCancelled); + } + if (((_flag0 & 0x4u) != 0)) + { + yield return new KeyValuePair("owin.RequestProtocol", RequestProtocol); + } + if (((_flag0 & 0x8u) != 0)) + { + yield return new KeyValuePair("owin.RequestMethod", RequestMethod); + } + if (((_flag0 & 0x10u) != 0)) + { + yield return new KeyValuePair("owin.RequestScheme", RequestScheme); + } + if (((_flag0 & 0x20u) != 0)) + { + yield return new KeyValuePair("owin.RequestPathBase", RequestPathBase); + } + if (((_flag0 & 0x40u) != 0)) + { + yield return new KeyValuePair("owin.RequestPath", RequestPath); + } + if (((_flag0 & 0x80u) != 0)) + { + yield return new KeyValuePair("owin.RequestQueryString", RequestQueryString); + } + if (((_flag0 & 0x100u) != 0)) + { + yield return new KeyValuePair("owin.RequestHeaders", RequestHeaders); + } + if (((_flag0 & 0x200u) != 0)) + { + yield return new KeyValuePair("owin.RequestBody", RequestBody); + } + if (((_flag0 & 0x400u) != 0)) + { + yield return new KeyValuePair("owin.ResponseHeaders", ResponseHeaders); + } + if (((_flag0 & 0x800u) != 0)) + { + yield return new KeyValuePair("owin.ResponseBody", ResponseBody); + } + if (((_flag0 & 0x1000u) != 0)) + { + yield return new KeyValuePair("owin.ResponseStatusCode", ResponseStatusCode); + } + if (((_flag0 & 0x2000u) != 0)) + { + yield return new KeyValuePair("owin.ResponseReasonPhrase", ResponseReasonPhrase); + } + if (((_flag0 & 0x4000u) != 0)) + { + yield return new KeyValuePair("host.TraceOutput", HostTraceOutput); + } + if (((_flag0 & 0x8000u) != 0)) + { + yield return new KeyValuePair("host.AppName", HostAppName); + } + if (((_flag0 & 0x10000u) != 0)) + { + yield return new KeyValuePair("host.AppMode", HostAppMode); + } + if (((_flag0 & 0x20000u) != 0)) + { + yield return new KeyValuePair("host.OnAppDisposing", OnAppDisposing); + } + if (((_flag0 & 0x40000u) != 0)) + { + yield return new KeyValuePair("server.User", User); + } + if (((_flag0 & 0x80000u) != 0)) + { + yield return new KeyValuePair("server.OnSendingHeaders", OnSendingHeaders); + } + if (((_flag0 & 0x100000u) != 0)) + { + yield return new KeyValuePair("server.Capabilities", ServerCapabilities); + } + if (((_flag0 & 0x200000u) != 0)) + { + yield return new KeyValuePair("server.RemoteIpAddress", RemoteIpAddress); + } + if (((_flag0 & 0x400000u) != 0)) + { + yield return new KeyValuePair("server.RemotePort", RemotePort); + } + if (((_flag0 & 0x800000u) != 0)) + { + yield return new KeyValuePair("server.LocalIpAddress", LocalIpAddress); + } + if (((_flag0 & 0x1000000u) != 0)) + { + yield return new KeyValuePair("server.LocalPort", LocalPort); + } + if (((_flag0 & 0x2000000u) != 0)) + { + yield return new KeyValuePair("server.IsLocal", IsLocal); + } + if (((_flag0 & 0x4000000u) != 0)) + { + yield return new KeyValuePair("server.ConnectionId", ConnectionId); + } + if (((_flag0 & 0x8000000u) != 0)) + { + yield return new KeyValuePair("server.ConnectionDisconnect", ConnectionDisconnect); + } + if (((_flag0 & 0x10000000u) != 0)) + { + yield return new KeyValuePair("ssl.ClientCertificate", ClientCert); + } + if (((_flag0 & 0x20000000u) != 0)) + { + yield return new KeyValuePair("ssl.LoadClientCertAsync", LoadClientCert); + } + if (((_flag0 & 0x40000000u) != 0)) + { + if (((_initFlag0 & 0x40000000u) == 0) || InitPropertyChannelBinding()) + { + yield return new KeyValuePair("ssl.ChannelBinding", ChannelBinding); + } + } + if (((_flag0 & 0x80000000u) != 0)) + { + yield return new KeyValuePair("sendfile.SendAsync", SendFileAsync); + } + if (((_flag1 & 0x1u) != 0)) + { + if (((_initFlag1 & 0x1u) == 0) || InitPropertyOpaqueUpgrade()) + { + yield return new KeyValuePair("opaque.Upgrade", OpaqueUpgrade); + } + } + if (((_flag1 & 0x2u) != 0)) + { + yield return new KeyValuePair("Microsoft.AspNet.Server.WebListener.OwinWebListener", Listener); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.tt b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.tt new file mode 100644 index 0000000000..41a1b3db06 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.tt @@ -0,0 +1,316 @@ +<#@ template language="C#" #> +<#@ assembly name="System.Core.dll" #> +<#@ import namespace="System.Linq" #> +<# +var Init = new {Yes = new object(), No = new object(), Maybe = new object()}; + +var props = new[] +{ +// owin standard keys + new {Key="owin.Version", Type="string", Name="OwinVersion", Init=Init.No}, + new {Key="owin.CallCancelled", Type="CancellationToken", Name="CallCancelled", Init=Init.No}, + + new {Key="owin.RequestProtocol", Type="string", Name="RequestProtocol", Init=Init.No}, + new {Key="owin.RequestMethod", Type="string", Name="RequestMethod", Init=Init.No}, + new {Key="owin.RequestScheme", Type="string", Name="RequestScheme", Init=Init.No}, + new {Key="owin.RequestPathBase", Type="string", Name="RequestPathBase", Init=Init.No}, + new {Key="owin.RequestPath", Type="string", Name="RequestPath", Init=Init.No}, + new {Key="owin.RequestQueryString", Type="string", Name="RequestQueryString", Init=Init.No}, + new {Key="owin.RequestHeaders", Type="IDictionary", Name="RequestHeaders", Init=Init.No}, + new {Key="owin.RequestBody", Type="Stream", Name="RequestBody", Init=Init.Yes}, + + new {Key="owin.ResponseHeaders", Type="IDictionary", Name="ResponseHeaders", Init=Init.No}, + new {Key="owin.ResponseBody", Type="Stream", Name="ResponseBody", Init=Init.No}, + new {Key="owin.ResponseStatusCode", Type="int?", Name="ResponseStatusCode", Init=Init.No}, + new {Key="owin.ResponseReasonPhrase", Type="string", Name="ResponseReasonPhrase", Init=Init.No}, + +// defacto host keys + new {Key="host.TraceOutput", Type="TextWriter", Name="HostTraceOutput", Init=Init.No}, + new {Key="host.AppName", Type="string", Name="HostAppName", Init=Init.No}, + new {Key="host.AppMode", Type="string", Name="HostAppMode", Init=Init.No}, + new {Key="host.OnAppDisposing", Type="CancellationToken", Name="OnAppDisposing", Init=Init.No}, + new {Key="server.User", Type="System.Security.Principal.IPrincipal", Name="User", Init=Init.No}, + new {Key="server.OnSendingHeaders", Type="Action, object>", Name="OnSendingHeaders", Init=Init.No}, + new {Key="server.Capabilities", Type="IDictionary", Name="ServerCapabilities", Init=Init.No}, + +// ServerVariable keys + new {Key="server.RemoteIpAddress", Type="string", Name="RemoteIpAddress", Init=Init.Yes}, + new {Key="server.RemotePort", Type="string", Name="RemotePort", Init=Init.Yes}, + new {Key="server.LocalIpAddress", Type="string", Name="LocalIpAddress", Init=Init.Yes}, + new {Key="server.LocalPort", Type="string", Name="LocalPort", Init=Init.Yes}, + new {Key="server.IsLocal", Type="bool", Name="IsLocal", Init=Init.Yes}, + new {Key="server.ConnectionId", Type="object", Name="ConnectionId", Init=Init.No}, + new {Key="server.ConnectionDisconnect", Type="CancellationToken", Name="ConnectionDisconnect", Init=Init.No}, + +// SSL + new { Key="ssl.ClientCertificate", Type="object", Name="ClientCert", Init=Init.No}, + new { Key="ssl.LoadClientCertAsync", Type="Func", Name="LoadClientCert", Init=Init.No }, + new { Key="ssl.ChannelBinding", Type="ChannelBinding", Name="ChannelBinding", Init=Init.Maybe }, + +// SendFile keys + new {Key="sendfile.SendAsync", Type="Func", Name="SendFileAsync", Init=Init.No}, + +// Opaque keys + new {Key="opaque.Upgrade", Type="OpaqueUpgrade", Name="OpaqueUpgrade", Init=Init.Maybe}, + +// Server specific keys + new { Key="Microsoft.AspNet.Server.WebListener.OwinWebListener", Type="OwinWebListener", Name="Listener", Init=Init.No}, +}.Select((prop, Index)=>new {prop.Key, prop.Type, prop.Name, prop.Init, Index}); + +var lengths = props.OrderBy(prop=>prop.Key.Length).GroupBy(prop=>prop.Key.Length); + +Func IsSet = Index => "((_flag" + (Index / 32) + " & 0x" + (1<<(Index % 32)).ToString("x") + "u) != 0)"; +Func Set = Index => "_flag" + (Index / 32) + " |= 0x" + (1<<(Index % 32)).ToString("x") + "u"; +Func Clear = Index => "_flag" + (Index / 32) + " &= ~0x" + (1<<(Index % 32)).ToString("x") + "u"; + +Func IsInitRequired = Index => "((_initFlag" + (Index / 32) + " & 0x" + (1<<(Index % 32)).ToString("x") + "u) != 0)"; +Func IsInitCompleted = Index => "((_initFlag" + (Index / 32) + " & 0x" + (1<<(Index % 32)).ToString("x") + "u) == 0)"; +Func CompleteInit = Index => "_initFlag" + (Index / 32) + " &= ~0x" + (1<<(Index % 32)).ToString("x") + "u"; + +#> +//----------------------------------------------------------------------- +// +// Copyright (c) Katana Contributors. All rights reserved. +// +//----------------------------------------------------------------------- +// + +using System; +using System.CodeDom.Compiler; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Security.Authentication.ExtendedProtection; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Owin.Host.WebListener +{ + using OpaqueUpgrade = Action, Func, Task>>; + + [GeneratedCode("TextTemplatingFileGenerator", "")] + internal partial class CallEnvironment + { + // Mark all fields with delay initialization support as set. + private UInt32 _flag0 = 0x<#=props.Aggregate(0, (agg,p) => agg | (((p.Init != Init.No) && (p.Index/32==0) ? 1:0)<u; + private UInt32 _flag1 = 0x<#=props.Aggregate(0, (agg,p) => agg | (((p.Init != Init.No) && (p.Index/32==1) ? 1:0)<u; + // Mark all fields with delay initialization support as requiring initialization. + private UInt32 _initFlag0 = 0x<#=props.Aggregate(0, (agg,p) => agg | (((p.Init != Init.No) && (p.Index/32==0) ? 1:0)<u; + private UInt32 _initFlag1 = 0x<#=props.Aggregate(0, (agg,p) => agg | (((p.Init != Init.No) && (p.Index/32==1) ? 1:0)<u; + + internal interface IPropertySource + { +<# foreach(var prop in props) { #> +<# if (prop.Init == Init.Yes) { #> + <#=prop.Type#> Get<#=prop.Name#>(); +<# } #> +<# if (prop.Init == Init.Maybe) { #> + bool TryGet<#=prop.Name#>(ref <#=prop.Type#> value); +<# } #> +<# } #> + } + +<# foreach(var prop in props) { #> + private <#=prop.Type#> _<#=prop.Name#>; +<# } #> + +<# foreach(var prop in props) { #> +<# // call TryGet once if init flag is set, clear value flag if TryGet returns false +if (prop.Init == Init.Maybe) { #> + bool InitProperty<#=prop.Name#>() + { + if (!_propertySource.TryGet<#=prop.Name#>(ref _<#=prop.Name#>)) + { + <#=Clear(prop.Index)#>; + <#=CompleteInit(prop.Index)#>; + return false; + } + <#=CompleteInit(prop.Index)#>; + return true; + } + +<# } #> +<# } #> +<# foreach(var prop in props) { #> + internal <#=prop.Type#> <#=prop.Name#> + { + get + { +<# // call Get once if init flag is set +if (prop.Init == Init.Yes) { #> + if (<#=IsInitRequired(prop.Index)#>) + { + _<#=prop.Name#> = _propertySource.Get<#=prop.Name#>(); + <#=CompleteInit(prop.Index)#>; + } +<# } #> +<# // call TryGet once if init flag is set, clear value flag if TryGet returns false +if (prop.Init == Init.Maybe) { #> + if (<#=IsInitRequired(prop.Index)#>) + { + InitProperty<#=prop.Name#>(); + } +<# } #> + return _<#=prop.Name#>; + } + set + { +<# // clear init flag - the assigned value is definitive +if (prop.Init != Init.No) { #> + <#=CompleteInit(prop.Index)#>; +<# } #> + <#=Set(prop.Index)#>; + _<#=prop.Name#> = value; + } + } + +<# } #> + private bool PropertiesContainsKey(string key) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (<#=IsSet(prop.Index)#> && string.Equals(key, "<#=prop.Key#>", StringComparison.Ordinal)) + { +<# // variable maybe init might revert +if (prop.Init == Init.Maybe) { #> + if (<#=IsInitCompleted(prop.Index)#> || InitProperty<#=prop.Name#>()) + { + return true; + } +<# } else { #> + return true; +<# } #> + } +<# } #> + break; +<# } #> + } + return false; + } + + private bool PropertiesTryGetValue(string key, out object value) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (<#=IsSet(prop.Index)#> && string.Equals(key, "<#=prop.Key#>", StringComparison.Ordinal)) + { + value = <#=prop.Name#>; +<# if (prop.Init == Init.Maybe) { #> + // Delayed initialization in the property getter may determine that the element is not actually present + if (!<#=IsSet(prop.Index)#>) + { + value = default(<#=prop.Type#>); + return false; + } +<# } #> + return true; + } +<# } #> + break; +<# } #> + } + value = null; + return false; + } + + private bool PropertiesTrySetValue(string key, object value) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (string.Equals(key, "<#=prop.Key#>", StringComparison.Ordinal)) + { + <#=prop.Name#> = (<#=prop.Type#>)value; + return true; + } +<# } #> + break; +<# } #> + } + return false; + } + + private bool PropertiesTryRemove(string key) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (<#=IsSet(prop.Index)#> && string.Equals(key, "<#=prop.Key#>", StringComparison.Ordinal)) + { +<# if (prop.Init != Init.No) { #> + <#=CompleteInit(prop.Index)#>; +<# } #> + <#=Clear(prop.Index)#>; + _<#=prop.Name#> = default(<#=prop.Type#>); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } +<# } #> + break; +<# } #> + } + return false; + } + + private IEnumerable PropertiesKeys() + { +<# foreach(var prop in props) { #> + if (<#=IsSet(prop.Index)#>) + { +<# if (prop.Init == Init.Maybe) { #> + if (<#=IsInitCompleted(prop.Index)#> || InitProperty<#=prop.Name#>()) + { + yield return "<#=prop.Key#>"; + } +<# } else { #> + yield return "<#=prop.Key#>"; +<# } #> + } +<# } #> + } + + private IEnumerable PropertiesValues() + { +<# foreach(var prop in props) { #> + if (<#=IsSet(prop.Index)#>) + { +<# if (prop.Init == Init.Maybe) { #> + if (<#=IsInitCompleted(prop.Index)#> || InitProperty<#=prop.Name#>()) + { + yield return <#=prop.Name#>; + } +<# } else { #> + yield return <#=prop.Name#>; +<# } #> + } +<# } #> + } + + private IEnumerable> PropertiesEnumerable() + { +<# foreach(var prop in props) { #> + if (<#=IsSet(prop.Index)#>) + { +<# if (prop.Init == Init.Maybe) { #> + if (<#=IsInitCompleted(prop.Index)#> || InitProperty<#=prop.Name#>()) + { + yield return new KeyValuePair("<#=prop.Key#>", <#=prop.Name#>); + } +<# } else { #> + yield return new KeyValuePair("<#=prop.Key#>", <#=prop.Name#>); +<# } #> + } +<# } #> + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.cs new file mode 100644 index 0000000000..360a2ea245 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.cs @@ -0,0 +1,165 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- +// Copyright 2011-2012 Katana contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Threading; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal partial class CallEnvironment : IDictionary + { + private static readonly IDictionary WeakNilEnvironment = new NilEnvDictionary(); + + private readonly IPropertySource _propertySource; + private IDictionary _extra = WeakNilEnvironment; + + internal CallEnvironment(IPropertySource propertySource) + { + _propertySource = propertySource; + } + + private IDictionary Extra + { + get { return _extra; } + } + + private IDictionary StrongExtra + { + get + { + if (_extra == WeakNilEnvironment) + { + Interlocked.CompareExchange(ref _extra, new Dictionary(), WeakNilEnvironment); + } + return _extra; + } + } + + internal bool IsExtraDictionaryCreated + { + get { return _extra != WeakNilEnvironment; } + } + + public object this[string key] + { + get + { + object value; + return PropertiesTryGetValue(key, out value) ? value : Extra[key]; + } + set + { + if (!PropertiesTrySetValue(key, value)) + { + StrongExtra[key] = value; + } + } + } + + public void Add(string key, object value) + { + if (!PropertiesTrySetValue(key, value)) + { + StrongExtra.Add(key, value); + } + } + + public bool ContainsKey(string key) + { + return PropertiesContainsKey(key) || Extra.ContainsKey(key); + } + + public ICollection Keys + { + get { return PropertiesKeys().Concat(Extra.Keys).ToArray(); } + } + + public bool Remove(string key) + { + // Although this is a mutating operation, Extra is used instead of StrongExtra, + // because if a real dictionary has not been allocated the default behavior of the + // nil dictionary is perfectly fine. + return PropertiesTryRemove(key) || Extra.Remove(key); + } + + public bool TryGetValue(string key, out object value) + { + return PropertiesTryGetValue(key, out value) || Extra.TryGetValue(key, out value); + } + + public ICollection Values + { + get { return PropertiesValues().Concat(Extra.Values).ToArray(); } + } + + public void Add(KeyValuePair item) + { + ((IDictionary)this).Add(item.Key, item.Value); + } + + public void Clear() + { + foreach (var key in PropertiesKeys()) + { + PropertiesTryRemove(key); + } + Extra.Clear(); + } + + public bool Contains(KeyValuePair item) + { + object value; + return ((IDictionary)this).TryGetValue(item.Key, out value) && Object.Equals(value, item.Value); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + PropertiesEnumerable().Concat(Extra).ToArray().CopyTo(array, arrayIndex); + } + + public int Count + { + get { return PropertiesKeys().Count() + Extra.Count; } + } + + public bool IsReadOnly + { + get { return false; } + } + + public bool Remove(KeyValuePair item) + { + return ((IDictionary)this).Contains(item) && + ((IDictionary)this).Remove(item.Key); + } + + public IEnumerator> GetEnumerator() + { + return PropertiesEnumerable().Concat(Extra).GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return ((IDictionary)this).GetEnumerator(); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ClientCertLoader.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ClientCertLoader.cs new file mode 100644 index 0000000000..3cd17db788 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ClientCertLoader.cs @@ -0,0 +1,347 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +#if NET45 + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.Runtime.InteropServices; +using System.Security; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + // This class is used to load the client certificate on-demand. Because client certs are optional, all + // failures are handled internally and reported via ClientCertException or ClientCertError. + internal unsafe sealed class ClientCertLoader : IAsyncResult, IDisposable + { + private const uint CertBoblSize = 1500; + private static readonly IOCompletionCallback IOCallback = new IOCompletionCallback(WaitCallback); + + private SafeNativeOverlapped _overlapped; + private byte[] _backingBuffer; + private UnsafeNclNativeMethods.HttpApi.HTTP_SSL_CLIENT_CERT_INFO* _memoryBlob; + private uint _size; + private TaskCompletionSource _tcs; + private RequestContext _requestContext; + + private int _clientCertError; + private X509Certificate2 _clientCert; + private Exception _clientCertException; + + internal ClientCertLoader(RequestContext requestContext) + { + _requestContext = requestContext; + _tcs = new TaskCompletionSource(); + // we will use this overlapped structure to issue async IO to ul + // the event handle will be put in by the BeginHttpApi2.ERROR_SUCCESS() method + Reset(CertBoblSize); + } + + internal X509Certificate2 ClientCert + { + get + { + Contract.Assert(Task.IsCompleted); + return _clientCert; + } + } + + internal int ClientCertError + { + get + { + Contract.Assert(Task.IsCompleted); + return _clientCertError; + } + } + + internal Exception ClientCertException + { + get + { + Contract.Assert(Task.IsCompleted); + return _clientCertException; + } + } + + private RequestContext RequestContext + { + get + { + return _requestContext; + } + } + + private Task Task + { + get + { + return _tcs.Task; + } + } + + private SafeNativeOverlapped NativeOverlapped + { + get + { + return _overlapped; + } + } + + private UnsafeNclNativeMethods.HttpApi.HTTP_SSL_CLIENT_CERT_INFO* RequestBlob + { + get + { + return _memoryBlob; + } + } + + private void Reset(uint size) + { + if (size == _size) + { + return; + } + if (_size != 0) + { + _overlapped.Dispose(); + } + _size = size; + if (size == 0) + { + _overlapped = null; + _memoryBlob = null; + _backingBuffer = null; + return; + } + _backingBuffer = new byte[checked((int)size)]; + Overlapped overlapped = new Overlapped(); + overlapped.AsyncResult = this; + _overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, _backingBuffer)); + _memoryBlob = (UnsafeNclNativeMethods.HttpApi.HTTP_SSL_CLIENT_CERT_INFO*)Marshal.UnsafeAddrOfPinnedArrayElement(_backingBuffer, 0); + } + + // When you use netsh to configure HTTP.SYS with clientcertnegotiation = enable + // which means negotiate client certificates, when the client makes the + // initial SSL connection, the server (HTTP.SYS) requests the client certificate. + // + // Some apps may not want to negotiate the client cert at the beginning, + // perhaps serving the default.htm. In this case the HTTP.SYS is configured + // with clientcertnegotiation = disabled, which means that the client certificate is + // optional so initially when SSL is established HTTP.SYS won't ask for client + // certificate. This works fine for the default.htm in the case above, + // however, if the app wants to demand a client certificate at a later time + // perhaps showing "YOUR ORDERS" page, then the server wants to negotiate + // Client certs. This will in turn makes HTTP.SYS to do the + // SEC_I_RENOGOTIATE through which the client cert demand is made + // + // NOTE: When calling HttpReceiveClientCertificate you can get + // ERROR_NOT_FOUND - which means the client did not provide the cert + // If this is important, the server should respond with 403 forbidden + // HTTP.SYS will not do this for you automatically + internal Task LoadClientCertificateAsync() + { + uint size = CertBoblSize; + bool retry; + do + { + retry = false; + uint bytesReceived = 0; + + uint statusCode = + UnsafeNclNativeMethods.HttpApi.HttpReceiveClientCertificate( + RequestContext.RequestQueueHandle, + RequestContext.Request.ConnectionId, + (uint)UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE, + RequestBlob, + size, + &bytesReceived, + NativeOverlapped); + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_MORE_DATA) + { + UnsafeNclNativeMethods.HttpApi.HTTP_SSL_CLIENT_CERT_INFO* pClientCertInfo = RequestBlob; + size = bytesReceived + pClientCertInfo->CertEncodedSize; + Reset(size); + retry = true; + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_NOT_FOUND) + { + // The client did not send a cert. + Complete(0, null); + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + OwinWebListener.SkipIOCPCallbackOnSuccess) + { + IOCompleted(statusCode, bytesReceived); + } + else if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + // Some other bad error, possible(?) return values are: + // ERROR_INVALID_HANDLE, ERROR_INSUFFICIENT_BUFFER, ERROR_OPERATION_ABORTED + // Also ERROR_BAD_DATA if we got it twice or it reported smaller size buffer required. + Fail(new WebListenerException((int)statusCode)); + } + } + while (retry); + + return Task; + } + + private void Complete(int certErrors, X509Certificate2 cert) + { + // May be null + _clientCert = cert; + _clientCertError = certErrors; + _tcs.TrySetResult(null); + Dispose(); + } + + private void Fail(Exception ex) + { + // TODO: Log + _clientCertException = ex; + _tcs.TrySetResult(null); + } + + private unsafe void IOCompleted(uint errorCode, uint numBytes) + { + IOCompleted(this, errorCode, numBytes); + } + + [SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Redirected to callback")] + private static unsafe void IOCompleted(ClientCertLoader asyncResult, uint errorCode, uint numBytes) + { + RequestContext requestContext = asyncResult.RequestContext; + try + { + if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_MORE_DATA) + { + // There is a bug that has existed in http.sys since w2k3. Bytesreceived will only + // return the size of the initial cert structure. To get the full size, + // we need to add the certificate encoding size as well. + + UnsafeNclNativeMethods.HttpApi.HTTP_SSL_CLIENT_CERT_INFO* pClientCertInfo = asyncResult.RequestBlob; + asyncResult.Reset(numBytes + pClientCertInfo->CertEncodedSize); + + uint bytesReceived = 0; + errorCode = + UnsafeNclNativeMethods.HttpApi.HttpReceiveClientCertificate( + requestContext.RequestQueueHandle, + requestContext.Request.ConnectionId, + (uint)UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE, + asyncResult._memoryBlob, + asyncResult._size, + &bytesReceived, + asyncResult._overlapped); + + if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING || + (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && !OwinWebListener.SkipIOCPCallbackOnSuccess)) + { + return; + } + } + + if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_NOT_FOUND) + { + // The client did not send a cert. + asyncResult.Complete(0, null); + } + else if (errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + asyncResult.Fail(new WebListenerException((int)errorCode)); + } + else + { + UnsafeNclNativeMethods.HttpApi.HTTP_SSL_CLIENT_CERT_INFO* pClientCertInfo = asyncResult._memoryBlob; + if (pClientCertInfo == null) + { + asyncResult.Complete(0, null); + } + else + { + if (pClientCertInfo->pCertEncoded != null) + { + try + { + byte[] certEncoded = new byte[pClientCertInfo->CertEncodedSize]; + Marshal.Copy((IntPtr)pClientCertInfo->pCertEncoded, certEncoded, 0, certEncoded.Length); + asyncResult.Complete((int)pClientCertInfo->CertFlags, new X509Certificate2(certEncoded)); + } + catch (CryptographicException exception) + { + // TODO: Log + asyncResult.Fail(exception); + } + catch (SecurityException exception) + { + // TODO: Log + asyncResult.Fail(exception); + } + } + } + } + } + catch (Exception exception) + { + asyncResult.Fail(exception); + } + } + + private static unsafe void WaitCallback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped) + { + Overlapped callbackOverlapped = Overlapped.Unpack(nativeOverlapped); + ClientCertLoader asyncResult = (ClientCertLoader)callbackOverlapped.AsyncResult; + + IOCompleted(asyncResult, errorCode, numBytes); + } + + public void Dispose() + { + Dispose(true); + } + + private void Dispose(bool disposing) + { + if (disposing) + { + if (_overlapped != null) + { + _memoryBlob = null; + _overlapped.Dispose(); + } + } + } + + public object AsyncState + { + get { return _tcs.Task.AsyncState; } + } + + public WaitHandle AsyncWaitHandle + { + get { return ((IAsyncResult)_tcs.Task).AsyncWaitHandle; } + } + + public bool CompletedSynchronously + { + get { return ((IAsyncResult)_tcs.Task).CompletedSynchronously; } + } + + public bool IsCompleted + { + get { return _tcs.Task.IsCompleted; } + } + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/EntitySendFormat.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/EntitySendFormat.cs new file mode 100644 index 0000000000..f6861f9bd6 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/EntitySendFormat.cs @@ -0,0 +1,17 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum EntitySendFormat + { + ContentLength = 0, // Content-Length: XXX + Chunked = 1, // Transfer-Encoding: chunked + /* + Raw = 2, // the app is responsible for sending the correct headers and body encoding + */ + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HeaderEncoding.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HeaderEncoding.cs new file mode 100644 index 0000000000..24a8df80ff --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HeaderEncoding.cs @@ -0,0 +1,97 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Microsoft.AspNet.Server.WebListener +{ + // we use this static class as a helper class to encode/decode HTTP headers. + // what we need is a 1-1 correspondence between a char in the range U+0000-U+00FF + // and a byte in the range 0x00-0xFF (which is the range that can hit the network). + // The Latin-1 encoding (ISO-88591-1) (GetEncoding(28591)) works for byte[] to string, but is a little slow. + // It doesn't work for string -> byte[] because of best-fit-mapping problems. + internal static class HeaderEncoding + { + internal static unsafe string GetString(byte[] bytes, int byteIndex, int byteCount) + { + fixed (byte* pBytes = bytes) + return GetString(pBytes + byteIndex, byteCount); + } + + internal static unsafe string GetString(sbyte* pBytes, int byteCount) + { + return GetString((byte*)pBytes, byteCount); + } + + internal static unsafe string GetString(byte* pBytes, int byteCount) + { + if (byteCount < 1) + { + return string.Empty; + } + + string s = new String('\0', byteCount); + + fixed (char* pStr = s) + { + char* pString = pStr; + while (byteCount >= 8) + { + pString[0] = (char)pBytes[0]; + pString[1] = (char)pBytes[1]; + pString[2] = (char)pBytes[2]; + pString[3] = (char)pBytes[3]; + pString[4] = (char)pBytes[4]; + pString[5] = (char)pBytes[5]; + pString[6] = (char)pBytes[6]; + pString[7] = (char)pBytes[7]; + pString += 8; + pBytes += 8; + byteCount -= 8; + } + for (int i = 0; i < byteCount; i++) + { + pString[i] = (char)pBytes[i]; + } + } + + return s; + } + + internal static int GetByteCount(string myString) + { + return myString.Length; + } + internal static unsafe void GetBytes(string myString, int charIndex, int charCount, byte[] bytes, int byteIndex) + { + if (myString.Length == 0) + { + return; + } + fixed (byte* bufferPointer = bytes) + { + byte* newBufferPointer = bufferPointer + byteIndex; + int finalIndex = charIndex + charCount; + while (charIndex < finalIndex) + { + *newBufferPointer++ = (byte)myString[charIndex++]; + } + } + } + internal static unsafe byte[] GetBytes(string myString) + { + byte[] bytes = new byte[myString.Length]; + if (myString.Length != 0) + { + GetBytes(myString, 0, myString.Length, bytes, 0); + } + return bytes; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpKnownHeaderNames.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpKnownHeaderNames.cs new file mode 100644 index 0000000000..9885de2fcd --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpKnownHeaderNames.cs @@ -0,0 +1,75 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class HttpKnownHeaderNames + { + internal const string CacheControl = "Cache-Control"; + internal const string Connection = "Connection"; + internal const string Date = "Date"; + internal const string KeepAlive = "Keep-Alive"; + internal const string Pragma = "Pragma"; + internal const string ProxyConnection = "Proxy-Connection"; + internal const string Trailer = "Trailer"; + internal const string TransferEncoding = "Transfer-Encoding"; + internal const string Upgrade = "Upgrade"; + internal const string Via = "Via"; + internal const string Warning = "Warning"; + internal const string ContentLength = "Content-Length"; + internal const string ContentType = "Content-Type"; + internal const string ContentDisposition = "Content-Disposition"; + internal const string ContentEncoding = "Content-Encoding"; + internal const string ContentLanguage = "Content-Language"; + internal const string ContentLocation = "Content-Location"; + internal const string ContentRange = "Content-Range"; + internal const string Expires = "Expires"; + internal const string LastModified = "Last-Modified"; + internal const string Age = "Age"; + internal const string Location = "Location"; + internal const string ProxyAuthenticate = "Proxy-Authenticate"; + internal const string RetryAfter = "Retry-After"; + internal const string Server = "Server"; + internal const string SetCookie = "Set-Cookie"; + internal const string SetCookie2 = "Set-Cookie2"; + internal const string Vary = "Vary"; + internal const string WWWAuthenticate = "WWW-Authenticate"; + internal const string Accept = "Accept"; + internal const string AcceptCharset = "Accept-Charset"; + internal const string AcceptEncoding = "Accept-Encoding"; + internal const string AcceptLanguage = "Accept-Language"; + internal const string Authorization = "Authorization"; + internal const string Cookie = "Cookie"; + internal const string Cookie2 = "Cookie2"; + internal const string Expect = "Expect"; + internal const string From = "From"; + internal const string Host = "Host"; + internal const string IfMatch = "If-Match"; + internal const string IfModifiedSince = "If-Modified-Since"; + internal const string IfNoneMatch = "If-None-Match"; + internal const string IfRange = "If-Range"; + internal const string IfUnmodifiedSince = "If-Unmodified-Since"; + internal const string MaxForwards = "Max-Forwards"; + internal const string ProxyAuthorization = "Proxy-Authorization"; + internal const string Referer = "Referer"; + internal const string Range = "Range"; + internal const string UserAgent = "User-Agent"; + internal const string ContentMD5 = "Content-MD5"; + internal const string ETag = "ETag"; + internal const string TE = "TE"; + internal const string Allow = "Allow"; + internal const string AcceptRanges = "Accept-Ranges"; + internal const string P3P = "P3P"; + internal const string XPoweredBy = "X-Powered-By"; + internal const string XAspNetVersion = "X-AspNet-Version"; + internal const string SecWebSocketKey = "Sec-WebSocket-Key"; + internal const string SecWebSocketExtensions = "Sec-WebSocket-Extensions"; + internal const string SecWebSocketAccept = "Sec-WebSocket-Accept"; + internal const string Origin = "Origin"; + internal const string SecWebSocketProtocol = "Sec-WebSocket-Protocol"; + internal const string SecWebSocketVersion = "Sec-WebSocket-Version"; + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpReasonPhrase.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpReasonPhrase.cs new file mode 100644 index 0000000000..27370832ff --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpReasonPhrase.cs @@ -0,0 +1,109 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class HttpReasonPhrase + { + private static readonly string[][] HttpReasonPhrases = new string[][] + { + null, + + new string[] + { + /* 100 */ "Continue", + /* 101 */ "Switching Protocols", + /* 102 */ "Processing" + }, + + new string[] + { + /* 200 */ "OK", + /* 201 */ "Created", + /* 202 */ "Accepted", + /* 203 */ "Non-Authoritative Information", + /* 204 */ "No Content", + /* 205 */ "Reset Content", + /* 206 */ "Partial Content", + /* 207 */ "Multi-Status" + }, + + new string[] + { + /* 300 */ "Multiple Choices", + /* 301 */ "Moved Permanently", + /* 302 */ "Found", + /* 303 */ "See Other", + /* 304 */ "Not Modified", + /* 305 */ "Use Proxy", + /* 306 */ null, + /* 307 */ "Temporary Redirect" + }, + + new string[] + { + /* 400 */ "Bad Request", + /* 401 */ "Unauthorized", + /* 402 */ "Payment Required", + /* 403 */ "Forbidden", + /* 404 */ "Not Found", + /* 405 */ "Method Not Allowed", + /* 406 */ "Not Acceptable", + /* 407 */ "Proxy Authentication Required", + /* 408 */ "Request Timeout", + /* 409 */ "Conflict", + /* 410 */ "Gone", + /* 411 */ "Length Required", + /* 412 */ "Precondition Failed", + /* 413 */ "Request Entity Too Large", + /* 414 */ "Request-Uri Too Long", + /* 415 */ "Unsupported Media Type", + /* 416 */ "Requested Range Not Satisfiable", + /* 417 */ "Expectation Failed", + /* 418 */ null, + /* 419 */ null, + /* 420 */ null, + /* 421 */ null, + /* 422 */ "Unprocessable Entity", + /* 423 */ "Locked", + /* 424 */ "Failed Dependency", + /* 425 */ null, + /* 426 */ "Upgrade Required", // RFC 2817 + }, + + new string[] + { + /* 500 */ "Internal Server Error", + /* 501 */ "Not Implemented", + /* 502 */ "Bad Gateway", + /* 503 */ "Service Unavailable", + /* 504 */ "Gateway Timeout", + /* 505 */ "Http Version Not Supported", + /* 506 */ null, + /* 507 */ "Insufficient Storage" + } + }; + + internal static string Get(HttpStatusCode code) + { + return Get((int)code); + } + + internal static string Get(int code) + { + if (code >= 100 && code < 600) + { + int i = code / 100; + int j = code % 100; + if (j < HttpReasonPhrases[i].Length) + { + return HttpReasonPhrases[i][j]; + } + } + return null; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpStatusCode.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpStatusCode.cs new file mode 100644 index 0000000000..88e6fe4b18 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpStatusCode.cs @@ -0,0 +1,314 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener +{ + // Redirect Status code numbers that need to be defined. + + /// + /// Contains the values of status + /// codes defined for the HTTP protocol. + /// + // UEUE : Any int can be cast to a HttpStatusCode to allow checking for non http1.1 codes. + internal enum HttpStatusCode + { + // Informational 1xx + + /// + /// [To be supplied.] + /// + Continue = 100, + + /// + /// [To be supplied.] + /// + SwitchingProtocols = 101, + + // Successful 2xx + + /// + /// [To be supplied.] + /// + OK = 200, + + /// + /// [To be supplied.] + /// + Created = 201, + + /// + /// [To be supplied.] + /// + Accepted = 202, + + /// + /// [To be supplied.] + /// + NonAuthoritativeInformation = 203, + + /// + /// [To be supplied.] + /// + NoContent = 204, + + /// + /// [To be supplied.] + /// + ResetContent = 205, + + /// + /// [To be supplied.] + /// + PartialContent = 206, + + // Redirection 3xx + + /// + /// [To be supplied.] + /// + MultipleChoices = 300, + + /// + /// [To be supplied.] + /// + Ambiguous = 300, + + /// + /// [To be supplied.] + /// + MovedPermanently = 301, + + /// + /// [To be supplied.] + /// + Moved = 301, + + /// + /// [To be supplied.] + /// + Found = 302, + + /// + /// [To be supplied.] + /// + Redirect = 302, + + /// + /// [To be supplied.] + /// + SeeOther = 303, + + /// + /// [To be supplied.] + /// + RedirectMethod = 303, + + /// + /// [To be supplied.] + /// + NotModified = 304, + + /// + /// [To be supplied.] + /// + UseProxy = 305, + + /// + /// [To be supplied.] + /// + Unused = 306, + + /// + /// [To be supplied.] + /// + TemporaryRedirect = 307, + + /// + /// [To be supplied.] + /// + RedirectKeepVerb = 307, + + // Client Error 4xx + + /// + /// [To be supplied.] + /// + BadRequest = 400, + + /// + /// [To be supplied.] + /// + Unauthorized = 401, + + /// + /// [To be supplied.] + /// + PaymentRequired = 402, + + /// + /// [To be supplied.] + /// + Forbidden = 403, + + /// + /// [To be supplied.] + /// + NotFound = 404, + + /// + /// [To be supplied.] + /// + MethodNotAllowed = 405, + + /// + /// [To be supplied.] + /// + NotAcceptable = 406, + + /// + /// [To be supplied.] + /// + ProxyAuthenticationRequired = 407, + + /// + /// [To be supplied.] + /// + RequestTimeout = 408, + + /// + /// [To be supplied.] + /// + Conflict = 409, + + /// + /// [To be supplied.] + /// + Gone = 410, + + /// + /// [To be supplied.] + /// + LengthRequired = 411, + + /// + /// [To be supplied.] + /// + PreconditionFailed = 412, + + /// + /// [To be supplied.] + /// + RequestEntityTooLarge = 413, + + /// + /// [To be supplied.] + /// + RequestUriTooLong = 414, + + /// + /// [To be supplied.] + /// + UnsupportedMediaType = 415, + + /// + /// [To be supplied.] + /// + RequestedRangeNotSatisfiable = 416, + + /// + /// [To be supplied.] + /// + ExpectationFailed = 417, + + UpgradeRequired = 426, + + // Server Error 5xx + + /// + /// [To be supplied.] + /// + InternalServerError = 500, + + /// + /// [To be supplied.] + /// + NotImplemented = 501, + + /// + /// [To be supplied.] + /// + BadGateway = 502, + + /// + /// [To be supplied.] + /// + ServiceUnavailable = 503, + + /// + /// [To be supplied.] + /// + GatewayTimeout = 504, + + /// + /// [To be supplied.] + /// + HttpVersionNotSupported = 505, + } // enum HttpStatusCode + + /* + Fielding, et al. Standards Track [Page 3] + + RFC 2616 HTTP/1.1 June 1999 + + + 10.1 Informational 1xx ...........................................57 + 10.1.1 100 Continue .............................................58 + 10.1.2 101 Switching Protocols ..................................58 + 10.2 Successful 2xx ..............................................58 + 10.2.1 200 OK ...................................................58 + 10.2.2 201 Created ..............................................59 + 10.2.3 202 Accepted .............................................59 + 10.2.4 203 Non-Authoritative Information ........................59 + 10.2.5 204 No Content ...........................................60 + 10.2.6 205 Reset Content ........................................60 + 10.2.7 206 Partial Content ......................................60 + 10.3 Redirection 3xx .............................................61 + 10.3.1 300 Multiple Choices .....................................61 + 10.3.2 301 Moved Permanently ....................................62 + 10.3.3 302 Found ................................................62 + 10.3.4 303 See Other ............................................63 + 10.3.5 304 Not Modified .........................................63 + 10.3.6 305 Use Proxy ............................................64 + 10.3.7 306 (Unused) .............................................64 + 10.3.8 307 Temporary Redirect ...................................65 + 10.4 Client Error 4xx ............................................65 + 10.4.1 400 Bad Request .........................................65 + 10.4.2 401 Unauthorized ........................................66 + 10.4.3 402 Payment Required ....................................66 + 10.4.4 403 Forbidden ...........................................66 + 10.4.5 404 Not Found ...........................................66 + 10.4.6 405 Method Not Allowed ..................................66 + 10.4.7 406 Not Acceptable ......................................67 + 10.4.8 407 Proxy Authentication Required .......................67 + 10.4.9 408 Request Timeout .....................................67 + 10.4.10 409 Conflict ............................................67 + 10.4.11 410 Gone ................................................68 + 10.4.12 411 Length Required .....................................68 + 10.4.13 412 Precondition Failed .................................68 + 10.4.14 413 Request Entity Too Large ............................69 + 10.4.15 414 Request-URI Too Long ................................69 + 10.4.16 415 Unsupported Media Type ..............................69 + 10.4.17 416 Requested Range Not Satisfiable .....................69 + 10.4.18 417 Expectation Failed ..................................70 + 10.5 Server Error 5xx ............................................70 + 10.5.1 500 Internal Server Error ................................70 + 10.5.2 501 Not Implemented ......................................70 + 10.5.3 502 Bad Gateway ..........................................70 + 10.5.4 503 Service Unavailable ..................................70 + 10.5.5 504 Gateway Timeout ......................................71 + 10.5.6 505 HTTP Version Not Supported ...........................71 + */ +} // namespace System.Net diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NativeRequestContext.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NativeRequestContext.cs new file mode 100644 index 0000000000..8ab4f029cb --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NativeRequestContext.cs @@ -0,0 +1,170 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal unsafe class NativeRequestContext : IDisposable + { + private const int DefaultBufferSize = 4096; + private UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST* _memoryBlob; + private IntPtr _originalBlobAddress; + private byte[] _backingBuffer; + private SafeNativeOverlapped _nativeOverlapped; + private AsyncAcceptContext _acceptResult; + + internal NativeRequestContext(AsyncAcceptContext result) + { + _acceptResult = result; + UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST* requestBlob = Allocate(0); + if (requestBlob == null) + { + GC.SuppressFinalize(this); + } + else + { + _memoryBlob = requestBlob; + } + } + + internal SafeNativeOverlapped NativeOverlapped + { + get + { + return _nativeOverlapped; + } + } + + internal UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST* RequestBlob + { + get + { + Debug.Assert(_memoryBlob != null || _backingBuffer == null, "RequestBlob requested after ReleasePins()."); + return _memoryBlob; + } + } + + internal byte[] RequestBuffer + { + get + { + return _backingBuffer; + } + } + + internal uint Size + { + get + { + return (uint)_backingBuffer.Length; + } + } + + internal IntPtr OriginalBlobAddress + { + get + { + UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST* blob = _memoryBlob; + return (blob == null ? _originalBlobAddress : (IntPtr)blob); + } + } + + // ReleasePins() should be called exactly once. It must be called before Dispose() is called, which means it must be called + // before an object (Request) which closes the RequestContext on demand is returned to the application. + internal void ReleasePins() + { + Debug.Assert(_memoryBlob != null || _backingBuffer == null, "RequestContextBase::ReleasePins()|ReleasePins() called twice."); + _originalBlobAddress = (IntPtr)_memoryBlob; + UnsetBlob(); + OnReleasePins(); + } + + private void OnReleasePins() + { + if (_nativeOverlapped != null) + { + SafeNativeOverlapped nativeOverlapped = _nativeOverlapped; + _nativeOverlapped = null; + nativeOverlapped.Dispose(); + } + } + + public void Dispose() + { + Debug.Assert(_memoryBlob == null, "RequestContextBase::Dispose()|Dispose() called before ReleasePins()."); + Dispose(true); + } + + protected void Dispose(bool disposing) + { + if (_nativeOverlapped != null) + { + Debug.Assert(!disposing, "AsyncRequestContext::Dispose()|Must call ReleasePins() before calling Dispose()."); + _nativeOverlapped.Dispose(); + } + } + + private void SetBlob(UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST* requestBlob) + { + Debug.Assert(_memoryBlob != null || _backingBuffer == null, "RequestContextBase::Dispose()|SetBlob() called after ReleasePins()."); + if (requestBlob == null) + { + UnsetBlob(); + return; + } + + if (_memoryBlob == null) + { + GC.ReRegisterForFinalize(this); + } + _memoryBlob = requestBlob; + } + + private void UnsetBlob() + { + if (_memoryBlob != null) + { + GC.SuppressFinalize(this); + } + _memoryBlob = null; + } + + private void SetBuffer(int size) + { + _backingBuffer = size == 0 ? null : new byte[size]; + } + + private UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST* Allocate(uint size) + { + uint newSize = size != 0 ? size : RequestBuffer == null ? DefaultBufferSize : Size; + if (_nativeOverlapped != null && newSize != RequestBuffer.Length) + { + SafeNativeOverlapped nativeOverlapped = _nativeOverlapped; + _nativeOverlapped = null; + nativeOverlapped.Dispose(); + } + if (_nativeOverlapped == null) + { + SetBuffer(checked((int)newSize)); + Overlapped overlapped = new Overlapped(); + overlapped.AsyncResult = _acceptResult; + _nativeOverlapped = new SafeNativeOverlapped(overlapped.Pack(AsyncAcceptContext.IOCallback, RequestBuffer)); + return (UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST*)Marshal.UnsafeAddrOfPinnedArrayElement(RequestBuffer, 0); + } + return RequestBlob; + } + + internal void Reset(ulong requestId, uint size) + { + SetBlob(Allocate(size)); + RequestBlob->RequestId = requestId; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NilEnvDictionary.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NilEnvDictionary.cs new file mode 100644 index 0000000000..af6cc6aff5 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NilEnvDictionary.cs @@ -0,0 +1,115 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- +// Copyright 2011-2012 Katana contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class NilEnvDictionary : IDictionary + { + private static readonly string[] EmptyKeys = new string[0]; + private static readonly object[] EmptyValues = new object[0]; + private static readonly IEnumerable> EmptyKeyValuePairs = Enumerable.Empty>(); + + public int Count + { + get { return 0; } + } + + public bool IsReadOnly + { + get { return false; } + } + + public ICollection Keys + { + get { return EmptyKeys; } + } + + public ICollection Values + { + get { return EmptyValues; } + } + + [SuppressMessage("Microsoft.Design", "CA1065:DoNotRaiseExceptionsInUnexpectedLocations", Justification = "Not Implemented")] + public object this[string key] + { + get { throw new NotImplementedException(); } + set { throw new NotImplementedException(); } + } + + public IEnumerator> GetEnumerator() + { + return EmptyKeyValuePairs.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return EmptyKeyValuePairs.GetEnumerator(); + } + + public void Add(KeyValuePair item) + { + throw new NotImplementedException(); + } + + public void Clear() + { + } + + public bool Contains(KeyValuePair item) + { + return false; + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + } + + public bool Remove(KeyValuePair item) + { + return false; + } + + public bool ContainsKey(string key) + { + return false; + } + + public void Add(string key, object value) + { + throw new NotImplementedException(); + } + + public bool Remove(string key) + { + return false; + } + + public bool TryGetValue(string key, out object value) + { + value = null; + return false; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/OpaqueStream.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/OpaqueStream.cs new file mode 100644 index 0000000000..6238d9185d --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/OpaqueStream.cs @@ -0,0 +1,168 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + // A duplex wrapper around RequestStream and ResponseStream. + // TODO: Consider merging RequestStream and ResponseStream instead. + internal class OpaqueStream : Stream + { + private readonly Stream _requestStream; + private readonly Stream _responseStream; + + internal OpaqueStream(Stream requestStream, Stream responseStream) + { + _requestStream = requestStream; + _responseStream = responseStream; + } + +#region Properties + + public override bool CanRead + { + get { return _requestStream.CanRead; } + } + + public override bool CanSeek + { + get { return false; } + } + + public override bool CanTimeout + { + get { return _requestStream.CanTimeout || _responseStream.CanTimeout; } + } + + public override bool CanWrite + { + get { return _responseStream.CanWrite; } + } + + public override long Length + { + get { throw new NotSupportedException(Resources.Exception_NoSeek); } + } + + public override long Position + { + get { throw new NotSupportedException(Resources.Exception_NoSeek); } + set { throw new NotSupportedException(Resources.Exception_NoSeek); } + } + + public override int ReadTimeout + { + get { return _requestStream.ReadTimeout; } + set { _requestStream.ReadTimeout = value; } + } + + public override int WriteTimeout + { + get { return _responseStream.WriteTimeout; } + set { _responseStream.WriteTimeout = value; } + } + +#endregion Properties + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + +#region Read + + public override int Read(byte[] buffer, int offset, int count) + { + return _requestStream.Read(buffer, offset, count); + } + + public override int ReadByte() + { + return _requestStream.ReadByte(); + } +#if NET45 + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _requestStream.BeginRead(buffer, offset, count, callback, state); + } + + public override int EndRead(IAsyncResult asyncResult) + { + return _requestStream.EndRead(asyncResult); + } +#endif + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _requestStream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + return _requestStream.CopyToAsync(destination, bufferSize, cancellationToken); + } + +#endregion Read + +#region Write + + public override void Write(byte[] buffer, int offset, int count) + { + _responseStream.Write(buffer, offset, count); + } + + public override void WriteByte(byte value) + { + _responseStream.WriteByte(value); + } +#if NET45 + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _responseStream.BeginWrite(buffer, offset, count, callback, state); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + _responseStream.EndWrite(asyncResult); + } +#endif + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _responseStream.WriteAsync(buffer, offset, count, cancellationToken); + } + + public override void Flush() + { + _responseStream.Flush(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _responseStream.FlushAsync(cancellationToken); + } + +#endregion Write + + protected override void Dispose(bool disposing) + { + // TODO: Suppress dispose? + if (disposing) + { + _requestStream.Dispose(); + _responseStream.Dispose(); + } + base.Dispose(disposing); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Request.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Request.cs new file mode 100644 index 0000000000..86b25185fb --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Request.cs @@ -0,0 +1,456 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Net; +using System.Runtime.InteropServices; +using System.Security.Authentication.ExtendedProtection; +using System.Security.Principal; +using System.Text; +using System.Threading; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal sealed unsafe class Request : IDisposable + { + private RequestContext _requestContext; + private NativeRequestContext _nativeRequestContext; + + private ulong _requestId; + private ulong _connectionId; + private ulong _contextId; + + private SslStatus _sslStatus; + + private string _httpMethod; + private Version _httpVersion; + + private Uri _requestUri; + private string _rawUrl; + private string _cookedUrlHost; + private string _cookedUrlPath; + private string _cookedUrlQuery; + + private RequestHeaders _headers; + private BoundaryType _contentBoundaryType; + private long _contentLength; + private Stream _requestStream; + private SocketAddress _localEndPoint; + private SocketAddress _remoteEndPoint; + + private IPrincipal _user; + + private bool _isDisposed = false; + private CancellationTokenRegistration _disconnectRegistration; + + internal Request(RequestContext httpContext, NativeRequestContext memoryBlob) + { + // TODO: Verbose log + _requestContext = httpContext; + _nativeRequestContext = memoryBlob; + _contentBoundaryType = BoundaryType.None; + + // Set up some of these now to avoid refcounting on memory blob later. + _requestId = memoryBlob.RequestBlob->RequestId; + _connectionId = memoryBlob.RequestBlob->ConnectionId; + _contextId = memoryBlob.RequestBlob->UrlContext; + _sslStatus = memoryBlob.RequestBlob->pSslInfo == null ? SslStatus.Insecure : + memoryBlob.RequestBlob->pSslInfo->SslClientCertNegotiated == 0 ? SslStatus.NoClientCert : + SslStatus.ClientCert; + if (memoryBlob.RequestBlob->pRawUrl != null && memoryBlob.RequestBlob->RawUrlLength > 0) + { + _rawUrl = Marshal.PtrToStringAnsi((IntPtr)memoryBlob.RequestBlob->pRawUrl, memoryBlob.RequestBlob->RawUrlLength); + } + + UnsafeNclNativeMethods.HttpApi.HTTP_COOKED_URL cookedUrl = memoryBlob.RequestBlob->CookedUrl; + if (cookedUrl.pHost != null && cookedUrl.HostLength > 0) + { + _cookedUrlHost = Marshal.PtrToStringUni((IntPtr)cookedUrl.pHost, cookedUrl.HostLength / 2); + } + if (cookedUrl.pAbsPath != null && cookedUrl.AbsPathLength > 0) + { + _cookedUrlPath = Marshal.PtrToStringUni((IntPtr)cookedUrl.pAbsPath, cookedUrl.AbsPathLength / 2); + } + if (cookedUrl.pQueryString != null && cookedUrl.QueryStringLength > 0) + { + _cookedUrlQuery = Marshal.PtrToStringUni((IntPtr)cookedUrl.pQueryString, cookedUrl.QueryStringLength / 2); + } + + int major = memoryBlob.RequestBlob->Version.MajorVersion; + int minor = memoryBlob.RequestBlob->Version.MinorVersion; + if (major == 1 && minor == 1) + { + _httpVersion = Constants.V1_1; + } + else if (major == 1 && minor == 0) + { + _httpVersion = Constants.V1_0; + } + else + { + _httpVersion = new Version(major, minor); + } + + _httpMethod = UnsafeNclNativeMethods.HttpApi.GetVerb(RequestBuffer, OriginalBlobAddress); + _headers = new RequestHeaders(_nativeRequestContext); + + UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_V2* requestV2 = (UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_V2*)memoryBlob.RequestBlob; + _user = GetUser(requestV2->pRequestInfo); + + // TODO: Verbose log parameters + + // TODO: Verbose log headers + } + + internal SslStatus SslStatus + { + get + { + return _sslStatus; + } + } + + internal ulong ConnectionId + { + get + { + return _connectionId; + } + } + + internal ulong ContextId + { + get { return _contextId; } + } + + internal RequestContext RequestContext + { + get + { + return _requestContext; + } + } + + internal byte[] RequestBuffer + { + get + { + CheckDisposed(); + return _nativeRequestContext.RequestBuffer; + } + } + + internal IntPtr OriginalBlobAddress + { + get + { + CheckDisposed(); + return _nativeRequestContext.OriginalBlobAddress; + } + } + + // Without the leading ? + internal string Query + { + get + { + if (!string.IsNullOrWhiteSpace(_cookedUrlQuery)) + { + return _cookedUrlQuery.Substring(1); + } + return string.Empty; + } + } + + internal ulong RequestId + { + get + { + return _requestId; + } + } + + // TODO: Move this to the constructor, that's where it will be called. + internal long ContentLength64 + { + get + { + if (_contentBoundaryType == BoundaryType.None) + { + string transferEncoding = Headers.Get(HttpKnownHeaderNames.TransferEncoding) ?? string.Empty; + if ("chunked".Equals(transferEncoding.Trim(), StringComparison.OrdinalIgnoreCase)) + { + _contentBoundaryType = BoundaryType.Chunked; + _contentLength = -1; + } + else + { + _contentLength = 0; + _contentBoundaryType = BoundaryType.ContentLength; + string length = Headers.Get(HttpKnownHeaderNames.ContentLength) ?? string.Empty; + if (length != null) + { + if (!long.TryParse(length.Trim(), NumberStyles.None, + CultureInfo.InvariantCulture.NumberFormat, out _contentLength)) + { + _contentLength = 0; + _contentBoundaryType = BoundaryType.Invalid; + } + } + } + } + + return _contentLength; + } + } + + internal IDictionary Headers + { + get + { + return _headers; + } + } + + internal string HttpMethod + { + get + { + return _httpMethod; + } + } + + internal Stream InputStream + { + get + { + if (_requestStream == null) + { + // TODO: Move this to the constructor (or a lazy Env dictionary) + _requestStream = HasEntityBody ? new RequestStream(RequestContext) : Stream.Null; + } + return _requestStream; + } + } + + internal bool IsLocal + { + get + { + return LocalEndPoint.GetIPAddressString().Equals(RemoteEndPoint.GetIPAddressString()); + } + } + + internal bool IsSecureConnection + { + get + { + return _sslStatus != SslStatus.Insecure; + } + } + + internal string RawUrl + { + get + { + return _rawUrl; + } + } + + internal Version ProtocolVersion + { + get + { + return _httpVersion; + } + } + + internal string Protocol + { + get + { + if (_httpVersion.Major == 1) + { + if (_httpVersion.Minor == 1) + { + return "HTTP/1.1"; + } + else if (_httpVersion.Minor == 0) + { + return "HTTP/1.0"; + } + } + return "HTTP/" + _httpVersion.ToString(2); + } + } + + // TODO: Move this to the constructor + internal bool HasEntityBody + { + get + { + // accessing the ContentLength property delay creates m_BoundaryType + return (ContentLength64 > 0 && _contentBoundaryType == BoundaryType.ContentLength) || + _contentBoundaryType == BoundaryType.Chunked || _contentBoundaryType == BoundaryType.Multipart; + } + } + + internal SocketAddress RemoteEndPoint + { + get + { + if (_remoteEndPoint == null) + { + _remoteEndPoint = UnsafeNclNativeMethods.HttpApi.GetRemoteEndPoint(RequestBuffer, OriginalBlobAddress); + } + + return _remoteEndPoint; + } + } + + internal SocketAddress LocalEndPoint + { + get + { + if (_localEndPoint == null) + { + _localEndPoint = UnsafeNclNativeMethods.HttpApi.GetLocalEndPoint(RequestBuffer, OriginalBlobAddress); + } + + return _localEndPoint; + } + } + + internal string RequestScheme + { + get + { + return IsSecureConnection ? Constants.HttpsScheme : Constants.HttpScheme; + } + } + + internal Uri RequestUri + { + get + { + if (_requestUri == null) + { + _requestUri = RequestUriBuilder.GetRequestUri( + _rawUrl, RequestScheme, _cookedUrlHost, _cookedUrlPath, _cookedUrlQuery); + } + + return _requestUri; + } + } + + internal string RequestPath + { + get + { + return RequestUriBuilder.GetRequestPath(_rawUrl, _cookedUrlPath); + } + } + + internal bool IsUpgradable + { + get + { + // HTTP.Sys allows you to upgrade anything to opaque unless content-length > 0 or chunked are specified. + return !HasEntityBody; + } + } + + internal IPrincipal User + { + get { return _user; } + } + + private unsafe IPrincipal GetUser(UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_INFO* requestInfo) + { + if (requestInfo == null + || requestInfo->InfoType != UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_INFO_TYPE.HttpRequestInfoTypeAuth) + { + return null; + } + + if (requestInfo->pInfo->AuthStatus != UnsafeNclNativeMethods.HttpApi.HTTP_AUTH_STATUS.HttpAuthStatusSuccess) + { + return null; + } + +#if NET45 + return new WindowsPrincipal(new WindowsIdentity(requestInfo->pInfo->AccessToken)); +#else + return null; +#endif + } + + // Use this to save the blob from dispose if this object was never used (never given to a user) and is about to be + // disposed. + internal void DetachBlob(NativeRequestContext memoryBlob) + { + if (memoryBlob != null && (object)memoryBlob == (object)_nativeRequestContext) + { + _nativeRequestContext = null; + } + } + + // Finalizes ownership of the memory blob. DetachBlob can't be called after this. + internal void ReleasePins() + { + _nativeRequestContext.ReleasePins(); + } + + // should only be called from RequestContext + public void Dispose() + { + // TODO: Verbose log + _isDisposed = true; + NativeRequestContext memoryBlob = _nativeRequestContext; + if (memoryBlob != null) + { + memoryBlob.Dispose(); + _nativeRequestContext = null; + } + _disconnectRegistration.Dispose(); + if (_requestStream != null) + { + _requestStream.Dispose(); + } + } + + internal void CheckDisposed() + { + if (_isDisposed) + { + throw new ObjectDisposedException(this.GetType().FullName); + } + } + + internal void SwitchToOpaqueMode() + { + if (_requestStream == null || _requestStream == Stream.Null) + { + _requestStream = new RequestStream(RequestContext); + } + } + + internal void RegisterForDisconnect(CancellationToken cancellationToken) + { + _disconnectRegistration = cancellationToken.Register(Cancel, this); + } + + private static void Cancel(object obj) + { + Request request = (Request)obj; + // Cancels owin.CallCanceled + request.RequestContext.Abort(); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestContext.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestContext.cs new file mode 100644 index 0000000000..5cb122006c --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestContext.cs @@ -0,0 +1,388 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.Globalization; +using System.IO; +using System.Runtime.InteropServices; +using System.Security.Authentication.ExtendedProtection; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + using LoggerFunc = Func, bool>; + using OpaqueFunc = Func, Task>; + + internal sealed class RequestContext : IDisposable, CallEnvironment.IPropertySource + { + private static readonly string[] ZeroContentLength = new[] { "0" }; + + private CallEnvironment _environment; + private OwinWebListener _server; + private Request _request; + private Response _response; + private CancellationTokenSource _cts; + private NativeRequestContext _memoryBlob; + private OpaqueFunc _opaqueCallback; + private bool _disposed; + + internal RequestContext(OwinWebListener httpListener, NativeRequestContext memoryBlob) + { + // TODO: Verbose log + _server = httpListener; + _memoryBlob = memoryBlob; + _request = new Request(this, _memoryBlob); + _response = new Response(this); + _environment = new CallEnvironment(this); + _cts = new CancellationTokenSource(); + + PopulateEnvironment(); + + _request.ReleasePins(); + } + + internal CallEnvironment Environment + { + get { return _environment; } + } + + internal Request Request + { + get + { + return _request; + } + } + + internal Response Response + { + get + { + return _response; + } + } + + internal OwinWebListener Server + { + get + { + return _server; + } + } + + internal LoggerFunc Logger + { + get { return Server.Logger; } + } + + internal SafeHandle RequestQueueHandle + { + get + { + return _server.RequestQueueHandle; + } + } + + internal ulong RequestId + { + get + { + return Request.RequestId; + } + } + + private void PopulateEnvironment() + { + // General + _environment.OwinVersion = Constants.OwinVersion; + _environment.CallCancelled = _cts.Token; + + // Server + _environment.ServerCapabilities = _server.Capabilities; + _environment.Listener = _server; + + // Request + _environment.RequestProtocol = _request.Protocol; + _environment.RequestMethod = _request.HttpMethod; + _environment.RequestScheme = _request.RequestScheme; + _environment.RequestQueryString = _request.Query; + _environment.RequestHeaders = _request.Headers; + + SetPaths(); + + _environment.ConnectionId = _request.ConnectionId; + + if (_request.IsSecureConnection) + { + _environment.LoadClientCert = LoadClientCertificateAsync; + } + + if (_request.User != null) + { + _environment.User = _request.User; + } + + // Response + _environment.ResponseStatusCode = 200; + _environment.ResponseHeaders = _response.Headers; + _environment.ResponseBody = _response.OutputStream; + _environment.SendFileAsync = _response.SendFileAsync; + + _environment.OnSendingHeaders = _response.RegisterForOnSendingHeaders; + + Contract.Assert(!_environment.IsExtraDictionaryCreated, + "All server keys should have a reserved slot in the environment."); + } + + // Find the closest matching prefix and use it to separate the request path in to path and base path. + // Scheme and port must match. Path will use a longest match. Host names are more complicated due to + // wildcards, IP addresses, etc. + private void SetPaths() + { + Prefix prefix = _server.UriPrefixes[(int)Request.ContextId]; + string orriginalPath = _request.RequestPath; + + // These paths are both unescaped already. + if (orriginalPath.Length == prefix.Path.Length - 1) + { + // They matched exactly except for the trailing slash. + _environment.RequestPathBase = orriginalPath; + _environment.RequestPath = string.Empty; + } + else + { + // url: /base/path, prefix: /base/, base: /base, path: /path + // url: /, prefix: /, base: , path: / + _environment.RequestPathBase = orriginalPath.Substring(0, prefix.Path.Length - 1); + _environment.RequestPath = orriginalPath.Substring(prefix.Path.Length - 1); + } + } + + // Lazy environment init + + public Stream GetRequestBody() + { + return _request.InputStream; + } + + public string GetRemoteIpAddress() + { + return _request.RemoteEndPoint.GetIPAddressString(); + } + + public string GetRemotePort() + { + return _request.RemoteEndPoint.GetPort().ToString(CultureInfo.InvariantCulture.NumberFormat); + } + + public string GetLocalIpAddress() + { + return _request.LocalEndPoint.GetIPAddressString(); + } + + public string GetLocalPort() + { + return _request.LocalEndPoint.GetPort().ToString(CultureInfo.InvariantCulture.NumberFormat); + } + + public bool GetIsLocal() + { + return _request.IsLocal; + } + + public bool TryGetOpaqueUpgrade(ref Action, OpaqueFunc> value) + { + if (_request.IsUpgradable) + { + value = OpaqueUpgrade; + return true; + } + return false; + } + + public bool TryGetChannelBinding(ref ChannelBinding value) + { + value = Server.GetChannelBinding(Request.ConnectionId, Request.IsSecureConnection); + return value != null; + } + + public void Dispose() + { + if (_disposed) + { + return; + } + _disposed = true; + + // TODO: Verbose log + try + { + _response.Dispose(); + } + finally + { + _request.Dispose(); + _cts.Dispose(); + } + } + + internal void Abort() + { + // TODO: Verbose log + _disposed = true; + try + { + _cts.Cancel(); + } + catch (ObjectDisposedException) + { + } + catch (AggregateException) + { + } + ForceCancelRequest(RequestQueueHandle, _request.RequestId); + _request.Dispose(); + _cts.Dispose(); + } + + internal UnsafeNclNativeMethods.HttpApi.HTTP_VERB GetKnownMethod() + { + return UnsafeNclNativeMethods.HttpApi.GetKnownVerb(Request.RequestBuffer, Request.OriginalBlobAddress); + } + + // This is only called while processing incoming requests. We don't have to worry about cancelling + // any response writes. + [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = + "It is safe to ignore the return value on a cancel operation because the connection is being closed")] + internal static void CancelRequest(SafeHandle requestQueueHandle, ulong requestId) + { + UnsafeNclNativeMethods.HttpApi.HttpCancelHttpRequest(requestQueueHandle, requestId, + IntPtr.Zero); + } + + // The request is being aborted, but large writes may be in progress. Cancel them. + internal void ForceCancelRequest(SafeHandle requestQueueHandle, ulong requestId) + { + try + { + uint statusCode = UnsafeNclNativeMethods.HttpApi.HttpCancelHttpRequest(requestQueueHandle, requestId, + IntPtr.Zero); + + // Either the connection has already dropped, or the last write is in progress. + // The requestId becomes invalid as soon as the last Content-Length write starts. + // The only way to cancel now is with CancelIoEx. + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_CONNECTION_INVALID) + { + _response.CancelLastWrite(requestQueueHandle); + } + } + catch (ObjectDisposedException) + { + // RequestQueueHandle may have been closed + } + } + + // Populates the environment ClicentCertificate. The result may be null if there is no client cert. + // TODO: Does it make sense for this to be invoked multiple times (e.g. renegotiate)? Client and server code appear to + // enable this, but it's unclear what Http.Sys would do. + private async Task LoadClientCertificateAsync() + { + if (Request.SslStatus == SslStatus.Insecure) + { + // Non-SSL + return; + } + // TODO: Verbose log +#if NET45 + ClientCertLoader certLoader = new ClientCertLoader(this); + try + { + await certLoader.LoadClientCertificateAsync().SupressContext(); + // Populate the environment. + if (certLoader.ClientCert != null) + { + Environment.ClientCert = certLoader.ClientCert; + } + // TODO: Expose errors and exceptions? + } + catch (Exception) + { + if (certLoader != null) + { + certLoader.Dispose(); + } + throw; + } +#else + throw new NotImplementedException(); +#endif + } + + internal void OpaqueUpgrade(IDictionary parameters, OpaqueFunc callback) + { + // Parameters are ignored for now + if (Response.SentHeaders) + { + throw new InvalidOperationException(); + } + if (callback == null) + { + throw new ArgumentNullException("callback"); + } + + // Set the status code and reason phrase + Environment.ResponseStatusCode = (int)HttpStatusCode.SwitchingProtocols; + Environment.ResponseReasonPhrase = HttpReasonPhrase.Get(HttpStatusCode.SwitchingProtocols); + + // Store the callback and process it after the stack unwind. + _opaqueCallback = callback; + } + + // Called after the AppFunc completes for any necessary post-processing. + internal unsafe Task ProcessResponseAsync() + { + // If an upgrade was requested, perform it + if (!Response.SentHeaders && _opaqueCallback != null + && Environment.ResponseStatusCode == (int)HttpStatusCode.SwitchingProtocols) + { + Response.SendOpaqueUpgrade(); + + IDictionary opaqueEnv = CreateOpaqueEnvironment(); + return _opaqueCallback(opaqueEnv); + } + + return Helpers.CompletedTask(); + } + + private IDictionary CreateOpaqueEnvironment() + { + IDictionary opaqueEnv = new Dictionary(); + + opaqueEnv[Constants.OpaqueVersionKey] = Constants.OpaqueVersion; + // TODO: Separate CT? + opaqueEnv[Constants.OpaqueCallCancelledKey] = Environment.CallCancelled; + + Request.SwitchToOpaqueMode(); + Response.SwitchToOpaqueMode(); + opaqueEnv[Constants.OpaqueStreamKey] = new OpaqueStream(Request.InputStream, Response.OutputStream); + + return opaqueEnv; + } + + internal void SetFatalResponse() + { + Environment.ResponseStatusCode = 500; + Environment.ResponseReasonPhrase = string.Empty; + Environment.ResponseHeaders.Clear(); + Environment.ResponseHeaders.Add(HttpKnownHeaderNames.ContentLength, ZeroContentLength); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.cs new file mode 100644 index 0000000000..2e2df172c3 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.cs @@ -0,0 +1,2544 @@ +//----------------------------------------------------------------------- +// +// Copyright (c) Katana Contributors. All rights reserved. +// +//----------------------------------------------------------------------- +// + +using System; +using System.CodeDom.Compiler; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + [GeneratedCode("TextTemplatingFileGenerator", "")] + internal partial class RequestHeaders + { + // Tracks if individual fields have been read from native or set directly. + // Once read or set, their presence in the collection is marked by if their string[] is null or not. + private UInt32 _flag0, _flag1; + + private string[] _Accept; + private string[] _AcceptCharset; + private string[] _AcceptEncoding; + private string[] _AcceptLanguage; + private string[] _Allow; + private string[] _Authorization; + private string[] _CacheControl; + private string[] _Connection; + private string[] _ContentEncoding; + private string[] _ContentLanguage; + private string[] _ContentLength; + private string[] _ContentLocation; + private string[] _ContentMd5; + private string[] _ContentRange; + private string[] _ContentType; + private string[] _Cookie; + private string[] _Date; + private string[] _Expect; + private string[] _Expires; + private string[] _From; + private string[] _Host; + private string[] _IfMatch; + private string[] _IfModifiedSince; + private string[] _IfNoneMatch; + private string[] _IfRange; + private string[] _IfUnmodifiedSince; + private string[] _KeepAlive; + private string[] _LastModified; + private string[] _MaxForwards; + private string[] _Pragma; + private string[] _ProxyAuthorization; + private string[] _Range; + private string[] _Referer; + private string[] _Te; + private string[] _Trailer; + private string[] _TransferEncoding; + private string[] _Translate; + private string[] _Upgrade; + private string[] _UserAgent; + private string[] _Via; + private string[] _Warning; + + internal string[] Accept + { + get + { + if (!((_flag0 & 0x1u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Accept); + if (nativeValue != null) + { + _Accept = new string[] { nativeValue }; + } + _flag0 |= 0x1u; + } + return _Accept; + } + set + { + _flag0 |= 0x1u; + _Accept = value; + } + } + + internal string[] AcceptCharset + { + get + { + if (!((_flag0 & 0x2u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.AcceptCharset); + if (nativeValue != null) + { + _AcceptCharset = new string[] { nativeValue }; + } + _flag0 |= 0x2u; + } + return _AcceptCharset; + } + set + { + _flag0 |= 0x2u; + _AcceptCharset = value; + } + } + + internal string[] AcceptEncoding + { + get + { + if (!((_flag0 & 0x4u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.AcceptEncoding); + if (nativeValue != null) + { + _AcceptEncoding = new string[] { nativeValue }; + } + _flag0 |= 0x4u; + } + return _AcceptEncoding; + } + set + { + _flag0 |= 0x4u; + _AcceptEncoding = value; + } + } + + internal string[] AcceptLanguage + { + get + { + if (!((_flag0 & 0x8u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.AcceptLanguage); + if (nativeValue != null) + { + _AcceptLanguage = new string[] { nativeValue }; + } + _flag0 |= 0x8u; + } + return _AcceptLanguage; + } + set + { + _flag0 |= 0x8u; + _AcceptLanguage = value; + } + } + + internal string[] Allow + { + get + { + if (!((_flag0 & 0x10u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Allow); + if (nativeValue != null) + { + _Allow = new string[] { nativeValue }; + } + _flag0 |= 0x10u; + } + return _Allow; + } + set + { + _flag0 |= 0x10u; + _Allow = value; + } + } + + internal string[] Authorization + { + get + { + if (!((_flag0 & 0x20u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Authorization); + if (nativeValue != null) + { + _Authorization = new string[] { nativeValue }; + } + _flag0 |= 0x20u; + } + return _Authorization; + } + set + { + _flag0 |= 0x20u; + _Authorization = value; + } + } + + internal string[] CacheControl + { + get + { + if (!((_flag0 & 0x40u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.CacheControl); + if (nativeValue != null) + { + _CacheControl = new string[] { nativeValue }; + } + _flag0 |= 0x40u; + } + return _CacheControl; + } + set + { + _flag0 |= 0x40u; + _CacheControl = value; + } + } + + internal string[] Connection + { + get + { + if (!((_flag0 & 0x80u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Connection); + if (nativeValue != null) + { + _Connection = new string[] { nativeValue }; + } + _flag0 |= 0x80u; + } + return _Connection; + } + set + { + _flag0 |= 0x80u; + _Connection = value; + } + } + + internal string[] ContentEncoding + { + get + { + if (!((_flag0 & 0x100u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentEncoding); + if (nativeValue != null) + { + _ContentEncoding = new string[] { nativeValue }; + } + _flag0 |= 0x100u; + } + return _ContentEncoding; + } + set + { + _flag0 |= 0x100u; + _ContentEncoding = value; + } + } + + internal string[] ContentLanguage + { + get + { + if (!((_flag0 & 0x200u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentLanguage); + if (nativeValue != null) + { + _ContentLanguage = new string[] { nativeValue }; + } + _flag0 |= 0x200u; + } + return _ContentLanguage; + } + set + { + _flag0 |= 0x200u; + _ContentLanguage = value; + } + } + + internal string[] ContentLength + { + get + { + if (!((_flag0 & 0x400u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentLength); + if (nativeValue != null) + { + _ContentLength = new string[] { nativeValue }; + } + _flag0 |= 0x400u; + } + return _ContentLength; + } + set + { + _flag0 |= 0x400u; + _ContentLength = value; + } + } + + internal string[] ContentLocation + { + get + { + if (!((_flag0 & 0x800u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentLocation); + if (nativeValue != null) + { + _ContentLocation = new string[] { nativeValue }; + } + _flag0 |= 0x800u; + } + return _ContentLocation; + } + set + { + _flag0 |= 0x800u; + _ContentLocation = value; + } + } + + internal string[] ContentMd5 + { + get + { + if (!((_flag0 & 0x1000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentMd5); + if (nativeValue != null) + { + _ContentMd5 = new string[] { nativeValue }; + } + _flag0 |= 0x1000u; + } + return _ContentMd5; + } + set + { + _flag0 |= 0x1000u; + _ContentMd5 = value; + } + } + + internal string[] ContentRange + { + get + { + if (!((_flag0 & 0x2000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentRange); + if (nativeValue != null) + { + _ContentRange = new string[] { nativeValue }; + } + _flag0 |= 0x2000u; + } + return _ContentRange; + } + set + { + _flag0 |= 0x2000u; + _ContentRange = value; + } + } + + internal string[] ContentType + { + get + { + if (!((_flag0 & 0x4000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentType); + if (nativeValue != null) + { + _ContentType = new string[] { nativeValue }; + } + _flag0 |= 0x4000u; + } + return _ContentType; + } + set + { + _flag0 |= 0x4000u; + _ContentType = value; + } + } + + internal string[] Cookie + { + get + { + if (!((_flag0 & 0x8000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Cookie); + if (nativeValue != null) + { + _Cookie = new string[] { nativeValue }; + } + _flag0 |= 0x8000u; + } + return _Cookie; + } + set + { + _flag0 |= 0x8000u; + _Cookie = value; + } + } + + internal string[] Date + { + get + { + if (!((_flag0 & 0x10000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Date); + if (nativeValue != null) + { + _Date = new string[] { nativeValue }; + } + _flag0 |= 0x10000u; + } + return _Date; + } + set + { + _flag0 |= 0x10000u; + _Date = value; + } + } + + internal string[] Expect + { + get + { + if (!((_flag0 & 0x20000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Expect); + if (nativeValue != null) + { + _Expect = new string[] { nativeValue }; + } + _flag0 |= 0x20000u; + } + return _Expect; + } + set + { + _flag0 |= 0x20000u; + _Expect = value; + } + } + + internal string[] Expires + { + get + { + if (!((_flag0 & 0x40000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Expires); + if (nativeValue != null) + { + _Expires = new string[] { nativeValue }; + } + _flag0 |= 0x40000u; + } + return _Expires; + } + set + { + _flag0 |= 0x40000u; + _Expires = value; + } + } + + internal string[] From + { + get + { + if (!((_flag0 & 0x80000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.From); + if (nativeValue != null) + { + _From = new string[] { nativeValue }; + } + _flag0 |= 0x80000u; + } + return _From; + } + set + { + _flag0 |= 0x80000u; + _From = value; + } + } + + internal string[] Host + { + get + { + if (!((_flag0 & 0x100000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Host); + if (nativeValue != null) + { + _Host = new string[] { nativeValue }; + } + _flag0 |= 0x100000u; + } + return _Host; + } + set + { + _flag0 |= 0x100000u; + _Host = value; + } + } + + internal string[] IfMatch + { + get + { + if (!((_flag0 & 0x200000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.IfMatch); + if (nativeValue != null) + { + _IfMatch = new string[] { nativeValue }; + } + _flag0 |= 0x200000u; + } + return _IfMatch; + } + set + { + _flag0 |= 0x200000u; + _IfMatch = value; + } + } + + internal string[] IfModifiedSince + { + get + { + if (!((_flag0 & 0x400000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.IfModifiedSince); + if (nativeValue != null) + { + _IfModifiedSince = new string[] { nativeValue }; + } + _flag0 |= 0x400000u; + } + return _IfModifiedSince; + } + set + { + _flag0 |= 0x400000u; + _IfModifiedSince = value; + } + } + + internal string[] IfNoneMatch + { + get + { + if (!((_flag0 & 0x800000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.IfNoneMatch); + if (nativeValue != null) + { + _IfNoneMatch = new string[] { nativeValue }; + } + _flag0 |= 0x800000u; + } + return _IfNoneMatch; + } + set + { + _flag0 |= 0x800000u; + _IfNoneMatch = value; + } + } + + internal string[] IfRange + { + get + { + if (!((_flag0 & 0x1000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.IfRange); + if (nativeValue != null) + { + _IfRange = new string[] { nativeValue }; + } + _flag0 |= 0x1000000u; + } + return _IfRange; + } + set + { + _flag0 |= 0x1000000u; + _IfRange = value; + } + } + + internal string[] IfUnmodifiedSince + { + get + { + if (!((_flag0 & 0x2000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.IfUnmodifiedSince); + if (nativeValue != null) + { + _IfUnmodifiedSince = new string[] { nativeValue }; + } + _flag0 |= 0x2000000u; + } + return _IfUnmodifiedSince; + } + set + { + _flag0 |= 0x2000000u; + _IfUnmodifiedSince = value; + } + } + + internal string[] KeepAlive + { + get + { + if (!((_flag0 & 0x4000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.KeepAlive); + if (nativeValue != null) + { + _KeepAlive = new string[] { nativeValue }; + } + _flag0 |= 0x4000000u; + } + return _KeepAlive; + } + set + { + _flag0 |= 0x4000000u; + _KeepAlive = value; + } + } + + internal string[] LastModified + { + get + { + if (!((_flag0 & 0x8000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.LastModified); + if (nativeValue != null) + { + _LastModified = new string[] { nativeValue }; + } + _flag0 |= 0x8000000u; + } + return _LastModified; + } + set + { + _flag0 |= 0x8000000u; + _LastModified = value; + } + } + + internal string[] MaxForwards + { + get + { + if (!((_flag0 & 0x10000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.MaxForwards); + if (nativeValue != null) + { + _MaxForwards = new string[] { nativeValue }; + } + _flag0 |= 0x10000000u; + } + return _MaxForwards; + } + set + { + _flag0 |= 0x10000000u; + _MaxForwards = value; + } + } + + internal string[] Pragma + { + get + { + if (!((_flag0 & 0x20000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Pragma); + if (nativeValue != null) + { + _Pragma = new string[] { nativeValue }; + } + _flag0 |= 0x20000000u; + } + return _Pragma; + } + set + { + _flag0 |= 0x20000000u; + _Pragma = value; + } + } + + internal string[] ProxyAuthorization + { + get + { + if (!((_flag0 & 0x40000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ProxyAuthorization); + if (nativeValue != null) + { + _ProxyAuthorization = new string[] { nativeValue }; + } + _flag0 |= 0x40000000u; + } + return _ProxyAuthorization; + } + set + { + _flag0 |= 0x40000000u; + _ProxyAuthorization = value; + } + } + + internal string[] Range + { + get + { + if (!((_flag0 & 0x80000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Range); + if (nativeValue != null) + { + _Range = new string[] { nativeValue }; + } + _flag0 |= 0x80000000u; + } + return _Range; + } + set + { + _flag0 |= 0x80000000u; + _Range = value; + } + } + + internal string[] Referer + { + get + { + if (!((_flag1 & 0x1u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Referer); + if (nativeValue != null) + { + _Referer = new string[] { nativeValue }; + } + _flag1 |= 0x1u; + } + return _Referer; + } + set + { + _flag1 |= 0x1u; + _Referer = value; + } + } + + internal string[] Te + { + get + { + if (!((_flag1 & 0x2u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Te); + if (nativeValue != null) + { + _Te = new string[] { nativeValue }; + } + _flag1 |= 0x2u; + } + return _Te; + } + set + { + _flag1 |= 0x2u; + _Te = value; + } + } + + internal string[] Trailer + { + get + { + if (!((_flag1 & 0x4u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Trailer); + if (nativeValue != null) + { + _Trailer = new string[] { nativeValue }; + } + _flag1 |= 0x4u; + } + return _Trailer; + } + set + { + _flag1 |= 0x4u; + _Trailer = value; + } + } + + internal string[] TransferEncoding + { + get + { + if (!((_flag1 & 0x8u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.TransferEncoding); + if (nativeValue != null) + { + _TransferEncoding = new string[] { nativeValue }; + } + _flag1 |= 0x8u; + } + return _TransferEncoding; + } + set + { + _flag1 |= 0x8u; + _TransferEncoding = value; + } + } + + internal string[] Translate + { + get + { + if (!((_flag1 & 0x10u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Translate); + if (nativeValue != null) + { + _Translate = new string[] { nativeValue }; + } + _flag1 |= 0x10u; + } + return _Translate; + } + set + { + _flag1 |= 0x10u; + _Translate = value; + } + } + + internal string[] Upgrade + { + get + { + if (!((_flag1 & 0x20u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Upgrade); + if (nativeValue != null) + { + _Upgrade = new string[] { nativeValue }; + } + _flag1 |= 0x20u; + } + return _Upgrade; + } + set + { + _flag1 |= 0x20u; + _Upgrade = value; + } + } + + internal string[] UserAgent + { + get + { + if (!((_flag1 & 0x40u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.UserAgent); + if (nativeValue != null) + { + _UserAgent = new string[] { nativeValue }; + } + _flag1 |= 0x40u; + } + return _UserAgent; + } + set + { + _flag1 |= 0x40u; + _UserAgent = value; + } + } + + internal string[] Via + { + get + { + if (!((_flag1 & 0x80u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Via); + if (nativeValue != null) + { + _Via = new string[] { nativeValue }; + } + _flag1 |= 0x80u; + } + return _Via; + } + set + { + _flag1 |= 0x80u; + _Via = value; + } + } + + internal string[] Warning + { + get + { + if (!((_flag1 & 0x100u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Warning); + if (nativeValue != null) + { + _Warning = new string[] { nativeValue }; + } + _flag1 |= 0x100u; + } + return _Warning; + } + set + { + _flag1 |= 0x100u; + _Warning = value; + } + } + + private bool PropertiesContainsKey(string key) + { + switch (key.Length) + { + case 2: + if (string.Equals(key, "Te", StringComparison.OrdinalIgnoreCase)) + { + return Te != null; + } + break; + case 3: + if (string.Equals(key, "Via", StringComparison.OrdinalIgnoreCase)) + { + return Via != null; + } + break; + case 4: + if (string.Equals(key, "Date", StringComparison.OrdinalIgnoreCase)) + { + return Date != null; + } + if (string.Equals(key, "From", StringComparison.OrdinalIgnoreCase)) + { + return From != null; + } + if (string.Equals(key, "Host", StringComparison.OrdinalIgnoreCase)) + { + return Host != null; + } + break; + case 5: + if (string.Equals(key, "Allow", StringComparison.OrdinalIgnoreCase)) + { + return Allow != null; + } + if (string.Equals(key, "Range", StringComparison.OrdinalIgnoreCase)) + { + return Range != null; + } + break; + case 6: + if (string.Equals(key, "Accept", StringComparison.OrdinalIgnoreCase)) + { + return Accept != null; + } + if (string.Equals(key, "Cookie", StringComparison.OrdinalIgnoreCase)) + { + return Cookie != null; + } + if (string.Equals(key, "Expect", StringComparison.OrdinalIgnoreCase)) + { + return Expect != null; + } + if (string.Equals(key, "Pragma", StringComparison.OrdinalIgnoreCase)) + { + return Pragma != null; + } + break; + case 7: + if (string.Equals(key, "Expires", StringComparison.OrdinalIgnoreCase)) + { + return Expires != null; + } + if (string.Equals(key, "Referer", StringComparison.OrdinalIgnoreCase)) + { + return Referer != null; + } + if (string.Equals(key, "Trailer", StringComparison.OrdinalIgnoreCase)) + { + return Trailer != null; + } + if (string.Equals(key, "Upgrade", StringComparison.OrdinalIgnoreCase)) + { + return Upgrade != null; + } + if (string.Equals(key, "Warning", StringComparison.OrdinalIgnoreCase)) + { + return Warning != null; + } + break; + case 8: + if (string.Equals(key, "If-Match", StringComparison.OrdinalIgnoreCase)) + { + return IfMatch != null; + } + if (string.Equals(key, "If-Range", StringComparison.OrdinalIgnoreCase)) + { + return IfRange != null; + } + break; + case 9: + if (string.Equals(key, "Translate", StringComparison.OrdinalIgnoreCase)) + { + return Translate != null; + } + break; + case 10: + if (string.Equals(key, "Connection", StringComparison.OrdinalIgnoreCase)) + { + return Connection != null; + } + if (string.Equals(key, "Keep-Alive", StringComparison.OrdinalIgnoreCase)) + { + return KeepAlive != null; + } + if (string.Equals(key, "User-Agent", StringComparison.OrdinalIgnoreCase)) + { + return UserAgent != null; + } + break; + case 11: + if (string.Equals(key, "Content-Md5", StringComparison.OrdinalIgnoreCase)) + { + return ContentMd5 != null; + } + break; + case 12: + if (string.Equals(key, "Content-Type", StringComparison.OrdinalIgnoreCase)) + { + return ContentType != null; + } + if (string.Equals(key, "Max-Forwards", StringComparison.OrdinalIgnoreCase)) + { + return MaxForwards != null; + } + break; + case 13: + if (string.Equals(key, "Authorization", StringComparison.OrdinalIgnoreCase)) + { + return Authorization != null; + } + if (string.Equals(key, "Cache-Control", StringComparison.OrdinalIgnoreCase)) + { + return CacheControl != null; + } + if (string.Equals(key, "Content-Range", StringComparison.OrdinalIgnoreCase)) + { + return ContentRange != null; + } + if (string.Equals(key, "If-None-Match", StringComparison.OrdinalIgnoreCase)) + { + return IfNoneMatch != null; + } + if (string.Equals(key, "Last-Modified", StringComparison.OrdinalIgnoreCase)) + { + return LastModified != null; + } + break; + case 14: + if (string.Equals(key, "Accept-Charset", StringComparison.OrdinalIgnoreCase)) + { + return AcceptCharset != null; + } + if (string.Equals(key, "Content-Length", StringComparison.OrdinalIgnoreCase)) + { + return ContentLength != null; + } + break; + case 15: + if (string.Equals(key, "Accept-Encoding", StringComparison.OrdinalIgnoreCase)) + { + return AcceptEncoding != null; + } + if (string.Equals(key, "Accept-Language", StringComparison.OrdinalIgnoreCase)) + { + return AcceptLanguage != null; + } + break; + case 16: + if (string.Equals(key, "Content-Encoding", StringComparison.OrdinalIgnoreCase)) + { + return ContentEncoding != null; + } + if (string.Equals(key, "Content-Language", StringComparison.OrdinalIgnoreCase)) + { + return ContentLanguage != null; + } + if (string.Equals(key, "Content-Location", StringComparison.OrdinalIgnoreCase)) + { + return ContentLocation != null; + } + break; + case 17: + if (string.Equals(key, "If-Modified-Since", StringComparison.OrdinalIgnoreCase)) + { + return IfModifiedSince != null; + } + if (string.Equals(key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase)) + { + return TransferEncoding != null; + } + break; + case 19: + if (string.Equals(key, "If-Unmodified-Since", StringComparison.OrdinalIgnoreCase)) + { + return IfUnmodifiedSince != null; + } + if (string.Equals(key, "Proxy-Authorization", StringComparison.OrdinalIgnoreCase)) + { + return ProxyAuthorization != null; + } + break; + } + return false; + } + + private bool PropertiesTryGetValue(string key, out string[] value) + { + switch (key.Length) + { + case 2: + if (string.Equals(key, "Te", StringComparison.OrdinalIgnoreCase)) + { + value = Te; + return value != null; + } + break; + case 3: + if (string.Equals(key, "Via", StringComparison.OrdinalIgnoreCase)) + { + value = Via; + return value != null; + } + break; + case 4: + if (string.Equals(key, "Date", StringComparison.OrdinalIgnoreCase)) + { + value = Date; + return value != null; + } + if (string.Equals(key, "From", StringComparison.OrdinalIgnoreCase)) + { + value = From; + return value != null; + } + if (string.Equals(key, "Host", StringComparison.OrdinalIgnoreCase)) + { + value = Host; + return value != null; + } + break; + case 5: + if (string.Equals(key, "Allow", StringComparison.OrdinalIgnoreCase)) + { + value = Allow; + return value != null; + } + if (string.Equals(key, "Range", StringComparison.OrdinalIgnoreCase)) + { + value = Range; + return value != null; + } + break; + case 6: + if (string.Equals(key, "Accept", StringComparison.OrdinalIgnoreCase)) + { + value = Accept; + return value != null; + } + if (string.Equals(key, "Cookie", StringComparison.OrdinalIgnoreCase)) + { + value = Cookie; + return value != null; + } + if (string.Equals(key, "Expect", StringComparison.OrdinalIgnoreCase)) + { + value = Expect; + return value != null; + } + if (string.Equals(key, "Pragma", StringComparison.OrdinalIgnoreCase)) + { + value = Pragma; + return value != null; + } + break; + case 7: + if (string.Equals(key, "Expires", StringComparison.OrdinalIgnoreCase)) + { + value = Expires; + return value != null; + } + if (string.Equals(key, "Referer", StringComparison.OrdinalIgnoreCase)) + { + value = Referer; + return value != null; + } + if (string.Equals(key, "Trailer", StringComparison.OrdinalIgnoreCase)) + { + value = Trailer; + return value != null; + } + if (string.Equals(key, "Upgrade", StringComparison.OrdinalIgnoreCase)) + { + value = Upgrade; + return value != null; + } + if (string.Equals(key, "Warning", StringComparison.OrdinalIgnoreCase)) + { + value = Warning; + return value != null; + } + break; + case 8: + if (string.Equals(key, "If-Match", StringComparison.OrdinalIgnoreCase)) + { + value = IfMatch; + return value != null; + } + if (string.Equals(key, "If-Range", StringComparison.OrdinalIgnoreCase)) + { + value = IfRange; + return value != null; + } + break; + case 9: + if (string.Equals(key, "Translate", StringComparison.OrdinalIgnoreCase)) + { + value = Translate; + return value != null; + } + break; + case 10: + if (string.Equals(key, "Connection", StringComparison.OrdinalIgnoreCase)) + { + value = Connection; + return value != null; + } + if (string.Equals(key, "Keep-Alive", StringComparison.OrdinalIgnoreCase)) + { + value = KeepAlive; + return value != null; + } + if (string.Equals(key, "User-Agent", StringComparison.OrdinalIgnoreCase)) + { + value = UserAgent; + return value != null; + } + break; + case 11: + if (string.Equals(key, "Content-Md5", StringComparison.OrdinalIgnoreCase)) + { + value = ContentMd5; + return value != null; + } + break; + case 12: + if (string.Equals(key, "Content-Type", StringComparison.OrdinalIgnoreCase)) + { + value = ContentType; + return value != null; + } + if (string.Equals(key, "Max-Forwards", StringComparison.OrdinalIgnoreCase)) + { + value = MaxForwards; + return value != null; + } + break; + case 13: + if (string.Equals(key, "Authorization", StringComparison.OrdinalIgnoreCase)) + { + value = Authorization; + return value != null; + } + if (string.Equals(key, "Cache-Control", StringComparison.OrdinalIgnoreCase)) + { + value = CacheControl; + return value != null; + } + if (string.Equals(key, "Content-Range", StringComparison.OrdinalIgnoreCase)) + { + value = ContentRange; + return value != null; + } + if (string.Equals(key, "If-None-Match", StringComparison.OrdinalIgnoreCase)) + { + value = IfNoneMatch; + return value != null; + } + if (string.Equals(key, "Last-Modified", StringComparison.OrdinalIgnoreCase)) + { + value = LastModified; + return value != null; + } + break; + case 14: + if (string.Equals(key, "Accept-Charset", StringComparison.OrdinalIgnoreCase)) + { + value = AcceptCharset; + return value != null; + } + if (string.Equals(key, "Content-Length", StringComparison.OrdinalIgnoreCase)) + { + value = ContentLength; + return value != null; + } + break; + case 15: + if (string.Equals(key, "Accept-Encoding", StringComparison.OrdinalIgnoreCase)) + { + value = AcceptEncoding; + return value != null; + } + if (string.Equals(key, "Accept-Language", StringComparison.OrdinalIgnoreCase)) + { + value = AcceptLanguage; + return value != null; + } + break; + case 16: + if (string.Equals(key, "Content-Encoding", StringComparison.OrdinalIgnoreCase)) + { + value = ContentEncoding; + return value != null; + } + if (string.Equals(key, "Content-Language", StringComparison.OrdinalIgnoreCase)) + { + value = ContentLanguage; + return value != null; + } + if (string.Equals(key, "Content-Location", StringComparison.OrdinalIgnoreCase)) + { + value = ContentLocation; + return value != null; + } + break; + case 17: + if (string.Equals(key, "If-Modified-Since", StringComparison.OrdinalIgnoreCase)) + { + value = IfModifiedSince; + return value != null; + } + if (string.Equals(key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase)) + { + value = TransferEncoding; + return value != null; + } + break; + case 19: + if (string.Equals(key, "If-Unmodified-Since", StringComparison.OrdinalIgnoreCase)) + { + value = IfUnmodifiedSince; + return value != null; + } + if (string.Equals(key, "Proxy-Authorization", StringComparison.OrdinalIgnoreCase)) + { + value = ProxyAuthorization; + return value != null; + } + break; + } + value = null; + return false; + } + + private bool PropertiesTrySetValue(string key, string[] value) + { + switch (key.Length) + { + case 2: + if (string.Equals(key, "Te", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x2u; + Te = value; + return true; + } + break; + case 3: + if (string.Equals(key, "Via", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x80u; + Via = value; + return true; + } + break; + case 4: + if (string.Equals(key, "Date", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x10000u; + Date = value; + return true; + } + if (string.Equals(key, "From", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x80000u; + From = value; + return true; + } + if (string.Equals(key, "Host", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x100000u; + Host = value; + return true; + } + break; + case 5: + if (string.Equals(key, "Allow", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x10u; + Allow = value; + return true; + } + if (string.Equals(key, "Range", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x80000000u; + Range = value; + return true; + } + break; + case 6: + if (string.Equals(key, "Accept", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x1u; + Accept = value; + return true; + } + if (string.Equals(key, "Cookie", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x8000u; + Cookie = value; + return true; + } + if (string.Equals(key, "Expect", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x20000u; + Expect = value; + return true; + } + if (string.Equals(key, "Pragma", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x20000000u; + Pragma = value; + return true; + } + break; + case 7: + if (string.Equals(key, "Expires", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x40000u; + Expires = value; + return true; + } + if (string.Equals(key, "Referer", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x1u; + Referer = value; + return true; + } + if (string.Equals(key, "Trailer", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x4u; + Trailer = value; + return true; + } + if (string.Equals(key, "Upgrade", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x20u; + Upgrade = value; + return true; + } + if (string.Equals(key, "Warning", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x100u; + Warning = value; + return true; + } + break; + case 8: + if (string.Equals(key, "If-Match", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x200000u; + IfMatch = value; + return true; + } + if (string.Equals(key, "If-Range", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x1000000u; + IfRange = value; + return true; + } + break; + case 9: + if (string.Equals(key, "Translate", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x10u; + Translate = value; + return true; + } + break; + case 10: + if (string.Equals(key, "Connection", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x80u; + Connection = value; + return true; + } + if (string.Equals(key, "Keep-Alive", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x4000000u; + KeepAlive = value; + return true; + } + if (string.Equals(key, "User-Agent", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x40u; + UserAgent = value; + return true; + } + break; + case 11: + if (string.Equals(key, "Content-Md5", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x1000u; + ContentMd5 = value; + return true; + } + break; + case 12: + if (string.Equals(key, "Content-Type", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x4000u; + ContentType = value; + return true; + } + if (string.Equals(key, "Max-Forwards", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x10000000u; + MaxForwards = value; + return true; + } + break; + case 13: + if (string.Equals(key, "Authorization", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x20u; + Authorization = value; + return true; + } + if (string.Equals(key, "Cache-Control", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x40u; + CacheControl = value; + return true; + } + if (string.Equals(key, "Content-Range", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x2000u; + ContentRange = value; + return true; + } + if (string.Equals(key, "If-None-Match", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x800000u; + IfNoneMatch = value; + return true; + } + if (string.Equals(key, "Last-Modified", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x8000000u; + LastModified = value; + return true; + } + break; + case 14: + if (string.Equals(key, "Accept-Charset", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x2u; + AcceptCharset = value; + return true; + } + if (string.Equals(key, "Content-Length", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x400u; + ContentLength = value; + return true; + } + break; + case 15: + if (string.Equals(key, "Accept-Encoding", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x4u; + AcceptEncoding = value; + return true; + } + if (string.Equals(key, "Accept-Language", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x8u; + AcceptLanguage = value; + return true; + } + break; + case 16: + if (string.Equals(key, "Content-Encoding", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x100u; + ContentEncoding = value; + return true; + } + if (string.Equals(key, "Content-Language", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x200u; + ContentLanguage = value; + return true; + } + if (string.Equals(key, "Content-Location", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x800u; + ContentLocation = value; + return true; + } + break; + case 17: + if (string.Equals(key, "If-Modified-Since", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x400000u; + IfModifiedSince = value; + return true; + } + if (string.Equals(key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x8u; + TransferEncoding = value; + return true; + } + break; + case 19: + if (string.Equals(key, "If-Unmodified-Since", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x2000000u; + IfUnmodifiedSince = value; + return true; + } + if (string.Equals(key, "Proxy-Authorization", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x40000000u; + ProxyAuthorization = value; + return true; + } + break; + } + return false; + } + + private bool PropertiesTryRemove(string key) + { + switch (key.Length) + { + case 2: + if (_Te != null + && string.Equals(key, "Te", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x2u) != 0); + Te = null; + return wasSet; + } + break; + case 3: + if (_Via != null + && string.Equals(key, "Via", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x80u) != 0); + Via = null; + return wasSet; + } + break; + case 4: + if (_Date != null + && string.Equals(key, "Date", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x10000u) != 0); + Date = null; + return wasSet; + } + if (_From != null + && string.Equals(key, "From", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x80000u) != 0); + From = null; + return wasSet; + } + if (_Host != null + && string.Equals(key, "Host", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x100000u) != 0); + Host = null; + return wasSet; + } + break; + case 5: + if (_Allow != null + && string.Equals(key, "Allow", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x10u) != 0); + Allow = null; + return wasSet; + } + if (_Range != null + && string.Equals(key, "Range", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x80000000u) != 0); + Range = null; + return wasSet; + } + break; + case 6: + if (_Accept != null + && string.Equals(key, "Accept", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x1u) != 0); + Accept = null; + return wasSet; + } + if (_Cookie != null + && string.Equals(key, "Cookie", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x8000u) != 0); + Cookie = null; + return wasSet; + } + if (_Expect != null + && string.Equals(key, "Expect", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x20000u) != 0); + Expect = null; + return wasSet; + } + if (_Pragma != null + && string.Equals(key, "Pragma", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x20000000u) != 0); + Pragma = null; + return wasSet; + } + break; + case 7: + if (_Expires != null + && string.Equals(key, "Expires", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x40000u) != 0); + Expires = null; + return wasSet; + } + if (_Referer != null + && string.Equals(key, "Referer", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x1u) != 0); + Referer = null; + return wasSet; + } + if (_Trailer != null + && string.Equals(key, "Trailer", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x4u) != 0); + Trailer = null; + return wasSet; + } + if (_Upgrade != null + && string.Equals(key, "Upgrade", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x20u) != 0); + Upgrade = null; + return wasSet; + } + if (_Warning != null + && string.Equals(key, "Warning", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x100u) != 0); + Warning = null; + return wasSet; + } + break; + case 8: + if (_IfMatch != null + && string.Equals(key, "If-Match", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x200000u) != 0); + IfMatch = null; + return wasSet; + } + if (_IfRange != null + && string.Equals(key, "If-Range", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x1000000u) != 0); + IfRange = null; + return wasSet; + } + break; + case 9: + if (_Translate != null + && string.Equals(key, "Translate", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x10u) != 0); + Translate = null; + return wasSet; + } + break; + case 10: + if (_Connection != null + && string.Equals(key, "Connection", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x80u) != 0); + Connection = null; + return wasSet; + } + if (_KeepAlive != null + && string.Equals(key, "Keep-Alive", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x4000000u) != 0); + KeepAlive = null; + return wasSet; + } + if (_UserAgent != null + && string.Equals(key, "User-Agent", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x40u) != 0); + UserAgent = null; + return wasSet; + } + break; + case 11: + if (_ContentMd5 != null + && string.Equals(key, "Content-Md5", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x1000u) != 0); + ContentMd5 = null; + return wasSet; + } + break; + case 12: + if (_ContentType != null + && string.Equals(key, "Content-Type", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x4000u) != 0); + ContentType = null; + return wasSet; + } + if (_MaxForwards != null + && string.Equals(key, "Max-Forwards", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x10000000u) != 0); + MaxForwards = null; + return wasSet; + } + break; + case 13: + if (_Authorization != null + && string.Equals(key, "Authorization", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x20u) != 0); + Authorization = null; + return wasSet; + } + if (_CacheControl != null + && string.Equals(key, "Cache-Control", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x40u) != 0); + CacheControl = null; + return wasSet; + } + if (_ContentRange != null + && string.Equals(key, "Content-Range", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x2000u) != 0); + ContentRange = null; + return wasSet; + } + if (_IfNoneMatch != null + && string.Equals(key, "If-None-Match", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x800000u) != 0); + IfNoneMatch = null; + return wasSet; + } + if (_LastModified != null + && string.Equals(key, "Last-Modified", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x8000000u) != 0); + LastModified = null; + return wasSet; + } + break; + case 14: + if (_AcceptCharset != null + && string.Equals(key, "Accept-Charset", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x2u) != 0); + AcceptCharset = null; + return wasSet; + } + if (_ContentLength != null + && string.Equals(key, "Content-Length", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x400u) != 0); + ContentLength = null; + return wasSet; + } + break; + case 15: + if (_AcceptEncoding != null + && string.Equals(key, "Accept-Encoding", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x4u) != 0); + AcceptEncoding = null; + return wasSet; + } + if (_AcceptLanguage != null + && string.Equals(key, "Accept-Language", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x8u) != 0); + AcceptLanguage = null; + return wasSet; + } + break; + case 16: + if (_ContentEncoding != null + && string.Equals(key, "Content-Encoding", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x100u) != 0); + ContentEncoding = null; + return wasSet; + } + if (_ContentLanguage != null + && string.Equals(key, "Content-Language", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x200u) != 0); + ContentLanguage = null; + return wasSet; + } + if (_ContentLocation != null + && string.Equals(key, "Content-Location", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x800u) != 0); + ContentLocation = null; + return wasSet; + } + break; + case 17: + if (_IfModifiedSince != null + && string.Equals(key, "If-Modified-Since", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x400000u) != 0); + IfModifiedSince = null; + return wasSet; + } + if (_TransferEncoding != null + && string.Equals(key, "Transfer-Encoding", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x8u) != 0); + TransferEncoding = null; + return wasSet; + } + break; + case 19: + if (_IfUnmodifiedSince != null + && string.Equals(key, "If-Unmodified-Since", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x2000000u) != 0); + IfUnmodifiedSince = null; + return wasSet; + } + if (_ProxyAuthorization != null + && string.Equals(key, "Proxy-Authorization", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x40000000u) != 0); + ProxyAuthorization = null; + return wasSet; + } + break; + } + return false; + } + + private IEnumerable PropertiesKeys() + { + if (Accept != null) + { + yield return "Accept"; + } + if (AcceptCharset != null) + { + yield return "Accept-Charset"; + } + if (AcceptEncoding != null) + { + yield return "Accept-Encoding"; + } + if (AcceptLanguage != null) + { + yield return "Accept-Language"; + } + if (Allow != null) + { + yield return "Allow"; + } + if (Authorization != null) + { + yield return "Authorization"; + } + if (CacheControl != null) + { + yield return "Cache-Control"; + } + if (Connection != null) + { + yield return "Connection"; + } + if (ContentEncoding != null) + { + yield return "Content-Encoding"; + } + if (ContentLanguage != null) + { + yield return "Content-Language"; + } + if (ContentLength != null) + { + yield return "Content-Length"; + } + if (ContentLocation != null) + { + yield return "Content-Location"; + } + if (ContentMd5 != null) + { + yield return "Content-Md5"; + } + if (ContentRange != null) + { + yield return "Content-Range"; + } + if (ContentType != null) + { + yield return "Content-Type"; + } + if (Cookie != null) + { + yield return "Cookie"; + } + if (Date != null) + { + yield return "Date"; + } + if (Expect != null) + { + yield return "Expect"; + } + if (Expires != null) + { + yield return "Expires"; + } + if (From != null) + { + yield return "From"; + } + if (Host != null) + { + yield return "Host"; + } + if (IfMatch != null) + { + yield return "If-Match"; + } + if (IfModifiedSince != null) + { + yield return "If-Modified-Since"; + } + if (IfNoneMatch != null) + { + yield return "If-None-Match"; + } + if (IfRange != null) + { + yield return "If-Range"; + } + if (IfUnmodifiedSince != null) + { + yield return "If-Unmodified-Since"; + } + if (KeepAlive != null) + { + yield return "Keep-Alive"; + } + if (LastModified != null) + { + yield return "Last-Modified"; + } + if (MaxForwards != null) + { + yield return "Max-Forwards"; + } + if (Pragma != null) + { + yield return "Pragma"; + } + if (ProxyAuthorization != null) + { + yield return "Proxy-Authorization"; + } + if (Range != null) + { + yield return "Range"; + } + if (Referer != null) + { + yield return "Referer"; + } + if (Te != null) + { + yield return "Te"; + } + if (Trailer != null) + { + yield return "Trailer"; + } + if (TransferEncoding != null) + { + yield return "Transfer-Encoding"; + } + if (Translate != null) + { + yield return "Translate"; + } + if (Upgrade != null) + { + yield return "Upgrade"; + } + if (UserAgent != null) + { + yield return "User-Agent"; + } + if (Via != null) + { + yield return "Via"; + } + if (Warning != null) + { + yield return "Warning"; + } + } + + private IEnumerable PropertiesValues() + { + if (Accept != null) + { + yield return Accept; + } + if (AcceptCharset != null) + { + yield return AcceptCharset; + } + if (AcceptEncoding != null) + { + yield return AcceptEncoding; + } + if (AcceptLanguage != null) + { + yield return AcceptLanguage; + } + if (Allow != null) + { + yield return Allow; + } + if (Authorization != null) + { + yield return Authorization; + } + if (CacheControl != null) + { + yield return CacheControl; + } + if (Connection != null) + { + yield return Connection; + } + if (ContentEncoding != null) + { + yield return ContentEncoding; + } + if (ContentLanguage != null) + { + yield return ContentLanguage; + } + if (ContentLength != null) + { + yield return ContentLength; + } + if (ContentLocation != null) + { + yield return ContentLocation; + } + if (ContentMd5 != null) + { + yield return ContentMd5; + } + if (ContentRange != null) + { + yield return ContentRange; + } + if (ContentType != null) + { + yield return ContentType; + } + if (Cookie != null) + { + yield return Cookie; + } + if (Date != null) + { + yield return Date; + } + if (Expect != null) + { + yield return Expect; + } + if (Expires != null) + { + yield return Expires; + } + if (From != null) + { + yield return From; + } + if (Host != null) + { + yield return Host; + } + if (IfMatch != null) + { + yield return IfMatch; + } + if (IfModifiedSince != null) + { + yield return IfModifiedSince; + } + if (IfNoneMatch != null) + { + yield return IfNoneMatch; + } + if (IfRange != null) + { + yield return IfRange; + } + if (IfUnmodifiedSince != null) + { + yield return IfUnmodifiedSince; + } + if (KeepAlive != null) + { + yield return KeepAlive; + } + if (LastModified != null) + { + yield return LastModified; + } + if (MaxForwards != null) + { + yield return MaxForwards; + } + if (Pragma != null) + { + yield return Pragma; + } + if (ProxyAuthorization != null) + { + yield return ProxyAuthorization; + } + if (Range != null) + { + yield return Range; + } + if (Referer != null) + { + yield return Referer; + } + if (Te != null) + { + yield return Te; + } + if (Trailer != null) + { + yield return Trailer; + } + if (TransferEncoding != null) + { + yield return TransferEncoding; + } + if (Translate != null) + { + yield return Translate; + } + if (Upgrade != null) + { + yield return Upgrade; + } + if (UserAgent != null) + { + yield return UserAgent; + } + if (Via != null) + { + yield return Via; + } + if (Warning != null) + { + yield return Warning; + } + } + + private IEnumerable> PropertiesEnumerable() + { + if (Accept != null) + { + yield return new KeyValuePair("Accept", Accept); + } + if (AcceptCharset != null) + { + yield return new KeyValuePair("Accept-Charset", AcceptCharset); + } + if (AcceptEncoding != null) + { + yield return new KeyValuePair("Accept-Encoding", AcceptEncoding); + } + if (AcceptLanguage != null) + { + yield return new KeyValuePair("Accept-Language", AcceptLanguage); + } + if (Allow != null) + { + yield return new KeyValuePair("Allow", Allow); + } + if (Authorization != null) + { + yield return new KeyValuePair("Authorization", Authorization); + } + if (CacheControl != null) + { + yield return new KeyValuePair("Cache-Control", CacheControl); + } + if (Connection != null) + { + yield return new KeyValuePair("Connection", Connection); + } + if (ContentEncoding != null) + { + yield return new KeyValuePair("Content-Encoding", ContentEncoding); + } + if (ContentLanguage != null) + { + yield return new KeyValuePair("Content-Language", ContentLanguage); + } + if (ContentLength != null) + { + yield return new KeyValuePair("Content-Length", ContentLength); + } + if (ContentLocation != null) + { + yield return new KeyValuePair("Content-Location", ContentLocation); + } + if (ContentMd5 != null) + { + yield return new KeyValuePair("Content-Md5", ContentMd5); + } + if (ContentRange != null) + { + yield return new KeyValuePair("Content-Range", ContentRange); + } + if (ContentType != null) + { + yield return new KeyValuePair("Content-Type", ContentType); + } + if (Cookie != null) + { + yield return new KeyValuePair("Cookie", Cookie); + } + if (Date != null) + { + yield return new KeyValuePair("Date", Date); + } + if (Expect != null) + { + yield return new KeyValuePair("Expect", Expect); + } + if (Expires != null) + { + yield return new KeyValuePair("Expires", Expires); + } + if (From != null) + { + yield return new KeyValuePair("From", From); + } + if (Host != null) + { + yield return new KeyValuePair("Host", Host); + } + if (IfMatch != null) + { + yield return new KeyValuePair("If-Match", IfMatch); + } + if (IfModifiedSince != null) + { + yield return new KeyValuePair("If-Modified-Since", IfModifiedSince); + } + if (IfNoneMatch != null) + { + yield return new KeyValuePair("If-None-Match", IfNoneMatch); + } + if (IfRange != null) + { + yield return new KeyValuePair("If-Range", IfRange); + } + if (IfUnmodifiedSince != null) + { + yield return new KeyValuePair("If-Unmodified-Since", IfUnmodifiedSince); + } + if (KeepAlive != null) + { + yield return new KeyValuePair("Keep-Alive", KeepAlive); + } + if (LastModified != null) + { + yield return new KeyValuePair("Last-Modified", LastModified); + } + if (MaxForwards != null) + { + yield return new KeyValuePair("Max-Forwards", MaxForwards); + } + if (Pragma != null) + { + yield return new KeyValuePair("Pragma", Pragma); + } + if (ProxyAuthorization != null) + { + yield return new KeyValuePair("Proxy-Authorization", ProxyAuthorization); + } + if (Range != null) + { + yield return new KeyValuePair("Range", Range); + } + if (Referer != null) + { + yield return new KeyValuePair("Referer", Referer); + } + if (Te != null) + { + yield return new KeyValuePair("Te", Te); + } + if (Trailer != null) + { + yield return new KeyValuePair("Trailer", Trailer); + } + if (TransferEncoding != null) + { + yield return new KeyValuePair("Transfer-Encoding", TransferEncoding); + } + if (Translate != null) + { + yield return new KeyValuePair("Translate", Translate); + } + if (Upgrade != null) + { + yield return new KeyValuePair("Upgrade", Upgrade); + } + if (UserAgent != null) + { + yield return new KeyValuePair("User-Agent", UserAgent); + } + if (Via != null) + { + yield return new KeyValuePair("Via", Via); + } + if (Warning != null) + { + yield return new KeyValuePair("Warning", Warning); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.tt b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.tt new file mode 100644 index 0000000000..2b339e4640 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.tt @@ -0,0 +1,218 @@ +<#@ template language="C#" #> +<#@ assembly name="System.Core.dll" #> +<#@ import namespace="System.Linq" #> +<# +var props = new[] +{ + new { Key = "Accept", Name = "Accept", ID = "HttpSysRequestHeader.Accept" }, + new { Key = "Accept-Charset", Name = "AcceptCharset", ID = "HttpSysRequestHeader.AcceptCharset" }, + new { Key = "Accept-Encoding", Name = "AcceptEncoding", ID = "HttpSysRequestHeader.AcceptEncoding" }, + new { Key = "Accept-Language", Name = "AcceptLanguage", ID = "HttpSysRequestHeader.AcceptLanguage" }, + new { Key = "Allow", Name = "Allow", ID = "HttpSysRequestHeader.Allow" }, + new { Key = "Authorization", Name = "Authorization", ID = "HttpSysRequestHeader.Authorization" }, + new { Key = "Cache-Control", Name = "CacheControl", ID = "HttpSysRequestHeader.CacheControl" }, + new { Key = "Connection", Name = "Connection", ID = "HttpSysRequestHeader.Connection" }, + new { Key = "Content-Encoding", Name = "ContentEncoding", ID = "HttpSysRequestHeader.ContentEncoding" }, + new { Key = "Content-Language", Name = "ContentLanguage", ID = "HttpSysRequestHeader.ContentLanguage" }, + new { Key = "Content-Length", Name = "ContentLength", ID = "HttpSysRequestHeader.ContentLength" }, + new { Key = "Content-Location", Name = "ContentLocation", ID = "HttpSysRequestHeader.ContentLocation" }, + new { Key = "Content-Md5", Name = "ContentMd5", ID = "HttpSysRequestHeader.ContentMd5" }, + new { Key = "Content-Range", Name = "ContentRange", ID = "HttpSysRequestHeader.ContentRange" }, + new { Key = "Content-Type", Name = "ContentType", ID = "HttpSysRequestHeader.ContentType" }, + new { Key = "Cookie", Name = "Cookie", ID = "HttpSysRequestHeader.Cookie" }, + new { Key = "Date", Name = "Date", ID = "HttpSysRequestHeader.Date" }, + new { Key = "Expect", Name = "Expect", ID = "HttpSysRequestHeader.Expect" }, + new { Key = "Expires", Name = "Expires", ID = "HttpSysRequestHeader.Expires" }, + new { Key = "From", Name = "From", ID = "HttpSysRequestHeader.From" }, + new { Key = "Host", Name = "Host", ID = "HttpSysRequestHeader.Host" }, + new { Key = "If-Match", Name = "IfMatch", ID = "HttpSysRequestHeader.IfMatch" }, + new { Key = "If-Modified-Since", Name = "IfModifiedSince", ID = "HttpSysRequestHeader.IfModifiedSince" }, + new { Key = "If-None-Match", Name = "IfNoneMatch", ID = "HttpSysRequestHeader.IfNoneMatch" }, + new { Key = "If-Range", Name = "IfRange", ID = "HttpSysRequestHeader.IfRange" }, + new { Key = "If-Unmodified-Since", Name = "IfUnmodifiedSince", ID = "HttpSysRequestHeader.IfUnmodifiedSince" }, + new { Key = "Keep-Alive", Name = "KeepAlive", ID = "HttpSysRequestHeader.KeepAlive" }, + new { Key = "Last-Modified", Name = "LastModified", ID = "HttpSysRequestHeader.LastModified" }, + new { Key = "Max-Forwards", Name = "MaxForwards", ID = "HttpSysRequestHeader.MaxForwards" }, + new { Key = "Pragma", Name = "Pragma", ID = "HttpSysRequestHeader.Pragma" }, + new { Key = "Proxy-Authorization", Name = "ProxyAuthorization", ID = "HttpSysRequestHeader.ProxyAuthorization" }, + new { Key = "Range", Name = "Range", ID = "HttpSysRequestHeader.Range" }, + new { Key = "Referer", Name = "Referer", ID = "HttpSysRequestHeader.Referer" }, + new { Key = "Te", Name = "Te", ID = "HttpSysRequestHeader.Te" }, + new { Key = "Trailer", Name = "Trailer", ID = "HttpSysRequestHeader.Trailer" }, + new { Key = "Transfer-Encoding", Name = "TransferEncoding", ID = "HttpSysRequestHeader.TransferEncoding" }, + new { Key = "Translate", Name = "Translate", ID = "HttpSysRequestHeader.Translate" }, + new { Key = "Upgrade", Name = "Upgrade", ID = "HttpSysRequestHeader.Upgrade" }, + new { Key = "User-Agent", Name = "UserAgent", ID = "HttpSysRequestHeader.UserAgent" }, + new { Key = "Via", Name = "Via", ID = "HttpSysRequestHeader.Via" }, + new { Key = "Warning", Name = "Warning", ID = "HttpSysRequestHeader.Warning" }, +}.Select((prop, Index)=>new {prop.Key, prop.Name, prop.ID, Index}); + +var lengths = props.GroupBy(prop=>prop.Key.Length).OrderBy(prop=>prop.Key); + + +Func IsRead = Index => "((_flag" + (Index / 32) + " & 0x" + (1<<(Index % 32)).ToString("x") + "u) != 0)"; +Func MarkRead = Index => "_flag" + (Index / 32) + " |= 0x" + (1<<(Index % 32)).ToString("x") + "u"; +Func Clear = Index => "_flag" + (Index / 32) + " &= ~0x" + (1<<(Index % 32)).ToString("x") + "u"; +#> +//----------------------------------------------------------------------- +// +// Copyright (c) Katana Contributors. All rights reserved. +// +//----------------------------------------------------------------------- +// + +using System; +using System.CodeDom.Compiler; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Owin.Host.WebListener +{ + [GeneratedCode("TextTemplatingFileGenerator", "")] + internal partial class RequestHeaders + { + // Tracks if individual fields have been read from native or set directly. + // Once read or set, their presence in the collection is marked by if their string[] is null or not. + private UInt32 _flag0, _flag1; + +<# foreach(var prop in props) { #> + private string[] _<#=prop.Name#>; +<# } #> + +<# foreach(var prop in props) { #> + internal string[] <#=prop.Name#> + { + get + { + if (!<#=IsRead(prop.Index)#>) + { + string nativeValue = GetKnownHeader(<#=prop.ID#>); + if (nativeValue != null) + { + _<#=prop.Name#> = new string[] { nativeValue }; + } + <#=MarkRead(prop.Index)#>; + } + return _<#=prop.Name#>; + } + set + { + <#=MarkRead(prop.Index)#>; + _<#=prop.Name#> = value; + } + } + +<# } #> + private bool PropertiesContainsKey(string key) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (string.Equals(key, "<#=prop.Key#>", StringComparison.OrdinalIgnoreCase)) + { + return <#=prop.Name#> != null; + } +<# } #> + break; +<# } #> + } + return false; + } + + private bool PropertiesTryGetValue(string key, out string[] value) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (string.Equals(key, "<#=prop.Key#>", StringComparison.OrdinalIgnoreCase)) + { + value = <#=prop.Name#>; + return value != null; + } +<# } #> + break; +<# } #> + } + value = null; + return false; + } + + private bool PropertiesTrySetValue(string key, string[] value) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (string.Equals(key, "<#=prop.Key#>", StringComparison.OrdinalIgnoreCase)) + { + <#=MarkRead(prop.Index)#>; + <#=prop.Name#> = value; + return true; + } +<# } #> + break; +<# } #> + } + return false; + } + + private bool PropertiesTryRemove(string key) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (_<#=prop.Name#> != null + && string.Equals(key, "<#=prop.Key#>", StringComparison.Ordinal)) + { + bool wasSet = <#=IsRead(prop.Index)#>; + <#=prop.Name#> = null; + return wasSet; + } +<# } #> + break; +<# } #> + } + return false; + } + + private IEnumerable PropertiesKeys() + { +<# foreach(var prop in props) { #> + if (<#=prop.Name#> != null) + { + yield return "<#=prop.Key#>"; + } +<# } #> + } + + private IEnumerable PropertiesValues() + { +<# foreach(var prop in props) { #> + if (<#=prop.Name#> != null) + { + yield return <#=prop.Name#>; + } +<# } #> + } + + private IEnumerable> PropertiesEnumerable() + { +<# foreach(var prop in props) { #> + if (<#=prop.Name#> != null) + { + yield return new KeyValuePair("<#=prop.Key#>", <#=prop.Name#>); + } +<# } #> + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.cs new file mode 100644 index 0000000000..9e43ab86e5 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.cs @@ -0,0 +1,167 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- +// Copyright 2011-2012 Katana contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Threading; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal partial class RequestHeaders : IDictionary + { + private IDictionary _extra; + private NativeRequestContext _requestMemoryBlob; + + internal RequestHeaders(NativeRequestContext requestMemoryBlob) + { + _requestMemoryBlob = requestMemoryBlob; + } + + private IDictionary Extra + { + get + { + if (_extra == null) + { + var newDict = new Dictionary(StringComparer.OrdinalIgnoreCase); + GetUnknownHeaders(newDict); + Interlocked.CompareExchange(ref _extra, newDict, null); + } + return _extra; + } + } + + string[] IDictionary.this[string key] + { + get + { + string[] value; + return PropertiesTryGetValue(key, out value) ? value : Extra[key]; + } + set + { + if (!PropertiesTrySetValue(key, value)) + { + Extra[key] = value; + } + } + } + + private string GetKnownHeader(HttpSysRequestHeader header) + { + return UnsafeNclNativeMethods.HttpApi.GetKnownHeader(_requestMemoryBlob.RequestBuffer, + _requestMemoryBlob.OriginalBlobAddress, (int)header); + } + + private void GetUnknownHeaders(IDictionary extra) + { + UnsafeNclNativeMethods.HttpApi.GetUnknownHeaders(extra, _requestMemoryBlob.RequestBuffer, + _requestMemoryBlob.OriginalBlobAddress); + } + + void IDictionary.Add(string key, string[] value) + { + if (!PropertiesTrySetValue(key, value)) + { + Extra.Add(key, value); + } + } + + bool IDictionary.ContainsKey(string key) + { + return PropertiesContainsKey(key) || Extra.ContainsKey(key); + } + + ICollection IDictionary.Keys + { + get { return PropertiesKeys().Concat(Extra.Keys).ToArray(); } + } + + bool IDictionary.Remove(string key) + { + // Although this is a mutating operation, Extra is used instead of StrongExtra, + // because if a real dictionary has not been allocated the default behavior of the + // nil dictionary is perfectly fine. + return PropertiesTryRemove(key) || Extra.Remove(key); + } + + bool IDictionary.TryGetValue(string key, out string[] value) + { + return PropertiesTryGetValue(key, out value) || Extra.TryGetValue(key, out value); + } + + ICollection IDictionary.Values + { + get { return PropertiesValues().Concat(Extra.Values).ToArray(); } + } + + void ICollection>.Add(KeyValuePair item) + { + ((IDictionary)this).Add(item.Key, item.Value); + } + + void ICollection>.Clear() + { + foreach (var key in PropertiesKeys()) + { + PropertiesTryRemove(key); + } + Extra.Clear(); + } + + bool ICollection>.Contains(KeyValuePair item) + { + object value; + return ((IDictionary)this).TryGetValue(item.Key, out value) && Object.Equals(value, item.Value); + } + + void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) + { + PropertiesEnumerable().Concat(Extra).ToArray().CopyTo(array, arrayIndex); + } + + int ICollection>.Count + { + get { return PropertiesKeys().Count() + Extra.Count; } + } + + bool ICollection>.IsReadOnly + { + get { return false; } + } + + bool ICollection>.Remove(KeyValuePair item) + { + return ((IDictionary)this).Contains(item) && + ((IDictionary)this).Remove(item.Key); + } + + IEnumerator> IEnumerable>.GetEnumerator() + { + return PropertiesEnumerable().Concat(Extra).GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return ((IDictionary)this).GetEnumerator(); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestStream.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestStream.cs new file mode 100644 index 0000000000..5c73f0d41c --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestStream.cs @@ -0,0 +1,614 @@ +// ------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class RequestStream : Stream + { + private const int MaxReadSize = 0x20000; // http.sys recommends we limit reads to 128k + + private RequestContext _requestContext; + private uint _dataChunkOffset; + private int _dataChunkIndex; + private bool _closed; + + internal RequestStream(RequestContext httpContext) + { + _requestContext = httpContext; + } + + public override bool CanSeek + { + get + { + return false; + } + } + + public override bool CanWrite + { + get + { + return false; + } + } + + public override bool CanRead + { + get + { + return true; + } + } + + public override long Length + { + get + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + } + + public override long Position + { + get + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + set + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + + public override void Flush() + { + throw new InvalidOperationException(Resources.Exception_ReadOnlyStream); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + throw new InvalidOperationException(Resources.Exception_ReadOnlyStream); + } + + public override unsafe int Read([In, Out] byte[] buffer, int offset, int size) + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException("size"); + } + if (size == 0 || _closed) + { + // TODO: zero sized buffer should be invalid. + return 0; + } + // TODO: Verbose log parameters + + uint dataRead = 0; + + if (_dataChunkIndex != -1) + { + dataRead = UnsafeNclNativeMethods.HttpApi.GetChunks(_requestContext.Request.RequestBuffer, _requestContext.Request.OriginalBlobAddress, ref _dataChunkIndex, ref _dataChunkOffset, buffer, offset, size); + } + + if (_dataChunkIndex == -1 && dataRead < size) + { + uint statusCode = 0; + uint extraDataRead = 0; + offset += (int)dataRead; + size -= (int)dataRead; + + // the http.sys team recommends that we limit the size to 128kb + if (size > MaxReadSize) + { + size = MaxReadSize; + } + + fixed (byte* pBuffer = buffer) + { + // issue unmanaged blocking call + + uint flags = 0; + + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpReceiveRequestEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + flags, + (IntPtr)(pBuffer + offset), + (uint)size, + out extraDataRead, + SafeNativeOverlapped.Zero); + + dataRead += extraDataRead; + } + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "Read", exception); + throw exception; + } + UpdateAfterRead(statusCode, dataRead); + } + + // TODO: Verbose log dump data read + return (int)dataRead; + } + + private void UpdateAfterRead(uint statusCode, uint dataRead) + { + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF || dataRead == 0) + { + Dispose(); + } + } + +#if NET45 + public override unsafe IAsyncResult BeginRead(byte[] buffer, int offset, int size, AsyncCallback callback, object state) +#else + public unsafe IAsyncResult BeginRead(byte[] buffer, int offset, int size, AsyncCallback callback, object state) +#endif + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException("size"); + } + if (size == 0 || _closed) + { + RequestStreamAsyncResult result = new RequestStreamAsyncResult(this, state, callback); + result.Complete(0); + return result; + } + // TODO: Verbose log parameters + + RequestStreamAsyncResult asyncResult = null; + + uint dataRead = 0; + if (_dataChunkIndex != -1) + { + dataRead = UnsafeNclNativeMethods.HttpApi.GetChunks(_requestContext.Request.RequestBuffer, _requestContext.Request.OriginalBlobAddress, ref _dataChunkIndex, ref _dataChunkOffset, buffer, offset, size); + if (_dataChunkIndex != -1 && dataRead == size) + { + asyncResult = new RequestStreamAsyncResult(this, state, callback, buffer, offset, 0); + asyncResult.Complete((int)dataRead); + } + } + + if (_dataChunkIndex == -1 && dataRead < size) + { + uint statusCode = 0; + offset += (int)dataRead; + size -= (int)dataRead; + + // the http.sys team recommends that we limit the size to 128kb + if (size > MaxReadSize) + { + size = MaxReadSize; + } + + asyncResult = new RequestStreamAsyncResult(this, state, callback, buffer, offset, dataRead); + uint bytesReturned; + + try + { + fixed (byte* pBuffer = buffer) + { + uint flags = 0; + + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpReceiveRequestEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + flags, + asyncResult.PinnedBuffer, + (uint)size, + out bytesReturned, + asyncResult.NativeOverlapped); + } + } + catch (Exception e) + { + LogHelper.LogException(_requestContext.Logger, "BeginRead", e); + asyncResult.Dispose(); + throw; + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + asyncResult.Dispose(); + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + asyncResult = new RequestStreamAsyncResult(this, state, callback, dataRead); + asyncResult.Complete((int)bytesReturned); + } + else + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "BeginRead", exception); + throw exception; + } + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + asyncResult.IOCompleted(statusCode, bytesReturned); + } + } + return asyncResult; + } + +#if NET45 + public override int EndRead(IAsyncResult asyncResult) +#else + public int EndRead(IAsyncResult asyncResult) +#endif + { + if (asyncResult == null) + { + throw new ArgumentNullException("asyncResult"); + } + RequestStreamAsyncResult castedAsyncResult = asyncResult as RequestStreamAsyncResult; + if (castedAsyncResult == null || castedAsyncResult.RequestStream != this) + { + throw new ArgumentException(Resources.Exception_WrongIAsyncResult, "asyncResult"); + } + if (castedAsyncResult.EndCalled) + { + throw new InvalidOperationException(Resources.Exception_EndCalledMultipleTimes); + } + castedAsyncResult.EndCalled = true; + // wait & then check for errors + // Throws on failure + int dataRead = castedAsyncResult.Task.Result; + // TODO: Verbose log #dataRead. + return dataRead; + } + + public override unsafe Task ReadAsync(byte[] buffer, int offset, int size, CancellationToken cancellationToken) + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException("size"); + } + if (_closed) + { + throw new ObjectDisposedException(GetType().FullName); + } + if (size == 0) + { + return Task.FromResult(0); + } + // TODO: Needs full cancellation integration + cancellationToken.ThrowIfCancellationRequested(); + // TODO: Verbose log parameters + + RequestStreamAsyncResult asyncResult = null; + + uint dataRead = 0; + if (_dataChunkIndex != -1) + { + dataRead = UnsafeNclNativeMethods.HttpApi.GetChunks(_requestContext.Request.RequestBuffer, _requestContext.Request.OriginalBlobAddress, ref _dataChunkIndex, ref _dataChunkOffset, buffer, offset, size); + if (_dataChunkIndex != -1 && dataRead == size) + { + UpdateAfterRead(UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS, dataRead); + // TODO: Verbose log #dataRead + return Task.FromResult((int)dataRead); + } + } + + if (_dataChunkIndex == -1 && dataRead < size) + { + uint statusCode = 0; + offset += (int)dataRead; + size -= (int)dataRead; + + // the http.sys team recommends that we limit the size to 128kb + if (size > MaxReadSize) + { + size = MaxReadSize; + } + + asyncResult = new RequestStreamAsyncResult(this, null, null, buffer, offset, dataRead); + uint bytesReturned; + + try + { + uint flags = 0; + + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpReceiveRequestEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + flags, + asyncResult.PinnedBuffer, + (uint)size, + out bytesReturned, + asyncResult.NativeOverlapped); + } + catch (Exception e) + { + asyncResult.Dispose(); + LogHelper.LogException(_requestContext.Logger, "ReadAsync", e); + throw; + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + asyncResult.Dispose(); + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + uint totalRead = dataRead + bytesReturned; + UpdateAfterRead(statusCode, totalRead); + // TODO: Verbose log totalRead + return Task.FromResult((int)totalRead); + } + else + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "ReadAsync", exception); + throw exception; + } + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + asyncResult.Dispose(); + uint totalRead = dataRead + bytesReturned; + UpdateAfterRead(statusCode, totalRead); + // TODO: Verbose log + return Task.FromResult((int)totalRead); + } + } + return asyncResult.Task; + } + + public override void Write(byte[] buffer, int offset, int size) + { + throw new InvalidOperationException(Resources.Exception_ReadOnlyStream); + } + +#if NET45 + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback callback, object state) +#else + public IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback callback, object state) +#endif + { + throw new InvalidOperationException(Resources.Exception_ReadOnlyStream); + } + +#if NET45 + public override void EndWrite(IAsyncResult asyncResult) +#else + public void EndWrite(IAsyncResult asyncResult) +#endif + { + throw new InvalidOperationException(Resources.Exception_ReadOnlyStream); + } + + protected override void Dispose(bool disposing) + { + try + { + _closed = true; + } + finally + { + base.Dispose(disposing); + } + } + + private unsafe class RequestStreamAsyncResult : IAsyncResult, IDisposable + { + private static readonly IOCompletionCallback IOCallback = new IOCompletionCallback(Callback); + + private SafeNativeOverlapped _overlapped; + private IntPtr _pinnedBuffer; + private uint _dataAlreadyRead = 0; + private TaskCompletionSource _tcs; + private RequestStream _requestStream; + private AsyncCallback _callback; + + internal RequestStreamAsyncResult(RequestStream requestStream, object userState, AsyncCallback callback) + { + _requestStream = requestStream; + _tcs = new TaskCompletionSource(userState); + _callback = callback; + } + + internal RequestStreamAsyncResult(RequestStream requestStream, object userState, AsyncCallback callback, uint dataAlreadyRead) + : this(requestStream, userState, callback) + { + _dataAlreadyRead = dataAlreadyRead; + } + + internal RequestStreamAsyncResult(RequestStream requestStream, object userState, AsyncCallback callback, byte[] buffer, int offset, uint dataAlreadyRead) + : this(requestStream, userState, callback) + { + _dataAlreadyRead = dataAlreadyRead; + Overlapped overlapped = new Overlapped(); + overlapped.AsyncResult = this; + _overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, buffer)); + _pinnedBuffer = (Marshal.UnsafeAddrOfPinnedArrayElement(buffer, offset)); + } + + internal RequestStream RequestStream + { + get { return _requestStream; } + } + + internal SafeNativeOverlapped NativeOverlapped + { + get { return _overlapped; } + } + + internal IntPtr PinnedBuffer + { + get { return _pinnedBuffer; } + } + + internal uint DataAlreadyRead + { + get { return _dataAlreadyRead; } + } + + internal Task Task + { + get { return _tcs.Task; } + } + + internal bool EndCalled { get; set; } + + internal void IOCompleted(uint errorCode, uint numBytes) + { + IOCompleted(this, errorCode, numBytes); + } + + [SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Redirecting to callback")] + private static void IOCompleted(RequestStreamAsyncResult asyncResult, uint errorCode, uint numBytes) + { + try + { + if (errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + asyncResult.Fail(new WebListenerException((int)errorCode)); + } + else + { + // TODO: Verbose log dump data read + asyncResult.Complete((int)numBytes, errorCode); + } + } + catch (Exception e) + { + asyncResult.Fail(e); + } + } + + private static unsafe void Callback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped) + { + Overlapped callbackOverlapped = Overlapped.Unpack(nativeOverlapped); + RequestStreamAsyncResult asyncResult = callbackOverlapped.AsyncResult as RequestStreamAsyncResult; + + IOCompleted(asyncResult, errorCode, numBytes); + } + + internal void Complete(int read, uint errorCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + if (_tcs.TrySetResult(read + (int)DataAlreadyRead)) + { + RequestStream.UpdateAfterRead((uint)errorCode, (uint)(read + DataAlreadyRead)); + if (_callback != null) + { + try + { + _callback(this); + } + catch (Exception) + { + // TODO: Exception handling? This may be an IO callback thread and throwing here could crash the app. + } + } + } + } + + internal void Fail(Exception ex) + { + if (_tcs.TrySetException(ex) && _callback != null) + { + try + { + _callback(this); + } + catch (Exception) + { + // TODO: Exception handling? This may be an IO callback thread and throwing here could crash the app. + // TODO: Log + } + } + } + + [SuppressMessage("Microsoft.Usage", "CA2216:DisposableTypesShouldDeclareFinalizer", Justification = "The disposable resource referenced does have a finalizer.")] + public void Dispose() + { + Dispose(true); + } + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + if (_overlapped != null) + { + _overlapped.Dispose(); + } + } + } + + public object AsyncState + { + get { return _tcs.Task.AsyncState; } + } + + public WaitHandle AsyncWaitHandle + { + get { return ((IAsyncResult)_tcs.Task).AsyncWaitHandle; } + } + + public bool CompletedSynchronously + { + get { return ((IAsyncResult)_tcs.Task).CompletedSynchronously; } + } + + public bool IsCompleted + { + get { return _tcs.Task.IsCompleted; } + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestUriBuilder.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestUriBuilder.cs new file mode 100644 index 0000000000..697934fc93 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestUriBuilder.cs @@ -0,0 +1,569 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Text; + +namespace Microsoft.AspNet.Server.WebListener +{ + // We don't use the cooked URL because http.sys unescapes all percent-encoded values. However, + // we also can't just use the raw Uri, since http.sys supports not only Utf-8, but also ANSI/DBCS and + // Unicode code points. System.Uri only supports Utf-8. + // The purpose of this class is to convert all ANSI, DBCS, and Unicode code points into percent encoded + // Utf-8 characters. + internal sealed class RequestUriBuilder + { + private static readonly bool UseCookedRequestUrl; + private static readonly Encoding Utf8Encoding; + private static readonly Encoding AnsiEncoding; + + private readonly string _rawUri; + private readonly string _cookedUriScheme; + private readonly string _cookedUriHost; + private readonly string _cookedUriPath; + private readonly string _cookedUriQuery; + + // This field is used to build the final request Uri string from the Uri parts passed to the ctor. + private StringBuilder _requestUriString; + + // The raw path is parsed by looping through all characters from left to right. 'rawOctets' + // is used to store consecutive percent encoded octets as actual byte values: e.g. for path /pa%C3%84th%2F/ + // rawOctets will be set to { 0xC3, 0x84 } when we reach character 't' and it will be { 0x2F } when + // we reach the final '/'. I.e. after a sequence of percent encoded octets ends, we use rawOctets as + // input to the encoding and percent encode the resulting string into UTF-8 octets. + // + // When parsing ANSI (Latin 1) encoded path '/pa%C4th/', %C4 will be added to rawOctets and when + // we reach 't', the content of rawOctets { 0xC4 } will be fed into the ANSI encoding. The resulting + // string 'Ä' will be percent encoded into UTF-8 octets and appended to requestUriString. The final + // path will be '/pa%C3%84th/', where '%C3%84' is the UTF-8 percent encoded character 'Ä'. + private List _rawOctets; + private string _rawPath; + + // Holds the final request Uri. + private Uri _requestUri; + + static RequestUriBuilder() + { + // TODO: False triggers more detailed/correct parsing, but it's rather slow. + UseCookedRequestUrl = true; // SettingsSectionInternal.Section.HttpListenerUnescapeRequestUrl; + Utf8Encoding = new UTF8Encoding(false, true); +#if NET45 + AnsiEncoding = Encoding.GetEncoding(0, new EncoderExceptionFallback(), new DecoderExceptionFallback()); +#else + AnsiEncoding = Utf8Encoding; +#endif + } + + private RequestUriBuilder(string rawUri, string cookedUriScheme, string cookedUriHost, + string cookedUriPath, string cookedUriQuery) + { + Debug.Assert(!string.IsNullOrEmpty(rawUri), "Empty raw URL."); + Debug.Assert(!string.IsNullOrEmpty(cookedUriScheme), "Empty cooked URL scheme."); + Debug.Assert(!string.IsNullOrEmpty(cookedUriHost), "Empty cooked URL host."); + Debug.Assert(!string.IsNullOrEmpty(cookedUriPath), "Empty cooked URL path."); + + this._rawUri = rawUri; + this._cookedUriScheme = cookedUriScheme; + this._cookedUriHost = cookedUriHost; + this._cookedUriPath = AddSlashToAsteriskOnlyPath(cookedUriPath); + this._cookedUriQuery = cookedUriQuery ?? string.Empty; + } + + private RequestUriBuilder(string rawUri, string cookedUriPath) + { + Debug.Assert(!string.IsNullOrEmpty(rawUri), "Empty raw URL."); + Debug.Assert(!string.IsNullOrEmpty(cookedUriPath), "Empty cooked URL path."); + + this._rawUri = rawUri; + this._cookedUriScheme = string.Empty; + this._cookedUriHost = string.Empty; + this._cookedUriPath = AddSlashToAsteriskOnlyPath(cookedUriPath); + this._cookedUriQuery = string.Empty; + } + + private enum ParsingResult + { + Success, + InvalidString, + EncodingError + } + + private enum EncodingType + { + Primary, + Secondary + } + + public static Uri GetRequestUri(string rawUri, string cookedUriScheme, string cookedUriHost, + string cookedUriPath, string cookedUriQuery) + { + RequestUriBuilder builder = new RequestUriBuilder(rawUri, + cookedUriScheme, cookedUriHost, cookedUriPath, cookedUriQuery); + + return builder.Build(); + } + + private Uri Build() + { + // if the user enabled the "use raw Uri" setting in section, we'll use the raw + // path rather than the cooked path. + if (UseCookedRequestUrl) + { + // corresponds to pre-4.0 behavior: use the cooked URI. + BuildRequestUriUsingCookedPath(); + + if (_requestUri == null) + { + BuildRequestUriUsingRawPath(); + } + } + else + { + BuildRequestUriUsingRawPath(); + + if (_requestUri == null) + { + BuildRequestUriUsingCookedPath(); + } + } + + return _requestUri; + } + + // Process only the path. + internal static string GetRequestPath(string rawUri, string cookedUriPath) + { + RequestUriBuilder builder = new RequestUriBuilder(rawUri, cookedUriPath); + + return builder.GetPath(); + } + + private string GetPath() + { + if (UseCookedRequestUrl) + { + return _cookedUriPath; + } + + // Initialize 'rawPath' only if really needed; i.e. if we build the request Uri from the raw Uri. + _rawPath = GetPath(_rawUri); + + // If HTTP.sys only parses Utf-8, we can safely use the raw path: it must be a valid Utf-8 string. + if (!HttpSysSettings.EnableNonUtf8 || string.IsNullOrEmpty(_rawPath)) + { + if (string.IsNullOrEmpty(_rawPath)) + { + _rawPath = "/"; + } + return _rawPath; + } + + // Try to check the raw path using first the primary encoding (according to http.sys settings); + // if it fails try the secondary encoding. + _rawOctets = new List(); + _requestUriString = new StringBuilder(); + ParsingResult result = ParseRawPath(GetEncoding(EncodingType.Primary)); + if (result == ParsingResult.EncodingError) + { + _rawOctets = new List(); + _requestUriString = new StringBuilder(); + result = ParseRawPath(GetEncoding(EncodingType.Secondary)); + } + + if (result == ParsingResult.Success) + { + return _requestUriString.ToString(); + } + + // Fallback + return _cookedUriPath; + } + + private void BuildRequestUriUsingCookedPath() + { + bool isValid = Uri.TryCreate(_cookedUriScheme + Constants.SchemeDelimiter + _cookedUriHost + _cookedUriPath + + _cookedUriQuery, UriKind.Absolute, out _requestUri); + + // Creating a Uri from the cooked Uri should really always work: If not, we log at least. + if (!isValid) + { + LogWarning("BuildRequestUriUsingCookedPath", "Unable to create URI: " + _cookedUriScheme + Constants.SchemeDelimiter + + _cookedUriHost + _cookedUriPath + _cookedUriQuery); + } + } + + private void BuildRequestUriUsingRawPath() + { + bool isValid = false; + + // Initialize 'rawPath' only if really needed; i.e. if we build the request Uri from the raw Uri. + _rawPath = GetPath(_rawUri); + + // If HTTP.sys only parses Utf-8, we can safely use the raw path: it must be a valid Utf-8 string. + if (!HttpSysSettings.EnableNonUtf8 || string.IsNullOrEmpty(_rawPath)) + { + string path = _rawPath; + if (string.IsNullOrEmpty(path)) + { + path = "/"; + Debug.Assert(string.IsNullOrEmpty(_cookedUriQuery), + "Query is only allowed if there is a non-empty path. At least '/' path required."); + } + + isValid = Uri.TryCreate(_cookedUriScheme + Constants.SchemeDelimiter + _cookedUriHost + path + _cookedUriQuery, + UriKind.Absolute, out _requestUri); + } + else + { + // Try to check the raw path using first the primary encoding (according to http.sys settings); + // if it fails try the secondary encoding. + ParsingResult result = BuildRequestUriUsingRawPath(GetEncoding(EncodingType.Primary)); + if (result == ParsingResult.EncodingError) + { + Encoding secondaryEncoding = GetEncoding(EncodingType.Secondary); + result = BuildRequestUriUsingRawPath(secondaryEncoding); + } + isValid = (result == ParsingResult.Success) ? true : false; + } + + // Log that we weren't able to create a Uri from the raw string. + if (!isValid) + { + LogWarning("BuildRequestUriUsingRawPath", "Unable to create Uri: " + _cookedUriScheme + Constants.SchemeDelimiter + + _cookedUriHost + _rawPath + _cookedUriQuery); + } + } + + private static Encoding GetEncoding(EncodingType type) + { + Debug.Assert(HttpSysSettings.EnableNonUtf8, + "If 'EnableNonUtf8' is false we shouldn't require an encoding. It's always Utf-8."); + /* This is mucking up the profiler for some reason. + Debug.Assert((type == EncodingType.Primary) || (type == EncodingType.Secondary), + "Unknown 'EncodingType' value: " + type.ToString()); + */ + if (((type == EncodingType.Primary) && (!HttpSysSettings.FavorUtf8)) || + ((type == EncodingType.Secondary) && (HttpSysSettings.FavorUtf8))) + { + return AnsiEncoding; + } + else + { + return Utf8Encoding; + } + } + + private ParsingResult BuildRequestUriUsingRawPath(Encoding encoding) + { + Debug.Assert(encoding != null, "'encoding' must be assigned."); + Debug.Assert(!string.IsNullOrEmpty(_rawPath), "'rawPath' must have at least one character."); + + _rawOctets = new List(); + _requestUriString = new StringBuilder(); + _requestUriString.Append(_cookedUriScheme); + _requestUriString.Append(Constants.SchemeDelimiter); + _requestUriString.Append(_cookedUriHost); + + ParsingResult result = ParseRawPath(encoding); + if (result == ParsingResult.Success) + { + _requestUriString.Append(_cookedUriQuery); + + Debug.Assert(_rawOctets.Count == 0, + "Still raw octets left. They must be added to the result path."); + + if (!Uri.TryCreate(_requestUriString.ToString(), UriKind.Absolute, out _requestUri)) + { + // If we can't create a Uri from the string, this is an invalid string and it doesn't make + // sense to try another encoding. + result = ParsingResult.InvalidString; + } + } + + if (result != ParsingResult.Success) + { + LogWarning("BuildRequestUriUsingRawPath", "Can't convert the raw path: " + _rawPath + " Encoding: " + encoding.WebName); + } + + return result; + } + + private ParsingResult ParseRawPath(Encoding encoding) + { + Debug.Assert(encoding != null, "'encoding' must be assigned."); + + int index = 0; + char current = '\0'; + while (index < _rawPath.Length) + { + current = _rawPath[index]; + if (current == '%') + { + // Assert is enough, since http.sys accepted the request string already. This should never happen. + Debug.Assert(index + 2 < _rawPath.Length, "Expected >=2 characters after '%' (e.g. %2F)"); + + index++; + current = _rawPath[index]; + if (current == 'u' || current == 'U') + { + // We found "%u" which means, we have a Unicode code point of the form "%uXXXX". + Debug.Assert(index + 4 < _rawPath.Length, "Expected >=4 characters after '%u' (e.g. %u0062)"); + + // Decode the content of rawOctets into percent encoded UTF-8 characters and append them + // to requestUriString. + if (!EmptyDecodeAndAppendRawOctetsList(encoding)) + { + return ParsingResult.EncodingError; + } + if (!AppendUnicodeCodePointValuePercentEncoded(_rawPath.Substring(index + 1, 4))) + { + return ParsingResult.InvalidString; + } + index += 5; + } + else + { + // We found '%', but not followed by 'u', i.e. we have a percent encoded octed: %XX + if (!AddPercentEncodedOctetToRawOctetsList(encoding, _rawPath.Substring(index, 2))) + { + return ParsingResult.InvalidString; + } + index += 2; + } + } + else + { + // We found a non-'%' character: decode the content of rawOctets into percent encoded + // UTF-8 characters and append it to the result. + if (!EmptyDecodeAndAppendRawOctetsList(encoding)) + { + return ParsingResult.EncodingError; + } + // Append the current character to the result. + _requestUriString.Append(current); + index++; + } + } + + // if the raw path ends with a sequence of percent encoded octets, make sure those get added to the + // result (requestUriString). + if (!EmptyDecodeAndAppendRawOctetsList(encoding)) + { + return ParsingResult.EncodingError; + } + + return ParsingResult.Success; + } + + private bool AppendUnicodeCodePointValuePercentEncoded(string codePoint) + { + // http.sys only supports %uXXXX (4 hex-digits), even though unicode code points could have up to + // 6 hex digits. Therefore we parse always 4 characters after %u and convert them to an int. + int codePointValue; + if (!int.TryParse(codePoint, NumberStyles.HexNumber, null, out codePointValue)) + { + LogWarning("AppendUnicodeCodePointValuePercentEncoded", "Can't convert code point: " + codePoint); + return false; + } + + string unicodeString = null; + try + { + unicodeString = char.ConvertFromUtf32(codePointValue); + AppendOctetsPercentEncoded(_requestUriString, Utf8Encoding.GetBytes(unicodeString)); + + return true; + } + catch (ArgumentOutOfRangeException) + { + LogWarning("AppendUnicodeCodePointValuePercentEncoded", "Can't convert code point: " + codePoint); + } + catch (EncoderFallbackException e) + { + // If utf8Encoding.GetBytes() fails + LogWarning("AppendUnicodeCodePointValuePercentEncoded", "Can't convert code point: " + unicodeString, e.Message); + } + + return false; + } + + private bool AddPercentEncodedOctetToRawOctetsList(Encoding encoding, string escapedCharacter) + { + byte encodedValue; + if (!byte.TryParse(escapedCharacter, NumberStyles.HexNumber, null, out encodedValue)) + { + LogWarning("AddPercentEncodedOctetToRawOctetsList", "Can't convert code point: " + escapedCharacter); + return false; + } + + _rawOctets.Add(encodedValue); + + return true; + } + + private bool EmptyDecodeAndAppendRawOctetsList(Encoding encoding) + { + if (_rawOctets.Count == 0) + { + return true; + } + + string decodedString = null; + try + { + // If the encoding can get a string out of the byte array, this is a valid string in the + // 'encoding' encoding. + byte[] bytes = _rawOctets.ToArray(); + decodedString = encoding.GetString(bytes, 0, bytes.Length); + + if (encoding == Utf8Encoding) + { + AppendOctetsPercentEncoded(_requestUriString, bytes); + } + else + { + AppendOctetsPercentEncoded(_requestUriString, Utf8Encoding.GetBytes(decodedString)); + } + + _rawOctets.Clear(); + + return true; + } + catch (DecoderFallbackException e) + { + LogWarning("EmptyDecodeAndAppendRawOctetsList", "Can't convert bytes: " + GetOctetsAsString(_rawOctets), e.Message); + } + catch (EncoderFallbackException e) + { + // If utf8Encoding.GetBytes() fails + LogWarning("EmptyDecodeAndAppendRawOctetsList", "Can't convert bytes: " + decodedString, e.Message); + } + + return false; + } + + private static void AppendOctetsPercentEncoded(StringBuilder target, IEnumerable octets) + { + foreach (byte octet in octets) + { + target.Append('%'); + target.Append(octet.ToString("X2", CultureInfo.InvariantCulture)); + } + } + + private static string GetOctetsAsString(IEnumerable octets) + { + StringBuilder octetString = new StringBuilder(); + + bool first = true; + foreach (byte octet in octets) + { + if (first) + { + first = false; + } + else + { + octetString.Append(" "); + } + octetString.Append(octet.ToString("X2", CultureInfo.InvariantCulture)); + } + + return octetString.ToString(); + } + + private static string GetPath(string uriString) + { + Debug.Assert(uriString != null, "uriString must not be null"); + Debug.Assert(uriString.Length > 0, "uriString must not be empty"); + + int pathStartIndex = 0; + + // Perf. improvement: nearly all strings are relative Uris. So just look if the + // string starts with '/'. If so, we have a relative Uri and the path starts at position 0. + // (http.sys already trimmed leading whitespaces) + if (uriString[0] != '/') + { + // We can't check against cookedUriScheme, since http.sys allows for request http://myserver/ to + // use a request line 'GET https://myserver/' (note http vs. https). Therefore check if the + // Uri starts with either http:// or https://. + int authorityStartIndex = 0; + if (uriString.StartsWith("http://", StringComparison.OrdinalIgnoreCase)) + { + authorityStartIndex = 7; + } + else if (uriString.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + { + authorityStartIndex = 8; + } + + if (authorityStartIndex > 0) + { + // we have an absolute Uri. Find out where the authority ends and the path begins. + // Note that Uris like "http://server?query=value/1/2" are invalid according to RFC2616 + // and http.sys behavior: If the Uri contains a query, there must be at least one '/' + // between the authority and the '?' character: It's safe to just look for the first + // '/' after the authority to determine the beginning of the path. + pathStartIndex = uriString.IndexOf('/', authorityStartIndex); + if (pathStartIndex == -1) + { + // e.g. for request lines like: 'GET http://myserver' (no final '/') + pathStartIndex = uriString.Length; + } + } + else + { + // RFC2616: Request-URI = "*" | absoluteURI | abs_path | authority + // 'authority' can only be used with CONNECT which is never received by HttpListener. + // I.e. if we don't have an absolute path (must start with '/') and we don't have + // an absolute Uri (must start with http:// or https://), then 'uriString' must be '*'. + Debug.Assert((uriString.Length == 1) && (uriString[0] == '*'), "Unknown request Uri string format; " + + "Request Uri string is not an absolute Uri, absolute path, or '*': " + uriString); + + // Should we ever get here, be consistent with 2.0/3.5 behavior: just add an initial + // slash to the string and treat it as a path: + uriString = "/" + uriString; + } + } + + // Find end of path: The path is terminated by + // - the first '?' character + // - the first '#' character: This is never the case here, since http.sys won't accept + // Uris containing fragments. Also, RFC2616 doesn't allow fragments in request Uris. + // - end of Uri string + int queryIndex = uriString.IndexOf('?'); + if (queryIndex == -1) + { + queryIndex = uriString.Length; + } + + // will always return a != null string. + return AddSlashToAsteriskOnlyPath(uriString.Substring(pathStartIndex, queryIndex - pathStartIndex)); + } + + private static string AddSlashToAsteriskOnlyPath(string path) + { + Debug.Assert(path != null, "'path' must not be null"); + + // If a request like "OPTIONS * HTTP/1.1" is sent to the listener, then the request Uri + // should be "http[s]://server[:port]/*" to be compatible with pre-4.0 behavior. + if ((path.Length == 1) && (path[0] == '*')) + { + return "/*"; + } + + return path; + } + + private void LogWarning(string methodName, string message, params object[] args) + { + // TODO: Verbose log + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Response.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Response.cs new file mode 100644 index 0000000000..d25af63ed7 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Response.cs @@ -0,0 +1,756 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal sealed unsafe class Response : IDisposable + { + private ResponseState _responseState; + private IDictionary _headers; + private ResponseStream _responseStream; + private long _contentLength; + private BoundaryType _boundaryType; + private UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE _nativeResponse; + private IList, object>> _onSendingHeadersActions; + + private RequestContext _requestContext; + + internal Response(RequestContext httpContext) + { + // TODO: Verbose log + _requestContext = httpContext; + _nativeResponse = new UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE(); + _headers = new Dictionary(StringComparer.OrdinalIgnoreCase); + _boundaryType = BoundaryType.None; + _nativeResponse.StatusCode = (ushort)HttpStatusCode.OK; + _nativeResponse.Version.MajorVersion = 1; + _nativeResponse.Version.MinorVersion = 1; + _responseState = ResponseState.Created; + _onSendingHeadersActions = new List, object>>(); + } + + private enum ResponseState + { + Created, + ComputedHeaders, + SentHeaders, + Closed, + } + + private RequestContext RequestContext + { + get + { + return _requestContext; + } + } + + private Request Request + { + get + { + return RequestContext.Request; + } + } + + internal Stream OutputStream + { + get + { + CheckDisposed(); + EnsureResponseStream(); + return _responseStream; + } + } + + internal int GetStatusCode() + { + int statusCode = _requestContext.Environment.ResponseStatusCode ?? 200; + if (statusCode <= 100 || statusCode > 999) + { + // TODO: Move this validation to the dictionary facade so it throws when the app sets it, rather than durring send? + throw new InvalidOperationException(string.Format(Resources.Exception_InvalidStatusCode, statusCode)); + } + return statusCode; + } + + internal string GetReasonPhrase(int statusCode) + { + // TODO: Validate user input for illegal chars, length limit, etc.? + string reasonPhrase = _requestContext.Environment.ResponseReasonPhrase; + if (string.IsNullOrWhiteSpace(reasonPhrase)) + { + // if the user hasn't set this, generated on the fly, if possible. + // We know this one is safe, no need to verify it as in the setter. + reasonPhrase = HttpReasonPhrase.Get(statusCode) ?? string.Empty; + } + return reasonPhrase; + } + + // We MUST NOT send message-body when we send responses with these Status codes + private static readonly int[] NoResponseBody = { 100, 101, 204, 205, 304 }; + + private static bool CanSendResponseBody(int responseCode) + { + for (int i = 0; i < NoResponseBody.Length; i++) + { + if (responseCode == NoResponseBody[i]) + { + return false; + } + } + return true; + } + + internal EntitySendFormat EntitySendFormat + { + get + { + return (EntitySendFormat)_boundaryType; + } + } + + internal IDictionary Headers + { + get + { + return _headers; + } + } + + internal long ContentLength64 + { + get + { + return _contentLength; + } + } + + private Version GetProtocolVersion() + { + Version requestVersion = Request.ProtocolVersion; + Version responseVersion = requestVersion; + string protocolVersion = RequestContext.Environment.Get(Constants.HttpResponseProtocolKey); + + // Optional + if (!string.IsNullOrWhiteSpace(protocolVersion)) + { + if (string.Equals("HTTP/1.1", protocolVersion, StringComparison.OrdinalIgnoreCase)) + { + responseVersion = Constants.V1_1; + } + if (string.Equals("HTTP/1.0", protocolVersion, StringComparison.OrdinalIgnoreCase)) + { + responseVersion = Constants.V1_0; + } + else + { + // TODO: Just log? It's too late to get this to user code. + throw new ArgumentException(string.Empty, Constants.HttpResponseProtocolKey); + } + } + + if (requestVersion == responseVersion) + { + return requestVersion; + } + + // Return the lesser of the two versions. There are only two, so it it will always be 1.0. + return Constants.V1_0; + } + + public void Dispose() + { + Dispose(true); + } + + private void Dispose(bool disposing) + { + if (disposing) + { + if (_responseState >= ResponseState.Closed) + { + return; + } + // TODO: Verbose log + EnsureResponseStream(); + _responseStream.Dispose(); + _responseState = ResponseState.Closed; + } + } + + // old API, now private, and helper methods + + internal BoundaryType BoundaryType + { + get + { + return _boundaryType; + } + } + + internal bool SentHeaders + { + get + { + return _responseState >= ResponseState.SentHeaders; + } + } + + internal bool ComputedHeaders + { + get + { + return _responseState >= ResponseState.ComputedHeaders; + } + } + + private void EnsureResponseStream() + { + if (_responseStream == null) + { + _responseStream = new ResponseStream(RequestContext); + } + } + + /* + 12.3 + HttpSendHttpResponse() and HttpSendResponseEntityBody() Flag Values. + The following flags can be used on calls to HttpSendHttpResponse() and HttpSendResponseEntityBody() API calls: + + #define HTTP_SEND_RESPONSE_FLAG_DISCONNECT 0x00000001 + #define HTTP_SEND_RESPONSE_FLAG_MORE_DATA 0x00000002 + #define HTTP_SEND_RESPONSE_FLAG_RAW_HEADER 0x00000004 + #define HTTP_SEND_RESPONSE_FLAG_VALID 0x00000007 + + HTTP_SEND_RESPONSE_FLAG_DISCONNECT: + specifies that the network connection should be disconnected immediately after + sending the response, overriding the HTTP protocol's persistent connection features. + HTTP_SEND_RESPONSE_FLAG_MORE_DATA: + specifies that additional entity body data will be sent by the caller. Thus, + the last call HttpSendResponseEntityBody for a RequestId, will have this flag reset. + HTTP_SEND_RESPONSE_RAW_HEADER: + specifies that a caller of HttpSendResponseEntityBody() is intentionally omitting + a call to HttpSendHttpResponse() in order to bypass normal header processing. The + actual HTTP header will be generated by the application and sent as entity body. + This flag should be passed on the first call to HttpSendResponseEntityBody, and + not after. Thus, flag is not applicable to HttpSendHttpResponse. + */ + + // TODO: Consider using HTTP_SEND_RESPONSE_RAW_HEADER with HttpSendResponseEntityBody instead of calling HttpSendHttpResponse. + // This will give us more control of the bytes that hit the wire, including encodings, HTTP 1.0, etc.. + // It may also be faster to do this work in managed code and then pass down only one buffer. + // What would we loose by bypassing HttpSendHttpResponse? + // + // TODO: Consider using the HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA flag for most/all responses rather than just Opaque. + internal unsafe uint SendHeaders(UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK* pDataChunk, + ResponseStreamAsyncResult asyncResult, + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags, + bool isOpaqueUpgrade) + { + Debug.Assert(!SentHeaders, "HttpListenerResponse::SendHeaders()|SentHeaders is true."); + + // TODO: Verbose log headers + _responseState = ResponseState.SentHeaders; + + _nativeResponse.StatusCode = (ushort)GetStatusCode(); + string reasonPhrase = GetReasonPhrase(_nativeResponse.StatusCode); + + /* + if (m_BoundaryType==BoundaryType.Raw) { + use HTTP_SEND_RESPONSE_FLAG_RAW_HEADER; + } + */ + uint statusCode; + uint bytesSent; + List pinnedHeaders = SerializeHeaders(ref _nativeResponse.Headers, isOpaqueUpgrade); + try + { + if (pDataChunk != null) + { + _nativeResponse.EntityChunkCount = 1; + _nativeResponse.pEntityChunks = pDataChunk; + } + else if (asyncResult != null && asyncResult.DataChunks != null) + { + _nativeResponse.EntityChunkCount = asyncResult.DataChunkCount; + _nativeResponse.pEntityChunks = asyncResult.DataChunks; + } + else + { + _nativeResponse.EntityChunkCount = 0; + _nativeResponse.pEntityChunks = null; + } + + if (reasonPhrase.Length > 0) + { + byte[] reasonPhraseBytes = new byte[HeaderEncoding.GetByteCount(reasonPhrase)]; + fixed (byte* pReasonPhrase = reasonPhraseBytes) + { + _nativeResponse.ReasonLength = (ushort)reasonPhraseBytes.Length; + HeaderEncoding.GetBytes(reasonPhrase, 0, reasonPhraseBytes.Length, reasonPhraseBytes, 0); + _nativeResponse.pReason = (sbyte*)pReasonPhrase; + fixed (UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE* pResponse = &_nativeResponse) + { + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendHttpResponse( + RequestContext.RequestQueueHandle, + Request.RequestId, + (uint)flags, + pResponse, + null, + &bytesSent, + SafeLocalFree.Zero, + 0, + asyncResult == null ? SafeNativeOverlapped.Zero : asyncResult.NativeOverlapped, + IntPtr.Zero); + + if (asyncResult != null && + statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + OwinWebListener.SkipIOCPCallbackOnSuccess) + { + asyncResult.BytesSent = bytesSent; + // The caller will invoke IOCompleted + } + } + } + } + else + { + fixed (UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE* pResponse = &_nativeResponse) + { + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendHttpResponse( + RequestContext.RequestQueueHandle, + Request.RequestId, + (uint)flags, + pResponse, + null, + &bytesSent, + SafeLocalFree.Zero, + 0, + asyncResult == null ? SafeNativeOverlapped.Zero : asyncResult.NativeOverlapped, + IntPtr.Zero); + + if (asyncResult != null && + statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + OwinWebListener.SkipIOCPCallbackOnSuccess) + { + asyncResult.BytesSent = bytesSent; + // The caller will invoke IOCompleted + } + } + } + } + finally + { + FreePinnedHeaders(pinnedHeaders); + } + return statusCode; + } + + internal UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS ComputeHeaders(bool endOfRequest = false) + { + // Notify that this is absolutely the last chance to make changes. + NotifyOnSendingHeaders(); + + // 401 + if (GetStatusCode() == (ushort)HttpStatusCode.Unauthorized) + { + RequestContext.Server.AuthenticationManager.SetAuthenticationChallenge(this); + } + + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + Debug.Assert(!ComputedHeaders, "HttpListenerResponse::ComputeHeaders()|ComputedHeaders is true."); + _responseState = ResponseState.ComputedHeaders; + /* + // here we would check for BoundaryType.Raw, in this case we wouldn't need to do anything + if (m_BoundaryType==BoundaryType.Raw) { + return flags; + } + */ + + // Check the response headers to determine the correct keep alive and boundary type. + Version responseVersion = GetProtocolVersion(); + _nativeResponse.Version.MajorVersion = (ushort)responseVersion.Major; + _nativeResponse.Version.MinorVersion = (ushort)responseVersion.Minor; + bool keepAlive = responseVersion >= Constants.V1_1; + string connectionString = Headers.Get(HttpKnownHeaderNames.Connection); + string keepAliveString = Headers.Get(HttpKnownHeaderNames.KeepAlive); + bool closeSet = false; + bool keepAliveSet = false; + + if (!string.IsNullOrWhiteSpace(connectionString) && string.Equals("close", connectionString.Trim(), StringComparison.OrdinalIgnoreCase)) + { + keepAlive = false; + closeSet = true; + } + else if (!string.IsNullOrWhiteSpace(keepAliveString) && string.Equals("true", keepAliveString.Trim(), StringComparison.OrdinalIgnoreCase)) + { + keepAlive = true; + keepAliveSet = true; + } + + // Content-Length takes priority + string contentLengthString = Headers.Get(HttpKnownHeaderNames.ContentLength); + string transferEncodingString = Headers.Get(HttpKnownHeaderNames.TransferEncoding); + + if (responseVersion == Constants.V1_0 && !string.IsNullOrEmpty(transferEncodingString) + && string.Equals("chunked", transferEncodingString.Trim(), StringComparison.OrdinalIgnoreCase)) + { + // A 1.0 client can't process chunked responses. + Headers.Remove(HttpKnownHeaderNames.TransferEncoding); + transferEncodingString = null; + } + + if (!string.IsNullOrWhiteSpace(contentLengthString)) + { + contentLengthString = contentLengthString.Trim(); + if (string.Equals("0", contentLengthString, StringComparison.Ordinal)) + { + _boundaryType = BoundaryType.ContentLength; + _contentLength = 0; + flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + } + else if (long.TryParse(contentLengthString, NumberStyles.None, CultureInfo.InvariantCulture.NumberFormat, out _contentLength)) + { + _boundaryType = BoundaryType.ContentLength; + if (_contentLength == 0) + { + flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + } + } + else + { + _boundaryType = BoundaryType.Invalid; + } + } + else if (!string.IsNullOrWhiteSpace(transferEncodingString) + && string.Equals("chunked", transferEncodingString.Trim(), StringComparison.OrdinalIgnoreCase)) + { + // Then Transfer-Encoding: chunked + _boundaryType = BoundaryType.Chunked; + } + else if (endOfRequest) + { + // The request is ending without a body, add a Content-Length: 0 header. + Headers[HttpKnownHeaderNames.ContentLength] = new string[] { "0" }; + _boundaryType = BoundaryType.ContentLength; + _contentLength = 0; + flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + } + else + { + // Then fall back to Connection:Close transparent mode. + _boundaryType = BoundaryType.None; + flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; // seems like HTTP_SEND_RESPONSE_FLAG_MORE_DATA but this hangs the app; + if (responseVersion == Constants.V1_0) + { + keepAlive = false; + } + else + { + Headers[HttpKnownHeaderNames.TransferEncoding] = new string[] { "chunked" }; + _boundaryType = BoundaryType.Chunked; + } + + if (CanSendResponseBody(_requestContext.Response.GetStatusCode())) + { + _contentLength = -1; + } + else + { + Headers[HttpKnownHeaderNames.ContentLength] = new string[] { "0" }; + _contentLength = 0; + _boundaryType = BoundaryType.ContentLength; + } + } + + // Also, Keep-Alive vs Connection Close + + if (!keepAlive) + { + if (!closeSet) + { + Headers.Append(HttpKnownHeaderNames.Connection, "close"); + } + if (flags == UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE) + { + flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_DISCONNECT; + } + } + else + { + if (Request.ProtocolVersion.Minor == 0 && !keepAliveSet) + { + Headers[HttpKnownHeaderNames.KeepAlive] = new string[] { "true" }; + } + } + return flags; + } + + private List SerializeHeaders(ref UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE_HEADERS headers, + bool isOpaqueUpgrade) + { + UnsafeNclNativeMethods.HttpApi.HTTP_UNKNOWN_HEADER[] unknownHeaders = null; + List pinnedHeaders; + GCHandle gcHandle; + /* + // here we would check for BoundaryType.Raw, in this case we wouldn't need to do anything + if (m_BoundaryType==BoundaryType.Raw) { + return null; + } + */ + if (Headers.Count == 0) + { + return null; + } + string headerName; + string headerValue; + int lookup; + byte[] bytes = null; + pinnedHeaders = new List(); + + //--------------------------------------------------- + // DTS Issue: 609383: + // The Set-Cookie headers are being merged into one. + // There are two issues here. + // 1. When Set-Cookie headers are set through SetCookie method on the ListenerResponse, + // there is code in the SetCookie method and the methods it calls to flatten the Set-Cookie + // values. This blindly concatenates the cookies with a comma delimiter. There could be + // a cookie value that contains comma, but we don't escape it with %XX value + // + // As an alternative users can add the Set-Cookie header through the AddHeader method + // like ListenerResponse.Headers.Add("name", "value") + // That way they can add multiple headers - AND They can format the value like they want it. + // + // 2. Now that the header collection contains multiple Set-Cookie name, value pairs + // you would think the problem would go away. However here is an interesting thing. + // For NameValueCollection, when you add + // "Set-Cookie", "value1" + // "Set-Cookie", "value2" + // The NameValueCollection.Count == 1. Because there is only one key + // NameValueCollection.Get("Set-Cookie") would conveniently take these two values + // concatenate them with a comma like + // value1,value2. + // In order to get individual values, you need to use + // string[] values = NameValueCollection.GetValues("Set-Cookie"); + // + // ------------------------------------------------------------- + // So here is the proposed fix here. + // We must first to loop through all the NameValueCollection keys + // and if the name is a unknown header, we must compute the number of + // values it has. Then, we should allocate that many unknown header array + // elements. + // + // Note that a part of the fix here is to treat Set-Cookie as an unknown header + // + // + //----------------------------------------------------------- + int numUnknownHeaders = 0; + foreach (KeyValuePair headerPair in Headers) + { + // See if this is an unknown header + lookup = UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE_HEADER_ID.IndexOfKnownHeader(headerPair.Key); + + // TODO: WWW-Authentiate header + // TODO: HTTP_RESPONSE_V2 has a HTTP_MULTIPLE_KNOWN_HEADERS option where you can supply multiple values. + + // TODO: Consider any 'known' header that has multiple values as 'unknown'? + // Treat Set-Cookie as well as Connection header in opaque mode as unknown + if (lookup == (int)HttpSysResponseHeader.SetCookie || + (isOpaqueUpgrade && lookup == (int)HttpSysResponseHeader.Connection)) + { + lookup = -1; + } + + if (lookup == -1) + { + numUnknownHeaders += headerPair.Value.Length; + } + } + + try + { + fixed (UnsafeNclNativeMethods.HttpApi.HTTP_KNOWN_HEADER* pKnownHeaders = &headers.KnownHeaders) + { + foreach (KeyValuePair headerPair in Headers) + { + headerName = headerPair.Key; + lookup = UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE_HEADER_ID.IndexOfKnownHeader(headerName); + if (lookup == (int)HttpSysResponseHeader.SetCookie || + (isOpaqueUpgrade && lookup == (int)HttpSysResponseHeader.Connection)) + { + lookup = -1; + } + + if (lookup == -1) + { + if (unknownHeaders == null) + { + //---------------------------------------- + // *** This following comment is no longer true *** + // we waste some memory here (up to 32*41=1312 bytes) but we gain speed + // unknownHeaders = new UnsafeNclNativeMethods.HttpApi.HTTP_UNKNOWN_HEADER[Headers.Count-index]; + //-------------------------------------------- + unknownHeaders = new UnsafeNclNativeMethods.HttpApi.HTTP_UNKNOWN_HEADER[numUnknownHeaders]; + gcHandle = GCHandle.Alloc(unknownHeaders, GCHandleType.Pinned); + pinnedHeaders.Add(gcHandle); + headers.pUnknownHeaders = (UnsafeNclNativeMethods.HttpApi.HTTP_UNKNOWN_HEADER*)gcHandle.AddrOfPinnedObject(); + } + + //---------------------------------------- + // FOR UNKNOWN HEADERS + // ALLOW MULTIPLE HEADERS to be added + //--------------------------------------- + string[] headerValues = headerPair.Value; + for (int headerValueIndex = 0; headerValueIndex < headerValues.Length; headerValueIndex++) + { + // Add Name + bytes = new byte[HeaderEncoding.GetByteCount(headerName)]; + unknownHeaders[headers.UnknownHeaderCount].NameLength = (ushort)bytes.Length; + HeaderEncoding.GetBytes(headerName, 0, bytes.Length, bytes, 0); + gcHandle = GCHandle.Alloc(bytes, GCHandleType.Pinned); + pinnedHeaders.Add(gcHandle); + unknownHeaders[headers.UnknownHeaderCount].pName = (sbyte*)gcHandle.AddrOfPinnedObject(); + + // Add Value + headerValue = headerValues[headerValueIndex]; + bytes = new byte[HeaderEncoding.GetByteCount(headerValue)]; + unknownHeaders[headers.UnknownHeaderCount].RawValueLength = (ushort)bytes.Length; + HeaderEncoding.GetBytes(headerValue, 0, bytes.Length, bytes, 0); + gcHandle = GCHandle.Alloc(bytes, GCHandleType.Pinned); + pinnedHeaders.Add(gcHandle); + unknownHeaders[headers.UnknownHeaderCount].pRawValue = (sbyte*)gcHandle.AddrOfPinnedObject(); + headers.UnknownHeaderCount++; + } + } + else + { + string[] headerValues = headerPair.Value; + headerValue = headerValues.Length == 1 ? headerValues[0] : string.Join(", ", headerValues); + if (headerValue != null) + { + bytes = new byte[HeaderEncoding.GetByteCount(headerValue)]; + pKnownHeaders[lookup].RawValueLength = (ushort)bytes.Length; + HeaderEncoding.GetBytes(headerValue, 0, bytes.Length, bytes, 0); + gcHandle = GCHandle.Alloc(bytes, GCHandleType.Pinned); + pinnedHeaders.Add(gcHandle); + pKnownHeaders[lookup].pRawValue = (sbyte*)gcHandle.AddrOfPinnedObject(); + } + } + } + } + } + catch + { + FreePinnedHeaders(pinnedHeaders); + throw; + } + return pinnedHeaders; + } + + private static void FreePinnedHeaders(List pinnedHeaders) + { + if (pinnedHeaders != null) + { + foreach (GCHandle gcHandle in pinnedHeaders) + { + if (gcHandle.IsAllocated) + { + gcHandle.Free(); + } + } + } + } + + // Subset of ComputeHeaders + internal void SendOpaqueUpgrade() + { + // TODO: Should we do this notification earlier when you still have a chance to change the status code to avoid an upgrade? + // Notify that this is absolutely the last chance to make changes. + NotifyOnSendingHeaders(); + + // TODO: Send headers async? + ulong errorCode = SendHeaders(null, null, + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_OPAQUE | + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA | + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA, + true); + + if (errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + throw new WebListenerException((int)errorCode); + } + } + + private void CheckDisposed() + { + if (_responseState >= ResponseState.Closed) + { + throw new ObjectDisposedException(this.GetType().FullName); + } + } + + internal void CancelLastWrite(SafeHandle requestQueueHandle) + { + if (_responseStream != null) + { + _responseStream.CancelLastWrite(requestQueueHandle); + } + } + + internal Task SendFileAsync(string fileName, long offset, long? count, CancellationToken cancel) + { + EnsureResponseStream(); + return _responseStream.SendFileAsync(fileName, offset, count, cancel); + } + + internal void SwitchToOpaqueMode() + { + EnsureResponseStream(); + _responseStream.SwitchToOpaqueMode(); + } + + internal void RegisterForOnSendingHeaders(Action callback, object state) + { + IList, object>> actions = _onSendingHeadersActions; + if (actions == null) + { + throw new InvalidOperationException("Headers already sent"); + } + + actions.Add(new Tuple, object>(callback, state)); + } + + private void NotifyOnSendingHeaders() + { + var actions = Interlocked.Exchange(ref _onSendingHeadersActions, null); + if (actions == null) + { + // Something threw the first time, do not try again. + return; + } + + // Execute last to first. This mimics a stack unwind. + foreach (var actionPair in actions.Reverse()) + { + actionPair.Item1(actionPair.Item2); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStream.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStream.cs new file mode 100644 index 0000000000..0177312901 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStream.cs @@ -0,0 +1,826 @@ +// ------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class ResponseStream : Stream + { + private static readonly byte[] ChunkTerminator = new byte[] { (byte)'0', (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n' }; + + private RequestContext _requestContext; + private long _leftToWrite = long.MinValue; + private bool _closed; + private bool _inOpaqueMode; + // The last write needs special handling to cancel. + private ResponseStreamAsyncResult _lastWrite; + + internal ResponseStream(RequestContext requestContext) + { + _requestContext = requestContext; + } + + public override bool CanSeek + { + get + { + return false; + } + } + + public override bool CanWrite + { + get + { + return true; + } + } + + public override bool CanRead + { + get + { + return false; + } + } + + public override long Length + { + get + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + } + + public override long Position + { + get + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + set + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + } + + // Send headers + public override void Flush() + { + if (_closed || _requestContext.Response.SentHeaders) + { + return; + } + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); + // TODO: Verbose log + + try + { + uint statusCode; + unsafe + { + flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + statusCode = _requestContext.Response.SendHeaders(null, null, flags, false); + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + throw new WebListenerException((int)statusCode); + } + } + catch (Exception e) + { + LogHelper.LogException(_requestContext.Logger, "Flush", e); + _closed = true; + _requestContext.Abort(); + throw; + } + } + + // Send headers + public override Task FlushAsync(CancellationToken cancellationToken) + { + if (_closed || _requestContext.Response.SentHeaders) + { + return Helpers.CompletedTask(); + } + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); + // TODO: Verbose log + + // TODO: Real cancellation + cancellationToken.ThrowIfCancellationRequested(); + + flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, null, 0, 0, _requestContext.Response.BoundaryType == BoundaryType.Chunked, false); + + try + { + uint statusCode; + unsafe + { + statusCode = _requestContext.Response.SendHeaders(null, asyncResult, flags, false); + } + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + asyncResult.IOCompleted(statusCode); + } + else if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + throw new WebListenerException((int)statusCode); + } + } + catch (Exception e) + { + LogHelper.LogException(_requestContext.Logger, "FlushAsync", e); + asyncResult.Dispose(); + _closed = true; + _requestContext.Abort(); + throw; + } + + return asyncResult.Task; + } + + #region NotSupported Read/Seek + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + + public override int Read([In, Out] byte[] buffer, int offset, int size) + { + throw new InvalidOperationException(Resources.Exception_WriteOnlyStream); + } + +#if NET45 + public override IAsyncResult BeginRead(byte[] buffer, int offset, int size, AsyncCallback callback, object state) + { + throw new InvalidOperationException(Resources.Exception_WriteOnlyStream); + } + + public override int EndRead(IAsyncResult asyncResult) + { + throw new InvalidOperationException(Resources.Exception_WriteOnlyStream); + } +#endif + + #endregion + + private UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS ComputeLeftToWrite(bool endOfRequest = false) + { + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + if (!_requestContext.Response.ComputedHeaders) + { + flags = _requestContext.Response.ComputeHeaders(endOfRequest: endOfRequest); + } + if (_leftToWrite == long.MinValue) + { + UnsafeNclNativeMethods.HttpApi.HTTP_VERB method = _requestContext.GetKnownMethod(); + if (method == UnsafeNclNativeMethods.HttpApi.HTTP_VERB.HttpVerbHEAD) + { + _leftToWrite = 0; + } + else if (_requestContext.Response.EntitySendFormat == EntitySendFormat.ContentLength) + { + _leftToWrite = _requestContext.Response.ContentLength64; + } + else + { + _leftToWrite = -1; // unlimited + } + } + return flags; + } + + public override unsafe void Write(byte[] buffer, int offset, int size) + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException("size"); + } + if (_closed) + { + throw new ObjectDisposedException(GetType().FullName); + } + // TODO: Verbose log parameters + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); + if (size == 0 && _leftToWrite != 0) + { + return; + } + if (_leftToWrite >= 0 && size > _leftToWrite) + { + throw new InvalidOperationException(Resources.Exception_TooMuchWritten); + } + // TODO: Verbose log + + uint statusCode; + uint dataToWrite = (uint)size; + SafeLocalFree bufferAsIntPtr = null; + IntPtr pBufferAsIntPtr = IntPtr.Zero; + bool sentHeaders = _requestContext.Response.SentHeaders; + try + { + if (size == 0) + { + // TODO: Is this code path accessible? Is this like a Flush? + statusCode = _requestContext.Response.SendHeaders(null, null, flags, false); + } + else + { + fixed (byte* pDataBuffer = buffer) + { + byte* pBuffer = pDataBuffer; + if (_requestContext.Response.BoundaryType == BoundaryType.Chunked) + { + // TODO: + // here we need some heuristics, some time it is definitely better to split this in 3 write calls + // but for small writes it is probably good enough to just copy the data internally. + string chunkHeader = size.ToString("x", CultureInfo.InvariantCulture); + dataToWrite = dataToWrite + (uint)(chunkHeader.Length + 4); + bufferAsIntPtr = SafeLocalFree.LocalAlloc((int)dataToWrite); + pBufferAsIntPtr = bufferAsIntPtr.DangerousGetHandle(); + for (int i = 0; i < chunkHeader.Length; i++) + { + Marshal.WriteByte(pBufferAsIntPtr, i, (byte)chunkHeader[i]); + } + Marshal.WriteInt16(pBufferAsIntPtr, chunkHeader.Length, 0x0A0D); + Marshal.Copy(buffer, offset, IntPtrHelper.Add(pBufferAsIntPtr, chunkHeader.Length + 2), size); + Marshal.WriteInt16(pBufferAsIntPtr, (int)(dataToWrite - 2), 0x0A0D); + pBuffer = (byte*)pBufferAsIntPtr; + offset = 0; + } + UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK dataChunk = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + dataChunk.DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + dataChunk.fromMemory.pBuffer = (IntPtr)(pBuffer + offset); + dataChunk.fromMemory.BufferLength = dataToWrite; + + flags |= _leftToWrite == size ? UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE : UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + if (!sentHeaders) + { + statusCode = _requestContext.Response.SendHeaders(&dataChunk, null, flags, false); + } + else + { + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendResponseEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + (uint)flags, + 1, + &dataChunk, + null, + SafeLocalFree.Zero, + 0, + SafeNativeOverlapped.Zero, + IntPtr.Zero); + + if (_requestContext.Server.IgnoreWriteExceptions) + { + statusCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS; + } + } + } + } + } + finally + { + if (bufferAsIntPtr != null) + { + // free unmanaged buffer + bufferAsIntPtr.Dispose(); + } + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "Write", exception); + _closed = true; + _requestContext.Abort(); + throw exception; + } + UpdateWritenCount(dataToWrite); + + // TODO: Verbose log data written + } +#if NET45 + public override unsafe IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback callback, object state) +#else + public unsafe IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback callback, object state) +#endif + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException("size"); + } + if (_closed) + { + throw new ObjectDisposedException(GetType().FullName); + } + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); + if (size == 0 && _leftToWrite != 0) + { + ResponseStreamAsyncResult result = new ResponseStreamAsyncResult(this, state, callback); + result.Complete(); + return result; + } + if (_leftToWrite >= 0 && size > _leftToWrite) + { + throw new InvalidOperationException(Resources.Exception_TooMuchWritten); + } + // TODO: Verbose log parameters + + uint statusCode; + uint bytesSent = 0; + flags |= _leftToWrite == size ? UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE : UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + bool sentHeaders = _requestContext.Response.SentHeaders; + ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, state, callback, buffer, offset, size, _requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders); + + // Update m_LeftToWrite now so we can queue up additional BeginWrite's without waiting for EndWrite. + UpdateWritenCount((uint)((_requestContext.Response.BoundaryType == BoundaryType.Chunked) ? 0 : size)); + + try + { + if (!sentHeaders) + { + statusCode = _requestContext.Response.SendHeaders(null, asyncResult, flags, false); + bytesSent = asyncResult.BytesSent; + } + else + { + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendResponseEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + (uint)flags, + asyncResult.DataChunkCount, + asyncResult.DataChunks, + &bytesSent, + SafeLocalFree.Zero, + 0, + asyncResult.NativeOverlapped, + IntPtr.Zero); + } + } + catch (Exception e) + { + LogHelper.LogException(_requestContext.Logger, "BeginWrite", e); + asyncResult.Dispose(); + _closed = true; + _requestContext.Abort(); + throw; + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + asyncResult.Dispose(); + if (_requestContext.Server.IgnoreWriteExceptions && sentHeaders) + { + asyncResult.Complete(); + } + else + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "BeginWrite", exception); + _closed = true; + _requestContext.Abort(); + throw exception; + } + } + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + asyncResult.IOCompleted(statusCode, bytesSent); + } + + // Last write, cache it for special cancelation handling. + if ((flags & UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA) == 0) + { + _lastWrite = asyncResult; + } + + return asyncResult; + } +#if NET45 + public override void EndWrite(IAsyncResult asyncResult) +#else + public void EndWrite(IAsyncResult asyncResult) +#endif + { + if (asyncResult == null) + { + throw new ArgumentNullException("asyncResult"); + } + ResponseStreamAsyncResult castedAsyncResult = asyncResult as ResponseStreamAsyncResult; + if (castedAsyncResult == null || castedAsyncResult.ResponseStream != this) + { + throw new ArgumentException(Resources.Exception_WrongIAsyncResult, "asyncResult"); + } + if (castedAsyncResult.EndCalled) + { + throw new InvalidOperationException(Resources.Exception_EndCalledMultipleTimes); + } + castedAsyncResult.EndCalled = true; + + try + { + // wait & then check for errors + // TODO: Gracefull re-throw + castedAsyncResult.Task.Wait(); + } + catch (Exception exception) + { + LogHelper.LogException(_requestContext.Logger, "EndWrite", exception); + _closed = true; + _requestContext.Abort(); + throw; + } + } + + public override unsafe Task WriteAsync(byte[] buffer, int offset, int size, CancellationToken cancel) + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException("size"); + } + if (_closed) + { + throw new ObjectDisposedException(GetType().FullName); + } + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); + if (size == 0 && _leftToWrite != 0) + { + return Helpers.CompletedTask(); + } + if (_leftToWrite >= 0 && size > _leftToWrite) + { + throw new InvalidOperationException(Resources.Exception_TooMuchWritten); + } + // TODO: Verbose log + + // TODO: Real cancelation + cancel.ThrowIfCancellationRequested(); + + uint statusCode; + uint bytesSent = 0; + flags |= _leftToWrite == size ? UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE : UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + bool sentHeaders = _requestContext.Response.SentHeaders; + ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, buffer, offset, size, _requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders); + + // Update m_LeftToWrite now so we can queue up additional BeginWrite's without waiting for EndWrite. + UpdateWritenCount((uint)((_requestContext.Response.BoundaryType == BoundaryType.Chunked) ? 0 : size)); + + try + { + if (!sentHeaders) + { + statusCode = _requestContext.Response.SendHeaders(null, asyncResult, flags, false); + bytesSent = asyncResult.BytesSent; + } + else + { + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendResponseEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + (uint)flags, + asyncResult.DataChunkCount, + asyncResult.DataChunks, + &bytesSent, + SafeLocalFree.Zero, + 0, + asyncResult.NativeOverlapped, + IntPtr.Zero); + } + } + catch (Exception e) + { + LogHelper.LogException(_requestContext.Logger, "WriteAsync", e); + asyncResult.Dispose(); + _closed = true; + _requestContext.Abort(); + throw; + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + asyncResult.Dispose(); + if (_requestContext.Server.IgnoreWriteExceptions && sentHeaders) + { + asyncResult.Complete(); + } + else + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "WriteAsync", exception); + _closed = true; + _requestContext.Abort(); + throw exception; + } + } + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + asyncResult.IOCompleted(statusCode, bytesSent); + } + + // Last write, cache it for special cancelation handling. + if ((flags & UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA) == 0) + { + _lastWrite = asyncResult; + } + + return asyncResult.Task; + } + + internal unsafe Task SendFileAsync(string fileName, long offset, long? size, CancellationToken cancel) + { + // It's too expensive to validate the file attributes before opening the file. Open the file and then check the lengths. + // This all happens inside of ResponseStreamAsyncResult. + if (string.IsNullOrWhiteSpace(fileName)) + { + throw new ArgumentNullException("fileName"); + } + if (_closed) + { + throw new ObjectDisposedException(GetType().FullName); + } + + // TODO: Real cancellation + cancel.ThrowIfCancellationRequested(); + + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); + if (size == 0 && _leftToWrite != 0) + { + return Helpers.CompletedTask(); + } + if (_leftToWrite >= 0 && size > _leftToWrite) + { + throw new InvalidOperationException(Resources.Exception_TooMuchWritten); + } + // TODO: Verbose log + + uint statusCode; + uint bytesSent = 0; + flags |= _leftToWrite == size ? UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE : UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + bool sentHeaders = _requestContext.Response.SentHeaders; + ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, fileName, offset, size, + _requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders); + + long bytesWritten; + if (_requestContext.Response.BoundaryType == BoundaryType.Chunked) + { + bytesWritten = 0; + } + else if (size.HasValue) + { + bytesWritten = size.Value; + } + else + { + bytesWritten = asyncResult.FileLength - offset; + } + // Update m_LeftToWrite now so we can queue up additional calls to SendFileAsync. + UpdateWritenCount((uint)bytesWritten); + + try + { + if (!sentHeaders) + { + statusCode = _requestContext.Response.SendHeaders(null, asyncResult, flags, false); + bytesSent = asyncResult.BytesSent; + } + else + { + // TODO: If opaque then include the buffer data flag. + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendResponseEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + (uint)flags, + asyncResult.DataChunkCount, + asyncResult.DataChunks, + &bytesSent, + SafeLocalFree.Zero, + 0, + asyncResult.NativeOverlapped, + IntPtr.Zero); + } + } + catch (Exception e) + { + LogHelper.LogException(_requestContext.Logger, "SendFileAsync", e); + asyncResult.Dispose(); + _closed = true; + _requestContext.Abort(); + throw; + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + asyncResult.Dispose(); + if (_requestContext.Server.IgnoreWriteExceptions && sentHeaders) + { + asyncResult.Complete(); + } + else + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "SendFileAsync", exception); + _closed = true; + _requestContext.Abort(); + throw exception; + } + } + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + asyncResult.IOCompleted(statusCode, bytesSent); + } + + // Last write, cache it for special cancellation handling. + if ((flags & UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA) == 0) + { + _lastWrite = asyncResult; + } + + return asyncResult.Task; + } + + private void UpdateWritenCount(uint dataWritten) + { + if (!_inOpaqueMode) + { + if (_leftToWrite > 0) + { + // keep track of the data transferred + _leftToWrite -= dataWritten; + } + if (_leftToWrite == 0) + { + // in this case we already passed 0 as the flag, so we don't need to call HttpSendResponseEntityBody() when we Close() + _closed = true; + } + } + } + + protected override unsafe void Dispose(bool disposing) + { + try + { + if (disposing) + { + if (_closed) + { + return; + } + _closed = true; + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(endOfRequest: true); + if (_leftToWrite > 0 && !_inOpaqueMode) + { + _requestContext.Abort(); + // TODO: Reduce this to a logged warning, it is thrown too late to be visible in user code. + LogHelper.LogError(_requestContext.Logger, "ResponseStream::Dispose", "Fewer bytes were written than were specified in the Content-Length."); + return; + } + bool sentHeaders = _requestContext.Response.SentHeaders; + if (sentHeaders && _leftToWrite == 0) + { + return; + } + + uint statusCode = 0; + if ((_requestContext.Response.BoundaryType == BoundaryType.Chunked || _requestContext.Response.BoundaryType == BoundaryType.None) && (String.Compare(_requestContext.Request.HttpMethod, "HEAD", StringComparison.OrdinalIgnoreCase) != 0)) + { + if (_requestContext.Response.BoundaryType == BoundaryType.None) + { + flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_DISCONNECT; + } + fixed (void* pBuffer = ChunkTerminator) + { + UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK* pDataChunk = null; + if (_requestContext.Response.BoundaryType == BoundaryType.Chunked) + { + UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK dataChunk = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + dataChunk.DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + dataChunk.fromMemory.pBuffer = (IntPtr)pBuffer; + dataChunk.fromMemory.BufferLength = (uint)ChunkTerminator.Length; + pDataChunk = &dataChunk; + } + if (!sentHeaders) + { + statusCode = _requestContext.Response.SendHeaders(pDataChunk, null, flags, false); + } + else + { + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendResponseEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + (uint)flags, + pDataChunk != null ? (ushort)1 : (ushort)0, + pDataChunk, + null, + SafeLocalFree.Zero, + 0, + SafeNativeOverlapped.Zero, + IntPtr.Zero); + + if (_requestContext.Server.IgnoreWriteExceptions) + { + statusCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS; + } + } + } + } + else + { + if (!sentHeaders) + { + statusCode = _requestContext.Response.SendHeaders(null, null, flags, false); + } + } + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF + // Don't throw for disconnects, we were already finished with the response. + && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_CONNECTION_INVALID) + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "Dispose", exception); + _requestContext.Abort(); + throw exception; + } + _leftToWrite = 0; + } + } + finally + { + base.Dispose(disposing); + } + } + + internal void SwitchToOpaqueMode() + { + _inOpaqueMode = true; + _leftToWrite = long.MaxValue; + } + + // The final Content-Length async write can only be cancelled by CancelIoEx. + // Sync can only be cancelled by CancelSynchronousIo, but we don't attempt this right now. + [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = + "It is safe to ignore the return value on a cancel operation because the connection is being closed")] + internal unsafe void CancelLastWrite(SafeHandle requestQueueHandle) + { + ResponseStreamAsyncResult asyncState = _lastWrite; + if (asyncState != null && !asyncState.IsCompleted) + { + UnsafeNclNativeMethods.CancelIoEx(requestQueueHandle, asyncState.NativeOverlapped); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStreamAsyncResult.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStreamAsyncResult.cs new file mode 100644 index 0000000000..ca38ba6355 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStreamAsyncResult.cs @@ -0,0 +1,441 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal unsafe class ResponseStreamAsyncResult : IAsyncResult, IDisposable + { + private static readonly byte[] CRLF = new byte[] { (byte)'\r', (byte)'\n' }; + private static readonly IOCompletionCallback IOCallback = new IOCompletionCallback(Callback); + + private SafeNativeOverlapped _overlapped; + private UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK[] _dataChunks; + private bool _sentHeaders; + private FileStream _fileStream; + private ResponseStream _responseStream; + private TaskCompletionSource _tcs; + private AsyncCallback _callback; + private uint _bytesSent; + + internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback) + { + _responseStream = responseStream; + _tcs = new TaskCompletionSource(userState); + _callback = callback; + } + + internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback, + byte[] buffer, int offset, int size, bool chunked, bool sentHeaders) + : this(responseStream, userState, callback) + { + _sentHeaders = sentHeaders; + Overlapped overlapped = new Overlapped(); + overlapped.AsyncResult = this; + + if (size == 0) + { + _dataChunks = null; + _overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, null)); + } + else + { + _dataChunks = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK[chunked ? 3 : 1]; + + object[] objectsToPin = new object[1 + _dataChunks.Length]; + objectsToPin[_dataChunks.Length] = _dataChunks; + + int chunkHeaderOffset = 0; + byte[] chunkHeaderBuffer = null; + if (chunked) + { + chunkHeaderBuffer = GetChunkHeader(size, out chunkHeaderOffset); + + _dataChunks[0] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[0].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + _dataChunks[0].fromMemory.BufferLength = (uint)(chunkHeaderBuffer.Length - chunkHeaderOffset); + + objectsToPin[0] = chunkHeaderBuffer; + + _dataChunks[1] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[1].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + _dataChunks[1].fromMemory.BufferLength = (uint)size; + + objectsToPin[1] = buffer; + + _dataChunks[2] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[2].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + _dataChunks[2].fromMemory.BufferLength = (uint)CRLF.Length; + + objectsToPin[2] = CRLF; + } + else + { + _dataChunks[0] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[0].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + _dataChunks[0].fromMemory.BufferLength = (uint)size; + + objectsToPin[0] = buffer; + } + + // This call will pin needed memory + _overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, objectsToPin)); + + if (chunked) + { + _dataChunks[0].fromMemory.pBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(chunkHeaderBuffer, chunkHeaderOffset); + _dataChunks[1].fromMemory.pBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(buffer, offset); + _dataChunks[2].fromMemory.pBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(CRLF, 0); + } + else + { + _dataChunks[0].fromMemory.pBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(buffer, offset); + } + } + } + + internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback, + string fileName, long offset, long? size, bool chunked, bool sentHeaders) + : this(responseStream, userState, callback) + { + _sentHeaders = sentHeaders; + Overlapped overlapped = new Overlapped(); + overlapped.AsyncResult = this; + + int bufferSize = 1024 * 64; // TODO: Validate buffer size choice. +#if NET45 + // It's too expensive to validate anything before opening the file. Open the file and then check the lengths. + _fileStream = new FileStream(fileName, FileMode.Open, FileAccess.Read, FileShare.ReadWrite, bufferSize, + FileOptions.Asynchronous | FileOptions.SequentialScan); // Extremely expensive. +#else + _fileStream = new FileStream(fileName, FileMode.Open, FileAccess.Read, FileShare.ReadWrite, bufferSize, useAsync: true); // Extremely expensive. +#endif +#if !NET45 + throw new NotImplementedException(); +#else + long length = _fileStream.Length; // Expensive + if (offset < 0 || offset > length) + { + _fileStream.Dispose(); + throw new ArgumentOutOfRangeException("offset", offset, string.Empty); + } + if (size.HasValue && (size < 0 || size > length - offset)) + { + _fileStream.Dispose(); + throw new ArgumentOutOfRangeException("size", size, string.Empty); + } + + if (size == 0 || (!size.HasValue && _fileStream.Length == 0)) + { + _dataChunks = null; + _overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, null)); + } + else + { + _dataChunks = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK[chunked ? 3 : 1]; + + object[] objectsToPin = new object[_dataChunks.Length]; + objectsToPin[_dataChunks.Length - 1] = _dataChunks; + + int chunkHeaderOffset = 0; + byte[] chunkHeaderBuffer = null; + if (chunked) + { + chunkHeaderBuffer = GetChunkHeader((int)(size ?? _fileStream.Length - offset), out chunkHeaderOffset); + + _dataChunks[0] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[0].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + _dataChunks[0].fromMemory.BufferLength = (uint)(chunkHeaderBuffer.Length - chunkHeaderOffset); + + objectsToPin[0] = chunkHeaderBuffer; + + _dataChunks[1] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[1].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromFileHandle; + _dataChunks[1].fromFile.offset = (ulong)offset; + _dataChunks[1].fromFile.count = (ulong)(size ?? -1); + _dataChunks[1].fromFile.fileHandle = _fileStream.SafeFileHandle.DangerousGetHandle(); + // Nothing to pin for the file handle. + + _dataChunks[2] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[2].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + _dataChunks[2].fromMemory.BufferLength = (uint)CRLF.Length; + + objectsToPin[1] = CRLF; + } + else + { + _dataChunks[0] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[0].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromFileHandle; + _dataChunks[0].fromFile.offset = (ulong)offset; + _dataChunks[0].fromFile.count = (ulong)(size ?? -1); + _dataChunks[0].fromFile.fileHandle = _fileStream.SafeFileHandle.DangerousGetHandle(); + } + + // This call will pin needed memory + _overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, objectsToPin)); + + if (chunked) + { + _dataChunks[0].fromMemory.pBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(chunkHeaderBuffer, chunkHeaderOffset); + _dataChunks[2].fromMemory.pBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(CRLF, 0); + } + } +#endif + } + + internal ResponseStream ResponseStream + { + get { return _responseStream; } + } + + internal SafeNativeOverlapped NativeOverlapped + { + get { return _overlapped; } + } + + internal Task Task + { + get { return _tcs.Task; } + } + + internal uint BytesSent + { + get { return _bytesSent; } + set { _bytesSent = value; } + } + + internal ushort DataChunkCount + { + get + { + if (_dataChunks == null) + { + return 0; + } + else + { + return (ushort)_dataChunks.Length; + } + } + } + + internal UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK* DataChunks + { + get + { + if (_dataChunks == null) + { + return null; + } + else + { + return (UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK*)(Marshal.UnsafeAddrOfPinnedArrayElement(_dataChunks, 0)); + } + } + } + + internal long FileLength + { + get { return _fileStream == null ? 0 : _fileStream.Length; } + } + + internal bool EndCalled { get; set; } + + internal void IOCompleted(uint errorCode) + { + IOCompleted(this, errorCode, BytesSent); + } + + internal void IOCompleted(uint errorCode, uint numBytes) + { + IOCompleted(this, errorCode, numBytes); + } + + [SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Redirecting to callback")] + private static void IOCompleted(ResponseStreamAsyncResult asyncResult, uint errorCode, uint numBytes) + { + try + { + if (errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + asyncResult.Fail(new WebListenerException((int)errorCode)); + } + else + { + if (asyncResult._dataChunks == null) + { + // TODO: Verbose log data written + } + else + { + // TODO: Verbose log + // for (int i = 0; i < asyncResult._dataChunks.Length; i++) + // { + // Logging.Dump(Logging.HttpListener, asyncResult, "Callback", (IntPtr)asyncResult._dataChunks[0].fromMemory.pBuffer, (int)asyncResult._dataChunks[0].fromMemory.BufferLength); + // } + } + asyncResult.Complete(); + } + } + catch (Exception e) + { + asyncResult.Fail(e); + } + } + + private static unsafe void Callback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped) + { + Overlapped callbackOverlapped = Overlapped.Unpack(nativeOverlapped); + ResponseStreamAsyncResult asyncResult = callbackOverlapped.AsyncResult as ResponseStreamAsyncResult; + + IOCompleted(asyncResult, errorCode, numBytes); + } + + internal void Complete() + { + if (_tcs.TrySetResult(null) && _callback != null) + { + try + { + _callback(this); + } + catch (Exception) + { + // TODO: Exception handling? This may be an IO callback thread and throwing here could crash the app. + // TODO: Log + } + } + Dispose(); + } + + internal void Fail(Exception ex) + { + if (_tcs.TrySetException(ex) && _callback != null) + { + try + { + _callback(this); + } + catch (Exception) + { + // TODO: Exception handling? This may be an IO callback thread and throwing here could crash the app. + } + } + Dispose(); + } + + /*++ + + GetChunkHeader + + A private utility routine to convert an integer to a chunk header, + which is an ASCII hex number followed by a CRLF. The header is retuned + as a byte array. + + Input: + + size - Chunk size to be encoded + offset - Out parameter where we store offset into buffer. + + Returns: + + A byte array with the header in int. + + --*/ + + private static byte[] GetChunkHeader(int size, out int offset) + { + uint mask = 0xf0000000; + byte[] header = new byte[10]; + int i; + offset = -1; + + // Loop through the size, looking at each nibble. If it's not 0 + // convert it to hex. Save the index of the first non-zero + // byte. + + for (i = 0; i < 8; i++, size <<= 4) + { + // offset == -1 means that we haven't found a non-zero nibble + // yet. If we haven't found one, and the current one is zero, + // don't do anything. + + if (offset == -1) + { + if ((size & mask) == 0) + { + continue; + } + } + + // Either we have a non-zero nibble or we're no longer skipping + // leading zeros. Convert this nibble to ASCII and save it. + + uint temp = (uint)size >> 28; + + if (temp < 10) + { + header[i] = (byte)(temp + '0'); + } + else + { + header[i] = (byte)((temp - 10) + 'A'); + } + + // If we haven't found a non-zero nibble yet, we've found one + // now, so remember that. + + if (offset == -1) + { + offset = i; + } + } + + header[8] = (byte)'\r'; + header[9] = (byte)'\n'; + + return header; + } + + public object AsyncState + { + get { return _tcs.Task.AsyncState; } + } + + public WaitHandle AsyncWaitHandle + { + get { return ((IAsyncResult)_tcs.Task).AsyncWaitHandle; } + } + + public bool CompletedSynchronously + { + get { return ((IAsyncResult)_tcs.Task).CompletedSynchronously; } + } + + public bool IsCompleted + { + get { return _tcs.Task.IsCompleted; } + } + + public void Dispose() + { + if (_overlapped != null) + { + _overlapped.Dispose(); + } + if (_fileStream != null) + { + _fileStream.Dispose(); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/SslStatus.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/SslStatus.cs new file mode 100644 index 0000000000..72118c2d46 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/SslStatus.cs @@ -0,0 +1,15 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum SslStatus : byte + { + Insecure, + NoClientCert, + ClientCert + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/Resources.Designer.cs b/src/Microsoft.AspNet.Server.WebListener/Resources.Designer.cs new file mode 100644 index 0000000000..aafe87fb8d --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/Resources.Designer.cs @@ -0,0 +1,153 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.34006 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class Resources { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal Resources() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Microsoft.AspNet.Server.WebListener.Resources", System.Reflection.IntrospectionExtensions.GetTypeInfo(typeof(Resources)).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to The destination array is too small.. + /// + internal static string Exception_ArrayTooSmall { + get { + return ResourceManager.GetString("Exception_ArrayTooSmall", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to End has already been called.. + /// + internal static string Exception_EndCalledMultipleTimes { + get { + return ResourceManager.GetString("Exception_EndCalledMultipleTimes", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The status code '{0}' is not supported.. + /// + internal static string Exception_InvalidStatusCode { + get { + return ResourceManager.GetString("Exception_InvalidStatusCode", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The stream is not seekable.. + /// + internal static string Exception_NoSeek { + get { + return ResourceManager.GetString("Exception_NoSeek", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The prefix '{0}' is already registered.. + /// + internal static string Exception_PrefixAlreadyRegistered { + get { + return ResourceManager.GetString("Exception_PrefixAlreadyRegistered", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to This stream only supports read operations.. + /// + internal static string Exception_ReadOnlyStream { + get { + return ResourceManager.GetString("Exception_ReadOnlyStream", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to More data written than specified in the Content-Length header.. + /// + internal static string Exception_TooMuchWritten { + get { + return ResourceManager.GetString("Exception_TooMuchWritten", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Only the http and https schemes are supported.. + /// + internal static string Exception_UnsupportedScheme { + get { + return ResourceManager.GetString("Exception_UnsupportedScheme", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to This stream only supports write operations.. + /// + internal static string Exception_WriteOnlyStream { + get { + return ResourceManager.GetString("Exception_WriteOnlyStream", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The given IAsyncResult does not match this opperation.. + /// + internal static string Exception_WrongIAsyncResult { + get { + return ResourceManager.GetString("Exception_WrongIAsyncResult", resourceCulture); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/Resources.resx b/src/Microsoft.AspNet.Server.WebListener/Resources.resx new file mode 100644 index 0000000000..005bafca2f --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/Resources.resx @@ -0,0 +1,150 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + The destination array is too small. + + + End has already been called. + + + The status code '{0}' is not supported. + + + The stream is not seekable. + + + The prefix '{0}' is already registered. + + + This stream only supports read operations. + + + More data written than specified in the Content-Length header. + + + Only the http and https schemes are supported. + + + This stream only supports write operations. + + + The given IAsyncResult does not match this opperation. + + \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/TimeoutManager.cs b/src/Microsoft.AspNet.Server.WebListener/TimeoutManager.cs new file mode 100644 index 0000000000..9e74f6563a --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/TimeoutManager.cs @@ -0,0 +1,267 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Server.WebListener +{ + // See the native HTTP_TIMEOUT_LIMIT_INFO structure documentation for additional information. + // http://msdn.microsoft.com/en-us/library/aa364661.aspx + + /// + /// Exposes the Http.Sys timeout configurations. These may also be configured in the registry. + /// + public sealed class TimeoutManager + { +#if NET45 + private static readonly int TimeoutLimitSize = + Marshal.SizeOf(typeof(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_LIMIT_INFO)); +#else + private static readonly int TimeoutLimitSize = + Marshal.SizeOf(); +#endif + private OwinWebListener _server; + private int[] _timeouts; + private uint _minSendBytesPerSecond; + + internal TimeoutManager(OwinWebListener context) + { + _server = context; + + // We have to maintain local state since we allow applications to set individual timeouts. Native Http + // API for setting timeouts expects all timeout values in every call so we have remember timeout values + // to fill in the blanks. Except MinSendBytesPerSecond, local state for remaining five timeouts is + // maintained in timeouts array. + // + // No initialization is required because a value of zero indicates that system defaults should be used. + _timeouts = new int[5]; + + LoadConfigurationSettings(); + } + + #region Properties + + /// + /// The time, in seconds, allowed for the request entity body to arrive. The default timer is 2 minutes. + /// + /// The HTTP Server API turns on this timer when the request has an entity body. The timer expiration is + /// initially set to the configured value. When the HTTP Server API receives additional data indications on the + /// request, it resets the timer to give the connection another interval. + /// + /// Use TimeSpan.Zero to indicate that system defaults should be used. + /// + public TimeSpan EntityBody + { + get + { + return GetTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.EntityBody); + } + set + { + SetTimespanTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.EntityBody, value); + } + } + + /// + /// The time, in seconds, allowed for the HTTP Server API to drain the entity body on a Keep-Alive connection. + /// The default timer is 2 minutes. + /// + /// On a Keep-Alive connection, after the application has sent a response for a request and before the request + /// entity body has completely arrived, the HTTP Server API starts draining the remainder of the entity body to + /// reach another potentially pipelined request from the client. If the time to drain the remaining entity body + /// exceeds the allowed period the connection is timed out. + /// + /// Use TimeSpan.Zero to indicate that system defaults should be used. + /// + public TimeSpan DrainEntityBody + { + get + { + return GetTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.DrainEntityBody); + } + set + { + SetTimespanTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.DrainEntityBody, value); + } + } + + /// + /// The time, in seconds, allowed for the request to remain in the request queue before the application picks + /// it up. The default timer is 2 minutes. + /// + /// Use TimeSpan.Zero to indicate that system defaults should be used. + /// + public TimeSpan RequestQueue + { + get + { + return GetTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.RequestQueue); + } + set + { + SetTimespanTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.RequestQueue, value); + } + } + + /// + /// The time, in seconds, allowed for an idle connection. The default timer is 2 minutes. + /// + /// This timeout is only enforced after the first request on the connection is routed to the application. + /// + /// Use TimeSpan.Zero to indicate that system defaults should be used. + /// + public TimeSpan IdleConnection + { + get + { + return GetTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.IdleConnection); + } + set + { + SetTimespanTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.IdleConnection, value); + } + } + + /// + /// The time, in seconds, allowed for the HTTP Server API to parse the request header. The default timer is + /// 2 minutes. + /// + /// This timeout is only enforced after the first request on the connection is routed to the application. + /// + /// Use TimeSpan.Zero to indicate that system defaults should be used. + /// + public TimeSpan HeaderWait + { + get + { + return GetTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.HeaderWait); + } + set + { + SetTimespanTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.HeaderWait, value); + } + } + + /// + /// The minimum send rate, in bytes-per-second, for the response. The default response send rate is 150 + /// bytes-per-second. + /// + /// To disable this timer set it to UInt32.MaxValue + /// + public long MinSendBytesPerSecond + { + get + { + // Since we maintain local state, GET is local. + return _minSendBytesPerSecond; + } + set + { + // MinSendRate value is ULONG in native layer. + if (value < 0 || value > uint.MaxValue) + { + throw new ArgumentOutOfRangeException("value"); + } + + SetServerTimeout(_timeouts, (uint)value); + _minSendBytesPerSecond = (uint)value; + } + } + + #endregion Properties + + // Initial values come from the config. The values can then be overridden using this public API. + private void LoadConfigurationSettings() + { + long[] configTimeouts = new long[_timeouts.Length + 1]; // SettingsSectionInternal.Section.HttpListenerTimeouts; + Debug.Assert(configTimeouts != null); + Debug.Assert(configTimeouts.Length == (_timeouts.Length + 1)); + + bool setNonDefaults = false; + for (int i = 0; i < _timeouts.Length; i++) + { + if (configTimeouts[i] != 0) + { + Debug.Assert(configTimeouts[i] <= ushort.MaxValue, "Timeout out of range: " + configTimeouts[i]); + _timeouts[i] = (int)configTimeouts[i]; + setNonDefaults = true; + } + } + + if (configTimeouts[5] != 0) + { + Debug.Assert(configTimeouts[5] <= uint.MaxValue, "Timeout out of range: " + configTimeouts[5]); + _minSendBytesPerSecond = (uint)configTimeouts[5]; + setNonDefaults = true; + } + + if (setNonDefaults) + { + SetServerTimeout(_timeouts, _minSendBytesPerSecond); + } + } + + #region Helpers + + private TimeSpan GetTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE type) + { + // Since we maintain local state, GET is local. + return new TimeSpan(0, 0, (int)_timeouts[(int)type]); + } + + private void SetTimespanTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE type, TimeSpan value) + { + Int64 timeoutValue; + + // All timeouts are defined as USHORT in native layer (except MinSendRate, which is ULONG). Make sure that + // timeout value is within range. + + timeoutValue = Convert.ToInt64(value.TotalSeconds); + + if (timeoutValue < 0 || timeoutValue > ushort.MaxValue) + { + throw new ArgumentOutOfRangeException("value"); + } + + // Use local state to get values for other timeouts. Call into the native layer and if that + // call succeeds, update local state. + + int[] currentTimeouts = _timeouts; + currentTimeouts[(int)type] = (int)timeoutValue; + SetServerTimeout(currentTimeouts, _minSendBytesPerSecond); + _timeouts[(int)type] = (int)timeoutValue; + } + + private unsafe void SetServerTimeout(int[] timeouts, uint minSendBytesPerSecond) + { + UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_LIMIT_INFO timeoutinfo = + new UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_LIMIT_INFO(); + + timeoutinfo.Flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_PROPERTY_FLAG_PRESENT; + timeoutinfo.DrainEntityBody = + (ushort)timeouts[(int)UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.DrainEntityBody]; + timeoutinfo.EntityBody = + (ushort)timeouts[(int)UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.EntityBody]; + timeoutinfo.RequestQueue = + (ushort)timeouts[(int)UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.RequestQueue]; + timeoutinfo.IdleConnection = + (ushort)timeouts[(int)UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.IdleConnection]; + timeoutinfo.HeaderWait = + (ushort)timeouts[(int)UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.HeaderWait]; + timeoutinfo.MinSendRate = minSendBytesPerSecond; + + IntPtr infoptr = new IntPtr(&timeoutinfo); + + _server.SetUrlGroupProperty( + UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_PROPERTY.HttpServerTimeoutsProperty, + infoptr, (uint)TimeoutLimitSize); + } + + #endregion Helpers + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/ValidationHelper.cs b/src/Microsoft.AspNet.Server.WebListener/ValidationHelper.cs new file mode 100644 index 0000000000..104e49e320 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/ValidationHelper.cs @@ -0,0 +1,67 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Globalization; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class ValidationHelper + { + public static string ExceptionMessage(Exception exception) + { + if (exception == null) + { + return string.Empty; + } + if (exception.InnerException == null) + { + return exception.Message; + } + return exception.Message + " (" + ExceptionMessage(exception.InnerException) + ")"; + } + + public static string ToString(object objectValue) + { + if (objectValue == null) + { + return "(null)"; + } + else if (objectValue is string && ((string)objectValue).Length == 0) + { + return "(string.empty)"; + } + else if (objectValue is Exception) + { + return ExceptionMessage(objectValue as Exception); + } + else if (objectValue is IntPtr) + { + return "0x" + ((IntPtr)objectValue).ToString("x"); + } + else + { + return objectValue.ToString(); + } + } + + public static string HashString(object objectValue) + { + if (objectValue == null) + { + return "(null)"; + } + else if (objectValue is string && ((string)objectValue).Length == 0) + { + return "(string.empty)"; + } + else + { + return objectValue.GetHashCode().ToString(NumberFormatInfo.InvariantInfo); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/WebListenerException.cs b/src/Microsoft.AspNet.Server.WebListener/WebListenerException.cs new file mode 100644 index 0000000000..19a7b78a6e --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/WebListenerException.cs @@ -0,0 +1,44 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.ComponentModel; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Server.WebListener +{ +#if NET45 + [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2237:MarkISerializableTypesWithSerializable")] +#endif + internal class WebListenerException : Win32Exception + { + internal WebListenerException() + : base(Marshal.GetLastWin32Error()) + { + } + + internal WebListenerException(int errorCode) + : base(errorCode) + { + } + + internal WebListenerException(int errorCode, string message) + : base(errorCode, message) + { + } + + public override int ErrorCode + { + // the base class returns the HResult with this property + // we need the Win32 Error Code, hence the override. + + get + { + return NativeErrorCode; + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/CriticalHandleZeroOrMinusOneIsInvalid.cs b/src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/CriticalHandleZeroOrMinusOneIsInvalid.cs new file mode 100644 index 0000000000..43897343aa --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/CriticalHandleZeroOrMinusOneIsInvalid.cs @@ -0,0 +1,32 @@ +// ==++== +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ==--== + +#if !NET45 + +namespace Microsoft.Win32.SafeHandles +{ + using System; + using System.Runtime.InteropServices; + using System.Runtime.CompilerServices; + + // Class of critical handle which uses 0 or -1 as an invalid handle. + [System.Security.SecurityCritical] // auto-generated_required + internal abstract class CriticalHandleZeroOrMinusOneIsInvalid : CriticalHandle + { + protected CriticalHandleZeroOrMinusOneIsInvalid() + : base(IntPtr.Zero) + { + } + + public override bool IsInvalid + { + [System.Security.SecurityCritical] + get { return handle == new IntPtr(0) || handle == new IntPtr(-1); } + } + } +} + +#endif diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs b/src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs new file mode 100644 index 0000000000..861c9e31ab --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs @@ -0,0 +1,31 @@ +// ==++== +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ==--== + +#if !NET45 + +namespace Microsoft.Win32.SafeHandles +{ + using System; + using System.Runtime.InteropServices; + using System.Runtime.CompilerServices; + + // Class of safe handle which uses 0 or -1 as an invalid handle. + [System.Security.SecurityCritical] // auto-generated_required + internal abstract class SafeHandleZeroOrMinusOneIsInvalid : SafeHandle + { + protected SafeHandleZeroOrMinusOneIsInvalid(bool ownsHandle) + : base(IntPtr.Zero, ownsHandle) + { + } + + public override bool IsInvalid + { + [System.Security.SecurityCritical] + get { return handle == new IntPtr(0) || handle == new IntPtr(-1); } + } + } +} +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/System/ComponentModel/Win32Exception.cs b/src/Microsoft.AspNet.Server.WebListener/fx/System/ComponentModel/Win32Exception.cs new file mode 100644 index 0000000000..108303ee2c --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/System/ComponentModel/Win32Exception.cs @@ -0,0 +1,112 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 + +using System.Runtime.InteropServices; +using System.Text; + +namespace System.ComponentModel +{ + internal class Win32Exception : ExternalException + { + /// + /// Represents the Win32 error code associated with this exception. This + /// field is read-only. + /// + private readonly int nativeErrorCode; + + /// + /// Initializes a new instance of the class with the last Win32 error + /// that occured. + /// + public Win32Exception() + : this(Marshal.GetLastWin32Error()) + { + } + /// + /// Initializes a new instance of the class with the specified error. + /// + public Win32Exception(int error) + : this(error, GetErrorMessage(error)) + { + } + /// + /// Initializes a new instance of the class with the specified error and the + /// specified detailed description. + /// + public Win32Exception(int error, string message) + : base(message) + { + nativeErrorCode = error; + } + + /// + /// Initializes a new instance of the Exception class with a specified error message. + /// FxCop CA1032: Multiple constructors are required to correctly implement a custom exception. + /// + public Win32Exception(string message) + : this(Marshal.GetLastWin32Error(), message) + { + } + + /// + /// Initializes a new instance of the Exception class with a specified error message and a + /// reference to the inner exception that is the cause of this exception. + /// FxCop CA1032: Multiple constructors are required to correctly implement a custom exception. + /// + public Win32Exception(string message, Exception innerException) + : base(message, innerException) + { + nativeErrorCode = Marshal.GetLastWin32Error(); + } + + + /// + /// Represents the Win32 error code associated with this exception. This + /// field is read-only. + /// + public int NativeErrorCode + { + get + { + return nativeErrorCode; + } + } + + private static string GetErrorMessage(int error) + { + //get the system error message... + string errorMsg = ""; + StringBuilder sb = new StringBuilder(256); + int result = SafeNativeMethods.FormatMessage( + SafeNativeMethods.FORMAT_MESSAGE_IGNORE_INSERTS | + SafeNativeMethods.FORMAT_MESSAGE_FROM_SYSTEM | + SafeNativeMethods.FORMAT_MESSAGE_ARGUMENT_ARRAY, + IntPtr.Zero, (uint)error, 0, sb, sb.Capacity + 1, + null); + if (result != 0) + { + int i = sb.Length; + while (i > 0) + { + char ch = sb[i - 1]; + if (ch > 32 && ch != '.') break; + i--; + } + errorMsg = sb.ToString(0, i); + } + else + { + errorMsg = "Unknown error (0x" + Convert.ToString(error, 16) + ")"; + } + + return errorMsg; + } + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/System/Diagnostics/TraceEventType.cs b/src/Microsoft.AspNet.Server.WebListener/fx/System/Diagnostics/TraceEventType.cs new file mode 100644 index 0000000000..36d0b93d25 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/System/Diagnostics/TraceEventType.cs @@ -0,0 +1,35 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 + +using System; +using System.ComponentModel; + +namespace System.Diagnostics +{ + internal enum TraceEventType + { + Critical = 0x01, + Error = 0x02, + Warning = 0x04, + Information = 0x08, + Verbose = 0x10, + + [EditorBrowsable(EditorBrowsableState.Advanced)] + Start = 0x0100, + [EditorBrowsable(EditorBrowsableState.Advanced)] + Stop = 0x0200, + [EditorBrowsable(EditorBrowsableState.Advanced)] + Suspend = 0x0400, + [EditorBrowsable(EditorBrowsableState.Advanced)] + Resume = 0x0800, + [EditorBrowsable(EditorBrowsableState.Advanced)] + Transfer = 0x1000, + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/System/ExternDll.cs b/src/Microsoft.AspNet.Server.WebListener/fx/System/ExternDll.cs new file mode 100644 index 0000000000..98303d48b8 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/System/ExternDll.cs @@ -0,0 +1,16 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 + +namespace System +{ + internal static class ExternDll + { + public const string Kernel32 = "kernel32.dll"; + } +} +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/System/Runtime/InteropServices/ExternalException.cs b/src/Microsoft.AspNet.Server.WebListener/fx/System/Runtime/InteropServices/ExternalException.cs new file mode 100644 index 0000000000..21d82b1de4 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/System/Runtime/InteropServices/ExternalException.cs @@ -0,0 +1,98 @@ +// ==++== +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ==--== +/*============================================================================= +** +** Class: ExternalException +** +** +** Purpose: Exception base class for all errors from Interop or Structured +** Exception Handling code. +** +** +=============================================================================*/ + +#if !NET45 + +namespace System.Runtime.InteropServices +{ + using System; + using System.Globalization; + + // Base exception for COM Interop errors &; Structured Exception Handler + // exceptions. + // + internal class ExternalException : Exception + { + public ExternalException() + { + SetErrorCode(__HResults.E_FAIL); + } + + public ExternalException(String message) + : base(message) + { + SetErrorCode(__HResults.E_FAIL); + } + + public ExternalException(String message, Exception inner) + : base(message, inner) + { + SetErrorCode(__HResults.E_FAIL); + } + + public ExternalException(String message, int errorCode) + : base(message) + { + SetErrorCode(errorCode); + } + + private void SetErrorCode(int errorCode) + { + HResult = ErrorCode; + } + + private static class __HResults + { + internal const int E_FAIL = unchecked((int)0x80004005); + } + + public virtual int ErrorCode + { + get + { + return HResult; + } + } + + public override String ToString() + { + String message = Message; + String s; + String _className = GetType().ToString(); + s = _className + " (0x" + HResult.ToString("X8", CultureInfo.InvariantCulture) + ")"; + + if (!(String.IsNullOrEmpty(message))) + { + s = s + ": " + message; + } + + Exception _innerException = InnerException; + + if (_innerException != null) + { + s = s + " ---> " + _innerException.ToString(); + } + + + if (StackTrace != null) + s += Environment.NewLine + StackTrace; + + return s; + } + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/System/SafeNativeMethods.cs b/src/Microsoft.AspNet.Server.WebListener/fx/System/SafeNativeMethods.cs new file mode 100644 index 0000000000..969d273fa8 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/System/SafeNativeMethods.cs @@ -0,0 +1,28 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 +using System.Runtime.InteropServices; +using System.Text; + +namespace System +{ + internal static class SafeNativeMethods + { + public const int + FORMAT_MESSAGE_ALLOCATE_BUFFER = 0x00000100, + FORMAT_MESSAGE_IGNORE_INSERTS = 0x00000200, + FORMAT_MESSAGE_FROM_STRING = 0x00000400, + FORMAT_MESSAGE_FROM_SYSTEM = 0x00001000, + FORMAT_MESSAGE_ARGUMENT_ARRAY = 0x00002000; + + [DllImport(ExternDll.Kernel32, CharSet = System.Runtime.InteropServices.CharSet.Unicode, SetLastError = true, BestFitMapping = true)] + public static unsafe extern int FormatMessage(int dwFlags, IntPtr lpSource_mustBeNull, uint dwMessageId, + int dwLanguageId, StringBuilder lpBuffer, int nSize, IntPtr[] arguments); + + } +} +#endif diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/System/Security/Authentication/ExtendedProtection/ChannelBinding.cs b/src/Microsoft.AspNet.Server.WebListener/fx/System/Security/Authentication/ExtendedProtection/ChannelBinding.cs new file mode 100644 index 0000000000..3c31d34a1b --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/System/Security/Authentication/ExtendedProtection/ChannelBinding.cs @@ -0,0 +1,33 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System.Runtime.InteropServices; +using Microsoft.Win32.SafeHandles; + +#if !NET45 + +namespace System.Security.Authentication.ExtendedProtection +{ + internal abstract class ChannelBinding : SafeHandleZeroOrMinusOneIsInvalid + { + protected ChannelBinding() + : base(true) + { + } + + protected ChannelBinding(bool ownsHandle) + : base(ownsHandle) + { + } + + public abstract int Size + { + get; + } + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/project.json b/src/Microsoft.AspNet.Server.WebListener/project.json new file mode 100644 index 0000000000..0ba47798dd --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/project.json @@ -0,0 +1,8 @@ +{ + "version": "0.1-alpha-*", + "compilationOptions" : { "allowUnsafe": true }, + "configurations": { + "net45" : { }, + "k10" : { } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/Constants.cs b/src/Microsoft.AspNet.WebSockets/Constants.cs new file mode 100644 index 0000000000..6e9cb51ad3 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/Constants.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. + +namespace Microsoft.AspNet.WebSockets +{ + /// + /// Standard keys and values for use within the OWIN interfaces + /// + internal static class Constants + { + internal const string WebSocketAcceptKey = "websocket.Accept"; + internal const string WebSocketSubProtocolKey = "websocket.SubProtocol"; + internal const string WebSocketSendAsyncKey = "websocket.SendAsync"; + internal const string WebSocketReceiveAyncKey = "websocket.ReceiveAsync"; + internal const string WebSocketCloseAsyncKey = "websocket.CloseAsync"; + internal const string WebSocketCallCancelledKey = "websocket.CallCancelled"; + internal const string WebSocketVersionKey = "websocket.Version"; + internal const string WebSocketVersion = "1.0"; + internal const string WebSocketCloseStatusKey = "websocket.ClientCloseStatus"; + internal const string WebSocketCloseDescriptionKey = "websocket.ClientCloseDescription"; + } +} diff --git a/src/Microsoft.AspNet.WebSockets/HttpKnownHeaderNames.cs b/src/Microsoft.AspNet.WebSockets/HttpKnownHeaderNames.cs new file mode 100644 index 0000000000..f4887cc569 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/HttpKnownHeaderNames.cs @@ -0,0 +1,76 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.WebSockets +{ + // this class contains known header names + internal static class HttpKnownHeaderNames + { + public const string CacheControl = "Cache-Control"; + public const string Connection = "Connection"; + public const string Date = "Date"; + public const string KeepAlive = "Keep-Alive"; + public const string Pragma = "Pragma"; + public const string ProxyConnection = "Proxy-Connection"; + public const string Trailer = "Trailer"; + public const string TransferEncoding = "Transfer-Encoding"; + public const string Upgrade = "Upgrade"; + public const string Via = "Via"; + public const string Warning = "Warning"; + public const string ContentLength = "Content-Length"; + public const string ContentType = "Content-Type"; + public const string ContentDisposition = "Content-Disposition"; + public const string ContentEncoding = "Content-Encoding"; + public const string ContentLanguage = "Content-Language"; + public const string ContentLocation = "Content-Location"; + public const string ContentRange = "Content-Range"; + public const string Expires = "Expires"; + public const string LastModified = "Last-Modified"; + public const string Age = "Age"; + public const string Location = "Location"; + public const string ProxyAuthenticate = "Proxy-Authenticate"; + public const string RetryAfter = "Retry-After"; + public const string Server = "Server"; + public const string SetCookie = "Set-Cookie"; + public const string SetCookie2 = "Set-Cookie2"; + public const string Vary = "Vary"; + public const string WWWAuthenticate = "WWW-Authenticate"; + public const string Accept = "Accept"; + public const string AcceptCharset = "Accept-Charset"; + public const string AcceptEncoding = "Accept-Encoding"; + public const string AcceptLanguage = "Accept-Language"; + public const string Authorization = "Authorization"; + public const string Cookie = "Cookie"; + public const string Cookie2 = "Cookie2"; + public const string Expect = "Expect"; + public const string From = "From"; + public const string Host = "Host"; + public const string IfMatch = "If-Match"; + public const string IfModifiedSince = "If-Modified-Since"; + public const string IfNoneMatch = "If-None-Match"; + public const string IfRange = "If-Range"; + public const string IfUnmodifiedSince = "If-Unmodified-Since"; + public const string MaxForwards = "Max-Forwards"; + public const string ProxyAuthorization = "Proxy-Authorization"; + public const string Referer = "Referer"; + public const string Range = "Range"; + public const string UserAgent = "User-Agent"; + public const string ContentMD5 = "Content-MD5"; + public const string ETag = "ETag"; + public const string TE = "TE"; + public const string Allow = "Allow"; + public const string AcceptRanges = "Accept-Ranges"; + public const string P3P = "P3P"; + public const string XPoweredBy = "X-Powered-By"; + public const string XAspNetVersion = "X-AspNet-Version"; + public const string SecWebSocketKey = "Sec-WebSocket-Key"; + public const string SecWebSocketExtensions = "Sec-WebSocket-Extensions"; + public const string SecWebSocketAccept = "Sec-WebSocket-Accept"; + public const string Origin = "Origin"; + public const string SecWebSocketProtocol = "Sec-WebSocket-Protocol"; + public const string SecWebSocketVersion = "Sec-WebSocket-Version"; + } +} diff --git a/src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerContext.cs b/src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerContext.cs new file mode 100644 index 0000000000..506c830ea2 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerContext.cs @@ -0,0 +1,58 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ +/* +namespace Microsoft.Net +{ + using Microsoft.AspNet.WebSockets; + using System; + using System.ComponentModel; + using System.Threading.Tasks; + + public sealed unsafe class HttpListenerContext { + private HttpListenerRequest m_Request; + + public Task AcceptWebSocketAsync(string subProtocol) + { + return this.AcceptWebSocketAsync(subProtocol, + WebSocketHelpers.DefaultReceiveBufferSize, + WebSocket.DefaultKeepAliveInterval); + } + + public Task AcceptWebSocketAsync(string subProtocol, TimeSpan keepAliveInterval) + { + return this.AcceptWebSocketAsync(subProtocol, + WebSocketHelpers.DefaultReceiveBufferSize, + keepAliveInterval); + } + + public Task AcceptWebSocketAsync(string subProtocol, + int receiveBufferSize, + TimeSpan keepAliveInterval) + { + WebSocketHelpers.ValidateOptions(subProtocol, receiveBufferSize, WebSocketBuffer.MinSendBufferSize, keepAliveInterval); + + ArraySegment internalBuffer = WebSocketBuffer.CreateInternalBufferArraySegment(receiveBufferSize, WebSocketBuffer.MinSendBufferSize, true); + return this.AcceptWebSocketAsync(subProtocol, + receiveBufferSize, + keepAliveInterval, + internalBuffer); + } + + [EditorBrowsable(EditorBrowsableState.Never)] + public Task AcceptWebSocketAsync(string subProtocol, + int receiveBufferSize, + TimeSpan keepAliveInterval, + ArraySegment internalBuffer) + { + return WebSocketHelpers.AcceptWebSocketAsync(this, + subProtocol, + receiveBufferSize, + keepAliveInterval, + internalBuffer); + } + } +} +*/ \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerRequest.cs b/src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerRequest.cs new file mode 100644 index 0000000000..d0c02c1842 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerRequest.cs @@ -0,0 +1,66 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +/* +namespace Microsoft.Net +{ + using System; + using System.Collections; + using System.Collections.Specialized; + using System.IO; + using System.Runtime.InteropServices; + using System.Globalization; + using System.Text; + using System.Security.Principal; + using System.Security.Cryptography.X509Certificates; + using System.Net; + using Microsoft.AspNet.WebSockets; + + public sealed unsafe class HttpListenerRequest { + + public bool IsWebSocketRequest + { + get + { + if (!WebSocketProtocolComponent.IsSupported) + { + return false; + } + + bool foundConnectionUpgradeHeader = false; + if (string.IsNullOrEmpty(this.Headers[HttpKnownHeaderNames.Connection]) || string.IsNullOrEmpty(this.Headers[HttpKnownHeaderNames.Upgrade])) + { + return false; + } + + foreach (string connection in this.Headers.GetValues(HttpKnownHeaderNames.Connection)) + { + if (string.Compare(connection, HttpKnownHeaderNames.Upgrade, StringComparison.OrdinalIgnoreCase) == 0) + { + foundConnectionUpgradeHeader = true; + break; + } + } + + if (!foundConnectionUpgradeHeader) + { + return false; + } + + foreach (string upgrade in this.Headers.GetValues(HttpKnownHeaderNames.Upgrade)) + { + if (string.Compare(upgrade, WebSocketHelpers.WebSocketUpgradeToken, StringComparison.OrdinalIgnoreCase) == 0) + { + return true; + } + } + + return false; + } + } + } +} +*/ \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/Legacy/SR.cs b/src/Microsoft.AspNet.WebSockets/Legacy/SR.cs new file mode 100644 index 0000000000..a8aeb2a30e --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/Legacy/SR.cs @@ -0,0 +1,66 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All Rights Reserved. +// Information Contained Herein is Proprietary and Confidential. +// +//------------------------------------------------------------------------------ + +namespace System +{ + internal sealed class SR + { + internal const string net_servicePointAddressNotSupportedInHostMode = "net_servicePointAddressNotSupportedInHostMode"; + internal const string net_Websockets_AlreadyOneOutstandingOperation = "net_Websockets_AlreadyOneOutstandingOperation"; + internal const string net_Websockets_WebSocketBaseFaulted = "net_Websockets_WebSocketBaseFaulted"; + internal const string net_WebSockets_NativeSendResponseHeaders = "net_WebSockets_NativeSendResponseHeaders"; + internal const string net_WebSockets_Generic = "net_WebSockets_Generic"; + internal const string net_WebSockets_NotAWebSocket_Generic = "net_WebSockets_NotAWebSocket_Generic"; + internal const string net_WebSockets_UnsupportedWebSocketVersion_Generic = "net_WebSockets_UnsupportedWebSocketVersion_Generic"; + internal const string net_WebSockets_HeaderError_Generic = "net_WebSockets_HeaderError_Generic"; + internal const string net_WebSockets_UnsupportedProtocol_Generic = "net_WebSockets_UnsupportedProtocol_Generic"; + internal const string net_WebSockets_UnsupportedPlatform = "net_WebSockets_UnsupportedPlatform"; + internal const string net_WebSockets_AcceptNotAWebSocket = "net_WebSockets_AcceptNotAWebSocket"; + internal const string net_WebSockets_AcceptUnsupportedWebSocketVersion = "net_WebSockets_AcceptUnsupportedWebSocketVersion"; + internal const string net_WebSockets_AcceptHeaderNotFound = "net_WebSockets_AcceptHeaderNotFound"; + internal const string net_WebSockets_AcceptUnsupportedProtocol = "net_WebSockets_AcceptUnsupportedProtocol"; + internal const string net_WebSockets_ClientAcceptingNoProtocols = "net_WebSockets_ClientAcceptingNoProtocols"; + internal const string net_WebSockets_ClientSecWebSocketProtocolsBlank = "net_WebSockets_ClientSecWebSocketProtocolsBlank"; + internal const string net_WebSockets_ArgumentOutOfRange_TooSmall = "net_WebSockets_ArgumentOutOfRange_TooSmall"; + internal const string net_WebSockets_ArgumentOutOfRange_InternalBuffer = "net_WebSockets_ArgumentOutOfRange_InternalBuffer"; + internal const string net_WebSockets_ArgumentOutOfRange_TooBig = "net_WebSockets_ArgumentOutOfRange_TooBig"; + internal const string net_WebSockets_InvalidState_Generic = "net_WebSockets_InvalidState_Generic"; + internal const string net_WebSockets_InvalidState_ClosedOrAborted = "net_WebSockets_InvalidState_ClosedOrAborted"; + internal const string net_WebSockets_InvalidState = "net_WebSockets_InvalidState"; + internal const string net_WebSockets_ReceiveAsyncDisallowedAfterCloseAsync = "net_WebSockets_ReceiveAsyncDisallowedAfterCloseAsync"; + internal const string net_WebSockets_InvalidMessageType = "net_WebSockets_InvalidMessageType"; + internal const string net_WebSockets_InvalidBufferType = "net_WebSockets_InvalidBufferType"; + internal const string net_WebSockets_InvalidMessageType_Generic = "net_WebSockets_InvalidMessageType_Generic"; + internal const string net_WebSockets_Argument_InvalidMessageType = "net_WebSockets_Argument_InvalidMessageType"; + internal const string net_WebSockets_ConnectionClosedPrematurely_Generic = "net_WebSockets_ConnectionClosedPrematurely_Generic"; + internal const string net_WebSockets_InvalidCharInProtocolString = "net_WebSockets_InvalidCharInProtocolString"; + internal const string net_WebSockets_InvalidEmptySubProtocol = "net_WebSockets_InvalidEmptySubProtocol"; + internal const string net_WebSockets_ReasonNotNull = "net_WebSockets_ReasonNotNull"; + internal const string net_WebSockets_InvalidCloseStatusCode = "net_WebSockets_InvalidCloseStatusCode"; + internal const string net_WebSockets_InvalidCloseStatusDescription = "net_WebSockets_InvalidCloseStatusDescription"; + internal const string net_WebSockets_Scheme = "net_WebSockets_Scheme"; + internal const string net_WebSockets_AlreadyStarted = "net_WebSockets_AlreadyStarted"; + internal const string net_WebSockets_Connect101Expected = "net_WebSockets_Connect101Expected"; + internal const string net_WebSockets_InvalidResponseHeader = "net_WebSockets_InvalidResponseHeader"; + internal const string net_WebSockets_NotConnected = "net_WebSockets_NotConnected"; + internal const string net_WebSockets_InvalidRegistration = "net_WebSockets_InvalidRegistration"; + internal const string net_WebSockets_NoDuplicateProtocol = "net_WebSockets_NoDuplicateProtocol"; + + internal const string NotReadableStream = "NotReadableStream"; + internal const string NotWriteableStream = "NotWriteableStream"; + + public static string GetString(string name, params object[] args) + { + return name; + } + + public static string GetString(string name) + { + return name; + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/Legacy/WebSocketHttpListenerDuplexStream.cs b/src/Microsoft.AspNet.WebSockets/Legacy/WebSocketHttpListenerDuplexStream.cs new file mode 100644 index 0000000000..54abff91bd --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/Legacy/WebSocketHttpListenerDuplexStream.cs @@ -0,0 +1,1262 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ +/* +namespace Microsoft.AspNet.WebSockets +{ + using Microsoft.Net; + using System; + using System.Collections.Concurrent; + using System.Collections.Generic; + using System.ComponentModel; + using System.Diagnostics; + using System.Diagnostics.Contracts; + using System.Globalization; + using System.IO; + using System.Runtime.InteropServices; + using System.Security; + using System.Threading; + using System.Threading.Tasks; + + internal class WebSocketHttpListenerDuplexStream : Stream, WebSocketBase.IWebSocketStream + { + private static readonly EventHandler s_OnReadCompleted = + new EventHandler(OnReadCompleted); + private static readonly EventHandler s_OnWriteCompleted = + new EventHandler(OnWriteCompleted); + private static readonly Func s_CanHandleException = new Func(CanHandleException); + private static readonly Action s_OnCancel = new Action(OnCancel); + // private readonly HttpRequestStream m_InputStream; + // private readonly HttpResponseStream m_OutputStream; + private HttpListenerContext m_Context; + private bool m_InOpaqueMode; + private WebSocketBase m_WebSocket; + private HttpListenerAsyncEventArgs m_WriteEventArgs; + private HttpListenerAsyncEventArgs m_ReadEventArgs; + private TaskCompletionSource m_WriteTaskCompletionSource; + private TaskCompletionSource m_ReadTaskCompletionSource; + private int m_CleanedUp; + +#if DEBUG + private class OutstandingOperations + { + internal int m_Reads; + internal int m_Writes; + } + + private readonly OutstandingOperations m_OutstandingOperations = new OutstandingOperations(); +#endif //DEBUG + + public WebSocketHttpListenerDuplexStream( + // HttpRequestStream inputStream, + // HttpResponseStream outputStream, + HttpListenerContext context) + { + Contract.Assert(inputStream != null, "'inputStream' MUST NOT be NULL."); + Contract.Assert(outputStream != null, "'outputStream' MUST NOT be NULL."); + Contract.Assert(context != null, "'context' MUST NOT be NULL."); + Contract.Assert(inputStream.CanRead, "'inputStream' MUST support read operations."); + Contract.Assert(outputStream.CanWrite, "'outputStream' MUST support write operations."); + + m_InputStream = inputStream; + m_OutputStream = outputStream; + m_Context = context; + + if (WebSocketBase.LoggingEnabled) + { + Logging.Associate(Logging.WebSockets, inputStream, this); + Logging.Associate(Logging.WebSockets, outputStream, this); + } + } + + public override bool CanRead + { + get + { + return m_InputStream.CanRead; + } + } + + public override bool CanSeek + { + get + { + return false; + } + } + + public override bool CanTimeout + { + get + { + return m_InputStream.CanTimeout && m_OutputStream.CanTimeout; + } + } + + public override bool CanWrite + { + get + { + return m_OutputStream.CanWrite; + } + } + + public override long Length + { + get + { + throw new NotSupportedException(SR.GetString(SR.net_noseek)); + } + } + + public override long Position + { + get + { + throw new NotSupportedException(SR.GetString(SR.net_noseek)); + } + set + { + throw new NotSupportedException(SR.GetString(SR.net_noseek)); + } + } + + public override int Read(byte[] buffer, int offset, int count) + { + return m_InputStream.Read(buffer, offset, count); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + WebSocketHelpers.ValidateBuffer(buffer, offset, count); + + return ReadAsyncCore(buffer, offset, count, cancellationToken); + } + + private async Task ReadAsyncCore(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.ReadAsyncCore, + WebSocketHelpers.GetTraceMsgForParameters(offset, count, cancellationToken)); + } + + CancellationTokenRegistration cancellationTokenRegistration = new CancellationTokenRegistration(); + + int bytesRead = 0; + try + { + if (cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.Register(s_OnCancel, this, false); + } + + if (!m_InOpaqueMode) + { + bytesRead = await m_InputStream.ReadAsync(buffer, offset, count, cancellationToken).SuppressContextFlow(); + } + else + { +#if DEBUG + // When using fast path only one outstanding read is permitted. By switching into opaque mode + // via IWebSocketStream.SwitchToOpaqueMode (see more detailed comments in interface definition) + // caller takes responsibility for enforcing this constraint. + Contract.Assert(Interlocked.Increment(ref m_OutstandingOperations.m_Reads) == 1, + "Only one outstanding read allowed at any given time."); +#endif + m_ReadTaskCompletionSource = new TaskCompletionSource(); + m_ReadEventArgs.SetBuffer(buffer, offset, count); + if (!ReadAsyncFast(m_ReadEventArgs)) + { + if (m_ReadEventArgs.Exception != null) + { + throw m_ReadEventArgs.Exception; + } + + bytesRead = m_ReadEventArgs.BytesTransferred; + } + else + { + bytesRead = await m_ReadTaskCompletionSource.Task.SuppressContextFlow(); + } + } + } + catch (Exception error) + { + if (s_CanHandleException(error)) + { + cancellationToken.ThrowIfCancellationRequested(); + } + + throw; + } + finally + { + cancellationTokenRegistration.Dispose(); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.ReadAsyncCore, bytesRead); + } + } + + return bytesRead; + } + + // return value indicates sync vs async completion + // false: sync completion + // true: async completion + private unsafe bool ReadAsyncFast(HttpListenerAsyncEventArgs eventArgs) + { + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.ReadAsyncFast, string.Empty); + } + + eventArgs.StartOperationCommon(this); + eventArgs.StartOperationReceive(); + + uint statusCode = 0; + bool completedAsynchronously = false; + try + { + Contract.Assert(eventArgs.Buffer != null, "'BufferList' is not supported for read operations."); + if (eventArgs.Count == 0 || m_InputStream.Closed) + { + eventArgs.FinishOperationSuccess(0, true); + return false; + } + + uint dataRead = 0; + int offset = eventArgs.Offset; + int remainingCount = eventArgs.Count; + + if (m_InputStream.BufferedDataChunksAvailable) + { + dataRead = m_InputStream.GetChunks(eventArgs.Buffer, eventArgs.Offset, eventArgs.Count); + if (m_InputStream.BufferedDataChunksAvailable && dataRead == eventArgs.Count) + { + eventArgs.FinishOperationSuccess(eventArgs.Count, true); + return false; + } + } + + Contract.Assert(!m_InputStream.BufferedDataChunksAvailable, "'m_InputStream.BufferedDataChunksAvailable' MUST BE 'FALSE' at this point."); + Contract.Assert(dataRead <= eventArgs.Count, "'dataRead' MUST NOT be bigger than 'eventArgs.Count'."); + + if (dataRead != 0) + { + offset += (int)dataRead; + remainingCount -= (int)dataRead; + //the http.sys team recommends that we limit the size to 128kb + if (remainingCount > HttpRequestStream.MaxReadSize) + { + remainingCount = HttpRequestStream.MaxReadSize; + } + + eventArgs.SetBuffer(eventArgs.Buffer, offset, remainingCount); + } + else if (remainingCount > HttpRequestStream.MaxReadSize) + { + remainingCount = HttpRequestStream.MaxReadSize; + eventArgs.SetBuffer(eventArgs.Buffer, offset, remainingCount); + } + + // m_InputStream.InternalHttpContext.EnsureBoundHandle(); + uint flags = 0; + uint bytesReturned = 0; + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpReceiveRequestEntityBody2( + m_InputStream.InternalHttpContext.RequestQueueHandle, + m_InputStream.InternalHttpContext.RequestId, + flags, + (byte*)m_WebSocket.InternalBuffer.ToIntPtr(eventArgs.Offset), + (uint)eventArgs.Count, + out bytesReturned, + eventArgs.NativeOverlapped); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + throw new HttpListenerException((int)statusCode); + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + HttpListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously. No IO completion port callback is used because + // it was disabled in SwitchToOpaqueMode() + eventArgs.FinishOperationSuccess((int)bytesReturned, true); + completedAsynchronously = false; + } + else + { + completedAsynchronously = true; + } + } + catch (Exception e) + { + m_ReadEventArgs.FinishOperationFailure(e, true); + m_OutputStream.SetClosedFlag(); + m_OutputStream.InternalHttpContext.Abort(); + + throw; + } + finally + { + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.ReadAsyncFast, completedAsynchronously); + } + } + + return completedAsynchronously; + } + + public override int ReadByte() + { + return m_InputStream.ReadByte(); + } + + public bool SupportsMultipleWrite + { + get + { + return true; + } + } + + public override IAsyncResult BeginRead(byte[] buffer, + int offset, + int count, + AsyncCallback callback, + object state) + { + return m_InputStream.BeginRead(buffer, offset, count, callback, state); + } + + public override int EndRead(IAsyncResult asyncResult) + { + return m_InputStream.EndRead(asyncResult); + } + + public Task MultipleWriteAsync(IList> sendBuffers, CancellationToken cancellationToken) + { + Contract.Assert(m_InOpaqueMode, "The stream MUST be in opaque mode at this point."); + Contract.Assert(sendBuffers != null, "'sendBuffers' MUST NOT be NULL."); + Contract.Assert(sendBuffers.Count == 1 || sendBuffers.Count == 2, + "'sendBuffers.Count' MUST be either '1' or '2'."); + + if (sendBuffers.Count == 1) + { + ArraySegment buffer = sendBuffers[0]; + return WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken); + } + + return MultipleWriteAsyncCore(sendBuffers, cancellationToken); + } + + private async Task MultipleWriteAsyncCore(IList> sendBuffers, CancellationToken cancellationToken) + { + Contract.Assert(sendBuffers != null, "'sendBuffers' MUST NOT be NULL."); + Contract.Assert(sendBuffers.Count == 2, "'sendBuffers.Count' MUST be '2' at this point."); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.MultipleWriteAsyncCore, string.Empty); + } + + CancellationTokenRegistration cancellationTokenRegistration = new CancellationTokenRegistration(); + + try + { + if (cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.Register(s_OnCancel, this, false); + } +#if DEBUG + // When using fast path only one outstanding read is permitted. By switching into opaque mode + // via IWebSocketStream.SwitchToOpaqueMode (see more detailed comments in interface definition) + // caller takes responsibility for enforcing this constraint. + Contract.Assert(Interlocked.Increment(ref m_OutstandingOperations.m_Writes) == 1, + "Only one outstanding write allowed at any given time."); +#endif + m_WriteTaskCompletionSource = new TaskCompletionSource(); + m_WriteEventArgs.SetBuffer(null, 0, 0); + m_WriteEventArgs.BufferList = sendBuffers; + if (WriteAsyncFast(m_WriteEventArgs)) + { + await m_WriteTaskCompletionSource.Task.SuppressContextFlow(); + } + } + catch (Exception error) + { + if (s_CanHandleException(error)) + { + cancellationToken.ThrowIfCancellationRequested(); + } + + throw; + } + finally + { + cancellationTokenRegistration.Dispose(); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.MultipleWriteAsyncCore, string.Empty); + } + } + } + + public override void Write(byte[] buffer, int offset, int count) + { + m_OutputStream.Write(buffer, offset, count); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + WebSocketHelpers.ValidateBuffer(buffer, offset, count); + + return WriteAsyncCore(buffer, offset, count, cancellationToken); + } + + private async Task WriteAsyncCore(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.WriteAsyncCore, + WebSocketHelpers.GetTraceMsgForParameters(offset, count, cancellationToken)); + } + + CancellationTokenRegistration cancellationTokenRegistration = new CancellationTokenRegistration(); + + try + { + if (cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.Register(s_OnCancel, this, false); + } + + if (!m_InOpaqueMode) + { + await m_OutputStream.WriteAsync(buffer, offset, count, cancellationToken).SuppressContextFlow(); + } + else + { +#if DEBUG + // When using fast path only one outstanding read is permitted. By switching into opaque mode + // via IWebSocketStream.SwitchToOpaqueMode (see more detailed comments in interface definition) + // caller takes responsibility for enforcing this constraint. + Contract.Assert(Interlocked.Increment(ref m_OutstandingOperations.m_Writes) == 1, + "Only one outstanding write allowed at any given time."); +#endif + m_WriteTaskCompletionSource = new TaskCompletionSource(); + m_WriteEventArgs.BufferList = null; + m_WriteEventArgs.SetBuffer(buffer, offset, count); + if (WriteAsyncFast(m_WriteEventArgs)) + { + await m_WriteTaskCompletionSource.Task.SuppressContextFlow(); + } + } + } + catch (Exception error) + { + if (s_CanHandleException(error)) + { + cancellationToken.ThrowIfCancellationRequested(); + } + + throw; + } + finally + { + cancellationTokenRegistration.Dispose(); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.WriteAsyncCore, string.Empty); + } + } + } + + // return value indicates sync vs async completion + // false: sync completion + // true: async completion + private bool WriteAsyncFast(HttpListenerAsyncEventArgs eventArgs) + { + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.WriteAsyncFast, string.Empty); + } + + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + + eventArgs.StartOperationCommon(this); + eventArgs.StartOperationSend(); + + uint statusCode; + bool completedAsynchronously = false; + try + { + if (m_OutputStream.Closed || + (eventArgs.Buffer != null && eventArgs.Count == 0)) + { + eventArgs.FinishOperationSuccess(eventArgs.Count, true); + return false; + } + + if (eventArgs.ShouldCloseOutput) + { + flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_DISCONNECT; + } + else + { + flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + // When using HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA HTTP.SYS will copy the payload to + // kernel memory (Non-Paged Pool). Http.Sys will buffer up to + // Math.Min(16 MB, current TCP window size) + flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA; + } + + m_OutputStream.InternalHttpContext.EnsureBoundHandle(); + uint bytesSent; + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendResponseEntityBody2( + m_OutputStream.InternalHttpContext.RequestQueueHandle, + m_OutputStream.InternalHttpContext.RequestId, + (uint)flags, + eventArgs.EntityChunkCount, + eventArgs.EntityChunks, + out bytesSent, + SafeLocalFree.Zero, + 0, + eventArgs.NativeOverlapped, + IntPtr.Zero); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + throw new HttpListenerException((int)statusCode); + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + HttpListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + eventArgs.FinishOperationSuccess((int)bytesSent, true); + completedAsynchronously = false; + } + else + { + completedAsynchronously = true; + } + } + catch (Exception e) + { + m_WriteEventArgs.FinishOperationFailure(e, true); + m_OutputStream.SetClosedFlag(); + m_OutputStream.InternalHttpContext.Abort(); + + throw; + } + finally + { + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.WriteAsyncFast, completedAsynchronously); + } + } + + return completedAsynchronously; + } + + public override void WriteByte(byte value) + { + m_OutputStream.WriteByte(value); + } + + public override IAsyncResult BeginWrite(byte[] buffer, + int offset, + int count, + AsyncCallback callback, + object state) + { + return m_OutputStream.BeginWrite(buffer, offset, count, callback, state); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + m_OutputStream.EndWrite(asyncResult); + } + + public override void Flush() + { + m_OutputStream.Flush(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return m_OutputStream.FlushAsync(cancellationToken); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(SR.GetString(SR.net_noseek)); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(SR.GetString(SR.net_noseek)); + } + + public async Task CloseNetworkConnectionAsync(CancellationToken cancellationToken) + { + // need to yield here to make sure that we don't get any exception synchronously + await Task.Yield(); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.CloseNetworkConnectionAsync, string.Empty); + } + + CancellationTokenRegistration cancellationTokenRegistration = new CancellationTokenRegistration(); + + try + { + if (cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.Register(s_OnCancel, this, false); + } +#if DEBUG + // When using fast path only one outstanding read is permitted. By switching into opaque mode + // via IWebSocketStream.SwitchToOpaqueMode (see more detailed comments in interface definition) + // caller takes responsibility for enforcing this constraint. + Contract.Assert(Interlocked.Increment(ref m_OutstandingOperations.m_Writes) == 1, + "Only one outstanding write allowed at any given time."); +#endif + m_WriteTaskCompletionSource = new TaskCompletionSource(); + m_WriteEventArgs.SetShouldCloseOutput(); + if (WriteAsyncFast(m_WriteEventArgs)) + { + await m_WriteTaskCompletionSource.Task.SuppressContextFlow(); + } + } + catch (Exception error) + { + if (!s_CanHandleException(error)) + { + throw; + } + + // throw OperationCancelledException when canceled by the caller + // otherwise swallow the exception + cancellationToken.ThrowIfCancellationRequested(); + } + finally + { + cancellationTokenRegistration.Dispose(); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.CloseNetworkConnectionAsync, string.Empty); + } + } + } + + protected override void Dispose(bool disposing) + { + if (disposing && Interlocked.Exchange(ref m_CleanedUp, 1) == 0) + { + if (m_ReadTaskCompletionSource != null) + { + m_ReadTaskCompletionSource.TrySetCanceled(); + } + + if (m_WriteTaskCompletionSource != null) + { + m_WriteTaskCompletionSource.TrySetCanceled(); + } + + if (m_ReadEventArgs != null) + { + m_ReadEventArgs.Dispose(); + } + + if (m_WriteEventArgs != null) + { + m_WriteEventArgs.Dispose(); + } + + try + { + m_InputStream.Close(); + } + finally + { + m_OutputStream.Close(); + } + } + } + + public void Abort() + { + OnCancel(this); + } + + private static bool CanHandleException(Exception error) + { + return error is HttpListenerException || + error is ObjectDisposedException || + error is IOException; + } + + private static void OnCancel(object state) + { + Contract.Assert(state != null, "'state' MUST NOT be NULL."); + WebSocketHttpListenerDuplexStream thisPtr = state as WebSocketHttpListenerDuplexStream; + Contract.Assert(thisPtr != null, "'thisPtr' MUST NOT be NULL."); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, state, Methods.OnCancel, string.Empty); + } + + try + { + thisPtr.m_OutputStream.SetClosedFlag(); + // thisPtr.m_Context.Abort(); + } + catch { } + + TaskCompletionSource readTaskCompletionSourceSnapshot = thisPtr.m_ReadTaskCompletionSource; + + if (readTaskCompletionSourceSnapshot != null) + { + readTaskCompletionSourceSnapshot.TrySetCanceled(); + } + + TaskCompletionSource writeTaskCompletionSourceSnapshot = thisPtr.m_WriteTaskCompletionSource; + + if (writeTaskCompletionSourceSnapshot != null) + { + writeTaskCompletionSourceSnapshot.TrySetCanceled(); + } + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, state, Methods.OnCancel, string.Empty); + } + } + + public void SwitchToOpaqueMode(WebSocketBase webSocket) + { + Contract.Assert(webSocket != null, "'webSocket' MUST NOT be NULL."); + Contract.Assert(m_OutputStream != null, "'m_OutputStream' MUST NOT be NULL."); + Contract.Assert(m_OutputStream.InternalHttpContext != null, + "'m_OutputStream.InternalHttpContext' MUST NOT be NULL."); + Contract.Assert(m_OutputStream.InternalHttpContext.Response != null, + "'m_OutputStream.InternalHttpContext.Response' MUST NOT be NULL."); + Contract.Assert(m_OutputStream.InternalHttpContext.Response.SentHeaders, + "Headers MUST have been sent at this point."); + Contract.Assert(!m_InOpaqueMode, "SwitchToOpaqueMode MUST NOT be called multiple times."); + + if (m_InOpaqueMode) + { + throw new InvalidOperationException(); + } + + m_WebSocket = webSocket; + m_InOpaqueMode = true; + m_ReadEventArgs = new HttpListenerAsyncEventArgs(webSocket, this); + m_ReadEventArgs.Completed += s_OnReadCompleted; + m_WriteEventArgs = new HttpListenerAsyncEventArgs(webSocket, this); + m_WriteEventArgs.Completed += s_OnWriteCompleted; + + if (WebSocketBase.LoggingEnabled) + { + Logging.Associate(Logging.WebSockets, this, webSocket); + } + } + + private static void OnWriteCompleted(object sender, HttpListenerAsyncEventArgs eventArgs) + { + Contract.Assert(eventArgs != null, "'eventArgs' MUST NOT be NULL."); + WebSocketHttpListenerDuplexStream thisPtr = eventArgs.CurrentStream; + Contract.Assert(thisPtr != null, "'thisPtr' MUST NOT be NULL."); +#if DEBUG + Contract.Assert(Interlocked.Decrement(ref thisPtr.m_OutstandingOperations.m_Writes) >= 0, + "'thisPtr.m_OutstandingOperations.m_Writes' MUST NOT be negative."); +#endif + + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, thisPtr, Methods.OnWriteCompleted, string.Empty); + } + + if (eventArgs.Exception != null) + { + thisPtr.m_WriteTaskCompletionSource.TrySetException(eventArgs.Exception); + } + else + { + thisPtr.m_WriteTaskCompletionSource.TrySetResult(null); + } + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, thisPtr, Methods.OnWriteCompleted, string.Empty); + } + } + + private static void OnReadCompleted(object sender, HttpListenerAsyncEventArgs eventArgs) + { + Contract.Assert(eventArgs != null, "'eventArgs' MUST NOT be NULL."); + WebSocketHttpListenerDuplexStream thisPtr = eventArgs.CurrentStream; + Contract.Assert(thisPtr != null, "'thisPtr' MUST NOT be NULL."); +#if DEBUG + Contract.Assert(Interlocked.Decrement(ref thisPtr.m_OutstandingOperations.m_Reads) >= 0, + "'thisPtr.m_OutstandingOperations.m_Reads' MUST NOT be negative."); +#endif + + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, thisPtr, Methods.OnReadCompleted, string.Empty); + } + + if (eventArgs.Exception != null) + { + thisPtr.m_ReadTaskCompletionSource.TrySetException(eventArgs.Exception); + } + else + { + thisPtr.m_ReadTaskCompletionSource.TrySetResult(eventArgs.BytesTransferred); + } + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, thisPtr, Methods.OnReadCompleted, string.Empty); + } + } + + internal class HttpListenerAsyncEventArgs : EventArgs, IDisposable + { + private const int Free = 0; + private const int InProgress = 1; + private const int Disposed = 2; + private int m_Operating; + + private bool m_DisposeCalled; + private SafeNativeOverlapped m_PtrNativeOverlapped; + private Overlapped m_Overlapped; + private event EventHandler m_Completed; + private byte[] m_Buffer; + private IList> m_BufferList; + private int m_Count; + private int m_Offset; + private int m_BytesTransferred; + private HttpListenerAsyncOperation m_CompletedOperation; + private UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK[] m_DataChunks; + private GCHandle m_DataChunksGCHandle; + private ushort m_DataChunkCount; + private Exception m_Exception; + private bool m_ShouldCloseOutput; + private readonly WebSocketBase m_WebSocket; + private readonly WebSocketHttpListenerDuplexStream m_CurrentStream; + + public HttpListenerAsyncEventArgs(WebSocketBase webSocket, WebSocketHttpListenerDuplexStream stream) + : base() + { + m_WebSocket = webSocket; + m_CurrentStream = stream; + InitializeOverlapped(); + } + + public int BytesTransferred + { + get { return m_BytesTransferred; } + } + + public byte[] Buffer + { + get { return m_Buffer; } + } + + // BufferList property. + // Mutually exclusive with Buffer. + // Setting this property with an existing non-null Buffer will cause an assert. + public IList> BufferList + { + get { return m_BufferList; } + set + { + Contract.Assert(!m_ShouldCloseOutput, "'m_ShouldCloseOutput' MUST be 'false' at this point."); + Contract.Assert(value == null || m_Buffer == null, + "Either 'm_Buffer' or 'm_BufferList' MUST be NULL."); + Contract.Assert(m_Operating == Free, + "This property can only be modified if no IO operation is outstanding."); + Contract.Assert(value == null || value.Count == 2, + "This list can only be 'NULL' or MUST have exactly '2' items."); + m_BufferList = value; + } + } + + public bool ShouldCloseOutput + { + get { return m_ShouldCloseOutput; } + } + + public int Offset + { + get { return m_Offset; } + } + + public int Count + { + get { return m_Count; } + } + + public Exception Exception + { + get { return m_Exception; } + } + + public ushort EntityChunkCount + { + get + { + if (m_DataChunks == null) + { + return 0; + } + + return m_DataChunkCount; + } + } + + public SafeNativeOverlapped NativeOverlapped + { + get { return m_PtrNativeOverlapped; } + } + + public IntPtr EntityChunks + { + get + { + if (m_DataChunks == null) + { + return IntPtr.Zero; + } + + return Marshal.UnsafeAddrOfPinnedArrayElement(m_DataChunks, 0); + } + } + + public WebSocketHttpListenerDuplexStream CurrentStream + { + get { return m_CurrentStream; } + } + + public event EventHandler Completed + { + add + { + m_Completed += value; + } + remove + { + m_Completed -= value; + } + } + + protected virtual void OnCompleted(HttpListenerAsyncEventArgs e) + { + EventHandler handler = m_Completed; + if (handler != null) + { + handler(e.m_CurrentStream, e); + } + } + + public void SetShouldCloseOutput() + { + m_BufferList = null; + m_Buffer = null; + m_ShouldCloseOutput = true; + } + + public void Dispose() + { + // Remember that Dispose was called. + m_DisposeCalled = true; + + // Check if this object is in-use for an async socket operation. + if (Interlocked.CompareExchange(ref m_Operating, Disposed, Free) != Free) + { + // Either already disposed or will be disposed when current operation completes. + return; + } + + // OK to dispose now. + // Free native overlapped data. + FreeOverlapped(false); + + // Don't bother finalizing later. + GC.SuppressFinalize(this); + } + + // Finalizer + ~HttpListenerAsyncEventArgs() + { + FreeOverlapped(true); + } + + private unsafe void InitializeOverlapped() + { + m_Overlapped = new Overlapped(); + m_PtrNativeOverlapped = new SafeNativeOverlapped(m_Overlapped.UnsafePack(CompletionPortCallback, null)); + } + + // Method to clean up any existing Overlapped object and related state variables. + private void FreeOverlapped(bool checkForShutdown) + { + if (!checkForShutdown || !NclUtilities.HasShutdownStarted) + { + // Free the overlapped object + if (m_PtrNativeOverlapped != null && !m_PtrNativeOverlapped.IsInvalid) + { + m_PtrNativeOverlapped.Dispose(); + } + + if (m_DataChunksGCHandle.IsAllocated) + { + m_DataChunksGCHandle.Free(); + } + } + } + + // Method called to prepare for a native async http.sys call. + // This method performs the tasks common to all http.sys operations. + internal void StartOperationCommon(WebSocketHttpListenerDuplexStream currentStream) + { + // Change status to "in-use". + if(Interlocked.CompareExchange(ref m_Operating, InProgress, Free) != Free) + { + // If it was already "in-use" check if Dispose was called. + if (m_DisposeCalled) + { + // Dispose was called - throw ObjectDisposed. + throw new ObjectDisposedException(GetType().FullName); + } + + Contract.Assert(false, "Only one outstanding async operation is allowed per HttpListenerAsyncEventArgs instance."); + // Only one at a time. + throw new InvalidOperationException(); + } + + // HttpSendResponseEntityBody can return ERROR_INVALID_PARAMETER if the InternalHigh field of the overlapped + // is not IntPtr.Zero, so we have to reset this field because we are reusing the Overlapped. + // When using the IAsyncResult based approach of HttpListenerResponseStream the Overlapped is reinitialized + // for each operation by the CLR when returned from the OverlappedDataCache. + NativeOverlapped.ReinitializeNativeOverlapped(); + m_Exception = null; + m_BytesTransferred = 0; + } + + internal void StartOperationReceive() + { + // Remember the operation type. + m_CompletedOperation = HttpListenerAsyncOperation.Receive; + } + + internal void StartOperationSend() + { + UpdateDataChunk(); + + // Remember the operation type. + m_CompletedOperation = HttpListenerAsyncOperation.Send; + } + + public void SetBuffer(byte[] buffer, int offset, int count) + { + Contract.Assert(!m_ShouldCloseOutput, "'m_ShouldCloseOutput' MUST be 'false' at this point."); + Contract.Assert(buffer == null || m_BufferList == null, "Either 'm_Buffer' or 'm_BufferList' MUST be NULL."); + m_Buffer = buffer; + m_Offset = offset; + m_Count = count; + } + + private unsafe void UpdateDataChunk() + { + if (m_DataChunks == null) + { + m_DataChunks = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK[2]; + m_DataChunksGCHandle = GCHandle.Alloc(m_DataChunks); + m_DataChunks[0] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + m_DataChunks[0].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + m_DataChunks[1] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + m_DataChunks[1].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + } + + Contract.Assert(m_Buffer == null || m_BufferList == null, "Either 'm_Buffer' or 'm_BufferList' MUST be NULL."); + Contract.Assert(m_ShouldCloseOutput || m_Buffer != null || m_BufferList != null, "Either 'm_Buffer' or 'm_BufferList' MUST NOT be NULL."); + + // The underlying byte[] m_Buffer or each m_BufferList[].Array are pinned already + if (m_Buffer != null) + { + UpdateDataChunk(0, m_Buffer, m_Offset, m_Count); + UpdateDataChunk(1, null, 0, 0); + m_DataChunkCount = 1; + } + else if (m_BufferList != null) + { + Contract.Assert(m_BufferList != null && m_BufferList.Count == 2, + "'m_BufferList' MUST NOT be NULL and have exactly '2' items at this point."); + UpdateDataChunk(0, m_BufferList[0].Array, m_BufferList[0].Offset, m_BufferList[0].Count); + UpdateDataChunk(1, m_BufferList[1].Array, m_BufferList[1].Offset, m_BufferList[1].Count); + m_DataChunkCount = 2; + } + else + { + Contract.Assert(m_ShouldCloseOutput, "'m_ShouldCloseOutput' MUST be 'true' at this point."); + m_DataChunks = null; + } + } + + private unsafe void UpdateDataChunk(int index, byte[] buffer, int offset, int count) + { + if (buffer == null) + { + m_DataChunks[index].pBuffer = null; + m_DataChunks[index].BufferLength = 0; + return; + } + + if (m_WebSocket.InternalBuffer.IsInternalBuffer(buffer, offset, count)) + { + m_DataChunks[index].pBuffer = (byte*)(m_WebSocket.InternalBuffer.ToIntPtr(offset)); + } + else + { + m_DataChunks[index].pBuffer = + (byte*)m_WebSocket.InternalBuffer.ConvertPinnedSendPayloadToNative(buffer, offset, count); + } + + m_DataChunks[index].BufferLength = (uint)count; + } + + // Method to mark this object as no longer "in-use". + // Will also execute a Dispose deferred because I/O was in progress. + internal void Complete() + { + // Mark as not in-use + m_Operating = Free; + + // Check for deferred Dispose(). + // The deferred Dispose is not guaranteed if Dispose is called while an operation is in progress. + // The m_DisposeCalled variable is not managed in a thread-safe manner on purpose for performance. + if (m_DisposeCalled) + { + Dispose(); + } + } + + // Method to update internal state after sync or async completion. + private void SetResults(Exception exception, int bytesTransferred) + { + m_Exception = exception; + m_BytesTransferred = bytesTransferred; + } + + internal void FinishOperationFailure(Exception exception, bool syncCompletion) + { + SetResults(exception, 0); + + if (WebSocketBase.LoggingEnabled) + { + Logging.PrintError(Logging.WebSockets, m_CurrentStream, + m_CompletedOperation == HttpListenerAsyncOperation.Receive ? Methods.ReadAsyncFast : Methods.WriteAsyncFast, + exception.ToString()); + } + + Complete(); + OnCompleted(this); + } + + internal void FinishOperationSuccess(int bytesTransferred, bool syncCompletion) + { + SetResults(null, bytesTransferred); + + if (WebSocketBase.LoggingEnabled) + { + if (m_Buffer != null) + { + Logging.Dump(Logging.WebSockets, m_CurrentStream, + m_CompletedOperation == HttpListenerAsyncOperation.Receive ? Methods.ReadAsyncFast : Methods.WriteAsyncFast, + m_Buffer, m_Offset, bytesTransferred); + } + else if (m_BufferList != null) + { + Contract.Assert(m_CompletedOperation == HttpListenerAsyncOperation.Send, + "'BufferList' is only supported for send operations."); + + foreach (ArraySegment buffer in BufferList) + { + Logging.Dump(Logging.WebSockets, this, Methods.WriteAsyncFast, buffer.Array, buffer.Offset, buffer.Count); + } + } + else + { + Logging.PrintLine(Logging.WebSockets, TraceEventType.Verbose, 0, + string.Format(CultureInfo.InvariantCulture, "Output channel closed for {0}#{1}", + m_CurrentStream.GetType().Name, ValidationHelper.HashString(m_CurrentStream))); + } + } + + if (m_ShouldCloseOutput) + { + m_CurrentStream.m_OutputStream.SetClosedFlag(); + } + + // Complete the operation and raise completion event. + Complete(); + OnCompleted(this); + } + + private unsafe void CompletionPortCallback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped) + { + if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS || + errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + FinishOperationSuccess((int)numBytes, false); + } + else + { + FinishOperationFailure(new HttpListenerException((int)errorCode), false); + } + } + + public enum HttpListenerAsyncOperation + { + None, + Receive, + Send + } + } + + private static class Methods + { + public const string CloseNetworkConnectionAsync = "CloseNetworkConnectionAsync"; + public const string OnCancel = "OnCancel"; + public const string OnReadCompleted = "OnReadCompleted"; + public const string OnWriteCompleted = "OnWriteCompleted"; + public const string ReadAsyncFast = "ReadAsyncFast"; + public const string ReadAsyncCore = "ReadAsyncCore"; + public const string WriteAsyncFast = "WriteAsyncFast"; + public const string WriteAsyncCore = "WriteAsyncCore"; + public const string MultipleWriteAsyncCore = "MultipleWriteAsyncCore"; + } + } +} +*/ \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeLoadLibrary.cs b/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeLoadLibrary.cs new file mode 100644 index 0000000000..43276eefe5 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeLoadLibrary.cs @@ -0,0 +1,40 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.WebSockets +{ + internal sealed class SafeLoadLibrary : SafeHandleZeroOrMinusOneIsInvalid + { + private const string KERNEL32 = "kernel32.dll"; + + public static readonly SafeLoadLibrary Zero = new SafeLoadLibrary(false); + + private SafeLoadLibrary() : base(true) + { + } + + private SafeLoadLibrary(bool ownsHandle) : base(ownsHandle) + { + } + + public static unsafe SafeLoadLibrary LoadLibraryEx(string library) + { + SafeLoadLibrary result = UnsafeNativeMethods.SafeNetHandles.LoadLibraryExW(library, null, 0); + if (result.IsInvalid) + { + result.SetHandleAsInvalid(); + } + return result; + } + + protected override bool ReleaseHandle() + { + return UnsafeNativeMethods.SafeNetHandles.FreeLibrary(handle); + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeNativeOverlapped.cs b/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeNativeOverlapped.cs new file mode 100644 index 0000000000..8682c89918 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeNativeOverlapped.cs @@ -0,0 +1,80 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Microsoft.AspNet.WebSockets +{ + internal class SafeNativeOverlapped : SafeHandle + { + internal static readonly SafeNativeOverlapped Zero = new SafeNativeOverlapped(); + + internal SafeNativeOverlapped() + : this(IntPtr.Zero) + { + } + + internal unsafe SafeNativeOverlapped(NativeOverlapped* handle) + : this((IntPtr)handle) + { + } + + internal SafeNativeOverlapped(IntPtr handle) + : base(IntPtr.Zero, true) + { + SetHandle(handle); + } + + public override bool IsInvalid + { + get { return handle == IntPtr.Zero; } + } + + public void ReinitializeNativeOverlapped() + { + IntPtr handleSnapshot = handle; + + if (handleSnapshot != IntPtr.Zero) + { + unsafe + { + ((NativeOverlapped*)handleSnapshot)->InternalHigh = IntPtr.Zero; + ((NativeOverlapped*)handleSnapshot)->InternalLow = IntPtr.Zero; + ((NativeOverlapped*)handleSnapshot)->EventHandle = IntPtr.Zero; + } + } + } + + protected override bool ReleaseHandle() + { + IntPtr oldHandle = Interlocked.Exchange(ref handle, IntPtr.Zero); + // Do not call free durring AppDomain shutdown, there may be an outstanding operation. + // Overlapped will take care calling free when the native callback completes. + if (oldHandle != IntPtr.Zero && !HasShutdownStarted) + { + unsafe + { + Overlapped.Free((NativeOverlapped*)oldHandle); + } + } + return true; + } + + internal static bool HasShutdownStarted + { + get + { + return Environment.HasShutdownStarted +#if NET45 + || AppDomain.CurrentDomain.IsFinalizingForUnload() +#endif + ; + } + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeWebSocketHandle.cs b/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeWebSocketHandle.cs new file mode 100644 index 0000000000..60c8341561 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeWebSocketHandle.cs @@ -0,0 +1,32 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using Microsoft.AspNet.WebSockets; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.WebSockets +{ + // This class is a wrapper for a WSPC (WebSocket protocol component) session. WebSocketCreateClientHandle and WebSocketCreateServerHandle return a PVOID and not a real handle + // but we use a SafeHandle because it provides us the guarantee that WebSocketDeleteHandle will always get called. + internal sealed class SafeWebSocketHandle : SafeHandleZeroOrMinusOneIsInvalid + { + internal SafeWebSocketHandle() + : base(true) + { + } + + protected override bool ReleaseHandle() + { + if (this.IsInvalid) + { + return true; + } + + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketDeleteHandle(this.handle); + return true; + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/NativeInterop/UnsafeNativeMethods.cs b/src/Microsoft.AspNet.WebSockets/NativeInterop/UnsafeNativeMethods.cs new file mode 100644 index 0000000000..86355cdee4 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/NativeInterop/UnsafeNativeMethods.cs @@ -0,0 +1,842 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.Contracts; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Microsoft.AspNet.WebSockets +{ + internal static class UnsafeNativeMethods + { + private const string KERNEL32 = "kernel32.dll"; + private const string WEBSOCKET = "websocket.dll"; + + internal static class SafeNetHandles + { + [DllImport(KERNEL32, ExactSpelling = true, CharSet=CharSet.Unicode, SetLastError = true)] + internal static extern unsafe SafeLoadLibrary LoadLibraryExW([In] string lpwLibFileName, [In] void* hFile, [In] uint dwFlags); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + internal static extern unsafe bool FreeLibrary([In] IntPtr hModule); + } + + internal static class WebSocketProtocolComponent + { + private static readonly string DllFileName; + private static readonly string DummyWebsocketKeyBase64 = Convert.ToBase64String(new byte[16]); + private static readonly SafeLoadLibrary WebSocketDllHandle; + private static readonly string PrivateSupportedVersion; + + private static readonly HttpHeader[] InitialClientRequestHeaders = new HttpHeader[] + { + new HttpHeader() + { + Name = HttpKnownHeaderNames.Connection, + NameLength = (uint)HttpKnownHeaderNames.Connection.Length, + Value = HttpKnownHeaderNames.Upgrade, + ValueLength = (uint)HttpKnownHeaderNames.Upgrade.Length + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.Upgrade, + NameLength = (uint)HttpKnownHeaderNames.Upgrade.Length, + Value = WebSocketHelpers.WebSocketUpgradeToken, + ValueLength = (uint)WebSocketHelpers.WebSocketUpgradeToken.Length + } + }; + + private static readonly HttpHeader[] ServerFakeRequestHeaders; + + internal static class Errors + { + internal const int E_INVALID_OPERATION = unchecked((int)0x80000050); + internal const int E_INVALID_PROTOCOL_OPERATION = unchecked((int)0x80000051); + internal const int E_INVALID_PROTOCOL_FORMAT = unchecked((int)0x80000052); + internal const int E_NUMERIC_OVERFLOW = unchecked((int)0x80000053); + internal const int E_FAIL = unchecked((int)0x80004005); + } + + internal enum Action + { + NoAction = 0, + SendToNetwork = 1, + IndicateSendComplete = 2, + ReceiveFromNetwork = 3, + IndicateReceiveComplete = 4, + } + + internal enum BufferType : uint + { + None = 0x00000000, + UTF8Message = 0x80000000, + UTF8Fragment = 0x80000001, + BinaryMessage = 0x80000002, + BinaryFragment = 0x80000003, + Close = 0x80000004, + PingPong = 0x80000005, + UnsolicitedPong = 0x80000006 + } + + internal enum PropertyType + { + ReceiveBufferSize = 0, + SendBufferSize = 1, + DisableMasking = 2, + AllocatedBuffer = 3, + DisableUtf8Verification = 4, + KeepAliveInterval = 5, + } + + internal enum ActionQueue + { + Send = 1, + Receive = 2, + } + + [StructLayout(LayoutKind.Sequential)] + internal struct Property + { + internal PropertyType Type; + internal IntPtr PropertyData; + internal uint PropertySize; + } + + [StructLayout(LayoutKind.Explicit)] + internal struct Buffer + { + [FieldOffset(0)] + internal DataBuffer Data; + [FieldOffset(0)] + internal CloseBuffer CloseStatus; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct DataBuffer + { + internal IntPtr BufferData; + internal uint BufferLength; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct CloseBuffer + { + internal IntPtr ReasonData; + internal uint ReasonLength; + internal ushort CloseStatus; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HttpHeader + { + [MarshalAs(UnmanagedType.LPStr)] + internal string Name; + internal uint NameLength; + [MarshalAs(UnmanagedType.LPStr)] + internal string Value; + internal uint ValueLength; + } + + static WebSocketProtocolComponent() + { + DllFileName = Path.Combine(Environment.SystemDirectory, WEBSOCKET); + WebSocketDllHandle = SafeLoadLibrary.LoadLibraryEx(DllFileName); + + if (!WebSocketDllHandle.IsInvalid) + { + PrivateSupportedVersion = GetSupportedVersion(); + + ServerFakeRequestHeaders = new HttpHeader[] + { + new HttpHeader() + { + Name = HttpKnownHeaderNames.Connection, + NameLength = (uint)HttpKnownHeaderNames.Connection.Length, + Value = HttpKnownHeaderNames.Upgrade, + ValueLength = (uint)HttpKnownHeaderNames.Upgrade.Length + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.Upgrade, + NameLength = (uint)HttpKnownHeaderNames.Upgrade.Length, + Value = WebSocketHelpers.WebSocketUpgradeToken, + ValueLength = (uint)WebSocketHelpers.WebSocketUpgradeToken.Length + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.Host, + NameLength = (uint)HttpKnownHeaderNames.Host.Length, + Value = string.Empty, + ValueLength = 0 + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.SecWebSocketVersion, + NameLength = (uint)HttpKnownHeaderNames.SecWebSocketVersion.Length, + Value = SupportedVersion, + ValueLength = (uint)SupportedVersion.Length + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.SecWebSocketKey, + NameLength = (uint)HttpKnownHeaderNames.SecWebSocketKey.Length, + Value = DummyWebsocketKeyBase64, + ValueLength = (uint)DummyWebsocketKeyBase64.Length + } + }; + } + } + + internal static string SupportedVersion + { + get + { + if (WebSocketDllHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + return PrivateSupportedVersion; + } + } + + internal static bool IsSupported + { + get + { + return !WebSocketDllHandle.IsInvalid; + } + } + + internal static string GetSupportedVersion() + { + if (WebSocketDllHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + SafeWebSocketHandle webSocketHandle = null; + try + { + int errorCode = WebSocketCreateClientHandle_Raw(null, 0, out webSocketHandle); + ThrowOnError(errorCode); + + if (webSocketHandle == null || + webSocketHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + IntPtr additionalHeadersPtr; + uint additionalHeaderCount; + + errorCode = WebSocketBeginClientHandshake_Raw(webSocketHandle, + IntPtr.Zero, + 0, + IntPtr.Zero, + 0, + InitialClientRequestHeaders, + (uint)InitialClientRequestHeaders.Length, + out additionalHeadersPtr, + out additionalHeaderCount); + ThrowOnError(errorCode); + + HttpHeader[] additionalHeaders = MarshalHttpHeaders(additionalHeadersPtr, (int)additionalHeaderCount); + + string version = null; + foreach (HttpHeader header in additionalHeaders) + { + if (string.Compare(header.Name, + HttpKnownHeaderNames.SecWebSocketVersion, + StringComparison.OrdinalIgnoreCase) == 0) + { + version = header.Value; + break; + } + } + Contract.Assert(version != null, "'version' MUST NOT be NULL."); + + return version; + } + finally + { + if (webSocketHandle != null) + { + webSocketHandle.Dispose(); + } + } + } + + internal static void WebSocketCreateClientHandle(Property[] properties, + out SafeWebSocketHandle webSocketHandle) + { + uint propertyCount = properties == null ? 0 : (uint)properties.Length; + + if (WebSocketDllHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + int errorCode = WebSocketCreateClientHandle_Raw(properties, propertyCount, out webSocketHandle); + ThrowOnError(errorCode); + + if (webSocketHandle == null || + webSocketHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + IntPtr additionalHeadersPtr; + uint additionalHeaderCount; + + // Currently the WSPC doesn't allow to initiate a data session + // without also being involved in the http handshake + // There is no information whatsoever, which is needed by the + // WSPC for parsing WebSocket frames from the HTTP handshake + // In the managed implementation the HTTP header handling + // will be done using the managed HTTP stack and we will + // just fake an HTTP handshake for the WSPC calling + // WebSocketBeginClientHandshake and WebSocketEndClientHandshake + // with statically defined dummy headers. + errorCode = WebSocketBeginClientHandshake_Raw(webSocketHandle, + IntPtr.Zero, + 0, + IntPtr.Zero, + 0, + InitialClientRequestHeaders, + (uint)InitialClientRequestHeaders.Length, + out additionalHeadersPtr, + out additionalHeaderCount); + + ThrowOnError(errorCode); + + HttpHeader[] additionalHeaders = MarshalHttpHeaders(additionalHeadersPtr, (int)additionalHeaderCount); + + string key = null; + foreach (HttpHeader header in additionalHeaders) + { + if (string.Compare(header.Name, + HttpKnownHeaderNames.SecWebSocketKey, + StringComparison.OrdinalIgnoreCase) == 0) + { + key = header.Value; + break; + } + } + Contract.Assert(key != null, "'key' MUST NOT be NULL."); + + string acceptValue = WebSocketHelpers.GetSecWebSocketAcceptString(key); + HttpHeader[] responseHeaders = new HttpHeader[] + { + new HttpHeader() + { + Name = HttpKnownHeaderNames.Connection, + NameLength = (uint)HttpKnownHeaderNames.Connection.Length, + Value = HttpKnownHeaderNames.Upgrade, + ValueLength = (uint)HttpKnownHeaderNames.Upgrade.Length + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.Upgrade, + NameLength = (uint)HttpKnownHeaderNames.Upgrade.Length, + Value = WebSocketHelpers.WebSocketUpgradeToken, + ValueLength = (uint)WebSocketHelpers.WebSocketUpgradeToken.Length + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.SecWebSocketAccept, + NameLength = (uint)HttpKnownHeaderNames.SecWebSocketAccept.Length, + Value = acceptValue, + ValueLength = (uint)acceptValue.Length + } + }; + + errorCode = WebSocketEndClientHandshake_Raw(webSocketHandle, + responseHeaders, + (uint)responseHeaders.Length, + IntPtr.Zero, + IntPtr.Zero, + IntPtr.Zero); + + ThrowOnError(errorCode); + + Contract.Assert(webSocketHandle != null, "'webSocketHandle' MUST NOT be NULL at this point."); + } + + internal static void WebSocketCreateServerHandle(Property[] properties, + int propertyCount, + out SafeWebSocketHandle webSocketHandle) + { + Contract.Assert(propertyCount >= 0, "'propertyCount' MUST NOT be negative."); + Contract.Assert((properties == null && propertyCount == 0) || + (properties != null && propertyCount == properties.Length), + "'propertyCount' MUST MATCH 'properties.Length'."); + + if (WebSocketDllHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + int errorCode = WebSocketCreateServerHandle_Raw(properties, (uint)propertyCount, out webSocketHandle); + ThrowOnError(errorCode); + + if (webSocketHandle == null || + webSocketHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + IntPtr responseHeadersPtr; + uint responseHeaderCount; + + // Currently the WSPC doesn't allow to initiate a data session + // without also being involved in the http handshake + // There is no information whatsoever, which is needed by the + // WSPC for parsing WebSocket frames from the HTTP handshake + // In the managed implementation the HTTP header handling + // will be done using the managed HTTP stack and we will + // just fake an HTTP handshake for the WSPC calling + // WebSocketBeginServerHandshake and WebSocketEndServerHandshake + // with statically defined dummy headers. + errorCode = WebSocketBeginServerHandshake_Raw(webSocketHandle, + IntPtr.Zero, + IntPtr.Zero, + 0, + ServerFakeRequestHeaders, + (uint)ServerFakeRequestHeaders.Length, + out responseHeadersPtr, + out responseHeaderCount); + + ThrowOnError(errorCode); + + HttpHeader[] responseHeaders = MarshalHttpHeaders(responseHeadersPtr, (int)responseHeaderCount); + errorCode = WebSocketEndServerHandshake_Raw(webSocketHandle); + + ThrowOnError(errorCode); + + Contract.Assert(webSocketHandle != null, "'webSocketHandle' MUST NOT be NULL at this point."); + } + + internal static void WebSocketAbortHandle(SafeHandle webSocketHandle) + { + Contract.Assert(webSocketHandle != null && !webSocketHandle.IsInvalid, + "'webSocketHandle' MUST NOT be NULL or INVALID."); + + WebSocketAbortHandle_Raw(webSocketHandle); + + DrainActionQueue(webSocketHandle, ActionQueue.Send); + DrainActionQueue(webSocketHandle, ActionQueue.Receive); + } + + internal static void WebSocketDeleteHandle(IntPtr webSocketPtr) + { + Contract.Assert(webSocketPtr != IntPtr.Zero, "'webSocketPtr' MUST NOT be IntPtr.Zero."); + WebSocketDeleteHandle_Raw(webSocketPtr); + } + + internal static void WebSocketSend(WebSocketBase webSocket, + BufferType bufferType, + Buffer buffer) + { + Contract.Assert(webSocket != null, + "'webSocket' MUST NOT be NULL or INVALID."); + Contract.Assert(webSocket.SessionHandle != null && !webSocket.SessionHandle.IsInvalid, + "'webSocket.SessionHandle' MUST NOT be NULL or INVALID."); + + ThrowIfSessionHandleClosed(webSocket); + + int errorCode; + try + { + errorCode = WebSocketSend_Raw(webSocket.SessionHandle, bufferType, ref buffer, IntPtr.Zero); + } + catch (ObjectDisposedException innerException) + { + throw ConvertObjectDisposedException(webSocket, innerException); + } + + ThrowOnError(errorCode); + } + + internal static void WebSocketSendWithoutBody(WebSocketBase webSocket, + BufferType bufferType) + { + Contract.Assert(webSocket != null, + "'webSocket' MUST NOT be NULL or INVALID."); + Contract.Assert(webSocket.SessionHandle != null && !webSocket.SessionHandle.IsInvalid, + "'webSocket.SessionHandle' MUST NOT be NULL or INVALID."); + + ThrowIfSessionHandleClosed(webSocket); + + int errorCode; + try + { + errorCode = WebSocketSendWithoutBody_Raw(webSocket.SessionHandle, bufferType, IntPtr.Zero, IntPtr.Zero); + } + catch (ObjectDisposedException innerException) + { + throw ConvertObjectDisposedException(webSocket, innerException); + } + + ThrowOnError(errorCode); + } + + internal static void WebSocketReceive(WebSocketBase webSocket) + { + Contract.Assert(webSocket != null, + "'webSocket' MUST NOT be NULL or INVALID."); + Contract.Assert(webSocket.SessionHandle != null && !webSocket.SessionHandle.IsInvalid, + "'webSocket.SessionHandle' MUST NOT be NULL or INVALID."); + + ThrowIfSessionHandleClosed(webSocket); + + int errorCode; + try + { + errorCode = WebSocketReceive_Raw(webSocket.SessionHandle, IntPtr.Zero, IntPtr.Zero); + } + catch (ObjectDisposedException innerException) + { + throw ConvertObjectDisposedException(webSocket, innerException); + } + + ThrowOnError(errorCode); + } + + internal static void WebSocketGetAction(WebSocketBase webSocket, + ActionQueue actionQueue, + Buffer[] dataBuffers, + ref uint dataBufferCount, + out Action action, + out BufferType bufferType, + out IntPtr actionContext) + { + Contract.Assert(webSocket != null, + "'webSocket' MUST NOT be NULL or INVALID."); + Contract.Assert(webSocket.SessionHandle != null && !webSocket.SessionHandle.IsInvalid, + "'webSocket.SessionHandle' MUST NOT be NULL or INVALID."); + Contract.Assert(dataBufferCount >= 0, "'dataBufferCount' MUST NOT be negative."); + Contract.Assert((dataBuffers == null && dataBufferCount == 0) || + (dataBuffers != null && dataBufferCount == dataBuffers.Length), + "'dataBufferCount' MUST MATCH 'dataBuffers.Length'."); + + action = Action.NoAction; + bufferType = BufferType.None; + actionContext = IntPtr.Zero; + + IntPtr dummy; + ThrowIfSessionHandleClosed(webSocket); + + int errorCode; + try + { + errorCode = WebSocketGetAction_Raw(webSocket.SessionHandle, + actionQueue, + dataBuffers, + ref dataBufferCount, + out action, + out bufferType, + out dummy, + out actionContext); + } + catch (ObjectDisposedException innerException) + { + throw ConvertObjectDisposedException(webSocket, innerException); + } + ThrowOnError(errorCode); + + webSocket.ValidateNativeBuffers(action, bufferType, dataBuffers, dataBufferCount); + + Contract.Assert(dataBufferCount >= 0); + Contract.Assert((dataBufferCount == 0 && dataBuffers == null) || + (dataBufferCount <= dataBuffers.Length)); + } + + internal static void WebSocketCompleteAction(WebSocketBase webSocket, + IntPtr actionContext, + int bytesTransferred) + { + Contract.Assert(webSocket != null, + "'webSocket' MUST NOT be NULL or INVALID."); + Contract.Assert(webSocket.SessionHandle != null && !webSocket.SessionHandle.IsInvalid, + "'webSocket.SessionHandle' MUST NOT be NULL or INVALID."); + Contract.Assert(actionContext != IntPtr.Zero, "'actionContext' MUST NOT be IntPtr.Zero."); + Contract.Assert(bytesTransferred >= 0, "'bytesTransferred' MUST NOT be negative."); + + if (webSocket.SessionHandle.IsClosed) + { + return; + } + + try + { + WebSocketCompleteAction_Raw(webSocket.SessionHandle, actionContext, (uint)bytesTransferred); + } + catch (ObjectDisposedException) + { + } + } + + internal static TimeSpan WebSocketGetDefaultKeepAliveInterval() + { + uint result = 0; + uint size = sizeof(uint); + int errorCode = WebSocketGetGlobalProperty_Raw(PropertyType.KeepAliveInterval, ref result, ref size); + if (!Succeeded(errorCode)) + { + Contract.Assert(errorCode == 0, "errorCode: " + errorCode); + return Timeout.InfiniteTimeSpan; + } + return TimeSpan.FromMilliseconds(result); + } + + private static void DrainActionQueue(SafeHandle webSocketHandle, ActionQueue actionQueue) + { + Contract.Assert(webSocketHandle != null && !webSocketHandle.IsInvalid, + "'webSocketHandle' MUST NOT be NULL or INVALID."); + + IntPtr actionContext; + IntPtr dummy; + Action action; + BufferType bufferType; + + while (true) + { + Buffer[] dataBuffers = new Buffer[1]; + uint dataBufferCount = 1; + int errorCode = WebSocketGetAction_Raw(webSocketHandle, + actionQueue, + dataBuffers, + ref dataBufferCount, + out action, + out bufferType, + out dummy, + out actionContext); + + if (!Succeeded(errorCode)) + { + Contract.Assert(errorCode == 0, "'errorCode' MUST be 0."); + return; + } + + if (action == Action.NoAction) + { + return; + } + + WebSocketCompleteAction_Raw(webSocketHandle, actionContext, 0); + } + } + + private static void MarshalAndVerifyHttpHeader(IntPtr httpHeaderPtr, + ref HttpHeader httpHeader) + { + Contract.Assert(httpHeaderPtr != IntPtr.Zero, "'currentHttpHeaderPtr' MUST NOT be IntPtr.Zero."); + + IntPtr httpHeaderNamePtr = Marshal.ReadIntPtr(httpHeaderPtr); + IntPtr lengthPtr = IntPtr.Add(httpHeaderPtr, IntPtr.Size); + int length = Marshal.ReadInt32(lengthPtr); + Contract.Assert(length >= 0, "'length' MUST NOT be negative."); + + if (httpHeaderNamePtr != IntPtr.Zero) + { + httpHeader.Name = Marshal.PtrToStringAnsi(httpHeaderNamePtr, length); + } + + if ((httpHeader.Name == null && length != 0) || + (httpHeader.Name != null && length != httpHeader.Name.Length)) + { + Contract.Assert(false, "The length of 'httpHeader.Name' MUST MATCH 'length'."); + throw new AccessViolationException(); + } + + // structure of HttpHeader: + // Name = string* + // NameLength = uint* + // Value = string* + // ValueLength = uint* + // NOTE - All fields in the object are pointers to the actual value, hence the use of + // n * IntPtr.Size to get to the correct place in the object. + int valueOffset = 2 * IntPtr.Size; + int lengthOffset = 3 * IntPtr.Size; + + IntPtr httpHeaderValuePtr = + Marshal.ReadIntPtr(IntPtr.Add(httpHeaderPtr, valueOffset)); + lengthPtr = IntPtr.Add(httpHeaderPtr, lengthOffset); + length = Marshal.ReadInt32(lengthPtr); + httpHeader.Value = Marshal.PtrToStringAnsi(httpHeaderValuePtr, (int)length); + + if ((httpHeader.Value == null && length != 0) || + (httpHeader.Value != null && length != httpHeader.Value.Length)) + { + Contract.Assert(false, "The length of 'httpHeader.Value' MUST MATCH 'length'."); + throw new AccessViolationException(); + } + } + + private static HttpHeader[] MarshalHttpHeaders(IntPtr nativeHeadersPtr, + int nativeHeaderCount) + { + Contract.Assert(nativeHeaderCount >= 0, "'nativeHeaderCount' MUST NOT be negative."); + Contract.Assert(nativeHeadersPtr != IntPtr.Zero || nativeHeaderCount == 0, + "'nativeHeaderCount' MUST be 0."); + + HttpHeader[] httpHeaders = new HttpHeader[nativeHeaderCount]; + + // structure of HttpHeader: + // Name = string* + // NameLength = uint* + // Value = string* + // ValueLength = uint* + // NOTE - All fields in the object are pointers to the actual value, hence the use of + // 4 * IntPtr.Size to get to the next header. + int httpHeaderStructSize = 4 * IntPtr.Size; + + for (int i = 0; i < nativeHeaderCount; i++) + { + int offset = httpHeaderStructSize * i; + IntPtr currentHttpHeaderPtr = IntPtr.Add(nativeHeadersPtr, offset); + MarshalAndVerifyHttpHeader(currentHttpHeaderPtr, ref httpHeaders[i]); + } + + Contract.Assert(httpHeaders != null); + Contract.Assert(httpHeaders.Length == nativeHeaderCount); + + return httpHeaders; + } + + public static bool Succeeded(int hr) + { + return (hr >= 0); + } + + private static void ThrowOnError(int errorCode) + { + if (Succeeded(errorCode)) + { + return; + } + + throw new WebSocketException(errorCode); + } + + private static void ThrowIfSessionHandleClosed(WebSocketBase webSocket) + { + if (webSocket.SessionHandle.IsClosed) + { + throw new WebSocketException(WebSocketError.InvalidState, + SR.GetString(SR.net_WebSockets_InvalidState_ClosedOrAborted, webSocket.GetType().FullName, webSocket.State)); + } + } + + private static WebSocketException ConvertObjectDisposedException(WebSocketBase webSocket, ObjectDisposedException innerException) + { + return new WebSocketException(WebSocketError.InvalidState, + SR.GetString(SR.net_WebSockets_InvalidState_ClosedOrAborted, webSocket.GetType().FullName, webSocket.State), + innerException); + } + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketCreateClientHandle", ExactSpelling = true)] + private static extern int WebSocketCreateClientHandle_Raw( + [In]Property[] properties, + [In] uint propertyCount, + [Out] out SafeWebSocketHandle webSocketHandle); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketBeginClientHandshake", ExactSpelling = true)] + private static extern int WebSocketBeginClientHandshake_Raw( + [In] SafeHandle webSocketHandle, + [In] IntPtr subProtocols, + [In] uint subProtocolCount, + [In] IntPtr extensions, + [In] uint extensionCount, + [In] HttpHeader[] initialHeaders, + [In] uint initialHeaderCount, + [Out] out IntPtr additionalHeadersPtr, + [Out] out uint additionalHeaderCount); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketEndClientHandshake", ExactSpelling = true)] + private static extern int WebSocketEndClientHandshake_Raw([In] SafeHandle webSocketHandle, + [In] HttpHeader[] responseHeaders, + [In] uint responseHeaderCount, + [In, Out] IntPtr selectedExtensions, + [In] IntPtr selectedExtensionCount, + [In] IntPtr selectedSubProtocol); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketBeginServerHandshake", ExactSpelling = true)] + private static extern int WebSocketBeginServerHandshake_Raw( + [In] SafeHandle webSocketHandle, + [In] IntPtr subProtocol, + [In] IntPtr extensions, + [In] uint extensionCount, + [In] HttpHeader[] requestHeaders, + [In] uint requestHeaderCount, + [Out] out IntPtr responseHeadersPtr, + [Out] out uint responseHeaderCount); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketEndServerHandshake", ExactSpelling = true)] + private static extern int WebSocketEndServerHandshake_Raw([In] SafeHandle webSocketHandle); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketCreateServerHandle", ExactSpelling = true)] + private static extern int WebSocketCreateServerHandle_Raw( + [In]Property[] properties, + [In] uint propertyCount, + [Out] out SafeWebSocketHandle webSocketHandle); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketAbortHandle", ExactSpelling = true)] + private static extern void WebSocketAbortHandle_Raw( + [In] SafeHandle webSocketHandle); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketDeleteHandle", ExactSpelling = true)] + private static extern void WebSocketDeleteHandle_Raw( + [In] IntPtr webSocketHandle); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketSend", ExactSpelling = true)] + private static extern int WebSocketSend_Raw( + [In] SafeHandle webSocketHandle, + [In] BufferType bufferType, + [In] ref Buffer buffer, + [In] IntPtr applicationContext); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketSend", ExactSpelling = true)] + private static extern int WebSocketSendWithoutBody_Raw( + [In] SafeHandle webSocketHandle, + [In] BufferType bufferType, + [In] IntPtr buffer, + [In] IntPtr applicationContext); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketReceive", ExactSpelling = true)] + private static extern int WebSocketReceive_Raw( + [In] SafeHandle webSocketHandle, + [In] IntPtr buffers, + [In] IntPtr applicationContext); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketGetAction", ExactSpelling = true)] + private static extern int WebSocketGetAction_Raw( + [In] SafeHandle webSocketHandle, + [In] ActionQueue actionQueue, + [In, Out] Buffer[] dataBuffers, + [In, Out] ref uint dataBufferCount, + [Out] out Action action, + [Out] out BufferType bufferType, + [Out] out IntPtr applicationContext, + [Out] out IntPtr actionContext); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketCompleteAction", ExactSpelling = true)] + private static extern void WebSocketCompleteAction_Raw( + [In] SafeHandle webSocketHandle, + [In] IntPtr actionContext, + [In] uint bytesTransferred); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketGetGlobalProperty", ExactSpelling = true)] + private static extern int WebSocketGetGlobalProperty_Raw( + [In] PropertyType property, + [In, Out] ref uint value, + [In, Out] ref uint size); + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/OwinWebSocketWrapper.cs b/src/Microsoft.AspNet.WebSockets/OwinWebSocketWrapper.cs new file mode 100644 index 0000000000..a11c34488d --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/OwinWebSocketWrapper.cs @@ -0,0 +1,137 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebSockets +{ + using WebSocketCloseAsync = + Func; + using WebSocketReceiveAsync = + Func /* data */, + CancellationToken /* cancel */, + Task>>; + using WebSocketReceiveTuple = + Tuple; + using WebSocketSendAsync = + Func /* data */, + int /* messageType */, + bool /* endOfMessage */, + CancellationToken /* cancel */, + Task>; + + internal class OwinWebSocketWrapper + { + private readonly WebSocket _webSocket; + private readonly IDictionary _environment; + private readonly CancellationToken _cancellationToken; + + internal OwinWebSocketWrapper(WebSocket webSocket, CancellationToken ct) + { + _webSocket = webSocket; + _cancellationToken = ct; + + _environment = new Dictionary(); + _environment[Constants.WebSocketSendAsyncKey] = new WebSocketSendAsync(SendAsync); + _environment[Constants.WebSocketReceiveAyncKey] = new WebSocketReceiveAsync(ReceiveAsync); + _environment[Constants.WebSocketCloseAsyncKey] = new WebSocketCloseAsync(CloseAsync); + _environment[Constants.WebSocketCallCancelledKey] = ct; + _environment[Constants.WebSocketVersionKey] = Constants.WebSocketVersion; + } + + internal IDictionary Environment + { + get { return _environment; } + } + + internal Task SendAsync(ArraySegment buffer, int messageType, bool endOfMessage, CancellationToken cancel) + { + // Remap close messages to CloseAsync. System.Net.WebSockets.WebSocket.SendAsync does not allow close messages. + if (messageType == 0x8) + { + return RedirectSendToCloseAsync(buffer, cancel); + } + else if (messageType == 0x9 || messageType == 0xA) + { + // Ping & Pong, not allowed by the underlying APIs, silently discard. + return Task.FromResult(0); + } + + return _webSocket.SendAsync(buffer, (WebSocketMessageType)messageType, endOfMessage, cancel); + } + + internal async Task ReceiveAsync(ArraySegment buffer, CancellationToken cancel) + { + WebSocketReceiveResult nativeResult = await _webSocket.ReceiveAsync(buffer, cancel); + + if (nativeResult.MessageType == WebSocketMessageType.Close) + { + _environment[Constants.WebSocketCloseStatusKey] = (int)(nativeResult.CloseStatus ?? WebSocketCloseStatus.NormalClosure); + _environment[Constants.WebSocketCloseDescriptionKey] = nativeResult.CloseStatusDescription ?? string.Empty; + } + + return new WebSocketReceiveTuple( + (int)nativeResult.MessageType, + nativeResult.EndOfMessage, + nativeResult.Count); + } + + internal Task CloseAsync(int status, string description, CancellationToken cancel) + { + return _webSocket.CloseOutputAsync((WebSocketCloseStatus)status, description, cancel); + } + + private Task RedirectSendToCloseAsync(ArraySegment buffer, CancellationToken cancel) + { + if (buffer.Array == null || buffer.Count == 0) + { + return CloseAsync(1000, string.Empty, cancel); + } + else if (buffer.Count >= 2) + { + // Unpack the close message. + int statusCode = + (buffer.Array[buffer.Offset] << 8) + | buffer.Array[buffer.Offset + 1]; + string description = Encoding.UTF8.GetString(buffer.Array, buffer.Offset + 2, buffer.Count - 2); + + return CloseAsync(statusCode, description, cancel); + } + else + { + throw new ArgumentOutOfRangeException("buffer"); + } + } + + internal async Task CleanupAsync() + { + switch (_webSocket.State) + { + case WebSocketState.Closed: // Closed gracefully, no action needed. + case WebSocketState.Aborted: // Closed abortively, no action needed. + break; + case WebSocketState.CloseReceived: + // Echo what the client said, if anything. + await _webSocket.CloseAsync(_webSocket.CloseStatus ?? WebSocketCloseStatus.NormalClosure, + _webSocket.CloseStatusDescription ?? string.Empty, _cancellationToken); + break; + case WebSocketState.Open: + case WebSocketState.CloseSent: // No close received, abort so we don't have to drain the pipe. + _webSocket.Abort(); + break; + default: + throw new ArgumentOutOfRangeException("state", _webSocket.State, string.Empty); + } + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/Properties/AssemblyInfo.cs b/src/Microsoft.AspNet.WebSockets/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..1abade9a33 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/Properties/AssemblyInfo.cs @@ -0,0 +1,36 @@ +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("Microsoft.AspNet.WebSockets")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.AspNet.WebSockets")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("1f471909-581f-4060-a147-430891e9c3c1")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] diff --git a/src/Microsoft.AspNet.WebSockets/ServerWebSocket.cs b/src/Microsoft.AspNet.WebSockets/ServerWebSocket.cs new file mode 100644 index 0000000000..f23d13af54 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/ServerWebSocket.cs @@ -0,0 +1,62 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.IO; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.WebSockets +{ + internal sealed class ServerWebSocket : WebSocketBase + { + private readonly SafeHandle _sessionHandle; + private readonly UnsafeNativeMethods.WebSocketProtocolComponent.Property[] _properties; + + public ServerWebSocket(Stream innerStream, + string subProtocol, + int receiveBufferSize, + TimeSpan keepAliveInterval, + ArraySegment internalBuffer) + : base(innerStream, subProtocol, keepAliveInterval, + WebSocketBuffer.CreateServerBuffer(internalBuffer, receiveBufferSize)) + { + _properties = this.InternalBuffer.CreateProperties(false); + _sessionHandle = this.CreateWebSocketHandle(); + + if (_sessionHandle == null || _sessionHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + StartKeepAliveTimer(); + } + + internal override SafeHandle SessionHandle + { + get + { + Contract.Assert(_sessionHandle != null, "'m_SessionHandle MUST NOT be NULL."); + return _sessionHandle; + } + } + + [SuppressMessage("Microsoft.Security", "CA2122:DoNotIndirectlyExposeMethodsWithLinkDemands", + Justification = "No arbitrary data controlled by PT code is leaking into native code.")] + private SafeHandle CreateWebSocketHandle() + { + Contract.Assert(_properties != null, "'m_Properties' MUST NOT be NULL."); + SafeWebSocketHandle sessionHandle; + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketCreateServerHandle(_properties, + _properties.Length, + out sessionHandle); + Contract.Assert(sessionHandle != null, "'sessionHandle MUST NOT be NULL."); + + return sessionHandle; + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/WebSocket.cs b/src/Microsoft.AspNet.WebSockets/WebSocket.cs new file mode 100644 index 0000000000..f75e50a3c6 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocket.cs @@ -0,0 +1,124 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebSockets +{ + public abstract class WebSocket : IDisposable + { + private static TimeSpan? defaultKeepAliveInterval; + + public abstract WebSocketCloseStatus? CloseStatus { get; } + public abstract string CloseStatusDescription { get; } + public abstract string SubProtocol { get; } + public abstract WebSocketState State { get; } + + public static TimeSpan DefaultKeepAliveInterval + { + [SuppressMessage("Microsoft.Security", "CA2122:DoNotIndirectlyExposeMethodsWithLinkDemands", + Justification = "This is a harmless read-only operation")] + get + { + if (defaultKeepAliveInterval == null) + { + if (UnsafeNativeMethods.WebSocketProtocolComponent.IsSupported) + { + defaultKeepAliveInterval = UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketGetDefaultKeepAliveInterval(); + } + else + { + defaultKeepAliveInterval = Timeout.InfiniteTimeSpan; + } + } + return defaultKeepAliveInterval.Value; + } + } + + public static ArraySegment CreateClientBuffer(int receiveBufferSize, int sendBufferSize) + { + WebSocketHelpers.ValidateBufferSizes(receiveBufferSize, sendBufferSize); + + return WebSocketBuffer.CreateInternalBufferArraySegment(receiveBufferSize, sendBufferSize, false); + } + + public static ArraySegment CreateServerBuffer(int receiveBufferSize) + { + WebSocketHelpers.ValidateBufferSizes(receiveBufferSize, WebSocketBuffer.MinSendBufferSize); + + return WebSocketBuffer.CreateInternalBufferArraySegment(receiveBufferSize, WebSocketBuffer.MinSendBufferSize, true); + } + + internal static WebSocket CreateServerWebSocket(Stream innerStream, + string subProtocol, + int receiveBufferSize, + TimeSpan keepAliveInterval, + ArraySegment internalBuffer) + { + if (!UnsafeNativeMethods.WebSocketProtocolComponent.IsSupported) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + WebSocketHelpers.ValidateInnerStream(innerStream); + WebSocketHelpers.ValidateOptions(subProtocol, receiveBufferSize, WebSocketBuffer.MinSendBufferSize, keepAliveInterval); + WebSocketHelpers.ValidateArraySegment(internalBuffer, "internalBuffer"); + WebSocketBuffer.Validate(internalBuffer.Count, receiveBufferSize, WebSocketBuffer.MinSendBufferSize, true); + + return new ServerWebSocket(innerStream, + subProtocol, + receiveBufferSize, + keepAliveInterval, + internalBuffer); + } + + public abstract void Abort(); + public abstract Task CloseAsync(WebSocketCloseStatus closeStatus, + string statusDescription, + CancellationToken cancellationToken); + public abstract Task CloseOutputAsync(WebSocketCloseStatus closeStatus, + string statusDescription, + CancellationToken cancellationToken); + [SuppressMessage("Microsoft.Design", "CA1063:ImplementIDisposableCorrectly", Justification = "This rule is outdated")] + public abstract void Dispose(); + public abstract Task ReceiveAsync(ArraySegment buffer, + CancellationToken cancellationToken); + public abstract Task SendAsync(ArraySegment buffer, + WebSocketMessageType messageType, + bool endOfMessage, + CancellationToken cancellationToken); + + protected static void ThrowOnInvalidState(WebSocketState state, params WebSocketState[] validStates) + { + string validStatesText = string.Empty; + + if (validStates != null && validStates.Length > 0) + { + foreach (WebSocketState currentState in validStates) + { + if (state == currentState) + { + return; + } + } + + validStatesText = string.Join(", ", validStates); + } + + throw new WebSocketException(SR.GetString(SR.net_WebSockets_InvalidState, state, validStatesText)); + } + + protected static bool IsStateTerminal(WebSocketState state) + { + return state == WebSocketState.Closed || + state == WebSocketState.Aborted; + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketBase.cs b/src/Microsoft.AspNet.WebSockets/WebSocketBase.cs new file mode 100644 index 0000000000..3e3b798396 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketBase.cs @@ -0,0 +1,2481 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebSockets +{ + internal abstract class WebSocketBase : WebSocket, IDisposable + { + // private static volatile bool s_LoggingEnabled; + + private readonly OutstandingOperationHelper _closeOutstandingOperationHelper; + private readonly OutstandingOperationHelper _closeOutputOutstandingOperationHelper; + private readonly OutstandingOperationHelper _receiveOutstandingOperationHelper; + private readonly OutstandingOperationHelper _sendOutstandingOperationHelper; + private readonly Stream _innerStream; + private readonly IWebSocketStream _innerStreamAsWebSocketStream; + private readonly string _subProtocol; + + // We are not calling Dispose method on this object in Cleanup method to avoid a race condition while one thread is calling disposing on + // this object and another one is still using WaitAsync. According to Dev11 358715, this should be fine as long as we are not accessing the + // AvailableWaitHandle on this SemaphoreSlim object. + private readonly SemaphoreSlim _sendFrameThrottle; + // locking m_ThisLock protects access to + // - State + // - m_CloseStack + // - m_CloseAsyncStartedReceive + // - m_CloseReceivedTaskCompletionSource + // - m_CloseNetworkConnectionTask + private readonly object _thisLock; + private readonly WebSocketBuffer _internalBuffer; + private readonly KeepAliveTracker _keepAliveTracker; + +#if DEBUG + private volatile string _closeStack; +#endif + + private volatile bool _cleanedUp; + private volatile TaskCompletionSource _closeReceivedTaskCompletionSource; + private volatile Task _closeOutputTask; + private volatile bool _isDisposed; + private volatile Task _closeNetworkConnectionTask; + private volatile bool _closeAsyncStartedReceive; + private volatile WebSocketState _state; + private volatile Task _keepAliveTask; + private volatile WebSocketOperation.ReceiveOperation _receiveOperation; + private volatile WebSocketOperation.SendOperation _sendOperation; + private volatile WebSocketOperation.SendOperation _keepAliveOperation; + private volatile WebSocketOperation.CloseOutputOperation _closeOutputOperation; + private WebSocketCloseStatus? _closeStatus; + private string _closeStatusDescription; + private int _receiveState; + private Exception _pendingException; + + protected WebSocketBase(Stream innerStream, + string subProtocol, + TimeSpan keepAliveInterval, + WebSocketBuffer internalBuffer) + { + Contract.Assert(internalBuffer != null, "'internalBuffer' MUST NOT be NULL."); + WebSocketHelpers.ValidateInnerStream(innerStream); + WebSocketHelpers.ValidateOptions(subProtocol, internalBuffer.ReceiveBufferSize, + internalBuffer.SendBufferSize, keepAliveInterval); + + // s_LoggingEnabled = Logging.On && Logging.WebSockets.Switch.ShouldTrace(TraceEventType.Critical); + string parameters = string.Empty; + /* + if (s_LoggingEnabled) + { + parameters = string.Format(CultureInfo.InvariantCulture, + "ReceiveBufferSize: {0}, SendBufferSize: {1}, Protocols: {2}, KeepAliveInterval: {3}, innerStream: {4}, internalBuffer: {5}", + internalBuffer.ReceiveBufferSize, + internalBuffer.SendBufferSize, + subProtocol, + keepAliveInterval, + Logging.GetObjectLogHash(innerStream), + Logging.GetObjectLogHash(internalBuffer)); + + Logging.Enter(Logging.WebSockets, this, Methods.Initialize, parameters); + } + */ + _thisLock = new object(); + + try + { + _innerStream = innerStream; + _internalBuffer = internalBuffer; + /*if (s_LoggingEnabled) + { + Logging.Associate(Logging.WebSockets, this, m_InnerStream); + Logging.Associate(Logging.WebSockets, this, m_InternalBuffer); + }*/ + + _closeOutstandingOperationHelper = new OutstandingOperationHelper(); + _closeOutputOutstandingOperationHelper = new OutstandingOperationHelper(); + _receiveOutstandingOperationHelper = new OutstandingOperationHelper(); + _sendOutstandingOperationHelper = new OutstandingOperationHelper(); + _state = WebSocketState.Open; + _subProtocol = subProtocol; + _sendFrameThrottle = new SemaphoreSlim(1, 1); + _closeStatus = null; + _closeStatusDescription = null; + _innerStreamAsWebSocketStream = innerStream as IWebSocketStream; + if (_innerStreamAsWebSocketStream != null) + { + _innerStreamAsWebSocketStream.SwitchToOpaqueMode(this); + } + _keepAliveTracker = KeepAliveTracker.Create(keepAliveInterval); + } + finally + { + /*if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.Initialize, parameters); + }*/ + } + } + /* + internal static bool LoggingEnabled + { + get + { + return s_LoggingEnabled; + } + } + */ + public override WebSocketState State + { + get + { + Contract.Assert(_state != WebSocketState.None, "'m_State' MUST NOT be 'WebSocketState.None'."); + return _state; + } + } + + public override string SubProtocol + { + get + { + return _subProtocol; + } + } + + public override WebSocketCloseStatus? CloseStatus + { + get + { + return _closeStatus; + } + } + + public override string CloseStatusDescription + { + get + { + return _closeStatusDescription; + } + } + + internal WebSocketBuffer InternalBuffer + { + get + { + Contract.Assert(_internalBuffer != null, "'m_InternalBuffer' MUST NOT be NULL."); + return _internalBuffer; + } + } + + protected void StartKeepAliveTimer() + { + _keepAliveTracker.StartTimer(this); + } + + // locking SessionHandle protects access to + // - WSPC (WebSocketProtocolComponent) + // - m_KeepAliveTask + // - m_CloseOutputTask + // - m_LastSendActivity + internal abstract SafeHandle SessionHandle { get; } + + // MultiThreading: ThreadSafe; At most one outstanding call to ReceiveAsync is allowed + public override Task ReceiveAsync(ArraySegment buffer, + CancellationToken cancellationToken) + { + WebSocketHelpers.ValidateArraySegment(buffer, "buffer"); + return ReceiveAsyncCore(buffer, cancellationToken); + } + + private async Task ReceiveAsyncCore(ArraySegment buffer, + CancellationToken cancellationToken) + { + Contract.Assert(buffer.Array != null); + /* + if (s_LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.ReceiveAsync, string.Empty); + } + */ + WebSocketReceiveResult receiveResult; + try + { + ThrowIfPendingException(); + ThrowIfDisposed(); + ThrowOnInvalidState(State, WebSocketState.Open, WebSocketState.CloseSent); + + bool ownsCancellationTokenSource = false; + CancellationToken linkedCancellationToken = CancellationToken.None; + try + { + ownsCancellationTokenSource = _receiveOutstandingOperationHelper.TryStartOperation(cancellationToken, + out linkedCancellationToken); + if (!ownsCancellationTokenSource) + { + lock (_thisLock) + { + if (_closeAsyncStartedReceive) + { + throw new InvalidOperationException( + SR.GetString(SR.net_WebSockets_ReceiveAsyncDisallowedAfterCloseAsync, Methods.CloseAsync, Methods.CloseOutputAsync)); + } + + throw new InvalidOperationException( + SR.GetString(SR.net_Websockets_AlreadyOneOutstandingOperation, Methods.ReceiveAsync)); + } + } + + EnsureReceiveOperation(); + receiveResult = await _receiveOperation.Process(buffer, linkedCancellationToken).SuppressContextFlow(); + /* + if (s_LoggingEnabled && receiveResult.Count > 0) + { + Logging.Dump(Logging.WebSockets, + this, + Methods.ReceiveAsync, + buffer.Array, + buffer.Offset, + receiveResult.Count); + }*/ + } + catch (Exception exception) + { + bool aborted = linkedCancellationToken.IsCancellationRequested; + Abort(); + ThrowIfConvertibleException(Methods.ReceiveAsync, exception, cancellationToken, aborted); + throw; + } + finally + { + _receiveOutstandingOperationHelper.CompleteOperation(ownsCancellationTokenSource); + } + } + finally + {/* + if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.ReceiveAsync, string.Empty); + }*/ + } + + return receiveResult; + } + + // MultiThreading: ThreadSafe; At most one outstanding call to SendAsync is allowed + public override Task SendAsync(ArraySegment buffer, + WebSocketMessageType messageType, + bool endOfMessage, + CancellationToken cancellationToken) + { + if (messageType != WebSocketMessageType.Binary && + messageType != WebSocketMessageType.Text) + { + throw new ArgumentException(SR.GetString(SR.net_WebSockets_Argument_InvalidMessageType, + messageType, + Methods.SendAsync, + WebSocketMessageType.Binary, + WebSocketMessageType.Text, + Methods.CloseOutputAsync), + "messageType"); + } + + WebSocketHelpers.ValidateArraySegment(buffer, "buffer"); + + return SendAsyncCore(buffer, messageType, endOfMessage, cancellationToken); + } + + private async Task SendAsyncCore(ArraySegment buffer, + WebSocketMessageType messageType, + bool endOfMessage, + CancellationToken cancellationToken) + { + Contract.Assert(messageType == WebSocketMessageType.Binary || messageType == WebSocketMessageType.Text, + "'messageType' MUST be either 'WebSocketMessageType.Binary' or 'WebSocketMessageType.Text'."); + Contract.Assert(buffer.Array != null); + + string inputParameter = string.Empty; + /*if (s_LoggingEnabled) + { + inputParameter = string.Format(CultureInfo.InvariantCulture, + "messageType: {0}, endOfMessage: {1}", + messageType, + endOfMessage); + Logging.Enter(Logging.WebSockets, this, Methods.SendAsync, inputParameter); + }*/ + + try + { + ThrowIfPendingException(); + ThrowIfDisposed(); + ThrowOnInvalidState(State, WebSocketState.Open, WebSocketState.CloseReceived); + bool ownsCancellationTokenSource = false; + CancellationToken linkedCancellationToken = CancellationToken.None; + + try + { + while (!(ownsCancellationTokenSource = _sendOutstandingOperationHelper.TryStartOperation(cancellationToken, out linkedCancellationToken))) + { + Task keepAliveTask; + + lock (SessionHandle) + { + keepAliveTask = _keepAliveTask; + + if (keepAliveTask == null) + { + // Check whether there is still another outstanding send operation + // Potentially the keepAlive operation has completed before this thread + // was able to enter the SessionHandle-lock. + _sendOutstandingOperationHelper.CompleteOperation(ownsCancellationTokenSource); + if (ownsCancellationTokenSource = _sendOutstandingOperationHelper.TryStartOperation(cancellationToken, out linkedCancellationToken)) + { + break; + } + else + { + throw new InvalidOperationException( + SR.GetString(SR.net_Websockets_AlreadyOneOutstandingOperation, Methods.SendAsync)); + } + } + } + + await keepAliveTask.SuppressContextFlow(); + ThrowIfPendingException(); + + _sendOutstandingOperationHelper.CompleteOperation(ownsCancellationTokenSource); + } + /* + if (s_LoggingEnabled && buffer.Count > 0) + { + Logging.Dump(Logging.WebSockets, + this, + Methods.SendAsync, + buffer.Array, + buffer.Offset, + buffer.Count); + }*/ + + int position = buffer.Offset; + + EnsureSendOperation(); + _sendOperation.BufferType = GetBufferType(messageType, endOfMessage); + await _sendOperation.Process(buffer, linkedCancellationToken).SuppressContextFlow(); + } + catch (Exception exception) + { + bool aborted = linkedCancellationToken.IsCancellationRequested; + Abort(); + ThrowIfConvertibleException(Methods.SendAsync, exception, cancellationToken, aborted); + throw; + } + finally + { + _sendOutstandingOperationHelper.CompleteOperation(ownsCancellationTokenSource); + } + } + finally + { + /*if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.SendAsync, inputParameter); + }*/ + } + } + + private async Task SendFrameAsync(IList> sendBuffers, CancellationToken cancellationToken) + { + bool sendFrameLockTaken = false; + try + { + await _sendFrameThrottle.WaitAsync(cancellationToken).SuppressContextFlow(); + sendFrameLockTaken = true; + + if (sendBuffers.Count > 1 && + _innerStreamAsWebSocketStream != null && + _innerStreamAsWebSocketStream.SupportsMultipleWrite) + { + await _innerStreamAsWebSocketStream.MultipleWriteAsync(sendBuffers, + cancellationToken).SuppressContextFlow(); + } + else + { + foreach (ArraySegment buffer in sendBuffers) + { + await _innerStream.WriteAsync(buffer.Array, + buffer.Offset, + buffer.Count, + cancellationToken).SuppressContextFlow(); + } + } + } + catch (ObjectDisposedException objectDisposedException) + { + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely, objectDisposedException); + } + catch (NotSupportedException notSupportedException) + { + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely, notSupportedException); + } + finally + { + if (sendFrameLockTaken) + { + _sendFrameThrottle.Release(); + } + } + } + + // MultiThreading: ThreadSafe; No-op if already in a terminal state + public override void Abort() + { + /*if (s_LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.Abort, string.Empty); + }*/ + + bool thisLockTaken = false; + bool sessionHandleLockTaken = false; + try + { + if (IsStateTerminal(State)) + { + return; + } + + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + if (IsStateTerminal(State)) + { + return; + } + + _state = WebSocketState.Aborted; + +#if DEBUG + string stackTrace = new StackTrace().ToString(); + if (_closeStack == null) + { + _closeStack = stackTrace; + } + /* + if (s_LoggingEnabled) + { + string message = string.Format(CultureInfo.InvariantCulture, "Stack: {0}", stackTrace); + Logging.PrintWarning(Logging.WebSockets, this, Methods.Abort, message); + }*/ +#endif + + // Abort any outstanding IO operations. + if (SessionHandle != null && !SessionHandle.IsClosed && !SessionHandle.IsInvalid) + { + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketAbortHandle(SessionHandle); + } + + _receiveOutstandingOperationHelper.CancelIO(); + _sendOutstandingOperationHelper.CancelIO(); + _closeOutputOutstandingOperationHelper.CancelIO(); + _closeOutstandingOperationHelper.CancelIO(); + if (_innerStreamAsWebSocketStream != null) + { + _innerStreamAsWebSocketStream.Abort(); + } + CleanUp(); + } + finally + { + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + /*if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.Abort, string.Empty); + }*/ + } + } + + // MultiThreading: ThreadSafe; No-op if already in a terminal state + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, + string statusDescription, + CancellationToken cancellationToken) + { + WebSocketHelpers.ValidateCloseStatus(closeStatus, statusDescription); + + return CloseOutputAsyncCore(closeStatus, statusDescription, cancellationToken); + } + + private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, + string statusDescription, + CancellationToken cancellationToken) + { + string inputParameter = string.Empty; + /*if (s_LoggingEnabled) + { + inputParameter = string.Format(CultureInfo.InvariantCulture, + "closeStatus: {0}, statusDescription: {1}", + closeStatus, + statusDescription); + Logging.Enter(Logging.WebSockets, this, Methods.CloseOutputAsync, inputParameter); + }*/ + + try + { + ThrowIfPendingException(); + if (IsStateTerminal(State)) + { + return; + } + ThrowIfDisposed(); + + bool thisLockTaken = false; + bool sessionHandleLockTaken = false; + bool needToCompleteSendOperation = false; + bool ownsCloseOutputCancellationTokenSource = false; + bool ownsSendCancellationTokenSource = false; + CancellationToken linkedCancellationToken = CancellationToken.None; + try + { + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + ThrowIfPendingException(); + ThrowIfDisposed(); + + if (IsStateTerminal(State)) + { + return; + } + + ThrowOnInvalidState(State, WebSocketState.Open, WebSocketState.CloseReceived); + ownsCloseOutputCancellationTokenSource = _closeOutputOutstandingOperationHelper.TryStartOperation(cancellationToken, out linkedCancellationToken); + if (!ownsCloseOutputCancellationTokenSource) + { + Task closeOutputTask = _closeOutputTask; + + if (closeOutputTask != null) + { + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + await closeOutputTask.SuppressContextFlow(); + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + } + } + else + { + needToCompleteSendOperation = true; + while (!(ownsSendCancellationTokenSource = + _sendOutstandingOperationHelper.TryStartOperation(cancellationToken, + out linkedCancellationToken))) + { + if (_keepAliveTask != null) + { + Task keepAliveTask = _keepAliveTask; + + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + await keepAliveTask.SuppressContextFlow(); + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + + ThrowIfPendingException(); + } + else + { + throw new InvalidOperationException( + SR.GetString(SR.net_Websockets_AlreadyOneOutstandingOperation, Methods.SendAsync)); + } + + _sendOutstandingOperationHelper.CompleteOperation(ownsSendCancellationTokenSource); + } + + EnsureCloseOutputOperation(); + _closeOutputOperation.CloseStatus = closeStatus; + _closeOutputOperation.CloseReason = statusDescription; + _closeOutputTask = _closeOutputOperation.Process(null, linkedCancellationToken); + + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + await _closeOutputTask.SuppressContextFlow(); + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + + if (OnCloseOutputCompleted()) + { + bool callCompleteOnCloseCompleted = false; + + try + { + callCompleteOnCloseCompleted = await StartOnCloseCompleted( + thisLockTaken, sessionHandleLockTaken, linkedCancellationToken).SuppressContextFlow(); + } + catch (Exception) + { + // If an exception is thrown we know that the locks have been released, + // because we enforce IWebSocketStream.CloseNetworkConnectionAsync to yield + ResetFlagsAndTakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + throw; + } + + if (callCompleteOnCloseCompleted) + { + ResetFlagsAndTakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + FinishOnCloseCompleted(); + } + } + } + } + catch (Exception exception) + { + bool aborted = linkedCancellationToken.IsCancellationRequested; + Abort(); + ThrowIfConvertibleException(Methods.CloseOutputAsync, exception, cancellationToken, aborted); + throw; + } + finally + { + _closeOutputOutstandingOperationHelper.CompleteOperation(ownsCloseOutputCancellationTokenSource); + + if (needToCompleteSendOperation) + { + _sendOutstandingOperationHelper.CompleteOperation(ownsSendCancellationTokenSource); + } + + _closeOutputTask = null; + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + } + } + finally + { + /*if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.CloseOutputAsync, inputParameter); + }*/ + } + } + + // returns TRUE if the caller should also call StartOnCloseCompleted + private bool OnCloseOutputCompleted() + { + if (IsStateTerminal(State)) + { + return false; + } + + switch (State) + { + case WebSocketState.Open: + _state = WebSocketState.CloseSent; + return false; + case WebSocketState.CloseReceived: + return true; + default: + return false; + } + } + + // MultiThreading: This method has to be called under a m_ThisLock-lock + // ReturnValue: This method returns true only if CompleteOnCloseCompleted needs to be called + // If this method returns true all locks were released before starting the IO operation + // and they have to be retaken by the caller before calling CompleteOnCloseCompleted + // Exception handling: If an exception is thrown from await StartOnCloseCompleted + // it always means the locks have been released already - so the caller has to retake the + // locks in the catch-block. + // This is ensured by enforcing a Task.Yield for IWebSocketStream.CloseNetowrkConnectionAsync + private async Task StartOnCloseCompleted(bool thisLockTakenSnapshot, + bool sessionHandleLockTakenSnapshot, + CancellationToken cancellationToken) + { + Contract.Assert(thisLockTakenSnapshot, "'thisLockTakenSnapshot' MUST be 'true' at this point."); + + if (IsStateTerminal(_state)) + { + return false; + } + + _state = WebSocketState.Closed; + +#if DEBUG + if (_closeStack == null) + { + _closeStack = new StackTrace().ToString(); + } +#endif + + if (_innerStreamAsWebSocketStream != null) + { + bool thisLockTaken = thisLockTakenSnapshot; + bool sessionHandleLockTaken = sessionHandleLockTakenSnapshot; + + try + { + if (_closeNetworkConnectionTask == null) + { + _closeNetworkConnectionTask = + _innerStreamAsWebSocketStream.CloseNetworkConnectionAsync(cancellationToken); + } + + if (thisLockTaken && sessionHandleLockTaken) + { + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + } + else if (thisLockTaken) + { + ReleaseLock(_thisLock, ref thisLockTaken); + } + + await _closeNetworkConnectionTask.SuppressContextFlow(); + } + catch (Exception closeNetworkConnectionTaskException) + { + if (!CanHandleExceptionDuringClose(closeNetworkConnectionTaskException)) + { + ThrowIfConvertibleException(Methods.StartOnCloseCompleted, + closeNetworkConnectionTaskException, + cancellationToken, + cancellationToken.IsCancellationRequested); + throw; + } + } + } + + return true; + } + + // MultiThreading: This method has to be called under a thisLock-lock + private void FinishOnCloseCompleted() + { + CleanUp(); + } + + // MultiThreading: ThreadSafe; No-op if already in a terminal state + public override Task CloseAsync(WebSocketCloseStatus closeStatus, + string statusDescription, + CancellationToken cancellationToken) + { + WebSocketHelpers.ValidateCloseStatus(closeStatus, statusDescription); + return CloseAsyncCore(closeStatus, statusDescription, cancellationToken); + } + + private async Task CloseAsyncCore(WebSocketCloseStatus closeStatus, + string statusDescription, + CancellationToken cancellationToken) + { + string inputParameter = string.Empty; + /*if (s_LoggingEnabled) + { + inputParameter = string.Format(CultureInfo.InvariantCulture, + "closeStatus: {0}, statusDescription: {1}", + closeStatus, + statusDescription); + Logging.Enter(Logging.WebSockets, this, Methods.CloseAsync, inputParameter); + }*/ + + try + { + ThrowIfPendingException(); + if (IsStateTerminal(State)) + { + return; + } + ThrowIfDisposed(); + + bool lockTaken = false; + Monitor.Enter(_thisLock, ref lockTaken); + bool ownsCloseCancellationTokenSource = false; + CancellationToken linkedCancellationToken = CancellationToken.None; + try + { + ThrowIfPendingException(); + if (IsStateTerminal(State)) + { + return; + } + ThrowIfDisposed(); + ThrowOnInvalidState(State, + WebSocketState.Open, WebSocketState.CloseReceived, WebSocketState.CloseSent); + + Task closeOutputTask; + ownsCloseCancellationTokenSource = _closeOutstandingOperationHelper.TryStartOperation(cancellationToken, out linkedCancellationToken); + if (ownsCloseCancellationTokenSource) + { + closeOutputTask = _closeOutputTask; + if (closeOutputTask == null && State != WebSocketState.CloseSent) + { + if (_closeReceivedTaskCompletionSource == null) + { + _closeReceivedTaskCompletionSource = new TaskCompletionSource(); + } + + closeOutputTask = CloseOutputAsync(closeStatus, + statusDescription, + linkedCancellationToken); + } + } + else + { + Contract.Assert(_closeReceivedTaskCompletionSource != null, + "'m_CloseReceivedTaskCompletionSource' MUST NOT be NULL."); + closeOutputTask = _closeReceivedTaskCompletionSource.Task; + } + + if (closeOutputTask != null) + { + ReleaseLock(_thisLock, ref lockTaken); + try + { + await closeOutputTask.SuppressContextFlow(); + } + catch (Exception closeOutputError) + { + Monitor.Enter(_thisLock, ref lockTaken); + + if (!CanHandleExceptionDuringClose(closeOutputError)) + { + ThrowIfConvertibleException(Methods.CloseOutputAsync, + closeOutputError, + cancellationToken, + linkedCancellationToken.IsCancellationRequested); + throw; + } + } + + // When closeOutputTask != null and an exception thrown from await closeOutputTask is handled, + // the lock will be taken in the catch-block. So the logic here avoids taking the lock twice. + if (!lockTaken) + { + Monitor.Enter(_thisLock, ref lockTaken); + } + } + + if (OnCloseOutputCompleted()) + { + bool callCompleteOnCloseCompleted = false; + + try + { + // linkedCancellationToken can be CancellationToken.None if ownsCloseCancellationTokenSource==false + // This is still ok because OnCloseOutputCompleted won't start any IO operation in this case + callCompleteOnCloseCompleted = await StartOnCloseCompleted( + lockTaken, false, linkedCancellationToken).SuppressContextFlow(); + } + catch (Exception) + { + // If an exception is thrown we know that the locks have been released, + // because we enforce IWebSocketStream.CloseNetworkConnectionAsync to yield + ResetFlagAndTakeLock(_thisLock, ref lockTaken); + throw; + } + + if (callCompleteOnCloseCompleted) + { + ResetFlagAndTakeLock(_thisLock, ref lockTaken); + FinishOnCloseCompleted(); + } + } + + if (IsStateTerminal(State)) + { + return; + } + + linkedCancellationToken = CancellationToken.None; + + bool ownsReceiveCancellationTokenSource = _receiveOutstandingOperationHelper.TryStartOperation(cancellationToken, out linkedCancellationToken); + if (ownsReceiveCancellationTokenSource) + { + _closeAsyncStartedReceive = true; + ArraySegment closeMessageBuffer = + new ArraySegment(new byte[WebSocketBuffer.MinReceiveBufferSize]); + EnsureReceiveOperation(); + Task receiveAsyncTask = _receiveOperation.Process(closeMessageBuffer, + linkedCancellationToken); + ReleaseLock(_thisLock, ref lockTaken); + + WebSocketReceiveResult receiveResult = null; + try + { + receiveResult = await receiveAsyncTask.SuppressContextFlow(); + } + catch (Exception receiveException) + { + Monitor.Enter(_thisLock, ref lockTaken); + + if (!CanHandleExceptionDuringClose(receiveException)) + { + ThrowIfConvertibleException(Methods.CloseAsync, + receiveException, + cancellationToken, + linkedCancellationToken.IsCancellationRequested); + throw; + } + } + + // receiveResult is NEVER NULL if WebSocketBase.ReceiveOperation.Process completes successfully + // - but in the close code path we handle some exception if another thread was able to tranistion + // the state into Closed successfully. In this case receiveResult can be NULL and it is safe to + // skip the statements in the if-block. + if (receiveResult != null) + { + /*if (s_LoggingEnabled && receiveResult.Count > 0) + { + Logging.Dump(Logging.WebSockets, + this, + Methods.ReceiveAsync, + closeMessageBuffer.Array, + closeMessageBuffer.Offset, + receiveResult.Count); + }*/ + + if (receiveResult.MessageType != WebSocketMessageType.Close) + { + throw new WebSocketException(WebSocketError.InvalidMessageType, + SR.GetString(SR.net_WebSockets_InvalidMessageType, + typeof(WebSocket).Name + "." + Methods.CloseAsync, + typeof(WebSocket).Name + "." + Methods.CloseOutputAsync, + receiveResult.MessageType)); + } + } + } + else + { + _receiveOutstandingOperationHelper.CompleteOperation(ownsReceiveCancellationTokenSource); + ReleaseLock(_thisLock, ref lockTaken); + await _closeReceivedTaskCompletionSource.Task.SuppressContextFlow(); + } + + // When ownsReceiveCancellationTokenSource is true and an exception is thrown, the lock will be taken. + // So this logic here is to avoid taking the lock twice. + if (!lockTaken) + { + Monitor.Enter(_thisLock, ref lockTaken); + } + + if (!IsStateTerminal(State)) + { + bool ownsSendCancellationSource = false; + try + { + // We know that the CloseFrame has been sent at this point. So no Send-operation is allowed anymore and we + // can hijack the m_SendOutstandingOperationHelper to create a linkedCancellationToken + ownsSendCancellationSource = _sendOutstandingOperationHelper.TryStartOperation(cancellationToken, out linkedCancellationToken); + Contract.Assert(ownsSendCancellationSource, "'ownsSendCancellationSource' MUST be 'true' at this point."); + + bool callCompleteOnCloseCompleted = false; + + try + { + // linkedCancellationToken can be CancellationToken.None if ownsCloseCancellationTokenSource==false + // This is still ok because OnCloseOutputCompleted won't start any IO operation in this case + callCompleteOnCloseCompleted = await StartOnCloseCompleted( + lockTaken, false, linkedCancellationToken).SuppressContextFlow(); + } + catch (Exception) + { + // If an exception is thrown we know that the locks have been released, + // because we enforce IWebSocketStream.CloseNetworkConnectionAsync to yield + ResetFlagAndTakeLock(_thisLock, ref lockTaken); + throw; + } + + if (callCompleteOnCloseCompleted) + { + ResetFlagAndTakeLock(_thisLock, ref lockTaken); + FinishOnCloseCompleted(); + } + } + finally + { + _sendOutstandingOperationHelper.CompleteOperation(ownsSendCancellationSource); + } + } + } + catch (Exception exception) + { + bool aborted = linkedCancellationToken.IsCancellationRequested; + Abort(); + ThrowIfConvertibleException(Methods.CloseAsync, exception, cancellationToken, aborted); + throw; + } + finally + { + _closeOutstandingOperationHelper.CompleteOperation(ownsCloseCancellationTokenSource); + ReleaseLock(_thisLock, ref lockTaken); + } + } + finally + { + /*if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.CloseAsync, inputParameter); + }*/ + } + } + + // MultiThreading: ThreadSafe; No-op if already in a terminal state + [SuppressMessage("Microsoft.Usage", "CA2213:DisposableFieldsShouldBeDisposed", MessageId = "_sendFrameThrottle", + Justification = "SemaphoreSlim.Dispose is not threadsafe and can cause NullRef exceptions on other threads." + + "Also according to the CLR Dev11#358715) there is no need to dispose SemaphoreSlim if the ManualResetEvent " + + "is not used.")] + public override void Dispose() + { + if (_isDisposed) + { + return; + } + + bool thisLockTaken = false; + bool sessionHandleLockTaken = false; + + try + { + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + + if (_isDisposed) + { + return; + } + + if (!IsStateTerminal(State)) + { + Abort(); + } + else + { + CleanUp(); + } + + _isDisposed = true; + } + finally + { + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + } + } + + private void ResetFlagAndTakeLock(object lockObject, ref bool thisLockTaken) + { + Contract.Assert(lockObject != null, "'lockObject' MUST NOT be NULL."); + thisLockTaken = false; + Monitor.Enter(lockObject, ref thisLockTaken); + } + + private void ResetFlagsAndTakeLocks(ref bool thisLockTaken, ref bool sessionHandleLockTaken) + { + thisLockTaken = false; + sessionHandleLockTaken = false; + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + } + + private void TakeLocks(ref bool thisLockTaken, ref bool sessionHandleLockTaken) + { + Contract.Assert(_thisLock != null, "'m_ThisLock' MUST NOT be NULL."); + Contract.Assert(SessionHandle != null, "'SessionHandle' MUST NOT be NULL."); + + Monitor.Enter(SessionHandle, ref sessionHandleLockTaken); + Monitor.Enter(_thisLock, ref thisLockTaken); + } + + private void ReleaseLocks(ref bool thisLockTaken, ref bool sessionHandleLockTaken) + { + Contract.Assert(_thisLock != null, "'m_ThisLock' MUST NOT be NULL."); + Contract.Assert(SessionHandle != null, "'SessionHandle' MUST NOT be NULL."); + + if (thisLockTaken || sessionHandleLockTaken) + { +#if NET45 + RuntimeHelpers.PrepareConstrainedRegions(); +#endif + try + { + } + finally + { + if (thisLockTaken) + { + Monitor.Exit(_thisLock); + thisLockTaken = false; + } + + if (sessionHandleLockTaken) + { + Monitor.Exit(SessionHandle); + sessionHandleLockTaken = false; + } + } + } + } + + private void EnsureReceiveOperation() + { + if (_receiveOperation == null) + { + lock (_thisLock) + { + if (_receiveOperation == null) + { + _receiveOperation = new WebSocketOperation.ReceiveOperation(this); + } + } + } + } + + private void EnsureSendOperation() + { + if (_sendOperation == null) + { + lock (_thisLock) + { + if (_sendOperation == null) + { + _sendOperation = new WebSocketOperation.SendOperation(this); + } + } + } + } + + private void EnsureKeepAliveOperation() + { + if (_keepAliveOperation == null) + { + lock (_thisLock) + { + if (_keepAliveOperation == null) + { + WebSocketOperation.SendOperation keepAliveOperation = new WebSocketOperation.SendOperation(this); + keepAliveOperation.BufferType = UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UnsolicitedPong; + _keepAliveOperation = keepAliveOperation; + } + } + } + } + + private void EnsureCloseOutputOperation() + { + if (_closeOutputOperation == null) + { + lock (_thisLock) + { + if (_closeOutputOperation == null) + { + _closeOutputOperation = new WebSocketOperation.CloseOutputOperation(this); + } + } + } + } + + private static void ReleaseLock(object lockObject, ref bool lockTaken) + { + Contract.Assert(lockObject != null, "'lockObject' MUST NOT be NULL."); + if (lockTaken) + { +#if NET45 + RuntimeHelpers.PrepareConstrainedRegions(); +#endif + try + { + } + finally + { + Monitor.Exit(lockObject); + lockTaken = false; + } + } + } + + private static UnsafeNativeMethods.WebSocketProtocolComponent.BufferType GetBufferType(WebSocketMessageType messageType, + bool endOfMessage) + { + Contract.Assert(messageType == WebSocketMessageType.Binary || messageType == WebSocketMessageType.Text, + string.Format(CultureInfo.InvariantCulture, + "The value of 'messageType' ({0}) is invalid. Valid message types: '{1}, {2}'", + messageType, + WebSocketMessageType.Binary, + WebSocketMessageType.Text)); + + if (messageType == WebSocketMessageType.Text) + { + if (endOfMessage) + { + return UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Message; + } + + return UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Fragment; + } + else + { + if (endOfMessage) + { + return UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryMessage; + } + + return UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryFragment; + } + } + + private static WebSocketMessageType GetMessageType(UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType) + { + switch (bufferType) + { + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close: + return WebSocketMessageType.Close; + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryFragment: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryMessage: + return WebSocketMessageType.Binary; + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Fragment: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Message: + return WebSocketMessageType.Text; + default: + // This indicates a contract violation of the websocket protocol component, + // because we currently don't support any WebSocket extensions and would + // not accept a Websocket handshake requesting extensions + Contract.Assert(false, + string.Format(CultureInfo.InvariantCulture, + "The value of 'bufferType' ({0}) is invalid. Valid buffer types: {1}, {2}, {3}, {4}, {5}.", + bufferType, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryFragment, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryMessage, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Fragment, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Message)); + + throw new WebSocketException(WebSocketError.NativeError, + SR.GetString(SR.net_WebSockets_InvalidBufferType, + bufferType, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryFragment, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryMessage, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Fragment, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Message)); + } + } + + internal void ValidateNativeBuffers(UnsafeNativeMethods.WebSocketProtocolComponent.Action action, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType, + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer[] dataBuffers, + uint dataBufferCount) + { + _internalBuffer.ValidateNativeBuffers(action, bufferType, dataBuffers, dataBufferCount); + } + + internal void ThrowIfClosedOrAborted() + { + if (State == WebSocketState.Closed || State == WebSocketState.Aborted) + { + throw new WebSocketException(WebSocketError.InvalidState, + SR.GetString(SR.net_WebSockets_InvalidState_ClosedOrAborted, GetType().FullName, State)); + } + } + + private void ThrowIfAborted(bool aborted, Exception innerException) + { + if (aborted) + { + throw new WebSocketException(WebSocketError.InvalidState, + SR.GetString(SR.net_WebSockets_InvalidState_ClosedOrAborted, GetType().FullName, WebSocketState.Aborted), + innerException); + } + } + + private bool CanHandleExceptionDuringClose(Exception error) + { + Contract.Assert(error != null, "'error' MUST NOT be NULL."); + + if (State != WebSocketState.Closed) + { + return false; + } + + return error is OperationCanceledException || + error is WebSocketException || + // error is SocketException || + // error is HttpListenerException || + error is IOException; + } + + // We only want to throw an OperationCanceledException if the CancellationToken passed + // down from the caller is canceled - not when Abort is called on another thread and + // the linkedCancellationToken is canceled. + private void ThrowIfConvertibleException(string methodName, + Exception exception, + CancellationToken cancellationToken, + bool aborted) + { + Contract.Assert(exception != null, "'exception' MUST NOT be NULL."); + /* + if (s_LoggingEnabled && !string.IsNullOrEmpty(methodName)) + { + Logging.Exception(Logging.WebSockets, this, methodName, exception); + }*/ + + OperationCanceledException operationCanceledException = exception as OperationCanceledException; + if (operationCanceledException != null) + { + if (cancellationToken.IsCancellationRequested || + !aborted) + { + return; + } + ThrowIfAborted(aborted, exception); + } + + WebSocketException convertedException = exception as WebSocketException; + if (convertedException != null) + { + cancellationToken.ThrowIfCancellationRequested(); + ThrowIfAborted(aborted, convertedException); + return; + } + /* + SocketException socketException = exception as SocketException; + if (socketException != null) + { + convertedException = new WebSocketException(socketException.NativeErrorCode, socketException); + } + HttpListenerException httpListenerException = exception as HttpListenerException; + if (httpListenerException != null) + { + convertedException = new WebSocketException(httpListenerException.ErrorCode, httpListenerException); + } + + IOException ioException = exception as IOException; + if (ioException != null) + { + socketException = exception.InnerException as SocketException; + if (socketException != null) + { + convertedException = new WebSocketException(socketException.NativeErrorCode, ioException); + } + } +*/ + if (convertedException != null) + { + cancellationToken.ThrowIfCancellationRequested(); + ThrowIfAborted(aborted, convertedException); + throw convertedException; + } + + AggregateException aggregateException = exception as AggregateException; + if (aggregateException != null) + { + // Collapse possibly nested graph into a flat list. + // Empty inner exception list is unlikely but possible via public api. + ReadOnlyCollection unwrappedExceptions = aggregateException.Flatten().InnerExceptions; + if (unwrappedExceptions.Count == 0) + { + return; + } + + foreach (Exception unwrappedException in unwrappedExceptions) + { + ThrowIfConvertibleException(null, unwrappedException, cancellationToken, aborted); + } + } + } + + private void CleanUp() + { + // Multithreading: This method is always called under the m_ThisLock lock + if (_cleanedUp) + { + return; + } + + _cleanedUp = true; + + if (SessionHandle != null) + { + SessionHandle.Dispose(); + } + + if (_internalBuffer != null) + { + _internalBuffer.Dispose(this.State); + } + + if (_receiveOutstandingOperationHelper != null) + { + _receiveOutstandingOperationHelper.Dispose(); + } + + if (_sendOutstandingOperationHelper != null) + { + _sendOutstandingOperationHelper.Dispose(); + } + + if (_closeOutputOutstandingOperationHelper != null) + { + _closeOutputOutstandingOperationHelper.Dispose(); + } + + if (_closeOutstandingOperationHelper != null) + { + _closeOutstandingOperationHelper.Dispose(); + } + + if (_innerStream != null) + { + try + { + _innerStream.Dispose(); + } + catch (ObjectDisposedException) + { + } + catch (IOException) + { + } + /*catch (SocketException) + { + }*/ + catch (Exception) + { + } + } + + _keepAliveTracker.Dispose(); + } + + private void OnBackgroundTaskException(Exception exception) + { + if (Interlocked.CompareExchange(ref _pendingException, exception, null) == null) + { + /*if (s_LoggingEnabled) + { + Logging.Exception(Logging.WebSockets, this, Methods.Fault, exception); + }*/ + Abort(); + } + } + + private void ThrowIfPendingException() + { + Exception pendingException = Interlocked.Exchange(ref _pendingException, null); + if (pendingException != null) + { + throw new WebSocketException(WebSocketError.Faulted, pendingException); + } + } + + private void ThrowIfDisposed() + { + if (_isDisposed) + { + throw new ObjectDisposedException(GetType().FullName); + } + } + + private void UpdateReceiveState(int newReceiveState, int expectedReceiveState) + { + int receiveState; + if ((receiveState = Interlocked.Exchange(ref _receiveState, newReceiveState)) != expectedReceiveState) + { + Contract.Assert(false, + string.Format(CultureInfo.InvariantCulture, + "'m_ReceiveState' had an invalid value '{0}'. The expected value was '{1}'.", + receiveState, + expectedReceiveState)); + } + } + + private bool StartOnCloseReceived(ref bool thisLockTaken) + { + ThrowIfDisposed(); + + if (IsStateTerminal(State) || State == WebSocketState.CloseReceived) + { + return false; + } + + Monitor.Enter(_thisLock, ref thisLockTaken); + if (IsStateTerminal(State) || State == WebSocketState.CloseReceived) + { + return false; + } + + if (State == WebSocketState.Open) + { + _state = WebSocketState.CloseReceived; + + if (_closeReceivedTaskCompletionSource == null) + { + _closeReceivedTaskCompletionSource = new TaskCompletionSource(); + } + + return false; + } + + return true; + } + + private void FinishOnCloseReceived(WebSocketCloseStatus closeStatus, + string closeStatusDescription) + { + if (_closeReceivedTaskCompletionSource != null) + { + _closeReceivedTaskCompletionSource.TrySetResult(null); + } + + _closeStatus = closeStatus; + _closeStatusDescription = closeStatusDescription; + /* + if (s_LoggingEnabled) + { + string parameters = string.Format(CultureInfo.InvariantCulture, + "closeStatus: {0}, closeStatusDescription: {1}, m_State: {2}", + closeStatus, closeStatusDescription, m_State); + + Logging.PrintInfo(Logging.WebSockets, this, Methods.FinishOnCloseReceived, parameters); + }*/ + } + + private async static void OnKeepAlive(object sender) + { + Contract.Assert(sender != null, "'sender' MUST NOT be NULL."); + Contract.Assert((sender as WebSocketBase) != null, "'sender as WebSocketBase' MUST NOT be NULL."); + + WebSocketBase thisPtr = sender as WebSocketBase; + bool lockTaken = false; + /* + if (s_LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, thisPtr, Methods.OnKeepAlive, string.Empty); + }*/ + + CancellationToken linkedCancellationToken = CancellationToken.None; + try + { + Monitor.Enter(thisPtr.SessionHandle, ref lockTaken); + + if (thisPtr._isDisposed || + thisPtr._state != WebSocketState.Open || + thisPtr._closeOutputTask != null) + { + return; + } + + if (thisPtr._keepAliveTracker.ShouldSendKeepAlive()) + { + bool ownsCancellationTokenSource = false; + try + { + ownsCancellationTokenSource = thisPtr._sendOutstandingOperationHelper.TryStartOperation(CancellationToken.None, out linkedCancellationToken); + if (ownsCancellationTokenSource) + { + thisPtr.EnsureKeepAliveOperation(); + thisPtr._keepAliveTask = thisPtr._keepAliveOperation.Process(null, linkedCancellationToken); + ReleaseLock(thisPtr.SessionHandle, ref lockTaken); + await thisPtr._keepAliveTask.SuppressContextFlow(); + } + } + finally + { + if (!lockTaken) + { + Monitor.Enter(thisPtr.SessionHandle, ref lockTaken); + } + thisPtr._sendOutstandingOperationHelper.CompleteOperation(ownsCancellationTokenSource); + thisPtr._keepAliveTask = null; + } + + thisPtr._keepAliveTracker.ResetTimer(); + } + } + catch (Exception exception) + { + try + { + thisPtr.ThrowIfConvertibleException(Methods.OnKeepAlive, + exception, + CancellationToken.None, + linkedCancellationToken.IsCancellationRequested); + throw; + } + catch (Exception backgroundException) + { + thisPtr.OnBackgroundTaskException(backgroundException); + } + } + finally + { + ReleaseLock(thisPtr.SessionHandle, ref lockTaken); + /* + if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, thisPtr, Methods.OnKeepAlive, string.Empty); + }*/ + } + } + + private abstract class WebSocketOperation + { + private readonly WebSocketBase _webSocket; + + internal WebSocketOperation(WebSocketBase webSocket) + { + Contract.Assert(webSocket != null, "'webSocket' MUST NOT be NULL."); + _webSocket = webSocket; + } + + public WebSocketReceiveResult ReceiveResult { get; protected set; } + protected abstract int BufferCount { get; } + protected abstract UnsafeNativeMethods.WebSocketProtocolComponent.ActionQueue ActionQueue { get; } + protected abstract void Initialize(ArraySegment? buffer, CancellationToken cancellationToken); + protected abstract bool ShouldContinue(CancellationToken cancellationToken); + + // Multi-Threading: This method has to be called under a SessionHandle-lock. It returns true if a + // close frame was received. Handling the received close frame might involve IO - to make the locking + // strategy easier and reduce one level in the await-hierarchy the IO is kicked off by the caller. + protected abstract bool ProcessAction_NoAction(); + + protected virtual void ProcessAction_IndicateReceiveComplete( + ArraySegment? buffer, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType, + UnsafeNativeMethods.WebSocketProtocolComponent.Action action, + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer[] dataBuffers, + uint dataBufferCount, + IntPtr actionContext) + { + throw new NotImplementedException(); + } + + protected abstract void Cleanup(); + + internal async Task Process(ArraySegment? buffer, + CancellationToken cancellationToken) + { + Contract.Assert(BufferCount >= 1 && BufferCount <= 2, "'bufferCount' MUST ONLY BE '1' or '2'."); + + bool sessionHandleLockTaken = false; + ReceiveResult = null; + try + { + Monitor.Enter(_webSocket.SessionHandle, ref sessionHandleLockTaken); + _webSocket.ThrowIfPendingException(); + Initialize(buffer, cancellationToken); + + while (ShouldContinue(cancellationToken)) + { + UnsafeNativeMethods.WebSocketProtocolComponent.Action action; + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType; + + bool completed = false; + while (!completed) + { + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer[] dataBuffers = + new UnsafeNativeMethods.WebSocketProtocolComponent.Buffer[BufferCount]; + uint dataBufferCount = (uint)BufferCount; + IntPtr actionContext; + + _webSocket.ThrowIfDisposed(); + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketGetAction(_webSocket, + ActionQueue, + dataBuffers, + ref dataBufferCount, + out action, + out bufferType, + out actionContext); + + switch (action) + { + case UnsafeNativeMethods.WebSocketProtocolComponent.Action.NoAction: + if (ProcessAction_NoAction()) + { + // A close frame was received + + Contract.Assert(ReceiveResult.Count == 0, "'receiveResult.Count' MUST be 0."); + Contract.Assert(ReceiveResult.CloseStatus != null, "'receiveResult.CloseStatus' MUST NOT be NULL for message type 'Close'."); + bool thisLockTaken = false; + try + { + if (_webSocket.StartOnCloseReceived(ref thisLockTaken)) + { + // If StartOnCloseReceived returns true the WebSocket close handshake has been completed + // so there is no need to retake the SessionHandle-lock. + // m_ThisLock lock is guaranteed to be taken by StartOnCloseReceived when returning true + ReleaseLock(_webSocket.SessionHandle, ref sessionHandleLockTaken); + bool callCompleteOnCloseCompleted = false; + + try + { + callCompleteOnCloseCompleted = await _webSocket.StartOnCloseCompleted( + thisLockTaken, sessionHandleLockTaken, cancellationToken).SuppressContextFlow(); + } + catch (Exception) + { + // If an exception is thrown we know that the locks have been released, + // because we enforce IWebSocketStream.CloseNetworkConnectionAsync to yield + _webSocket.ResetFlagAndTakeLock(_webSocket._thisLock, ref thisLockTaken); + throw; + } + + if (callCompleteOnCloseCompleted) + { + _webSocket.ResetFlagAndTakeLock(_webSocket._thisLock, ref thisLockTaken); + _webSocket.FinishOnCloseCompleted(); + } + } + _webSocket.FinishOnCloseReceived(ReceiveResult.CloseStatus.Value, ReceiveResult.CloseStatusDescription); + } + finally + { + if (thisLockTaken) + { + ReleaseLock(_webSocket._thisLock, ref thisLockTaken); + } + } + } + completed = true; + break; + case UnsafeNativeMethods.WebSocketProtocolComponent.Action.IndicateReceiveComplete: + ProcessAction_IndicateReceiveComplete(buffer, + bufferType, + action, + dataBuffers, + dataBufferCount, + actionContext); + break; + case UnsafeNativeMethods.WebSocketProtocolComponent.Action.ReceiveFromNetwork: + int count = 0; + try + { + ArraySegment payload = _webSocket._internalBuffer.ConvertNativeBuffer(action, dataBuffers[0], bufferType); + + ReleaseLock(_webSocket.SessionHandle, ref sessionHandleLockTaken); + WebSocketHelpers.ThrowIfConnectionAborted(_webSocket._innerStream, true); + try + { + Task readTask = _webSocket._innerStream.ReadAsync(payload.Array, + payload.Offset, + payload.Count, + cancellationToken); + count = await readTask.SuppressContextFlow(); + _webSocket._keepAliveTracker.OnDataReceived(); + } + catch (ObjectDisposedException objectDisposedException) + { + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely, objectDisposedException); + } + catch (NotSupportedException notSupportedException) + { + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely, notSupportedException); + } + Monitor.Enter(_webSocket.SessionHandle, ref sessionHandleLockTaken); + _webSocket.ThrowIfPendingException(); + // If the client unexpectedly closed the socket we throw an exception as we didn't get any close message + if (count == 0) + { + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); + } + } + finally + { + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketCompleteAction(_webSocket, + actionContext, + count); + } + break; + case UnsafeNativeMethods.WebSocketProtocolComponent.Action.IndicateSendComplete: + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketCompleteAction(_webSocket, actionContext, 0); + ReleaseLock(_webSocket.SessionHandle, ref sessionHandleLockTaken); + await _webSocket._innerStream.FlushAsync().SuppressContextFlow(); + Monitor.Enter(_webSocket.SessionHandle, ref sessionHandleLockTaken); + break; + case UnsafeNativeMethods.WebSocketProtocolComponent.Action.SendToNetwork: + int bytesSent = 0; + try + { + if (_webSocket.State != WebSocketState.CloseSent || + (bufferType != UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.PingPong && + bufferType != UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UnsolicitedPong)) + { + if (dataBufferCount == 0) + { + break; + } + + List> sendBuffers = new List>((int)dataBufferCount); + int sendBufferSize = 0; + ArraySegment framingBuffer = _webSocket._internalBuffer.ConvertNativeBuffer(action, dataBuffers[0], bufferType); + sendBuffers.Add(framingBuffer); + sendBufferSize += framingBuffer.Count; + + // There can be at most 2 dataBuffers + // - one for the framing header and one for the payload + if (dataBufferCount == 2) + { + ArraySegment payload = _webSocket._internalBuffer.ConvertPinnedSendPayloadFromNative(dataBuffers[1], bufferType); + sendBuffers.Add(payload); + sendBufferSize += payload.Count; + } + + ReleaseLock(_webSocket.SessionHandle, ref sessionHandleLockTaken); + WebSocketHelpers.ThrowIfConnectionAborted(_webSocket._innerStream, false); + await _webSocket.SendFrameAsync(sendBuffers, cancellationToken).SuppressContextFlow(); + Monitor.Enter(_webSocket.SessionHandle, ref sessionHandleLockTaken); + _webSocket.ThrowIfPendingException(); + bytesSent += sendBufferSize; + _webSocket._keepAliveTracker.OnDataSent(); + } + } + finally + { + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketCompleteAction(_webSocket, + actionContext, + bytesSent); + } + + break; + default: + string assertMessage = string.Format(CultureInfo.InvariantCulture, + "Invalid action '{0}' returned from WebSocketGetAction.", + action); + Contract.Assert(false, assertMessage); + throw new InvalidOperationException(); + } + } + } + } + finally + { + Cleanup(); + ReleaseLock(_webSocket.SessionHandle, ref sessionHandleLockTaken); + } + + return ReceiveResult; + } + + public class ReceiveOperation : WebSocketOperation + { + private int _receiveState; + private bool _pongReceived; + private bool _receiveCompleted; + + public ReceiveOperation(WebSocketBase webSocket) + : base(webSocket) + { + } + + protected override UnsafeNativeMethods.WebSocketProtocolComponent.ActionQueue ActionQueue + { + get { return UnsafeNativeMethods.WebSocketProtocolComponent.ActionQueue.Receive; } + } + + protected override int BufferCount + { + get { return 1; } + } + + protected override void Initialize(ArraySegment? buffer, CancellationToken cancellationToken) + { + Contract.Assert(buffer != null, "'buffer' MUST NOT be NULL."); + _pongReceived = false; + _receiveCompleted = false; + _webSocket.ThrowIfDisposed(); + + int originalReceiveState = Interlocked.CompareExchange(ref _webSocket._receiveState, + ReceiveState.Application, ReceiveState.Idle); + + switch (originalReceiveState) + { + case ReceiveState.Idle: + _receiveState = ReceiveState.Application; + break; + case ReceiveState.Application: + Contract.Assert(false, "'originalReceiveState' MUST NEVER be ReceiveState.Application at this point."); + break; + case ReceiveState.PayloadAvailable: + WebSocketReceiveResult receiveResult; + if (!_webSocket._internalBuffer.ReceiveFromBufferedPayload(buffer.Value, out receiveResult)) + { + _webSocket.UpdateReceiveState(ReceiveState.Idle, ReceiveState.PayloadAvailable); + } + ReceiveResult = receiveResult; + _receiveCompleted = true; + break; + default: + Contract.Assert(false, + string.Format(CultureInfo.InvariantCulture, "Invalid ReceiveState '{0}'.", originalReceiveState)); + break; + } + } + + protected override void Cleanup() + { + } + + protected override bool ShouldContinue(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (_receiveCompleted) + { + return false; + } + + _webSocket.ThrowIfDisposed(); + _webSocket.ThrowIfPendingException(); + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketReceive(_webSocket); + + return true; + } + + protected override bool ProcessAction_NoAction() + { + if (_pongReceived) + { + _receiveCompleted = false; + _pongReceived = false; + return false; + } + + Contract.Assert(ReceiveResult != null, + "'ReceiveResult' MUST NOT be NULL."); + _receiveCompleted = true; + + if (ReceiveResult.MessageType == WebSocketMessageType.Close) + { + return true; + } + + return false; + } + + protected override void ProcessAction_IndicateReceiveComplete( + ArraySegment? buffer, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType, + UnsafeNativeMethods.WebSocketProtocolComponent.Action action, + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer[] dataBuffers, + uint dataBufferCount, + IntPtr actionContext) + { + Contract.Assert(buffer != null, "'buffer MUST NOT be NULL."); + + int bytesTransferred = 0; + _pongReceived = false; + + if (bufferType == UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.PingPong) + { + // ignoring received pong frame + _pongReceived = true; + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketCompleteAction(_webSocket, + actionContext, + bytesTransferred); + return; + } + + WebSocketReceiveResult receiveResult; + try + { + ArraySegment payload; + WebSocketMessageType messageType = GetMessageType(bufferType); + int newReceiveState = ReceiveState.Idle; + + if (bufferType == UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close) + { + payload = WebSocketHelpers.EmptyPayload; + string reason; + WebSocketCloseStatus closeStatus; + _webSocket._internalBuffer.ConvertCloseBuffer(action, dataBuffers[0], out closeStatus, out reason); + + receiveResult = new WebSocketReceiveResult(bytesTransferred, + messageType, true, closeStatus, reason); + } + else + { + payload = _webSocket._internalBuffer.ConvertNativeBuffer(action, dataBuffers[0], bufferType); + + bool endOfMessage = bufferType == + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryMessage || + bufferType == UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Message || + bufferType == UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close; + + if (payload.Count > buffer.Value.Count) + { + _webSocket._internalBuffer.BufferPayload(payload, buffer.Value.Count, messageType, endOfMessage); + newReceiveState = ReceiveState.PayloadAvailable; + endOfMessage = false; + } + + bytesTransferred = Math.Min(payload.Count, (int)buffer.Value.Count); + if (bytesTransferred > 0) + { + Buffer.BlockCopy(payload.Array, + payload.Offset, + buffer.Value.Array, + buffer.Value.Offset, + bytesTransferred); + } + + receiveResult = new WebSocketReceiveResult(bytesTransferred, messageType, endOfMessage); + } + + _webSocket.UpdateReceiveState(newReceiveState, _receiveState); + } + finally + { + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketCompleteAction(_webSocket, + actionContext, + bytesTransferred); + } + + ReceiveResult = receiveResult; + } + } + + public class SendOperation : WebSocketOperation + { + private bool _completed; + protected bool _bufferHasBeenPinned; + + public SendOperation(WebSocketBase webSocket) + : base(webSocket) + { + } + + protected override UnsafeNativeMethods.WebSocketProtocolComponent.ActionQueue ActionQueue + { + get { return UnsafeNativeMethods.WebSocketProtocolComponent.ActionQueue.Send; } + } + + protected override int BufferCount + { + get { return 2; } + } + + protected virtual UnsafeNativeMethods.WebSocketProtocolComponent.Buffer? CreateBuffer(ArraySegment? buffer) + { + if (buffer == null) + { + return null; + } + + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer payloadBuffer; + payloadBuffer = new UnsafeNativeMethods.WebSocketProtocolComponent.Buffer(); + _webSocket._internalBuffer.PinSendBuffer(buffer.Value, out _bufferHasBeenPinned); + payloadBuffer.Data.BufferData = _webSocket._internalBuffer.ConvertPinnedSendPayloadToNative(buffer.Value); + payloadBuffer.Data.BufferLength = (uint)buffer.Value.Count; + return payloadBuffer; + } + + protected override bool ProcessAction_NoAction() + { + _completed = true; + return false; + } + + protected override void Cleanup() + { + if (_bufferHasBeenPinned) + { + _bufferHasBeenPinned = false; + _webSocket._internalBuffer.ReleasePinnedSendBuffer(); + } + } + + internal UnsafeNativeMethods.WebSocketProtocolComponent.BufferType BufferType { get; set; } + + protected override void Initialize(ArraySegment? buffer, + CancellationToken cancellationToken) + { + Contract.Assert(!_bufferHasBeenPinned, "'m_BufferHasBeenPinned' MUST NOT be pinned at this point."); + _webSocket.ThrowIfDisposed(); + _webSocket.ThrowIfPendingException(); + _completed = false; + + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer? payloadBuffer = CreateBuffer(buffer); + if (payloadBuffer != null) + { + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketSend(_webSocket, BufferType, payloadBuffer.Value); + } + else + { + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketSendWithoutBody(_webSocket, BufferType); + } + } + + protected override bool ShouldContinue(CancellationToken cancellationToken) + { + Contract.Assert(ReceiveResult == null, "'ReceiveResult' MUST be NULL."); + if (_completed) + { + return false; + } + + cancellationToken.ThrowIfCancellationRequested(); + return true; + } + } + + public class CloseOutputOperation : SendOperation + { + public CloseOutputOperation(WebSocketBase webSocket) + : base(webSocket) + { + BufferType = UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close; + } + + internal WebSocketCloseStatus CloseStatus { get; set; } + internal string CloseReason { get; set; } + + protected override UnsafeNativeMethods.WebSocketProtocolComponent.Buffer? CreateBuffer(ArraySegment? buffer) + { + Contract.Assert(buffer == null, "'buffer' MUST BE NULL."); + _webSocket.ThrowIfDisposed(); + _webSocket.ThrowIfPendingException(); + + if (CloseStatus == WebSocketCloseStatus.Empty) + { + return null; + } + + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer payloadBuffer = new UnsafeNativeMethods.WebSocketProtocolComponent.Buffer(); + if (CloseReason != null) + { + byte[] blob = UTF8Encoding.UTF8.GetBytes(CloseReason); + Contract.Assert(blob.Length <= WebSocketHelpers.MaxControlFramePayloadLength, + "The close reason is too long."); + ArraySegment closeBuffer = new ArraySegment(blob, 0, Math.Min(WebSocketHelpers.MaxControlFramePayloadLength, blob.Length)); + _webSocket._internalBuffer.PinSendBuffer(closeBuffer, out _bufferHasBeenPinned); + payloadBuffer.CloseStatus.ReasonData = _webSocket._internalBuffer.ConvertPinnedSendPayloadToNative(closeBuffer); + payloadBuffer.CloseStatus.ReasonLength = (uint)closeBuffer.Count; + } + + payloadBuffer.CloseStatus.CloseStatus = (ushort)CloseStatus; + return payloadBuffer; + } + } + } + + private abstract class KeepAliveTracker : IDisposable + { + // Multi-Threading: only one thread at a time is allowed to call OnDataReceived or OnDataSent + // - but both methods can be called from different threads at the same time. + public abstract void OnDataReceived(); + public abstract void OnDataSent(); + public abstract void Dispose(); + public abstract void StartTimer(WebSocketBase webSocket); + public abstract void ResetTimer(); + public abstract bool ShouldSendKeepAlive(); + + public static KeepAliveTracker Create(TimeSpan keepAliveInterval) + { + if ((int)keepAliveInterval.TotalMilliseconds > 0) + { + return new DefaultKeepAliveTracker(keepAliveInterval); + } + + return new DisabledKeepAliveTracker(); + } + + private class DisabledKeepAliveTracker : KeepAliveTracker + { + public override void OnDataReceived() + { + } + + public override void OnDataSent() + { + } + + public override void ResetTimer() + { + } + + public override void StartTimer(WebSocketBase webSocket) + { + } + + public override bool ShouldSendKeepAlive() + { + return false; + } + + public override void Dispose() + { + } + } + + private class DefaultKeepAliveTracker : KeepAliveTracker + { + private static readonly TimerCallback _keepAliveTimerElapsedCallback = new TimerCallback(OnKeepAlive); + private readonly TimeSpan _keepAliveInterval; + private readonly Stopwatch _lastSendActivity; + private readonly Stopwatch _lastReceiveActivity; + private Timer _keepAliveTimer; + + public DefaultKeepAliveTracker(TimeSpan keepAliveInterval) + { + _keepAliveInterval = keepAliveInterval; + _lastSendActivity = new Stopwatch(); + _lastReceiveActivity = new Stopwatch(); + } + + public override void OnDataReceived() + { + _lastReceiveActivity.Restart(); + } + + public override void OnDataSent() + { + _lastSendActivity.Restart(); + } + + public override void ResetTimer() + { + ResetTimer((int)_keepAliveInterval.TotalMilliseconds); + } + + public override void StartTimer(WebSocketBase webSocket) + { + Contract.Assert(webSocket != null, "'webSocket' MUST NOT be NULL."); + Contract.Assert(webSocket._keepAliveTracker != null, + "'webSocket.m_KeepAliveTracker' MUST NOT be NULL at this point."); + int keepAliveIntervalMilliseconds = (int)_keepAliveInterval.TotalMilliseconds; + Contract.Assert(keepAliveIntervalMilliseconds > 0, "'keepAliveIntervalMilliseconds' MUST be POSITIVE."); +#if NET45 + if (ExecutionContext.IsFlowSuppressed()) + { + _keepAliveTimer = new Timer(_keepAliveTimerElapsedCallback, webSocket, keepAliveIntervalMilliseconds, Timeout.Infinite); + } + else + { + using (ExecutionContext.SuppressFlow()) + { + _keepAliveTimer = new Timer(_keepAliveTimerElapsedCallback, webSocket, keepAliveIntervalMilliseconds, Timeout.Infinite); + } + } +#else + _keepAliveTimer = new Timer(_keepAliveTimerElapsedCallback, webSocket, keepAliveIntervalMilliseconds, Timeout.Infinite); +#endif + } + + public override bool ShouldSendKeepAlive() + { + TimeSpan idleTime = GetIdleTime(); + if (idleTime >= _keepAliveInterval) + { + return true; + } + + ResetTimer((int)(_keepAliveInterval - idleTime).TotalMilliseconds); + return false; + } + + public override void Dispose() + { + _keepAliveTimer.Dispose(); + } + + private void ResetTimer(int dueInMilliseconds) + { + _keepAliveTimer.Change(dueInMilliseconds, Timeout.Infinite); + } + + private TimeSpan GetIdleTime() + { + TimeSpan sinceLastSendActivity = GetTimeElapsed(_lastSendActivity); + TimeSpan sinceLastReceiveActivity = GetTimeElapsed(_lastReceiveActivity); + + if (sinceLastReceiveActivity < sinceLastSendActivity) + { + return sinceLastReceiveActivity; + } + + return sinceLastSendActivity; + } + + private TimeSpan GetTimeElapsed(Stopwatch watch) + { + if (watch.IsRunning) + { + return watch.Elapsed; + } + + return _keepAliveInterval; + } + } + } + + private class OutstandingOperationHelper : IDisposable + { + private volatile int _operationsOutstanding; + private volatile CancellationTokenSource _cancellationTokenSource; + private volatile bool _isDisposed; + private readonly object _thisLock = new object(); + + public bool TryStartOperation(CancellationToken userCancellationToken, out CancellationToken linkedCancellationToken) + { + linkedCancellationToken = CancellationToken.None; + ThrowIfDisposed(); + + lock (_thisLock) + { + int operationsOutstanding = ++_operationsOutstanding; + + if (operationsOutstanding == 1) + { + linkedCancellationToken = CreateLinkedCancellationToken(userCancellationToken); + return true; + } + + Contract.Assert(operationsOutstanding >= 1, "'operationsOutstanding' must never be smaller than 1."); + return false; + } + } + + public void CompleteOperation(bool ownsCancellationTokenSource) + { + if (_isDisposed) + { + // no-op if the WebSocket is already aborted + return; + } + + CancellationTokenSource snapshot = null; + + lock (_thisLock) + { + --_operationsOutstanding; + Contract.Assert(_operationsOutstanding >= 0, "'m_OperationsOutstanding' must never be smaller than 0."); + + if (ownsCancellationTokenSource) + { + snapshot = _cancellationTokenSource; + _cancellationTokenSource = null; + } + } + + if (snapshot != null) + { + snapshot.Dispose(); + } + } + + // Has to be called under m_ThisLock lock + private CancellationToken CreateLinkedCancellationToken(CancellationToken cancellationToken) + { + CancellationTokenSource linkedCancellationTokenSource; + + if (cancellationToken == CancellationToken.None) + { + linkedCancellationTokenSource = new CancellationTokenSource(); + } + else + { + linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, + new CancellationTokenSource().Token); + } + + Contract.Assert(_cancellationTokenSource == null, "'m_CancellationTokenSource' MUST be NULL."); + _cancellationTokenSource = linkedCancellationTokenSource; + + return linkedCancellationTokenSource.Token; + } + + public void CancelIO() + { + CancellationTokenSource cancellationTokenSourceSnapshot = null; + + lock (_thisLock) + { + if (_operationsOutstanding == 0) + { + return; + } + + cancellationTokenSourceSnapshot = _cancellationTokenSource; + } + + if (cancellationTokenSourceSnapshot != null) + { + try + { + cancellationTokenSourceSnapshot.Cancel(); + } + catch (ObjectDisposedException) + { + // Simply ignore this exception - There is apparently a rare race condition + // where the cancellationTokensource is disposed before the Cancel method call completed. + } + } + } + + public void Dispose() + { + if (_isDisposed) + { + return; + } + + CancellationTokenSource snapshot = null; + lock (_thisLock) + { + if (_isDisposed) + { + return; + } + + _isDisposed = true; + snapshot = _cancellationTokenSource; + _cancellationTokenSource = null; + } + + if (snapshot != null) + { + snapshot.Dispose(); + } + } + + private void ThrowIfDisposed() + { + if (_isDisposed) + { + throw new ObjectDisposedException(GetType().FullName); + } + } + } + + internal interface IWebSocketStream + { + // Switching to opaque mode will change the behavior to use the knowledge that the WebSocketBase class + // is pinning all payloads already and that we will have at most one outstanding send and receive at any + // given time. This allows us to avoid creation of OverlappedData and pinning for each operation. + + void SwitchToOpaqueMode(WebSocketBase webSocket); + void Abort(); + bool SupportsMultipleWrite { get; } + Task MultipleWriteAsync(IList> buffers, CancellationToken cancellationToken); + + // Any implementation has to guarantee that no exception is thrown synchronously + // for example by enforcing a Task.Yield at the beginning of the method + // This is necessary to enforce an API contract (for WebSocketBase.StartOnCloseCompleted) that ensures + // that all locks have been released whenever an exception is thrown from it. + Task CloseNetworkConnectionAsync(CancellationToken cancellationToken); + } + + private static class ReceiveState + { + internal const int SendOperation = -1; + internal const int Idle = 0; + internal const int Application = 1; + internal const int PayloadAvailable = 2; + } + + internal static class Methods + { + internal const string ReceiveAsync = "ReceiveAsync"; + internal const string SendAsync = "SendAsync"; + internal const string CloseAsync = "CloseAsync"; + internal const string CloseOutputAsync = "CloseOutputAsync"; + internal const string Abort = "Abort"; + internal const string Initialize = "Initialize"; + internal const string Fault = "Fault"; + internal const string StartOnCloseCompleted = "StartOnCloseCompleted"; + internal const string FinishOnCloseReceived = "FinishOnCloseReceived"; + internal const string OnKeepAlive = "OnKeepAlive"; + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketBuffer.cs b/src/Microsoft.AspNet.WebSockets/WebSocketBuffer.cs new file mode 100644 index 0000000000..56b24a4fed --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketBuffer.cs @@ -0,0 +1,698 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics; +using System.Diagnostics.Contracts; +using System.Globalization; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; + +namespace Microsoft.AspNet.WebSockets +{ + // This class helps to abstract the internal WebSocket buffer, which is used to interact with the native WebSocket + // protocol component (WSPC). It helps to shield the details of the layout and the involved pointer arithmetic. + // The internal WebSocket buffer also contains a segment, which is used by the WebSocketBase class to buffer + // payload (parsed by WSPC already) for the application, if the application requested fewer bytes than the + // WSPC returned. The internal buffer is pinned for the whole lifetime if this class. + // LAYOUT: + // | Native buffer | PayloadReceiveBuffer | PropertyBuffer | + // | RBS + SBS + 144 | RBS | PBS | + // | Only WSPC may modify | Only WebSocketBase may modify | + // + // *RBS = ReceiveBufferSize, *SBS = SendBufferSize + // *PBS = PropertyBufferSize (32-bit: 16, 64 bit: 20 bytes) + internal class WebSocketBuffer : IDisposable + { + private const int NativeOverheadBufferSize = 144; + internal const int MinSendBufferSize = 16; + internal const int MinReceiveBufferSize = 256; + internal const int MaxBufferSize = 64 * 1024; +#if NET45 + private static readonly int SizeOfUInt = Marshal.SizeOf(typeof(uint)); + private static readonly int SizeOfBool = Marshal.SizeOf(typeof(bool)); +#else + private static readonly int SizeOfUInt = Marshal.SizeOf(); + private static readonly int SizeOfBool = Marshal.SizeOf(); +#endif + private static readonly int PropertyBufferSize = (2 * SizeOfUInt) + SizeOfBool + IntPtr.Size; + + private readonly int _ReceiveBufferSize; + + // Indicates the range of the pinned byte[] that can be used by the WSPC (nativeBuffer + pinnedSendBuffer) + private readonly long _StartAddress; + private readonly long _EndAddress; + private readonly GCHandle _GCHandle; + private readonly ArraySegment _InternalBuffer; + private readonly ArraySegment _NativeBuffer; + private readonly ArraySegment _PayloadBuffer; + private readonly ArraySegment _PropertyBuffer; + private readonly int _SendBufferSize; + private volatile int _PayloadOffset; + private volatile WebSocketReceiveResult _BufferedPayloadReceiveResult; + private long _PinnedSendBufferStartAddress; + private long _PinnedSendBufferEndAddress; + private ArraySegment _PinnedSendBuffer; + private GCHandle _PinnedSendBufferHandle; + private int _StateWhenDisposing = int.MinValue; + private int _SendBufferState; + + private WebSocketBuffer(ArraySegment internalBuffer, int receiveBufferSize, int sendBufferSize) + { + Contract.Assert(internalBuffer.Array != null, "'internalBuffer' MUST NOT be NULL."); + Contract.Assert(receiveBufferSize >= MinReceiveBufferSize, + "'receiveBufferSize' MUST be at least " + MinReceiveBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(sendBufferSize >= MinSendBufferSize, + "'sendBufferSize' MUST be at least " + MinSendBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(receiveBufferSize <= MaxBufferSize, + "'receiveBufferSize' MUST NOT exceed " + MaxBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(sendBufferSize <= MaxBufferSize, + "'sendBufferSize' MUST NOT exceed " + MaxBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + + _ReceiveBufferSize = receiveBufferSize; + _SendBufferSize = sendBufferSize; + _InternalBuffer = internalBuffer; + _GCHandle = GCHandle.Alloc(internalBuffer.Array, GCHandleType.Pinned); + // Size of the internal buffer owned exclusively by the WSPC. + int nativeBufferSize = _ReceiveBufferSize + _SendBufferSize + NativeOverheadBufferSize; + _StartAddress = Marshal.UnsafeAddrOfPinnedArrayElement(internalBuffer.Array, internalBuffer.Offset).ToInt64(); + _EndAddress = _StartAddress + nativeBufferSize; + _NativeBuffer = new ArraySegment(internalBuffer.Array, internalBuffer.Offset, nativeBufferSize); + _PayloadBuffer = new ArraySegment(internalBuffer.Array, + _NativeBuffer.Offset + _NativeBuffer.Count, + _ReceiveBufferSize); + _PropertyBuffer = new ArraySegment(internalBuffer.Array, + _PayloadBuffer.Offset + _PayloadBuffer.Count, + PropertyBufferSize); + _SendBufferState = SendBufferState.None; + } + + public int ReceiveBufferSize + { + get { return _ReceiveBufferSize; } + } + + public int SendBufferSize + { + get { return _SendBufferSize; } + } + + internal static WebSocketBuffer CreateClientBuffer(ArraySegment internalBuffer, int receiveBufferSize, int sendBufferSize) + { + Contract.Assert(internalBuffer.Count >= GetInternalBufferSize(receiveBufferSize, sendBufferSize, false), + "Array 'internalBuffer' is TOO SMALL. Call Validate before instantiating WebSocketBuffer."); + + return new WebSocketBuffer(internalBuffer, receiveBufferSize, GetNativeSendBufferSize(sendBufferSize, false)); + } + + internal static WebSocketBuffer CreateServerBuffer(ArraySegment internalBuffer, int receiveBufferSize) + { + int sendBufferSize = GetNativeSendBufferSize(MinSendBufferSize, true); + Contract.Assert(internalBuffer.Count >= GetInternalBufferSize(receiveBufferSize, sendBufferSize, true), + "Array 'internalBuffer' is TOO SMALL. Call Validate before instantiating WebSocketBuffer."); + + return new WebSocketBuffer(internalBuffer, receiveBufferSize, sendBufferSize); + } + + public void Dispose(WebSocketState webSocketState) + { + if (Interlocked.CompareExchange(ref _StateWhenDisposing, (int)webSocketState, int.MinValue) != int.MinValue) + { + return; + } + + this.CleanUp(); + } + + public void Dispose() + { + this.Dispose(WebSocketState.None); + } + + internal UnsafeNativeMethods.WebSocketProtocolComponent.Property[] CreateProperties(bool useZeroMaskingKey) + { + ThrowIfDisposed(); + // serialize marshaled property values in the property segment of the internal buffer + IntPtr internalBufferPtr = _GCHandle.AddrOfPinnedObject(); + int offset = _PropertyBuffer.Offset; + Marshal.WriteInt32(internalBufferPtr, offset, _ReceiveBufferSize); + offset += SizeOfUInt; + Marshal.WriteInt32(internalBufferPtr, offset, _SendBufferSize); + offset += SizeOfUInt; + Marshal.WriteIntPtr(internalBufferPtr, offset, internalBufferPtr); + offset += IntPtr.Size; + Marshal.WriteInt32(internalBufferPtr, offset, useZeroMaskingKey ? (int)1 : (int)0); + + int propertyCount = useZeroMaskingKey ? 4 : 3; + UnsafeNativeMethods.WebSocketProtocolComponent.Property[] properties = + new UnsafeNativeMethods.WebSocketProtocolComponent.Property[propertyCount]; + + // Calculate the pointers to the positions of the properties within the internal buffer + offset = _PropertyBuffer.Offset; + properties[0] = new UnsafeNativeMethods.WebSocketProtocolComponent.Property() + { + Type = UnsafeNativeMethods.WebSocketProtocolComponent.PropertyType.ReceiveBufferSize, + PropertySize = (uint)SizeOfUInt, + PropertyData = IntPtr.Add(internalBufferPtr, offset) + }; + offset += SizeOfUInt; + + properties[1] = new UnsafeNativeMethods.WebSocketProtocolComponent.Property() + { + Type = UnsafeNativeMethods.WebSocketProtocolComponent.PropertyType.SendBufferSize, + PropertySize = (uint)SizeOfUInt, + PropertyData = IntPtr.Add(internalBufferPtr, offset) + }; + offset += SizeOfUInt; + + properties[2] = new UnsafeNativeMethods.WebSocketProtocolComponent.Property() + { + Type = UnsafeNativeMethods.WebSocketProtocolComponent.PropertyType.AllocatedBuffer, + PropertySize = (uint)_NativeBuffer.Count, + PropertyData = IntPtr.Add(internalBufferPtr, offset) + }; + offset += IntPtr.Size; + + if (useZeroMaskingKey) + { + properties[3] = new UnsafeNativeMethods.WebSocketProtocolComponent.Property() + { + Type = UnsafeNativeMethods.WebSocketProtocolComponent.PropertyType.DisableMasking, + PropertySize = (uint)SizeOfBool, + PropertyData = IntPtr.Add(internalBufferPtr, offset) + }; + } + + return properties; + } + + // This method is not thread safe. It must only be called after enforcing at most 1 outstanding send operation + internal void PinSendBuffer(ArraySegment payload, out bool bufferHasBeenPinned) + { + bufferHasBeenPinned = false; + WebSocketHelpers.ValidateBuffer(payload.Array, payload.Offset, payload.Count); + int previousState = Interlocked.Exchange(ref _SendBufferState, SendBufferState.SendPayloadSpecified); + + if (previousState != SendBufferState.None) + { + Contract.Assert(false, "'m_SendBufferState' MUST BE 'None' at this point."); + // Indicates a violation in the API contract that could indicate + // memory corruption because the pinned sendbuffer is shared between managed and native code + throw new AccessViolationException(); + } + _PinnedSendBuffer = payload; + _PinnedSendBufferHandle = GCHandle.Alloc(_PinnedSendBuffer.Array, GCHandleType.Pinned); + bufferHasBeenPinned = true; + _PinnedSendBufferStartAddress = + Marshal.UnsafeAddrOfPinnedArrayElement(_PinnedSendBuffer.Array, _PinnedSendBuffer.Offset).ToInt64(); + _PinnedSendBufferEndAddress = _PinnedSendBufferStartAddress + _PinnedSendBuffer.Count; + } + + // This method is not thread safe. It must only be called after enforcing at most 1 outstanding send operation + internal IntPtr ConvertPinnedSendPayloadToNative(ArraySegment payload) + { + return ConvertPinnedSendPayloadToNative(payload.Array, payload.Offset, payload.Count); + } + + // This method is not thread safe. It must only be called after enforcing at most 1 outstanding send operation + internal IntPtr ConvertPinnedSendPayloadToNative(byte[] buffer, int offset, int count) + { + if (!IsPinnedSendPayloadBuffer(buffer, offset, count)) + { + // Indicates a violation in the API contract that could indicate + // memory corruption because the pinned sendbuffer is shared between managed and native code + throw new AccessViolationException(); + } + + Contract.Assert(Marshal.UnsafeAddrOfPinnedArrayElement(_PinnedSendBuffer.Array, + _PinnedSendBuffer.Offset).ToInt64() == _PinnedSendBufferStartAddress, + "'m_PinnedSendBuffer.Array' MUST be pinned during the entire send operation."); + + return new IntPtr(_PinnedSendBufferStartAddress + offset - _PinnedSendBuffer.Offset); + } + + // This method is not thread safe. It must only be called after enforcing at most 1 outstanding send operation + internal ArraySegment ConvertPinnedSendPayloadFromNative(UnsafeNativeMethods.WebSocketProtocolComponent.Buffer buffer, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType) + { + if (!IsPinnedSendPayloadBuffer(buffer, bufferType)) + { + // Indicates a violation in the API contract that could indicate + // memory corruption because the pinned sendbuffer is shared between managed and native code + throw new AccessViolationException(); + } + + Contract.Assert(Marshal.UnsafeAddrOfPinnedArrayElement(_PinnedSendBuffer.Array, + _PinnedSendBuffer.Offset).ToInt64() == _PinnedSendBufferStartAddress, + "'m_PinnedSendBuffer.Array' MUST be pinned during the entire send operation."); + + IntPtr bufferData; + uint bufferSize; + + UnwrapWebSocketBuffer(buffer, bufferType, out bufferData, out bufferSize); + + int internalOffset = (int)(bufferData.ToInt64() - _PinnedSendBufferStartAddress); + + return new ArraySegment(_PinnedSendBuffer.Array, _PinnedSendBuffer.Offset + internalOffset, (int)bufferSize); + } + + // This method is not thread safe. It must only be called after enforcing at most 1 outstanding send operation + private bool IsPinnedSendPayloadBuffer(byte[] buffer, int offset, int count) + { + if (_SendBufferState != SendBufferState.SendPayloadSpecified) + { + return false; + } + + return object.ReferenceEquals(buffer, _PinnedSendBuffer.Array) && + offset >= _PinnedSendBuffer.Offset && + offset + count <= _PinnedSendBuffer.Offset + _PinnedSendBuffer.Count; + } + + // This method is not thread safe. It must only be called after enforcing at most 1 outstanding send operation + private bool IsPinnedSendPayloadBuffer(UnsafeNativeMethods.WebSocketProtocolComponent.Buffer buffer, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType) + { + if (_SendBufferState != SendBufferState.SendPayloadSpecified) + { + return false; + } + + IntPtr bufferData; + uint bufferSize; + + UnwrapWebSocketBuffer(buffer, bufferType, out bufferData, out bufferSize); + + long nativeBufferStartAddress = bufferData.ToInt64(); + long nativeBufferEndAddress = nativeBufferStartAddress + bufferSize; + + return nativeBufferStartAddress >= _PinnedSendBufferStartAddress && + nativeBufferEndAddress >= _PinnedSendBufferStartAddress && + nativeBufferStartAddress <= _PinnedSendBufferEndAddress && + nativeBufferEndAddress <= _PinnedSendBufferEndAddress; + } + + // This method is only thread safe for races between Abort and at most 1 uncompleted send operation + internal void ReleasePinnedSendBuffer() + { + int previousState = Interlocked.Exchange(ref _SendBufferState, SendBufferState.None); + + if (previousState != SendBufferState.SendPayloadSpecified) + { + return; + } + + if (_PinnedSendBufferHandle.IsAllocated) + { + _PinnedSendBufferHandle.Free(); + } + + _PinnedSendBuffer = WebSocketHelpers.EmptyPayload; + } + + internal void BufferPayload(ArraySegment payload, + int unconsumedDataOffset, + WebSocketMessageType messageType, + bool endOfMessage) + { + ThrowIfDisposed(); + int bytesBuffered = payload.Count - unconsumedDataOffset; + + Contract.Assert(_PayloadOffset == 0, + "'m_PayloadOffset' MUST be '0' at this point."); + Contract.Assert(_BufferedPayloadReceiveResult == null || _BufferedPayloadReceiveResult.Count == 0, + "'m_BufferedPayloadReceiveResult.Count' MUST be '0' at this point."); + + Buffer.BlockCopy(payload.Array, + payload.Offset + unconsumedDataOffset, + _PayloadBuffer.Array, + _PayloadBuffer.Offset, + bytesBuffered); + + _BufferedPayloadReceiveResult = + new WebSocketReceiveResult(bytesBuffered, messageType, endOfMessage); + + this.ValidateBufferedPayload(); + } + + internal bool ReceiveFromBufferedPayload(ArraySegment buffer, out WebSocketReceiveResult receiveResult) + { + ThrowIfDisposed(); + ValidateBufferedPayload(); + + int bytesTransferred = Math.Min(buffer.Count, _BufferedPayloadReceiveResult.Count); + receiveResult = _BufferedPayloadReceiveResult.Copy(bytesTransferred); + + Buffer.BlockCopy(_PayloadBuffer.Array, + _PayloadBuffer.Offset + _PayloadOffset, + buffer.Array, + buffer.Offset, + bytesTransferred); + + bool morePayloadBuffered; + if (_BufferedPayloadReceiveResult.Count == 0) + { + _PayloadOffset = 0; + _BufferedPayloadReceiveResult = null; + morePayloadBuffered = false; + } + else + { + _PayloadOffset += bytesTransferred; + morePayloadBuffered = true; + this.ValidateBufferedPayload(); + } + + return morePayloadBuffered; + } + + internal ArraySegment ConvertNativeBuffer(UnsafeNativeMethods.WebSocketProtocolComponent.Action action, + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer buffer, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType) + { + ThrowIfDisposed(); + + IntPtr bufferData; + uint bufferLength; + + UnwrapWebSocketBuffer(buffer, bufferType, out bufferData, out bufferLength); + + if (bufferData == IntPtr.Zero) + { + return WebSocketHelpers.EmptyPayload; + } + + if (this.IsNativeBuffer(bufferData, bufferLength)) + { + return new ArraySegment(_InternalBuffer.Array, + this.GetOffset(bufferData), + (int)bufferLength); + } + + Contract.Assert(false, "'buffer' MUST reference a memory segment within the pinned InternalBuffer."); + // Indicates a violation in the contract with native Websocket.dll and could indicate + // memory corruption because the internal buffer is shared between managed and native code + throw new AccessViolationException(); + } + + internal void ConvertCloseBuffer(UnsafeNativeMethods.WebSocketProtocolComponent.Action action, + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer buffer, + out WebSocketCloseStatus closeStatus, + out string reason) + { + ThrowIfDisposed(); + IntPtr bufferData; + uint bufferLength; + closeStatus = (WebSocketCloseStatus)buffer.CloseStatus.CloseStatus; + + UnwrapWebSocketBuffer(buffer, UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close, out bufferData, out bufferLength); + + if (bufferData == IntPtr.Zero) + { + reason = null; + } + else + { + ArraySegment reasonBlob; + if (this.IsNativeBuffer(bufferData, bufferLength)) + { + reasonBlob = new ArraySegment(_InternalBuffer.Array, + this.GetOffset(bufferData), + (int)bufferLength); + } + else + { + Contract.Assert(false, "'buffer' MUST reference a memory segment within the pinned InternalBuffer."); + // Indicates a violation in the contract with native Websocket.dll and could indicate + // memory corruption because the internal buffer is shared between managed and native code + throw new AccessViolationException(); + } + + // No need to wrap DecoderFallbackException for invalid UTF8 chacters, because + // Encoding.UTF8 will not throw but replace invalid characters instead. + reason = Encoding.UTF8.GetString(reasonBlob.Array, reasonBlob.Offset, reasonBlob.Count); + } + } + + internal void ValidateNativeBuffers(UnsafeNativeMethods.WebSocketProtocolComponent.Action action, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType, + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer[] dataBuffers, + uint dataBufferCount) + { + Contract.Assert(dataBufferCount <= (uint)int.MaxValue, + "'dataBufferCount' MUST NOT be bigger than Int32.MaxValue."); + Contract.Assert(dataBuffers != null, "'dataBuffers' MUST NOT be NULL."); + + ThrowIfDisposed(); + if (dataBufferCount > dataBuffers.Length) + { + Contract.Assert(false, "'dataBufferCount' MUST NOT be bigger than 'dataBuffers.Length'."); + // Indicates a violation in the contract with native Websocket.dll and could indicate + // memory corruption because the internal buffer is shared between managed and native code + throw new AccessViolationException(); + } + + int count = dataBuffers.Length; + bool isSendActivity = action == UnsafeNativeMethods.WebSocketProtocolComponent.Action.IndicateSendComplete || + action == UnsafeNativeMethods.WebSocketProtocolComponent.Action.SendToNetwork; + + if (isSendActivity) + { + count = (int)dataBufferCount; + } + + bool nonZeroBufferFound = false; + for (int i = 0; i < count; i++) + { + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer dataBuffer = dataBuffers[i]; + + IntPtr bufferData; + uint bufferLength; + UnwrapWebSocketBuffer(dataBuffer, bufferType, out bufferData, out bufferLength); + + if (bufferData == IntPtr.Zero) + { + continue; + } + + nonZeroBufferFound = true; + + bool isPinnedSendPayloadBuffer = IsPinnedSendPayloadBuffer(dataBuffer, bufferType); + + if (bufferLength > GetMaxBufferSize()) + { + if (!isSendActivity || !isPinnedSendPayloadBuffer) + { + Contract.Assert(false, + "'dataBuffer.BufferLength' MUST NOT be bigger than 'm_ReceiveBufferSize' and 'm_SendBufferSize'."); + // Indicates a violation in the contract with native Websocket.dll and could indicate + // memory corruption because the internal buffer is shared between managed and native code + throw new AccessViolationException(); + } + } + + if (!isPinnedSendPayloadBuffer && !IsNativeBuffer(bufferData, bufferLength)) + { + Contract.Assert(false, + "WebSocketGetAction MUST return a pointer within the pinned internal buffer."); + // Indicates a violation in the contract with native Websocket.dll and could indicate + // memory corruption because the internal buffer is shared between managed and native code + throw new AccessViolationException(); + } + } + + if (!nonZeroBufferFound && + action != UnsafeNativeMethods.WebSocketProtocolComponent.Action.NoAction && + action != UnsafeNativeMethods.WebSocketProtocolComponent.Action.IndicateReceiveComplete && + action != UnsafeNativeMethods.WebSocketProtocolComponent.Action.IndicateSendComplete) + { + Contract.Assert(false, "At least one 'dataBuffer.Buffer' MUST NOT be NULL."); + } + } + + private static int GetNativeSendBufferSize(int sendBufferSize, bool isServerBuffer) + { + return isServerBuffer ? MinSendBufferSize : sendBufferSize; + } + + internal static void UnwrapWebSocketBuffer(UnsafeNativeMethods.WebSocketProtocolComponent.Buffer buffer, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType, + out IntPtr bufferData, + out uint bufferLength) + { + bufferData = IntPtr.Zero; + bufferLength = 0; + + switch (bufferType) + { + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close: + bufferData = buffer.CloseStatus.ReasonData; + bufferLength = buffer.CloseStatus.ReasonLength; + break; + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.None: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryFragment: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryMessage: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Fragment: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Message: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.PingPong: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UnsolicitedPong: + bufferData = buffer.Data.BufferData; + bufferLength = buffer.Data.BufferLength; + break; + default: + Contract.Assert(false, + string.Format(CultureInfo.InvariantCulture, + "BufferType '{0}' is invalid/unknown.", + bufferType)); + break; + } + } + + private void ThrowIfDisposed() + { + switch (_StateWhenDisposing) + { + case int.MinValue: + return; + case (int)WebSocketState.Closed: + case (int)WebSocketState.Aborted: + throw new WebSocketException(WebSocketError.InvalidState, + SR.GetString(SR.net_WebSockets_InvalidState_ClosedOrAborted, typeof(WebSocketBase), _StateWhenDisposing)); + default: + throw new ObjectDisposedException(GetType().FullName); + } + } + + [Conditional("DEBUG"), Conditional("CONTRACTS_FULL")] + private void ValidateBufferedPayload() + { + Contract.Assert(_BufferedPayloadReceiveResult != null, + "'m_BufferedPayloadReceiveResult' MUST NOT be NULL."); + Contract.Assert(_BufferedPayloadReceiveResult.Count >= 0, + "'m_BufferedPayloadReceiveResult.Count' MUST NOT be negative."); + Contract.Assert(_PayloadOffset >= 0, "'m_PayloadOffset' MUST NOT be smaller than 0."); + Contract.Assert(_PayloadOffset <= _PayloadBuffer.Count, + "'m_PayloadOffset' MUST NOT be bigger than 'm_PayloadBuffer.Count'."); + Contract.Assert(_PayloadOffset + _BufferedPayloadReceiveResult.Count <= _PayloadBuffer.Count, + "'m_PayloadOffset + m_PayloadBytesBuffered' MUST NOT be bigger than 'm_PayloadBuffer.Count'."); + } + + private int GetOffset(IntPtr pBuffer) + { + Contract.Assert(pBuffer != IntPtr.Zero, "'pBuffer' MUST NOT be IntPtr.Zero."); + int offset = (int)(pBuffer.ToInt64() - _StartAddress + _InternalBuffer.Offset); + + Contract.Assert(offset >= 0, "'offset' MUST NOT be negative."); + return offset; + } + + [Pure] + private int GetMaxBufferSize() + { + return Math.Max(_ReceiveBufferSize, _SendBufferSize); + } + + internal bool IsInternalBuffer(byte[] buffer, int offset, int count) + { + Contract.Assert(buffer != null, "'buffer' MUST NOT be NULL."); + Contract.Assert(_InternalBuffer.Array != null, "'m_InternalBuffer.Array' MUST NOT be NULL."); + Contract.Assert(offset >= 0, "'offset' MUST NOT be negative."); + Contract.Assert(count >= 0, "'count' MUST NOT be negative."); + Contract.Assert(offset + count <= buffer.Length, "'offset + count' MUST NOT exceed 'buffer.Length'."); + + return object.ReferenceEquals(buffer, _InternalBuffer.Array); + } + + internal IntPtr ToIntPtr(int offset) + { + Contract.Assert(offset >= 0, "'offset' MUST NOT be negative."); + Contract.Assert(_StartAddress + offset <= _EndAddress, "'offset' is TOO BIG."); + return new IntPtr(_StartAddress + offset); + } + + private bool IsNativeBuffer(IntPtr pBuffer, uint bufferSize) + { + Contract.Assert(pBuffer != IntPtr.Zero, "'pBuffer' MUST NOT be NULL."); + Contract.Assert(bufferSize <= GetMaxBufferSize(), + "'bufferSize' MUST NOT be bigger than 'm_ReceiveBufferSize' and 'm_SendBufferSize'."); + + long nativeBufferStartAddress = pBuffer.ToInt64(); + long nativeBufferEndAddress = bufferSize + nativeBufferStartAddress; + + Contract.Assert(Marshal.UnsafeAddrOfPinnedArrayElement(_InternalBuffer.Array, _InternalBuffer.Offset).ToInt64() == _StartAddress, + "'m_InternalBuffer.Array' MUST be pinned for the whole lifetime of a WebSocket."); + + if (nativeBufferStartAddress >= _StartAddress && + nativeBufferStartAddress <= _EndAddress && + nativeBufferEndAddress >= _StartAddress && + nativeBufferEndAddress <= _EndAddress) + { + return true; + } + + return false; + } + + private void CleanUp() + { + if (_GCHandle.IsAllocated) + { + _GCHandle.Free(); + } + + ReleasePinnedSendBuffer(); + } + + internal static ArraySegment CreateInternalBufferArraySegment(int receiveBufferSize, int sendBufferSize, bool isServerBuffer) + { + Contract.Assert(receiveBufferSize >= MinReceiveBufferSize, + "'receiveBufferSize' MUST be at least " + MinReceiveBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(sendBufferSize >= MinSendBufferSize, + "'sendBufferSize' MUST be at least " + MinSendBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + + int internalBufferSize = GetInternalBufferSize(receiveBufferSize, sendBufferSize, isServerBuffer); + return new ArraySegment(new byte[internalBufferSize]); + } + + internal static void Validate(int count, int receiveBufferSize, int sendBufferSize, bool isServerBuffer) + { + Contract.Assert(receiveBufferSize >= MinReceiveBufferSize, + "'receiveBufferSize' MUST be at least " + MinReceiveBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(sendBufferSize >= MinSendBufferSize, + "'sendBufferSize' MUST be at least " + MinSendBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + + int minBufferSize = GetInternalBufferSize(receiveBufferSize, sendBufferSize, isServerBuffer); + if (count < minBufferSize) + { + throw new ArgumentOutOfRangeException("internalBuffer", + SR.GetString(SR.net_WebSockets_ArgumentOutOfRange_InternalBuffer, minBufferSize)); + } + } + + private static int GetInternalBufferSize(int receiveBufferSize, int sendBufferSize, bool isServerBuffer) + { + Contract.Assert(receiveBufferSize >= MinReceiveBufferSize, + "'receiveBufferSize' MUST be at least " + MinReceiveBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(sendBufferSize >= MinSendBufferSize, + "'sendBufferSize' MUST be at least " + MinSendBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + + Contract.Assert(receiveBufferSize <= MaxBufferSize, + "'receiveBufferSize' MUST be less than or equal to " + MaxBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(sendBufferSize <= MaxBufferSize, + "'sendBufferSize' MUST be at less than or equal to " + MaxBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + + int nativeSendBufferSize = GetNativeSendBufferSize(sendBufferSize, isServerBuffer); + return (2 * receiveBufferSize) + nativeSendBufferSize + NativeOverheadBufferSize + PropertyBufferSize; + } + + private static class SendBufferState + { + public const int None = 0; + public const int SendPayloadSpecified = 1; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketCloseStatus.cs b/src/Microsoft.AspNet.WebSockets/WebSocketCloseStatus.cs new file mode 100644 index 0000000000..a773179e21 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketCloseStatus.cs @@ -0,0 +1,40 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.AspNet.WebSockets +{ + [SuppressMessage("Microsoft.Design", + "CA1008:EnumsShouldHaveZeroValue", + Justification = "This enum is reflecting the IETF's WebSocket specification. " + + "'0' is a disallowed value for the close status code")] + public enum WebSocketCloseStatus + { + NormalClosure = 1000, + EndpointUnavailable = 1001, + ProtocolError = 1002, + InvalidMessageType = 1003, + Empty = 1005, + // AbnormalClosure = 1006, // 1006 is reserved and should never be used by user + InvalidPayloadData = 1007, + PolicyViolation = 1008, + MessageTooBig = 1009, + MandatoryExtension = 1010, + InternalServerError = 1011 + // TLSHandshakeFailed = 1015, // 1015 is reserved and should never be used by user + + // 0 - 999 Status codes in the range 0-999 are not used. + // 1000 - 1999 Status codes in the range 1000-1999 are reserved for definition by this protocol. + // 2000 - 2999 Status codes in the range 2000-2999 are reserved for use by extensions. + // 3000 - 3999 Status codes in the range 3000-3999 MAY be used by libraries and frameworks. The + // interpretation of these codes is undefined by this protocol. End applications MUST + // NOT use status codes in this range. + // 4000 - 4999 Status codes in the range 4000-4999 MAY be used by application code. The interpretaion + // of these codes is undefined by this protocol. + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketError.cs b/src/Microsoft.AspNet.WebSockets/WebSocketError.cs new file mode 100644 index 0000000000..473bb56349 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketError.cs @@ -0,0 +1,22 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.WebSockets +{ + public enum WebSocketError + { + Success = 0, + InvalidMessageType = 1, + Faulted = 2, + NativeError = 3, + NotAWebSocket = 4, + UnsupportedVersion = 5, + UnsupportedProtocol = 6, + HeaderError = 7, + ConnectionClosedPrematurely = 8, + InvalidState = 9 + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketException.cs b/src/Microsoft.AspNet.WebSockets/WebSocketException.cs new file mode 100644 index 0000000000..64edc3d0cb --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketException.cs @@ -0,0 +1,155 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.ComponentModel; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.WebSockets +{ +#if NET45 + [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2237:MarkISerializableTypesWithSerializable")] +#endif + internal sealed class WebSocketException : Win32Exception + { + private WebSocketError _WebSocketErrorCode; + + public WebSocketException() + : this(Marshal.GetLastWin32Error()) + { + } + + public WebSocketException(WebSocketError error) + : this(error, GetErrorMessage(error)) + { + } + + public WebSocketException(WebSocketError error, string message) : base(message) + { + _WebSocketErrorCode = error; + } + + public WebSocketException(WebSocketError error, Exception innerException) + : this(error, GetErrorMessage(error), innerException) + { + } + + public WebSocketException(WebSocketError error, string message, Exception innerException) + : base(message, innerException) + { + _WebSocketErrorCode = error; + } + + public WebSocketException(int nativeError) + : base(nativeError) + { + _WebSocketErrorCode = !UnsafeNativeMethods.WebSocketProtocolComponent.Succeeded(nativeError) ? WebSocketError.NativeError : WebSocketError.Success; + this.SetErrorCodeOnError(nativeError); + } + + public WebSocketException(int nativeError, string message) + : base(nativeError, message) + { + _WebSocketErrorCode = !UnsafeNativeMethods.WebSocketProtocolComponent.Succeeded(nativeError) ? WebSocketError.NativeError : WebSocketError.Success; + this.SetErrorCodeOnError(nativeError); + } + + public WebSocketException(int nativeError, Exception innerException) + : base(SR.GetString(SR.net_WebSockets_Generic), innerException) + { + _WebSocketErrorCode = !UnsafeNativeMethods.WebSocketProtocolComponent.Succeeded(nativeError) ? WebSocketError.NativeError : WebSocketError.Success; + this.SetErrorCodeOnError(nativeError); + } + + public WebSocketException(WebSocketError error, int nativeError) + : this(error, nativeError, GetErrorMessage(error)) + { + } + + public WebSocketException(WebSocketError error, int nativeError, string message) + : base(message) + { + _WebSocketErrorCode = error; + this.SetErrorCodeOnError(nativeError); + } + + public WebSocketException(WebSocketError error, int nativeError, Exception innerException) + : this(error, nativeError, GetErrorMessage(error), innerException) + { + } + + public WebSocketException(WebSocketError error, int nativeError, string message, Exception innerException) + : base(message, innerException) + { + _WebSocketErrorCode = error; + this.SetErrorCodeOnError(nativeError); + } + + public WebSocketException(string message) + : base(message) + { + } + + public WebSocketException(string message, Exception innerException) + : base(message, innerException) + { + } + + public override int ErrorCode + { + get + { + return base.NativeErrorCode; + } + } + + public WebSocketError WebSocketErrorCode + { + get + { + return _WebSocketErrorCode; + } + } + + private static string GetErrorMessage(WebSocketError error) + { + // provide a canned message for the error type + switch (error) + { + case WebSocketError.InvalidMessageType: + return SR.GetString(SR.net_WebSockets_InvalidMessageType_Generic, + typeof(WebSocket).Name + WebSocketBase.Methods.CloseAsync, + typeof(WebSocket).Name + WebSocketBase.Methods.CloseOutputAsync); + case WebSocketError.Faulted: + return SR.GetString(SR.net_Websockets_WebSocketBaseFaulted); + case WebSocketError.NotAWebSocket: + return SR.GetString(SR.net_WebSockets_NotAWebSocket_Generic); + case WebSocketError.UnsupportedVersion: + return SR.GetString(SR.net_WebSockets_UnsupportedWebSocketVersion_Generic); + case WebSocketError.UnsupportedProtocol: + return SR.GetString(SR.net_WebSockets_UnsupportedProtocol_Generic); + case WebSocketError.HeaderError: + return SR.GetString(SR.net_WebSockets_HeaderError_Generic); + case WebSocketError.ConnectionClosedPrematurely: + return SR.GetString(SR.net_WebSockets_ConnectionClosedPrematurely_Generic); + case WebSocketError.InvalidState: + return SR.GetString(SR.net_WebSockets_InvalidState_Generic); + default: + return SR.GetString(SR.net_WebSockets_Generic); + } + } + + // Set the error code only if there is an error (i.e. nativeError >= 0). Otherwise the code blows up on deserialization + // as the Exception..ctor() throws on setting HResult to 0. The default for HResult is -2147467259. + private void SetErrorCodeOnError(int nativeError) + { + if (!UnsafeNativeMethods.WebSocketProtocolComponent.Succeeded(nativeError)) + { + this.HResult = nativeError; + } + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketExtensions.cs b/src/Microsoft.AspNet.WebSockets/WebSocketExtensions.cs new file mode 100644 index 0000000000..abcbd12865 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketExtensions.cs @@ -0,0 +1,14 @@ +#if NET45 +using Microsoft.AspNet.WebSockets; + +namespace Owin +{ + public static class WebSocketExtensions + { + public static IAppBuilder UseWebSockets(this IAppBuilder app) + { + return app.Use(); + } + } +} +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketHelpers.cs b/src/Microsoft.AspNet.WebSockets/WebSocketHelpers.cs new file mode 100644 index 0000000000..3479e137fa --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketHelpers.cs @@ -0,0 +1,522 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; +#if NET45 +using System.Security.Cryptography; +#endif +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebSockets +{ + internal static class WebSocketHelpers + { + internal const string SecWebSocketKeyGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + internal const string WebSocketUpgradeToken = "websocket"; + internal const int DefaultReceiveBufferSize = 16 * 1024; + internal const int DefaultClientSendBufferSize = 16 * 1024; + internal const int MaxControlFramePayloadLength = 123; + + // RFC 6455 requests WebSocket clients to let the server initiate the TCP close to avoid that client sockets + // end up in TIME_WAIT-state + // + // After both sending and receiving a Close message, an endpoint considers the WebSocket connection closed and + // MUST close the underlying TCP connection. The server MUST close the underlying TCP connection immediately; + // the client SHOULD wait for the server to close the connection but MAY close the connection at any time after + // sending and receiving a Close message, e.g., if it has not received a TCP Close from the server in a + // reasonable time period. + internal const int ClientTcpCloseTimeout = 1000; // 1s + + private const int CloseStatusCodeAbort = 1006; + private const int CloseStatusCodeFailedTLSHandshake = 1015; + private const int InvalidCloseStatusCodesFrom = 0; + private const int InvalidCloseStatusCodesTo = 999; + private const string Separators = "()<>@,;:\\\"/[]?={} "; + + internal static readonly ArraySegment EmptyPayload = new ArraySegment(new byte[] { }, 0, 0); + private static readonly Random KeyGenerator = new Random(); + +/* + internal static Task AcceptWebSocketAsync(HttpListenerContext context, + string subProtocol, + int receiveBufferSize, + TimeSpan keepAliveInterval, + ArraySegment internalBuffer) + { + WebSocketHelpers.ValidateOptions(subProtocol, receiveBufferSize, WebSocketBuffer.MinSendBufferSize, keepAliveInterval); + WebSocketHelpers.ValidateArraySegment(internalBuffer, "internalBuffer"); + WebSocketBuffer.Validate(internalBuffer.Count, receiveBufferSize, WebSocketBuffer.MinSendBufferSize, true); + + return AcceptWebSocketAsyncCore(context, subProtocol, receiveBufferSize, keepAliveInterval, internalBuffer); + } + + private static async Task AcceptWebSocketAsyncCore(HttpListenerContext context, + string subProtocol, + int receiveBufferSize, + TimeSpan keepAliveInterval, + ArraySegment internalBuffer) + { + HttpListenerWebSocketContext webSocketContext = null; + /*if (Logging.On) + { + Logging.Enter(Logging.WebSockets, context, "AcceptWebSocketAsync", ""); + }* / + + try + { + // get property will create a new response if one doesn't exist. + HttpListenerResponse response = context.Response; + HttpListenerRequest request = context.Request; + ValidateWebSocketHeaders(context); + + string secWebSocketVersion = request.Headers[HttpKnownHeaderNames.SecWebSocketVersion]; + + // Optional for non-browser client + string origin = request.Headers[HttpKnownHeaderNames.Origin]; + + List secWebSocketProtocols = new List(); + string outgoingSecWebSocketProtocolString; + bool shouldSendSecWebSocketProtocolHeader = + WebSocketHelpers.ProcessWebSocketProtocolHeader( + request.Headers[HttpKnownHeaderNames.SecWebSocketProtocol], + subProtocol, + out outgoingSecWebSocketProtocolString); + + if (shouldSendSecWebSocketProtocolHeader) + { + secWebSocketProtocols.Add(outgoingSecWebSocketProtocolString); + response.Headers.Add(HttpKnownHeaderNames.SecWebSocketProtocol, + outgoingSecWebSocketProtocolString); + } + + // negotiate the websocket key return value + string secWebSocketKey = request.Headers[HttpKnownHeaderNames.SecWebSocketKey]; + string secWebSocketAccept = WebSocketHelpers.GetSecWebSocketAcceptString(secWebSocketKey); + + response.Headers.Add(HttpKnownHeaderNames.Connection, HttpKnownHeaderNames.Upgrade); + response.Headers.Add(HttpKnownHeaderNames.Upgrade, WebSocketHelpers.WebSocketUpgradeToken); + response.Headers.Add(HttpKnownHeaderNames.SecWebSocketAccept, secWebSocketAccept); + + response.StatusCode = (int)HttpStatusCode.SwitchingProtocols; // HTTP 101 + response.ComputeCoreHeaders(); + ulong hresult = SendWebSocketHeaders(response); + if (hresult != 0) + { + throw new WebSocketException((int)hresult, + SR.GetString(SR.net_WebSockets_NativeSendResponseHeaders, + WebSocketHelpers.MethodNames.AcceptWebSocketAsync, + hresult)); + } + + await response.OutputStream.FlushAsync().SuppressContextFlow(); // TODO:??? FlushAsync was never implemented + + HttpResponseStream responseStream = response.OutputStream as HttpResponseStream; + Contract.Assert(responseStream != null, "'responseStream' MUST be castable to System.Net.HttpResponseStream."); + ((HttpResponseStream)response.OutputStream).SwitchToOpaqueMode(); + HttpRequestStream requestStream = new HttpRequestStream(context); + requestStream.SwitchToOpaqueMode(); + WebSocketHttpListenerDuplexStream webSocketStream = + new WebSocketHttpListenerDuplexStream(requestStream, responseStream, context); + WebSocket webSocket = WebSocket.CreateServerWebSocket(webSocketStream, + subProtocol, + receiveBufferSize, + keepAliveInterval, + internalBuffer); + + webSocketContext = new HttpListenerWebSocketContext( + request.Url, + request.Headers, + request.Cookies, + context.User, + request.IsAuthenticated, + request.IsLocal, + request.IsSecureConnection, + origin, + secWebSocketProtocols.AsReadOnly(), + secWebSocketVersion, + secWebSocketKey, + webSocket); + + if (Logging.On) + { + Logging.Associate(Logging.WebSockets, context, webSocketContext); + Logging.Associate(Logging.WebSockets, webSocketContext, webSocket); + } + } + catch (Exception ex) + { + if (Logging.On) + { + Logging.Exception(Logging.WebSockets, context, "AcceptWebSocketAsync", ex); + } + throw; + } + finally + { + if (Logging.On) + { + Logging.Exit(Logging.WebSockets, context, "AcceptWebSocketAsync", ""); + } + } + return webSocketContext; + } + +*/ + [SuppressMessage("Microsoft.Cryptographic.Standard", "CA5354:SHA1CannotBeUsed", + Justification = "SHA1 used only for hashing purposes, not for crypto.")] + internal static string GetSecWebSocketAcceptString(string secWebSocketKey) + { + string retVal; +#if NET45 + // SHA1 used only for hashing purposes, not for crypto. Check here for FIPS compat. + using (SHA1 sha1 = SHA1.Create()) + { + string acceptString = string.Concat(secWebSocketKey, WebSocketHelpers.SecWebSocketKeyGuid); + byte[] toHash = Encoding.UTF8.GetBytes(acceptString); + retVal = Convert.ToBase64String(sha1.ComputeHash(toHash)); + } +#endif + return retVal; + } + + internal static string GetTraceMsgForParameters(int offset, int count, CancellationToken cancellationToken) + { + return string.Format(CultureInfo.InvariantCulture, + "offset: {0}, count: {1}, cancellationToken.CanBeCanceled: {2}", + offset, + count, + cancellationToken.CanBeCanceled); + } + + // return value here signifies if a Sec-WebSocket-Protocol header should be returned by the server. + internal static bool ProcessWebSocketProtocolHeader(string clientSecWebSocketProtocol, + string subProtocol, + out string acceptProtocol) + { + acceptProtocol = string.Empty; + if (string.IsNullOrEmpty(clientSecWebSocketProtocol)) + { + // client hasn't specified any Sec-WebSocket-Protocol header + if (subProtocol != null) + { + // If the server specified _anything_ this isn't valid. + throw new WebSocketException(WebSocketError.UnsupportedProtocol, + SR.GetString(SR.net_WebSockets_ClientAcceptingNoProtocols, subProtocol)); + } + // Treat empty and null from the server as the same thing here, server should not send headers. + return false; + } + + // here, we know the client specified something and it's non-empty. + + if (subProtocol == null) + { + // client specified some protocols, server specified 'null'. So server should send headers. + return true; + } + + // here, we know that the client has specified something, it's not empty + // and the server has specified exactly one protocol + + string[] requestProtocols = clientSecWebSocketProtocol.Split(new char[] { ',' }, + StringSplitOptions.RemoveEmptyEntries); + acceptProtocol = subProtocol; + + // client specified protocols, serverOptions has exactly 1 non-empty entry. Check that + // this exists in the list the client specified. + for (int i = 0; i < requestProtocols.Length; i++) + { + string currentRequestProtocol = requestProtocols[i].Trim(); + if (string.Compare(acceptProtocol, currentRequestProtocol, StringComparison.OrdinalIgnoreCase) == 0) + { + return true; + } + } + + throw new WebSocketException(WebSocketError.UnsupportedProtocol, + SR.GetString(SR.net_WebSockets_AcceptUnsupportedProtocol, + clientSecWebSocketProtocol, + subProtocol)); + } + + internal static ConfiguredTaskAwaitable SuppressContextFlow(this Task task) + { + // We don't flow the synchronization context within WebSocket.xxxAsync - but the calling application + // can decide whether the completion callback for the task returned from WebSocket.xxxAsync runs + // under the caller's synchronization context. + return task.ConfigureAwait(false); + } + + internal static ConfiguredTaskAwaitable SuppressContextFlow(this Task task) + { + // We don't flow the synchronization context within WebSocket.xxxAsync - but the calling application + // can decide whether the completion callback for the task returned from WebSocket.xxxAsync runs + // under the caller's synchronization context. + return task.ConfigureAwait(false); + } + + internal static void ValidateBuffer(byte[] buffer, int offset, int count) + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + + if (count < 0 || count > (buffer.Length - offset)) + { + throw new ArgumentOutOfRangeException("count"); + } + } + /* + private static unsafe ulong SendWebSocketHeaders(HttpListenerResponse response) + { + return response.SendHeaders(null, null, + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_OPAQUE | + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA | + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA, + true); + } + private static void ValidateWebSocketHeaders(HttpListenerContext context) + { + EnsureHttpSysSupportsWebSockets(); + + if (!context.Request.IsWebSocketRequest) + { + throw new WebSocketException(WebSocketError.NotAWebSocket, + SR.GetString(SR.net_WebSockets_AcceptNotAWebSocket, + WebSocketHelpers.MethodNames.ValidateWebSocketHeaders, + HttpKnownHeaderNames.Connection, + HttpKnownHeaderNames.Upgrade, + WebSocketHelpers.WebSocketUpgradeToken, + context.Request.Headers[HttpKnownHeaderNames.Upgrade])); + } + + string secWebSocketVersion = context.Request.Headers[HttpKnownHeaderNames.SecWebSocketVersion]; + if (string.IsNullOrEmpty(secWebSocketVersion)) + { + throw new WebSocketException(WebSocketError.HeaderError, + SR.GetString(SR.net_WebSockets_AcceptHeaderNotFound, + WebSocketHelpers.MethodNames.ValidateWebSocketHeaders, + HttpKnownHeaderNames.SecWebSocketVersion)); + } + + if (string.Compare(secWebSocketVersion, WebSocketProtocolComponent.SupportedVersion, StringComparison.OrdinalIgnoreCase) != 0) + { + throw new WebSocketException(WebSocketError.UnsupportedVersion, + SR.GetString(SR.net_WebSockets_AcceptUnsupportedWebSocketVersion, + WebSocketHelpers.MethodNames.ValidateWebSocketHeaders, + secWebSocketVersion, + WebSocketProtocolComponent.SupportedVersion)); + } + + if (string.IsNullOrWhiteSpace(context.Request.Headers[HttpKnownHeaderNames.SecWebSocketKey])) + { + throw new WebSocketException(WebSocketError.HeaderError, + SR.GetString(SR.net_WebSockets_AcceptHeaderNotFound, + WebSocketHelpers.MethodNames.ValidateWebSocketHeaders, + HttpKnownHeaderNames.SecWebSocketKey)); + } + } + */ + + internal static void ValidateSubprotocol(string subProtocol) + { + if (string.IsNullOrWhiteSpace(subProtocol)) + { + throw new ArgumentException(SR.GetString(SR.net_WebSockets_InvalidEmptySubProtocol), "subProtocol"); + } + + char[] chars = subProtocol.ToCharArray(); + string invalidChar = null; + int i = 0; + while (i < chars.Length) + { + char ch = chars[i]; + if (ch < 0x21 || ch > 0x7e) + { + invalidChar = string.Format(CultureInfo.InvariantCulture, "[{0}]", (int)ch); + break; + } + + if (!char.IsLetterOrDigit(ch) && + Separators.IndexOf(ch) >= 0) + { + invalidChar = ch.ToString(); + break; + } + + i++; + } + + if (invalidChar != null) + { + throw new ArgumentException(SR.GetString(SR.net_WebSockets_InvalidCharInProtocolString, subProtocol, invalidChar), + "subProtocol"); + } + } + + internal static void ValidateCloseStatus(WebSocketCloseStatus closeStatus, string statusDescription) + { + if (closeStatus == WebSocketCloseStatus.Empty && !string.IsNullOrEmpty(statusDescription)) + { + throw new ArgumentException(SR.GetString(SR.net_WebSockets_ReasonNotNull, + statusDescription, + WebSocketCloseStatus.Empty), + "statusDescription"); + } + + int closeStatusCode = (int)closeStatus; + + if ((closeStatusCode >= InvalidCloseStatusCodesFrom && + closeStatusCode <= InvalidCloseStatusCodesTo) || + closeStatusCode == CloseStatusCodeAbort || + closeStatusCode == CloseStatusCodeFailedTLSHandshake) + { + // CloseStatus 1006 means Aborted - this will never appear on the wire and is reflected by calling WebSocket.Abort + throw new ArgumentException(SR.GetString(SR.net_WebSockets_InvalidCloseStatusCode, + closeStatusCode), + "closeStatus"); + } + + int length = 0; + if (!string.IsNullOrEmpty(statusDescription)) + { + length = UTF8Encoding.UTF8.GetByteCount(statusDescription); + } + + if (length > WebSocketHelpers.MaxControlFramePayloadLength) + { + throw new ArgumentException(SR.GetString(SR.net_WebSockets_InvalidCloseStatusDescription, + statusDescription, + WebSocketHelpers.MaxControlFramePayloadLength), + "statusDescription"); + } + } + + internal static void ValidateOptions(string subProtocol, + int receiveBufferSize, + int sendBufferSize, + TimeSpan keepAliveInterval) + { + // We allow the subProtocol to be null. Validate if it is not null. + if (subProtocol != null) + { + ValidateSubprotocol(subProtocol); + } + + ValidateBufferSizes(receiveBufferSize, sendBufferSize); + + // -1 + if (keepAliveInterval < Timeout.InfiniteTimeSpan) + { + throw new ArgumentOutOfRangeException("keepAliveInterval", keepAliveInterval, + SR.GetString(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, Timeout.InfiniteTimeSpan.ToString())); + } + } + + internal static void ValidateBufferSizes(int receiveBufferSize, int sendBufferSize) + { + if (receiveBufferSize < WebSocketBuffer.MinReceiveBufferSize) + { + throw new ArgumentOutOfRangeException("receiveBufferSize", receiveBufferSize, + SR.GetString(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, WebSocketBuffer.MinReceiveBufferSize)); + } + + if (sendBufferSize < WebSocketBuffer.MinSendBufferSize) + { + throw new ArgumentOutOfRangeException("sendBufferSize", sendBufferSize, + SR.GetString(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, WebSocketBuffer.MinSendBufferSize)); + } + + if (receiveBufferSize > WebSocketBuffer.MaxBufferSize) + { + throw new ArgumentOutOfRangeException("receiveBufferSize", receiveBufferSize, + SR.GetString(SR.net_WebSockets_ArgumentOutOfRange_TooBig, + "receiveBufferSize", + receiveBufferSize, + WebSocketBuffer.MaxBufferSize)); + } + + if (sendBufferSize > WebSocketBuffer.MaxBufferSize) + { + throw new ArgumentOutOfRangeException("sendBufferSize", sendBufferSize, + SR.GetString(SR.net_WebSockets_ArgumentOutOfRange_TooBig, + "sendBufferSize", + sendBufferSize, + WebSocketBuffer.MaxBufferSize)); + } + } + + internal static void ValidateInnerStream(Stream innerStream) + { + if (innerStream == null) + { + throw new ArgumentNullException("innerStream"); + } + + if (!innerStream.CanRead) + { + throw new ArgumentException(SR.GetString(SR.NotReadableStream), "innerStream"); + } + + if (!innerStream.CanWrite) + { + throw new ArgumentException(SR.GetString(SR.NotWriteableStream), "innerStream"); + } + } + + internal static void ThrowIfConnectionAborted(Stream connection, bool read) + { + if ((!read && !connection.CanWrite) || + (read && !connection.CanRead)) + { + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); + } + } + + internal static void ThrowPlatformNotSupportedException_WSPC() + { + throw new PlatformNotSupportedException(SR.GetString(SR.net_WebSockets_UnsupportedPlatform)); + } + + internal static void ValidateArraySegment(ArraySegment arraySegment, string parameterName) + { + Contract.Requires(!string.IsNullOrEmpty(parameterName), "'parameterName' MUST NOT be NULL or string.Empty"); + + if (arraySegment.Array == null) + { + throw new ArgumentNullException(parameterName + ".Array"); + } + + if (arraySegment.Offset < 0 || arraySegment.Offset > arraySegment.Array.Length) + { + throw new ArgumentOutOfRangeException(parameterName + ".Offset"); + } + if (arraySegment.Count < 0 || arraySegment.Count > (arraySegment.Array.Length - arraySegment.Offset)) + { + throw new ArgumentOutOfRangeException(parameterName + ".Count"); + } + } + + internal static class MethodNames + { + internal const string AcceptWebSocketAsync = "AcceptWebSocketAsync"; + internal const string ValidateWebSocketHeaders = "ValidateWebSocketHeaders"; + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketMessageType.cs b/src/Microsoft.AspNet.WebSockets/WebSocketMessageType.cs new file mode 100644 index 0000000000..1661721a14 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketMessageType.cs @@ -0,0 +1,15 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.WebSockets +{ + public enum WebSocketMessageType + { + Text = 0x1, + Binary = 0x2, + Close = 0x8, + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketMiddleware.cs b/src/Microsoft.AspNet.WebSockets/WebSocketMiddleware.cs new file mode 100644 index 0000000000..7a77a6a4f6 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketMiddleware.cs @@ -0,0 +1,167 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Owin; + +namespace Microsoft.AspNet.WebSockets +{ + using AppFunc = Func, Task>; + using OpaqueUpgrade = + Action + < + IDictionary, // Parameters + Func // OpaqueFunc callback + < + IDictionary, // Opaque environment + Task // Complete + > + >; + using WebSocketAccept = + Action + < + IDictionary, // WebSocket Accept parameters + Func // WebSocketFunc callback + < + IDictionary, // WebSocket environment + Task // Complete + > + >; + using WebSocketFunc = + Func + < + IDictionary, // WebSocket Environment + Task // Complete + >; + + public class WebSocketMiddleware + { + private AppFunc _next; + + public WebSocketMiddleware(AppFunc next) + { + _next = next; + } + + public Task Invoke(IDictionary environment) + { + IOwinContext context = new OwinContext(environment); + // Detect if an opaque upgrade is available, and if websocket upgrade headers are present. + // If so, add a websocket upgrade. + OpaqueUpgrade upgrade = context.Get("opaque.Upgrade"); + if (upgrade != null) + { + // Headers and values: + // Connection: Upgrade + // Upgrade: WebSocket + // Sec-WebSocket-Version: (WebSocketProtocolComponent.SupportedVersion) + // Sec-WebSocket-Key: (hash, see WebSocketHelpers.GetSecWebSocketAcceptString) + // Sec-WebSocket-Protocol: (optional, list) + IList connectionHeaders = context.Request.Headers.GetCommaSeparatedValues(HttpKnownHeaderNames.Connection); // "Upgrade, KeepAlive" + string upgradeHeader = context.Request.Headers[HttpKnownHeaderNames.Upgrade]; + string versionHeader = context.Request.Headers[HttpKnownHeaderNames.SecWebSocketVersion]; + string keyHeader = context.Request.Headers[HttpKnownHeaderNames.SecWebSocketKey]; + + if (connectionHeaders != null && connectionHeaders.Count > 0 + && connectionHeaders.Contains(HttpKnownHeaderNames.Upgrade, StringComparer.OrdinalIgnoreCase) + && string.Equals(upgradeHeader, WebSocketHelpers.WebSocketUpgradeToken, StringComparison.OrdinalIgnoreCase) + && string.Equals(versionHeader, UnsafeNativeMethods.WebSocketProtocolComponent.SupportedVersion, StringComparison.OrdinalIgnoreCase) + && !string.IsNullOrWhiteSpace(keyHeader)) + { + environment["websocket.Accept"] = new WebSocketAccept(new UpgradeHandshake(context, upgrade).AcceptWebSocket); + } + } + + return _next(environment); + } + + private class UpgradeHandshake + { + private IOwinContext _context; + private OpaqueUpgrade _upgrade; + private WebSocketFunc _webSocketFunc; + + private string _subProtocol; + private int _receiveBufferSize = WebSocketHelpers.DefaultReceiveBufferSize; + private TimeSpan _keepAliveInterval = WebSocket.DefaultKeepAliveInterval; + private ArraySegment _internalBuffer; + + internal UpgradeHandshake(IOwinContext context, OpaqueUpgrade upgrade) + { + _context = context; + _upgrade = upgrade; + } + + internal void AcceptWebSocket(IDictionary options, WebSocketFunc webSocketFunc) + { + _webSocketFunc = webSocketFunc; + + // Get options + object temp; + if (options != null && options.TryGetValue("websocket.SubProtocol", out temp)) + { + _subProtocol = temp as string; + } + if (options != null && options.TryGetValue("websocket.ReceiveBufferSize", out temp)) + { + _receiveBufferSize = (int)temp; + } + if (options != null && options.TryGetValue("websocket.KeepAliveInterval", out temp)) + { + _keepAliveInterval = (TimeSpan)temp; + } + if (options != null && options.TryGetValue("websocket.Buffer", out temp)) + { + _internalBuffer = (ArraySegment)temp; + } + else + { + _internalBuffer = WebSocketBuffer.CreateInternalBufferArraySegment(_receiveBufferSize, WebSocketBuffer.MinSendBufferSize, true); + } + + // Set WebSocket upgrade response headers + + string outgoingSecWebSocketProtocolString; + bool shouldSendSecWebSocketProtocolHeader = + WebSocketHelpers.ProcessWebSocketProtocolHeader( + _context.Request.Headers[HttpKnownHeaderNames.SecWebSocketProtocol], + _subProtocol, + out outgoingSecWebSocketProtocolString); + + if (shouldSendSecWebSocketProtocolHeader) + { + _context.Response.Headers[HttpKnownHeaderNames.SecWebSocketProtocol] = outgoingSecWebSocketProtocolString; + } + + string secWebSocketKey = _context.Request.Headers[HttpKnownHeaderNames.SecWebSocketKey]; + string secWebSocketAccept = WebSocketHelpers.GetSecWebSocketAcceptString(secWebSocketKey); + + _context.Response.Headers[HttpKnownHeaderNames.Connection] = HttpKnownHeaderNames.Upgrade; + _context.Response.Headers[HttpKnownHeaderNames.Upgrade] = WebSocketHelpers.WebSocketUpgradeToken; + _context.Response.Headers[HttpKnownHeaderNames.SecWebSocketAccept] = secWebSocketAccept; + + _context.Response.StatusCode = 101; // Switching Protocols; + + _upgrade(options, OpaqueCallback); + } + + internal async Task OpaqueCallback(IDictionary opaqueEnv) + { + // Create WebSocket wrapper around the opaque env + WebSocket webSocket = CreateWebSocket(opaqueEnv); + OwinWebSocketWrapper wrapper = new OwinWebSocketWrapper(webSocket, (CancellationToken)opaqueEnv["opaque.CallCancelled"]); + await _webSocketFunc(wrapper.Environment); + // Close down the WebSocekt, gracefully if possible + await wrapper.CleanupAsync(); + } + + private WebSocket CreateWebSocket(IDictionary opaqueEnv) + { + Stream stream = (Stream)opaqueEnv["opaque.Stream"]; + return new ServerWebSocket(stream, _subProtocol, _receiveBufferSize, _keepAliveInterval, _internalBuffer); + } + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketReceiveResult.cs b/src/Microsoft.AspNet.WebSockets/WebSocketReceiveResult.cs new file mode 100644 index 0000000000..a48be62a2d --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketReceiveResult.cs @@ -0,0 +1,55 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.Contracts; + +namespace Microsoft.AspNet.WebSockets +{ + public class WebSocketReceiveResult + { + public WebSocketReceiveResult(int count, WebSocketMessageType messageType, bool endOfMessage) + : this(count, messageType, endOfMessage, null, null) + { + } + + public WebSocketReceiveResult(int count, + WebSocketMessageType messageType, + bool endOfMessage, + WebSocketCloseStatus? closeStatus, + string closeStatusDescription) + { + if (count < 0) + { + throw new ArgumentOutOfRangeException("count"); + } + + this.Count = count; + this.EndOfMessage = endOfMessage; + this.MessageType = messageType; + this.CloseStatus = closeStatus; + this.CloseStatusDescription = closeStatusDescription; + } + + public int Count { get; private set; } + public bool EndOfMessage { get; private set; } + public WebSocketMessageType MessageType { get; private set; } + public WebSocketCloseStatus? CloseStatus { get; private set; } + public string CloseStatusDescription { get; private set; } + + internal WebSocketReceiveResult Copy(int count) + { + Contract.Assert(count >= 0, "'count' MUST NOT be negative."); + Contract.Assert(count <= this.Count, "'count' MUST NOT be bigger than 'this.Count'."); + this.Count -= count; + return new WebSocketReceiveResult(count, + this.MessageType, + this.Count == 0 && this.EndOfMessage, + this.CloseStatus, + this.CloseStatusDescription); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketState.cs b/src/Microsoft.AspNet.WebSockets/WebSocketState.cs new file mode 100644 index 0000000000..3c58d00a20 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketState.cs @@ -0,0 +1,19 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.WebSockets +{ + public enum WebSocketState + { + None = 0, + Connecting = 1, + Open = 2, + CloseSent = 3, // WebSocket close handshake started form local endpoint + CloseReceived = 4, // WebSocket close message received from remote endpoint. Waiting for app to call close + Closed = 5, + Aborted = 6, + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/build.cmd b/src/Microsoft.AspNet.WebSockets/build.cmd new file mode 100644 index 0000000000..110694d575 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/build.cmd @@ -0,0 +1,3 @@ +rem set TARGET_FRAMEWORK=k10 +@call ..\..\packages\ProjectK.0.0.1-pre-30121-096\tools\k build +pause \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs b/src/Microsoft.AspNet.WebSockets/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs new file mode 100644 index 0000000000..861c9e31ab --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs @@ -0,0 +1,31 @@ +// ==++== +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ==--== + +#if !NET45 + +namespace Microsoft.Win32.SafeHandles +{ + using System; + using System.Runtime.InteropServices; + using System.Runtime.CompilerServices; + + // Class of safe handle which uses 0 or -1 as an invalid handle. + [System.Security.SecurityCritical] // auto-generated_required + internal abstract class SafeHandleZeroOrMinusOneIsInvalid : SafeHandle + { + protected SafeHandleZeroOrMinusOneIsInvalid(bool ownsHandle) + : base(IntPtr.Zero, ownsHandle) + { + } + + public override bool IsInvalid + { + [System.Security.SecurityCritical] + get { return handle == new IntPtr(0) || handle == new IntPtr(-1); } + } + } +} +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/fx/System/AccessViolationException.cs b/src/Microsoft.AspNet.WebSockets/fx/System/AccessViolationException.cs new file mode 100644 index 0000000000..c32cafb4e1 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/System/AccessViolationException.cs @@ -0,0 +1,16 @@ +#if !NET45 + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace System +{ + internal class AccessViolationException : SystemException + { + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/fx/System/ComponentModel/Win32Exception.cs b/src/Microsoft.AspNet.WebSockets/fx/System/ComponentModel/Win32Exception.cs new file mode 100644 index 0000000000..108303ee2c --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/System/ComponentModel/Win32Exception.cs @@ -0,0 +1,112 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 + +using System.Runtime.InteropServices; +using System.Text; + +namespace System.ComponentModel +{ + internal class Win32Exception : ExternalException + { + /// + /// Represents the Win32 error code associated with this exception. This + /// field is read-only. + /// + private readonly int nativeErrorCode; + + /// + /// Initializes a new instance of the class with the last Win32 error + /// that occured. + /// + public Win32Exception() + : this(Marshal.GetLastWin32Error()) + { + } + /// + /// Initializes a new instance of the class with the specified error. + /// + public Win32Exception(int error) + : this(error, GetErrorMessage(error)) + { + } + /// + /// Initializes a new instance of the class with the specified error and the + /// specified detailed description. + /// + public Win32Exception(int error, string message) + : base(message) + { + nativeErrorCode = error; + } + + /// + /// Initializes a new instance of the Exception class with a specified error message. + /// FxCop CA1032: Multiple constructors are required to correctly implement a custom exception. + /// + public Win32Exception(string message) + : this(Marshal.GetLastWin32Error(), message) + { + } + + /// + /// Initializes a new instance of the Exception class with a specified error message and a + /// reference to the inner exception that is the cause of this exception. + /// FxCop CA1032: Multiple constructors are required to correctly implement a custom exception. + /// + public Win32Exception(string message, Exception innerException) + : base(message, innerException) + { + nativeErrorCode = Marshal.GetLastWin32Error(); + } + + + /// + /// Represents the Win32 error code associated with this exception. This + /// field is read-only. + /// + public int NativeErrorCode + { + get + { + return nativeErrorCode; + } + } + + private static string GetErrorMessage(int error) + { + //get the system error message... + string errorMsg = ""; + StringBuilder sb = new StringBuilder(256); + int result = SafeNativeMethods.FormatMessage( + SafeNativeMethods.FORMAT_MESSAGE_IGNORE_INSERTS | + SafeNativeMethods.FORMAT_MESSAGE_FROM_SYSTEM | + SafeNativeMethods.FORMAT_MESSAGE_ARGUMENT_ARRAY, + IntPtr.Zero, (uint)error, 0, sb, sb.Capacity + 1, + null); + if (result != 0) + { + int i = sb.Length; + while (i > 0) + { + char ch = sb[i - 1]; + if (ch > 32 && ch != '.') break; + i--; + } + errorMsg = sb.ToString(0, i); + } + else + { + errorMsg = "Unknown error (0x" + Convert.ToString(error, 16) + ")"; + } + + return errorMsg; + } + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/fx/System/ExternDll.cs b/src/Microsoft.AspNet.WebSockets/fx/System/ExternDll.cs new file mode 100644 index 0000000000..98303d48b8 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/System/ExternDll.cs @@ -0,0 +1,16 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 + +namespace System +{ + internal static class ExternDll + { + public const string Kernel32 = "kernel32.dll"; + } +} +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/fx/System/Runtime/InteropServices/ExternalException.cs b/src/Microsoft.AspNet.WebSockets/fx/System/Runtime/InteropServices/ExternalException.cs new file mode 100644 index 0000000000..21d82b1de4 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/System/Runtime/InteropServices/ExternalException.cs @@ -0,0 +1,98 @@ +// ==++== +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ==--== +/*============================================================================= +** +** Class: ExternalException +** +** +** Purpose: Exception base class for all errors from Interop or Structured +** Exception Handling code. +** +** +=============================================================================*/ + +#if !NET45 + +namespace System.Runtime.InteropServices +{ + using System; + using System.Globalization; + + // Base exception for COM Interop errors &; Structured Exception Handler + // exceptions. + // + internal class ExternalException : Exception + { + public ExternalException() + { + SetErrorCode(__HResults.E_FAIL); + } + + public ExternalException(String message) + : base(message) + { + SetErrorCode(__HResults.E_FAIL); + } + + public ExternalException(String message, Exception inner) + : base(message, inner) + { + SetErrorCode(__HResults.E_FAIL); + } + + public ExternalException(String message, int errorCode) + : base(message) + { + SetErrorCode(errorCode); + } + + private void SetErrorCode(int errorCode) + { + HResult = ErrorCode; + } + + private static class __HResults + { + internal const int E_FAIL = unchecked((int)0x80004005); + } + + public virtual int ErrorCode + { + get + { + return HResult; + } + } + + public override String ToString() + { + String message = Message; + String s; + String _className = GetType().ToString(); + s = _className + " (0x" + HResult.ToString("X8", CultureInfo.InvariantCulture) + ")"; + + if (!(String.IsNullOrEmpty(message))) + { + s = s + ": " + message; + } + + Exception _innerException = InnerException; + + if (_innerException != null) + { + s = s + " ---> " + _innerException.ToString(); + } + + + if (StackTrace != null) + s += Environment.NewLine + StackTrace; + + return s; + } + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/fx/System/SafeNativeMethods.cs b/src/Microsoft.AspNet.WebSockets/fx/System/SafeNativeMethods.cs new file mode 100644 index 0000000000..969d273fa8 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/System/SafeNativeMethods.cs @@ -0,0 +1,28 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 +using System.Runtime.InteropServices; +using System.Text; + +namespace System +{ + internal static class SafeNativeMethods + { + public const int + FORMAT_MESSAGE_ALLOCATE_BUFFER = 0x00000100, + FORMAT_MESSAGE_IGNORE_INSERTS = 0x00000200, + FORMAT_MESSAGE_FROM_STRING = 0x00000400, + FORMAT_MESSAGE_FROM_SYSTEM = 0x00001000, + FORMAT_MESSAGE_ARGUMENT_ARRAY = 0x00002000; + + [DllImport(ExternDll.Kernel32, CharSet = System.Runtime.InteropServices.CharSet.Unicode, SetLastError = true, BestFitMapping = true)] + public static unsafe extern int FormatMessage(int dwFlags, IntPtr lpSource_mustBeNull, uint dwMessageId, + int dwLanguageId, StringBuilder lpBuffer, int nSize, IntPtr[] arguments); + + } +} +#endif diff --git a/src/Microsoft.AspNet.WebSockets/fx/System/SystemException.cs b/src/Microsoft.AspNet.WebSockets/fx/System/SystemException.cs new file mode 100644 index 0000000000..4b6c66b4a4 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/System/SystemException.cs @@ -0,0 +1,16 @@ +#if !NET45 + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace System +{ + internal class SystemException : Exception + { + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/packages.config b/src/Microsoft.AspNet.WebSockets/packages.config new file mode 100644 index 0000000000..fb47f58ced --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/packages.config @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/project.json b/src/Microsoft.AspNet.WebSockets/project.json new file mode 100644 index 0000000000..943893dbd2 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/project.json @@ -0,0 +1,17 @@ +{ + "version": "0.1-alpha-*", + "dependencies": { + "Microsoft.AspNet.Abstractions" : "0.1-alpha-*", + "Microsoft.AspNet.HttpFeature" : "0.1-alpha-*" + }, + "compilationOptions" : { "allowUnsafe": true }, + "configurations": + { + "net45" : { + "dependencies": { + "Owin": "1.0", + "Microsoft.Owin": "2.1.0" + } + } + } +} diff --git a/test/Microsoft.AspNet.Security.Windows.Test/DenyAnonymous.cs b/test/Microsoft.AspNet.Security.Windows.Test/DenyAnonymous.cs new file mode 100644 index 0000000000..3bd1cc9f00 --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/DenyAnonymous.cs @@ -0,0 +1,47 @@ +// +// Copyright 2011-2012 Katana contributors +// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Security.Principal; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Security.Windows.Tests +{ + using AppFunc = Func, Task>; + + // This middleware can be placed at the end of a chain of pass-through auth schemes if at least one type of auth is required. + public class DenyAnonymous + { + private readonly AppFunc _nextApp; + + public DenyAnonymous(AppFunc nextApp) + { + _nextApp = nextApp; + } + + public async Task Invoke(IDictionary env) + { + if (env.Get("server.User") == null) + { + env["owin.ResponseStatusCode"] = 401; + return; + } + + await _nextApp(env); + } + } +} diff --git a/test/Microsoft.AspNet.Security.Windows.Test/DictionaryExtensions.cs b/test/Microsoft.AspNet.Security.Windows.Test/DictionaryExtensions.cs new file mode 100644 index 0000000000..a1a6e3577a --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/DictionaryExtensions.cs @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Linq; +using System.Text; + +namespace System.Collections.Generic +{ + internal static class DictionaryExtensions + { + internal static void Append(this IDictionary dictionary, string key, string value) + { + string[] orriginalValues; + if (dictionary.TryGetValue(key, out orriginalValues)) + { + string[] newValues = new string[orriginalValues.Length + 1]; + orriginalValues.CopyTo(newValues, 0); + newValues[newValues.Length - 1] = value; + dictionary[key] = newValues; + } + else + { + dictionary[key] = new string[] { value }; + } + } + + internal static string Get(this IDictionary dictionary, string key) + { + string[] values; + if (dictionary.TryGetValue(key, out values)) + { + return string.Join(", ", values); + } + return null; + } + + internal static T Get(this IDictionary dictionary, string key, T fallback = default(T)) + { + object values; + if (dictionary.TryGetValue(key, out values)) + { + return (T)values; + } + return fallback; + } + } +} diff --git a/test/Microsoft.AspNet.Security.Windows.Test/DigestTests.cs b/test/Microsoft.AspNet.Security.Windows.Test/DigestTests.cs new file mode 100644 index 0000000000..e097a7e65b --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/DigestTests.cs @@ -0,0 +1,244 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Security.Authentication.ExtendedProtection; +using System.Threading.Tasks; +using Microsoft.AspNet.Server.WebListener; +using Xunit; + +namespace Microsoft.AspNet.Security.Windows.Tests +{ + using AppFunc = Func, Task>; + + public class DigestTests + { + private const string Address = "http://localhost:8080/"; + private const string SecureAddress = "https://localhost:9090/"; + private const int DefaultStatusCode = 201; + + [Fact] + public async Task Digest_PartialMatch_PassedThrough() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp); + IDictionary emptyEnv = CreateEmptyRequest("Authorization", "Digestion blablabla"); + await windowsAuth.Invoke(emptyEnv); + + Assert.Equal(DefaultStatusCode, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + } + + [Fact] + public async Task Digest_BadData_400() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp); + IDictionary emptyEnv = CreateEmptyRequest("Authorization", "Digest blablabla"); + await windowsAuth.Invoke(emptyEnv); + + Assert.Equal(400, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + } + + [Fact] + public async Task Digest_AppSets401_401WithChallenge() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + IDictionary emptyEnv = CreateEmptyRequest(); + await windowsAuth.Invoke(emptyEnv); + FireOnSendingHeadersActions(emptyEnv); + + Assert.Equal(401, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(1, responseHeaders.Count); + Assert.NotNull(responseHeaders.Get("www-authenticate")); + Assert.True(responseHeaders.Get("www-authenticate").StartsWith("Digest ")); + } + + [Fact] + public async Task Digest_CbtOptionalButNotPresent_401WithChallenge() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + windowsAuth.ExtendedProtectionPolicy = new ExtendedProtectionPolicy(PolicyEnforcement.WhenSupported); + IDictionary emptyEnv = CreateEmptyRequest(); + emptyEnv["owin.RequestScheme"] = "https"; + await windowsAuth.Invoke(emptyEnv); + FireOnSendingHeadersActions(emptyEnv); + + Assert.Equal(401, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + Assert.Null(responseHeaders.Get("www-authenticate")); + } + + [Fact] + public async Task Digest_CbtRequiredButNotPresent_400() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + windowsAuth.ExtendedProtectionPolicy = new ExtendedProtectionPolicy(PolicyEnforcement.Always); + IDictionary emptyEnv = CreateEmptyRequest(); + emptyEnv["owin.RequestScheme"] = "https"; + await windowsAuth.Invoke(emptyEnv); + FireOnSendingHeadersActions(emptyEnv); + + Assert.Equal(401, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + Assert.Null(responseHeaders.Get("www-authenticate")); + } + + [Fact(Skip = "Broken")] + public async Task Digest_ClientAuthenticates_Success() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + + using (CreateServer(windowsAuth.Invoke)) + { + HttpResponseMessage response = await SendAuthRequestAsync(Address); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + + [Fact(Skip = "Broken")] + public async Task Digest_ClientAuthenticatesMultipleTimes_Success() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + + using (CreateServer(windowsAuth.Invoke)) + { + for (int i = 0; i < 10; i++) + { + HttpResponseMessage response = await SendAuthRequestAsync(Address); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + } + + [Fact] + public async Task Digest_AnonmousClient_401() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + + using (CreateServer(windowsAuth.Invoke)) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(401, (int)response.StatusCode); + Assert.True(response.Headers.WwwAuthenticate.ToString().StartsWith("Digest ")); + } + } + + [Fact(Skip = "Broken")] + public async Task Digest_ClientAuthenticatesWithCbt_Success() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + windowsAuth.ExtendedProtectionPolicy = new ExtendedProtectionPolicy(PolicyEnforcement.Always); + + using (CreateSecureServer(windowsAuth.Invoke)) + { + HttpResponseMessage response = await SendAuthRequestAsync(SecureAddress); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + + private IDictionary CreateEmptyRequest(string header = null, string value = null) + { + IDictionary env = new Dictionary(); + var requestHeaders = new Dictionary(StringComparer.OrdinalIgnoreCase); + env["owin.RequestHeaders"] = requestHeaders; + if (header != null) + { + requestHeaders[header] = new string[] { value }; + } + env["owin.ResponseHeaders"] = new Dictionary(StringComparer.OrdinalIgnoreCase); + + var onSendingHeadersActions = new List, object>>(); + env["server.OnSendingHeaders"] = new Action, object>( + (a, b) => onSendingHeadersActions.Add(new Tuple, object>(a, b))); + + env["test.OnSendingHeadersActions"] = onSendingHeadersActions; + return env; + } + + private void FireOnSendingHeadersActions(IDictionary env) + { + var onSendingHeadersActions = env.Get, object>>>("test.OnSendingHeadersActions"); + foreach (var actionPair in onSendingHeadersActions.Reverse()) + { + actionPair.Item1(actionPair.Item2); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private IDisposable CreateSecureServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "https"; + address["host"] = "localhost"; + address["port"] = "9090"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + + private async Task SendAuthRequestAsync(string uri) + { + WebRequestHandler handler = new WebRequestHandler(); + handler.UseDefaultCredentials = true; + handler.ServerCertificateValidationCallback = (a, b, c, d) => true; + using (HttpClient client = new HttpClient(handler)) + { + return await client.GetAsync(uri); + } + } + + private Task SimpleApp(IDictionary env) + { + env["owin.ResponseStatusCode"] = DefaultStatusCode; + return Task.FromResult(null); + } + } +} diff --git a/test/Microsoft.AspNet.Security.Windows.Test/NegotiateTests.cs b/test/Microsoft.AspNet.Security.Windows.Test/NegotiateTests.cs new file mode 100644 index 0000000000..572650b9b0 --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/NegotiateTests.cs @@ -0,0 +1,261 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Security.Authentication.ExtendedProtection; +using System.Threading.Tasks; +using Microsoft.AspNet.Server.WebListener; +using Xunit; +using Xunit.Extensions; + +namespace Microsoft.AspNet.Security.Windows.Tests +{ + using AppFunc = Func, Task>; + + public class NegotiateTests + { + private const string Address = "http://localhost:8080/"; + private const string SecureAddress = "https://localhost:9090/"; + private const int DefaultStatusCode = 201; + + [Theory] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_PartialMatch_PassedThrough(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp); + IDictionary emptyEnv = CreateEmptyRequest("Authorization", package + "ion blablabla"); + await windowsAuth.Invoke(emptyEnv); + + Assert.Equal(DefaultStatusCode, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + } + + [Theory] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_BadData_400(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp); + IDictionary emptyEnv = CreateEmptyRequest("Authorization", package + " blablabla"); + await windowsAuth.Invoke(emptyEnv); + + Assert.Equal(400, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + } + + [Theory] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_AppSets401_401WithChallenge(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp401); + windowsAuth.AuthenticationSchemes = (AuthTypes)Enum.Parse(typeof(AuthTypes), package, true); + IDictionary emptyEnv = CreateEmptyRequest(); + await windowsAuth.Invoke(emptyEnv); + FireOnSendingHeadersActions(emptyEnv); + + Assert.Equal(401, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(1, responseHeaders.Count); + Assert.NotNull(responseHeaders.Get("www-authenticate")); + Assert.Equal(package, responseHeaders.Get("www-authenticate")); + } + + [Theory(Skip = "Broken")] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_ClientAuthenticates_Success(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = (AuthTypes)Enum.Parse(typeof(AuthTypes), package, true); + + using (CreateServer(windowsAuth.Invoke)) + { + HttpResponseMessage response = await SendAuthRequestAsync(Address); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + + [Theory(Skip = "Broken")] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_ClientAuthenticatesMultipleTimes_Success(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = (AuthTypes)Enum.Parse(typeof(AuthTypes), package, true); + + using (CreateServer(windowsAuth.Invoke)) + { + for (int i = 0; i < 10; i++) + { + HttpResponseMessage response = await SendAuthRequestAsync(Address); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + } + + [Theory] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_AnonmousClient_401(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = (AuthTypes)Enum.Parse(typeof(AuthTypes), package, true); + + using (CreateServer(windowsAuth.Invoke)) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(401, (int)response.StatusCode); + Assert.Equal(package, response.Headers.WwwAuthenticate.ToString()); + } + } + + [Fact(Skip = "Broken")] + public async Task UnsafeSharedNTLM_AuthenticatedClient_Success() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Ntlm; + windowsAuth.UnsafeConnectionNtlmAuthentication = true; + + using (CreateServer(windowsAuth.Invoke)) + { + WebRequestHandler handler = new WebRequestHandler(); + CredentialCache cache = new CredentialCache(); + cache.Add(new Uri(Address), "NTLM", CredentialCache.DefaultNetworkCredentials); + handler.Credentials = cache; + handler.UnsafeAuthenticatedConnectionSharing = true; + using (HttpClient client = new HttpClient(handler)) + { + HttpResponseMessage response = await client.GetAsync(Address); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + response.EnsureSuccessStatusCode(); + + // Remove the credentials before try two just to prove they aren't used. + cache.Remove(new Uri(Address), "NTLM"); + response = await client.GetAsync(Address); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + } + + [Theory(Skip = "Broken")] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_ClientAuthenticatesWithCbt_Success(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = (AuthTypes)Enum.Parse(typeof(AuthTypes), package, true); + windowsAuth.ExtendedProtectionPolicy = new ExtendedProtectionPolicy(PolicyEnforcement.Always); + + using (CreateSecureServer(windowsAuth.Invoke)) + { + HttpResponseMessage response = await SendAuthRequestAsync(SecureAddress); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + + private IDictionary CreateEmptyRequest(string header = null, string value = null, string connectionId = "Random") + { + IDictionary env = new Dictionary(); + var requestHeaders = new Dictionary(StringComparer.OrdinalIgnoreCase); + env["owin.RequestHeaders"] = requestHeaders; + if (header != null) + { + requestHeaders[header] = new string[] { value }; + } + env["owin.ResponseHeaders"] = new Dictionary(StringComparer.OrdinalIgnoreCase); + + var onSendingHeadersActions = new List, object>>(); + env["server.OnSendingHeaders"] = new Action, object>( + (a, b) => onSendingHeadersActions.Add(new Tuple, object>(a, b))); + + env["test.OnSendingHeadersActions"] = onSendingHeadersActions; + env["server.ConnectionId"] = connectionId; + return env; + } + + private void FireOnSendingHeadersActions(IDictionary env) + { + var onSendingHeadersActions = env.Get, object>>>("test.OnSendingHeadersActions"); + foreach (var actionPair in onSendingHeadersActions.Reverse()) + { + actionPair.Item1(actionPair.Item2); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private IDisposable CreateSecureServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "https"; + address["host"] = "localhost"; + address["port"] = "9090"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + + private async Task SendAuthRequestAsync(string uri) + { + WebRequestHandler handler = new WebRequestHandler(); + handler.UseDefaultCredentials = true; + handler.ServerCertificateValidationCallback = (a, b, c, d) => true; + using (HttpClient client = new HttpClient(handler)) + { + return await client.GetAsync(uri); + } + } + + private Task SimpleApp(IDictionary env) + { + env["owin.ResponseStatusCode"] = DefaultStatusCode; + return Task.FromResult(null); + } + + private Task SimpleApp401(IDictionary env) + { + env["owin.ResponseStatusCode"] = 401; + return Task.FromResult(null); + } + } +} diff --git a/test/Microsoft.AspNet.Security.Windows.Test/PassThroughTests.cs b/test/Microsoft.AspNet.Security.Windows.Test/PassThroughTests.cs new file mode 100644 index 0000000000..6ffcf2b8a8 --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/PassThroughTests.cs @@ -0,0 +1,64 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Security.Windows.Tests +{ + public class PassThroughTests + { + private const int DefaultStatusCode = 201; + + [Fact] + public async Task PassThrough_EmptyRequest_Success() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp); + IDictionary emptyEnv = CreateEmptyRequest(); + await windowsAuth.Invoke(emptyEnv); + + Assert.Equal(DefaultStatusCode, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + } + + [Fact] + public async Task PassThrough_BasicAuth_Success() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp); + IDictionary emptyEnv = CreateEmptyRequest("Authorization", "Basic blablabla"); + await windowsAuth.Invoke(emptyEnv); + + Assert.Equal(DefaultStatusCode, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + } + + private IDictionary CreateEmptyRequest(string header = null, string value = null) + { + IDictionary env = new Dictionary(); + var requestHeaders = new Dictionary(StringComparer.OrdinalIgnoreCase); + env["owin.RequestHeaders"] = requestHeaders; + if (header != null) + { + requestHeaders[header] = new string[] { value }; + } + env["owin.ResponseHeaders"] = new Dictionary(StringComparer.OrdinalIgnoreCase); + env["server.OnSendingHeaders"] = new Action, object>((a, b) => { }); + return env; + } + + private Task SimpleApp(IDictionary env) + { + env["owin.ResponseStatusCode"] = DefaultStatusCode; + return Task.FromResult(null); + } + } +} diff --git a/test/Microsoft.AspNet.Security.Windows.Test/Properties/AssemblyInfo.cs b/test/Microsoft.AspNet.Security.Windows.Test/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..681f7d66cd --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/Properties/AssemblyInfo.cs @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("Microsoft.AspNet.Security.Windows.Tests")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.AspNet.Security.Windows.Tests")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("334c99b0-a718-4cda-9ca0-d5a45c3a32b0")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] diff --git a/test/Microsoft.AspNet.Security.Windows.Test/packages.config b/test/Microsoft.AspNet.Security.Windows.Test/packages.config new file mode 100644 index 0000000000..67a23e70da --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/packages.config @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/test/Microsoft.AspNet.Security.Windows.Test/project.json b/test/Microsoft.AspNet.Security.Windows.Test/project.json new file mode 100644 index 0000000000..ecbdca9ad7 --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/project.json @@ -0,0 +1,17 @@ +{ + "version" : "0.1-alpha-*", + "dependencies": { + "Microsoft.AspNet.Server.WebListener" : "", + "Microsoft.AspNet.Security.Windows" : "" + }, + "configurations": { + "net45": { + "dependencies": { + "XUnit": "1.9.2", + "XUnit.Extensions": "1.9.2", + "System.Net.Http": "", + "System.Net.Http.WebRequest": "" + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/AuthenticationTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/AuthenticationTests.cs new file mode 100644 index 0000000000..d85855d66a --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/AuthenticationTests.cs @@ -0,0 +1,136 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Threading.Tasks; +using Xunit; +using Xunit.Extensions; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class AuthenticationTests + { + private const string Address = "http://localhost:8080/"; + + [Theory] + [InlineData(AuthenticationType.Kerberos)] + [InlineData(AuthenticationType.Negotiate)] + [InlineData(AuthenticationType.Ntlm)] + [InlineData(AuthenticationType.Digest)] + [InlineData(AuthenticationType.Basic)] + [InlineData(AuthenticationType.Kerberos | AuthenticationType.Negotiate | AuthenticationType.Ntlm | AuthenticationType.Digest | AuthenticationType.Basic)] + public async Task AuthTypes_EnabledButNotChalleneged_PassThrough(AuthenticationType authType) + { + using (CreateServer(authType, env => + { + return Task.FromResult(0); + })) + { + var response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + } + } + + [Theory] + [InlineData(AuthenticationType.Kerberos)] + [InlineData(AuthenticationType.Negotiate)] + [InlineData(AuthenticationType.Ntlm)] + // [InlineData(AuthenticationType.Digest)] // TODO: Not implemented + [InlineData(AuthenticationType.Basic)] + public async Task AuthType_Specify401_ChallengesAdded(AuthenticationType authType) + { + using (CreateServer(authType, env => + { + env["owin.ResponseStatusCode"] = 401; + return Task.FromResult(0); + })) + { + var response = await SendRequestAsync(Address); + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + Assert.Equal(authType.ToString(), response.Headers.WwwAuthenticate.ToString(), StringComparer.OrdinalIgnoreCase); + } + } + + [Fact] + public async Task MultipleAuthTypes_Specify401_ChallengesAdded() + { + // TODO: Not implemented - Digest + using (CreateServer(AuthenticationType.Kerberos | AuthenticationType.Negotiate | AuthenticationType.Ntlm | /*AuthenticationType.Digest |*/ AuthenticationType.Basic, env => + { + env["owin.ResponseStatusCode"] = 401; + return Task.FromResult(0); + })) + { + var response = await SendRequestAsync(Address); + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + Assert.Equal("Kerberos, Negotiate, NTLM, basic", response.Headers.WwwAuthenticate.ToString(), StringComparer.OrdinalIgnoreCase); + } + } + + [Theory] + [InlineData(AuthenticationType.Kerberos)] + [InlineData(AuthenticationType.Negotiate)] + [InlineData(AuthenticationType.Ntlm)] + // [InlineData(AuthenticationType.Digest)] // TODO: Not implemented + // [InlineData(AuthenticationType.Basic)] // Doesn't work with default creds + [InlineData(AuthenticationType.Kerberos | AuthenticationType.Negotiate | AuthenticationType.Ntlm | /*AuthenticationType.Digest |*/ AuthenticationType.Basic)] + public async Task AuthTypes_Login_Success(AuthenticationType authType) + { + int requestCount = 0; + using (CreateServer(authType, env => + { + requestCount++; + object obj; + if (env.TryGetValue("server.User", out obj) && obj != null) + { + return Task.FromResult(0); + } + env["owin.ResponseStatusCode"] = 401; + return Task.FromResult(0); + })) + { + var response = await SendRequestAsync(Address, useDefaultCredentials: true); + response.EnsureSuccessStatusCode(); + } + } + + private IDisposable CreateServer(AuthenticationType authType, AppFunc app) + { + IDictionary properties = new Dictionary(); + OwinServerFactory.Initialize(properties); + OwinWebListener listener = (OwinWebListener)properties[typeof(OwinWebListener).FullName]; + listener.AuthenticationManager.AuthenticationTypes = authType; + + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri, bool useDefaultCredentials = false) + { + HttpClientHandler handler = new HttpClientHandler(); + handler.UseDefaultCredentials = useDefaultCredentials; + using (HttpClient client = new HttpClient(handler)) + { + return await client.GetAsync(uri); + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/DictionaryExtensions.cs b/test/Microsoft.AspNet.Server.WebListener.Test/DictionaryExtensions.cs new file mode 100644 index 0000000000..b999b1fc4f --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/DictionaryExtensions.cs @@ -0,0 +1,31 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace System.Collections.Generic +{ + internal static class DictionaryExtensions + { + internal static string Get(this IDictionary dictionary, string key) + { + string[] values; + if (dictionary.TryGetValue(key, out values)) + { + return string.Join(", ", values); + } + return null; + } + + internal static T Get(this IDictionary dictionary, string key) + { + object values; + if (dictionary.TryGetValue(key, out values)) + { + return (T)values; + } + return default(T); + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/HttpsTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/HttpsTests.cs new file mode 100644 index 0000000000..98661a8380 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/HttpsTests.cs @@ -0,0 +1,191 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.Http; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class HttpsTests + { + private const string Address = "https://localhost:9090/"; + + [Fact] + public async Task Https_200OK_Success() + { + using (CreateServer(env => + { + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal(string.Empty, response); + } + } + + [Fact] + public async Task Https_SendHelloWorld_Success() + { + using (CreateServer(env => + { + byte[] body = Encoding.UTF8.GetBytes("Hello World"); + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { body.Length.ToString() }; + return env.Get("owin.ResponseBody").WriteAsync(body, 0, body.Length); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public async Task Https_EchoHelloWorld_Success() + { + using (CreateServer(env => + { + string input = new StreamReader(env.Get("owin.RequestBody")).ReadToEnd(); + Assert.Equal("Hello World", input); + byte[] body = Encoding.UTF8.GetBytes("Hello World"); + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { body.Length.ToString() }; + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address, "Hello World"); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public async Task Https_ClientCertNotSent_ClientCertNotPresent() + { + X509Certificate clientCert = null; + using (CreateServer(env => + { + var loadAsync = env.Get>("ssl.LoadClientCertAsync"); + loadAsync().Wait(); + clientCert = env.Get("ssl.ClientCertificate"); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal(string.Empty, response); + Assert.Null(clientCert); + } + } + + [Fact] + public async Task Https_ClientCertRequested_ClientCertPresent() + { + X509Certificate clientCert = null; + using (CreateServer(env => + { + var loadAsync = env.Get>("ssl.LoadClientCertAsync"); + loadAsync().Wait(); + clientCert = env.Get("ssl.ClientCertificate"); + return Task.FromResult(0); + })) + { + X509Certificate2 cert = FindClientCert(); + Assert.NotNull(cert); + string response = await SendRequestAsync(Address, cert); + Assert.Equal(string.Empty, response); + Assert.NotNull(clientCert); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "https"; + address["host"] = "localhost"; + address["port"] = "9090"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri, + X509Certificate cert = null) + { + WebRequestHandler handler = new WebRequestHandler(); + handler.ServerCertificateValidationCallback = (a, b, c, d) => true; + if (cert != null) + { + handler.ClientCertificates.Add(cert); + } + using (HttpClient client = new HttpClient(handler)) + { + return await client.GetStringAsync(uri); + } + } + + private async Task SendRequestAsync(string uri, string upload) + { + WebRequestHandler handler = new WebRequestHandler(); + handler.ServerCertificateValidationCallback = (a, b, c, d) => true; + using (HttpClient client = new HttpClient(handler)) + { + HttpResponseMessage response = await client.PostAsync(uri, new StringContent(upload)); + response.EnsureSuccessStatusCode(); + return await response.Content.ReadAsStringAsync(); + } + } + + private X509Certificate2 FindClientCert() + { + var store = new X509Store(); + store.Open(OpenFlags.ReadOnly); + + foreach (var cert in store.Certificates) + { + bool isClientAuth = false; + bool isSmartCard = false; + foreach (var extension in cert.Extensions) + { + var eku = extension as X509EnhancedKeyUsageExtension; + if (eku != null) + { + foreach (var oid in eku.EnhancedKeyUsages) + { + if (oid.FriendlyName == "Client Authentication") + { + isClientAuth = true; + } + else if (oid.FriendlyName == "Smart Card Logon") + { + isSmartCard = true; + break; + } + } + } + } + + if (isClientAuth && !isSmartCard) + { + return cert; + } + } + return null; + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/OpaqueUpgradeTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/OpaqueUpgradeTests.cs new file mode 100644 index 0000000000..07c8610feb --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/OpaqueUpgradeTests.cs @@ -0,0 +1,324 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.Http; +using System.Net.Sockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; +using Xunit.Extensions; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + using OpaqueUpgrade = Action, Func, Task>>; + + public class OpaqueUpgradeTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task OpaqueUpgrade_SupportKeys_Present() + { + using (CreateServer(env => + { + try + { + IDictionary capabilities = env.Get>("server.Capabilities"); + Assert.NotNull(capabilities); + + Assert.Equal("1.0", capabilities.Get("opaque.Version")); + + OpaqueUpgrade opaqueUpgrade = env.Get("opaque.Upgrade"); + Assert.NotNull(opaqueUpgrade); + } + catch (Exception ex) + { + byte[] body = Encoding.UTF8.GetBytes(ex.ToString()); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + } + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.False(response.Headers.TransferEncodingChunked.HasValue, "Chunked"); + Assert.Equal(0, response.Content.Headers.ContentLength); + Assert.Equal(string.Empty, response.Content.ReadAsStringAsync().Result); + } + } + + [Fact] + public async Task OpaqueUpgrade_NullCallback_Throws() + { + using (CreateServer(env => + { + try + { + OpaqueUpgrade opaqueUpgrade = env.Get("opaque.Upgrade"); + opaqueUpgrade(new Dictionary(), null); + } + catch (Exception ex) + { + byte[] body = Encoding.UTF8.GetBytes(ex.ToString()); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + } + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.True(response.Headers.TransferEncodingChunked.Value, "Chunked"); + Assert.Contains("callback", response.Content.ReadAsStringAsync().Result); + } + } + + [Fact] + public async Task OpaqueUpgrade_AfterHeadersSent_Throws() + { + bool? upgradeThrew = null; + using (CreateServer(env => + { + byte[] body = Encoding.UTF8.GetBytes("Hello World"); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + OpaqueUpgrade opaqueUpgrade = env.Get("opaque.Upgrade"); + try + { + opaqueUpgrade(null, _ => Task.FromResult(0)); + upgradeThrew = false; + } + catch (InvalidOperationException) + { + upgradeThrew = true; + } + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.True(response.Headers.TransferEncodingChunked.Value, "Chunked"); + Assert.True(upgradeThrew.Value); + } + } + + [Fact] + public async Task OpaqueUpgrade_GetUpgrade_Success() + { + ManualResetEvent waitHandle = new ManualResetEvent(false); + bool? callbackInvoked = null; + using (CreateServer(env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Upgrade"] = new string[] { "websocket" }; // Win8.1 blocks anything but WebSockets + OpaqueUpgrade opaqueUpgrade = env.Get("opaque.Upgrade"); + opaqueUpgrade(null, opqEnv => + { + callbackInvoked = true; + waitHandle.Set(); + return Task.FromResult(0); + }); + return Task.FromResult(0); + })) + { + using (Stream stream = await SendOpaqueRequestAsync("GET", Address)) + { + Assert.True(waitHandle.WaitOne(TimeSpan.FromSeconds(1)), "Timed out"); + Assert.True(callbackInvoked.HasValue, "CallbackInvoked not set"); + Assert.True(callbackInvoked.Value, "Callback not invoked"); + } + } + } + + [Theory] + // See HTTP_VERB for known verbs + [InlineData("UNKNOWN", null)] + [InlineData("INVALID", null)] + [InlineData("OPTIONS", null)] + [InlineData("GET", null)] + [InlineData("HEAD", null)] + [InlineData("DELETE", null)] + [InlineData("TRACE", null)] + [InlineData("CONNECT", null)] + [InlineData("TRACK", null)] + [InlineData("MOVE", null)] + [InlineData("COPY", null)] + [InlineData("PROPFIND", null)] + [InlineData("PROPPATCH", null)] + [InlineData("MKCOL", null)] + [InlineData("LOCK", null)] + [InlineData("UNLOCK", null)] + [InlineData("SEARCH", null)] + [InlineData("CUSTOMVERB", null)] + [InlineData("PATCH", null)] + [InlineData("POST", "Content-Length: 0")] + [InlineData("PUT", "Content-Length: 0")] + public async Task OpaqueUpgrade_VariousMethodsUpgradeSendAndReceive_Success(string method, string extraHeader) + { + using (CreateServer(env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Upgrade"] = new string[] { "WebSocket" }; // Win8.1 blocks anything but WebSockets + OpaqueUpgrade opaqueUpgrade = env.Get("opaque.Upgrade"); + opaqueUpgrade(null, async opqEnv => + { + Stream opaqueStream = opqEnv.Get("opaque.Stream"); + + byte[] buffer = new byte[100]; + int read = await opaqueStream.ReadAsync(buffer, 0, buffer.Length); + + await opaqueStream.WriteAsync(buffer, 0, read); + }); + return Task.FromResult(0); + })) + { + using (Stream stream = await SendOpaqueRequestAsync(method, Address, extraHeader)) + { + byte[] data = new byte[100]; + stream.WriteAsync(data, 0, 49).Wait(); + int read = stream.ReadAsync(data, 0, data.Length).Result; + Assert.Equal(49, read); + } + } + } + + [Theory] + // Http.Sys returns a 411 Length Required if PUT or POST does not specify content-length or chunked. + [InlineData("POST", "Content-Length: 10")] + [InlineData("POST", "Transfer-Encoding: chunked")] + [InlineData("PUT", "Content-Length: 10")] + [InlineData("PUT", "Transfer-Encoding: chunked")] + [InlineData("CUSTOMVERB", "Content-Length: 10")] + [InlineData("CUSTOMVERB", "Transfer-Encoding: chunked")] + public void OpaqueUpgrade_InvalidMethodUpgrade_Disconnected(string method, string extraHeader) + { + OpaqueUpgrade opaqueUpgrade = null; + using (CreateServer(env => + { + opaqueUpgrade = env.Get("opaque.Upgrade"); + if (opaqueUpgrade == null) + { + throw new NotImplementedException(); + } + opaqueUpgrade(null, opqEnv => Task.FromResult(0)); + return Task.FromResult(0); + })) + { + Assert.Throws(() => + { + try + { + return SendOpaqueRequestAsync(method, Address, extraHeader).Result; + } + catch (AggregateException ag) + { + throw ag.GetBaseException(); + } + }); + Assert.Null(opaqueUpgrade); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + OwinServerFactory.Initialize(properties); + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + + // Returns a bidirectional opaque stream or throws if the upgrade fails + private async Task SendOpaqueRequestAsync(string method, string address, string extraHeader = null) + { + // Connect with a socket + Uri uri = new Uri(address); + TcpClient client = new TcpClient(); + try + { + await client.ConnectAsync(uri.Host, uri.Port); + NetworkStream stream = client.GetStream(); + + // Send an HTTP GET request + byte[] requestBytes = BuildGetRequest(method, uri, extraHeader); + await stream.WriteAsync(requestBytes, 0, requestBytes.Length); + + // Read the response headers, fail if it's not a 101 + await ParseResponseAsync(stream); + + // Return the opaque network stream + return stream; + } + catch (Exception) + { + client.Close(); + throw; + } + } + + private byte[] BuildGetRequest(string method, Uri uri, string extraHeader) + { + StringBuilder builder = new StringBuilder(); + builder.Append(method); + builder.Append(" "); + builder.Append(uri.PathAndQuery); + builder.Append(" HTTP/1.1"); + builder.AppendLine(); + + builder.Append("Host: "); + builder.Append(uri.Host); + builder.Append(':'); + builder.Append(uri.Port); + builder.AppendLine(); + + if (!string.IsNullOrEmpty(extraHeader)) + { + builder.AppendLine(extraHeader); + } + + builder.AppendLine(); + return Encoding.ASCII.GetBytes(builder.ToString()); + } + + // Read the response headers, fail if it's not a 101 + private async Task ParseResponseAsync(NetworkStream stream) + { + StreamReader reader = new StreamReader(stream); + string statusLine = await reader.ReadLineAsync(); + string[] parts = statusLine.Split(' '); + if (int.Parse(parts[1]) != 101) + { + throw new InvalidOperationException("The response status code was incorrect: " + statusLine); + } + + // Scan to the end of the headers + while (!string.IsNullOrEmpty(reader.ReadLine())) + { + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/Properties/AssemblyInfo.cs b/test/Microsoft.AspNet.Server.WebListener.Test/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..486c5d20cd --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/Properties/AssemblyInfo.cs @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("Microsoft.AspNet.Server.WebListener.Tests")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.AspNet.Server.WebListener.Tests")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("a265fcd6-3542-4f59-a1dd-ad423d40ddde")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/RequestBodyTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/RequestBodyTests.cs new file mode 100644 index 0000000000..3687766ca1 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/RequestBodyTests.cs @@ -0,0 +1,176 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class RequestBodyTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task RequestBody_ReadSync_Success() + { + using (CreateServer(env => + { + byte[] input = new byte[100]; + int read = env.Get("owin.RequestBody").Read(input, 0, input.Length); + + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { read.ToString() }; + env.Get("owin.ResponseBody").Write(input, 0, read); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address, "Hello World"); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public async Task RequestBody_ReadAync_Success() + { + using (CreateServer(async env => + { + byte[] input = new byte[100]; + int read = await env.Get("owin.RequestBody").ReadAsync(input, 0, input.Length); + + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { read.ToString() }; + await env.Get("owin.ResponseBody").WriteAsync(input, 0, read); + })) + { + string response = await SendRequestAsync(Address, "Hello World"); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public async Task RequestBody_ReadBeginEnd_Success() + { + using (CreateServer(env => + { + Stream requestStream = env.Get("owin.RequestBody"); + byte[] input = new byte[100]; + int read = requestStream.EndRead(requestStream.BeginRead(input, 0, input.Length, null, null)); + + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { read.ToString() }; + Stream responseStream = env.Get("owin.ResponseBody"); + responseStream.EndWrite(responseStream.BeginWrite(input, 0, read, null, null)); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address, "Hello World"); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public async Task RequestBody_ReadSyncPartialBody_Success() + { + StaggardContent content = new StaggardContent(); + using (CreateServer(env => + { + byte[] input = new byte[10]; + int read = env.Get("owin.RequestBody").Read(input, 0, input.Length); + Assert.Equal(5, read); + content.Block.Release(); + read = env.Get("owin.RequestBody").Read(input, 0, input.Length); + Assert.Equal(5, read); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address, content); + Assert.Equal(string.Empty, response); + } + } + + [Fact] + public async Task RequestBody_ReadAsyncPartialBody_Success() + { + StaggardContent content = new StaggardContent(); + using (CreateServer(async env => + { + byte[] input = new byte[10]; + int read = await env.Get("owin.RequestBody").ReadAsync(input, 0, input.Length); + Assert.Equal(5, read); + content.Block.Release(); + read = await env.Get("owin.RequestBody").ReadAsync(input, 0, input.Length); + Assert.Equal(5, read); + })) + { + string response = await SendRequestAsync(Address, content); + Assert.Equal(string.Empty, response); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private Task SendRequestAsync(string uri, string upload) + { + return SendRequestAsync(uri, new StringContent(upload)); + } + + private async Task SendRequestAsync(string uri, HttpContent content) + { + using (HttpClient client = new HttpClient()) + { + HttpResponseMessage response = await client.PostAsync(uri, content); + response.EnsureSuccessStatusCode(); + return await response.Content.ReadAsStringAsync(); + } + } + + private class StaggardContent : HttpContent + { + public StaggardContent() + { + Block = new SemaphoreSlim(0, 1); + } + + public SemaphoreSlim Block { get; private set; } + + protected async override Task SerializeToStreamAsync(Stream stream, TransportContext context) + { + await stream.WriteAsync(new byte[5], 0, 5); + await Block.WaitAsync(); + await stream.WriteAsync(new byte[5], 0, 5); + } + + protected override bool TryComputeLength(out long length) + { + length = 10; + return true; + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/RequestHeaderTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/RequestHeaderTests.cs new file mode 100644 index 0000000000..a322ea06d0 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/RequestHeaderTests.cs @@ -0,0 +1,120 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Net.Sockets; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class RequestHeaderTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task RequestHeaders_ClientSendsDefaultHeaders_Success() + { + using (CreateServer(env => + { + var requestHeaders = env.Get>("owin.RequestHeaders"); + // NOTE: The System.Net client only sends the Connection: keep-alive header on the first connection per service-point. + // Assert.Equal(2, requestHeaders.Count); + // Assert.Equal("Keep-Alive", requestHeaders.Get("Connection")); + Assert.Equal("localhost:8080", requestHeaders.Get("Host")); + Assert.Equal(null, requestHeaders.Get("Accept")); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal(string.Empty, response); + } + } + + [Fact] + public async Task RequestHeaders_ClientSendsCustomHeaders_Success() + { + using (CreateServer(env => + { + var requestHeaders = env.Get>("owin.RequestHeaders"); + Assert.Equal(4, requestHeaders.Count); + Assert.Equal("localhost:8080", requestHeaders.Get("Host")); + Assert.Equal("close", requestHeaders.Get("Connection")); + Assert.Equal(1, requestHeaders["Custom-Header"].Length); + // Apparently Http.Sys squashes request headers together. + Assert.Equal("custom1, and custom2, custom3", requestHeaders.Get("Custom-Header")); + Assert.Equal(1, requestHeaders["Spacer-Header"].Length); + Assert.Equal("spacervalue, spacervalue", requestHeaders.Get("Spacer-Header")); + return Task.FromResult(0); + })) + { + string[] customValues = new string[] { "custom1, and custom2", "custom3" }; + + await SendRequestAsync("localhost", 8080, "Custom-Header", customValues); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetStringAsync(uri); + } + } + + private async Task SendRequestAsync(string host, int port, string customHeader, string[] customValues) + { + StringBuilder builder = new StringBuilder(); + builder.AppendLine("GET / HTTP/1.1"); + builder.AppendLine("Connection: close"); + builder.Append("HOST: "); + builder.Append(host); + builder.Append(':'); + builder.AppendLine(port.ToString()); + foreach (string value in customValues) + { + builder.Append(customHeader); + builder.Append(": "); + builder.AppendLine(value); + builder.AppendLine("Spacer-Header: spacervalue"); + } + builder.AppendLine(); + + byte[] request = Encoding.ASCII.GetBytes(builder.ToString()); + + Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + socket.Connect(host, port); + + socket.Send(request); + + byte[] response = new byte[1024 * 5]; + await Task.Run(() => socket.Receive(response)); + socket.Close(); + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/RequestTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/RequestTests.cs new file mode 100644 index 0000000000..a99bdafa86 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/RequestTests.cs @@ -0,0 +1,185 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; +using Xunit.Extensions; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class RequestTests + { + private const string Address = "http://localhost:8080"; + + [Fact] + public async Task Request_SimpleGet_Success() + { + using (CreateServer(env => + { + try + { + // General keys + Assert.Equal("1.0", env.Get("owin.Version")); + Assert.True(env.Get("owin.CallCancelled").CanBeCanceled); + + // Request Keys + Assert.Equal("GET", env.Get("owin.RequestMethod")); + Assert.Equal(Stream.Null, env.Get("owin.RequestBody")); + Assert.NotNull(env.Get>("owin.RequestHeaders")); + Assert.Equal("http", env.Get("owin.RequestScheme")); + Assert.Equal("/basepath", env.Get("owin.RequestPathBase")); + Assert.Equal("/SomePath", env.Get("owin.RequestPath")); + Assert.Equal("SomeQuery", env.Get("owin.RequestQueryString")); + Assert.Equal("HTTP/1.1", env.Get("owin.RequestProtocol")); + + // Server Keys + Assert.NotNull(env.Get>("server.Capabilities")); + Assert.Equal("::1", env.Get("server.RemoteIpAddress")); + Assert.NotNull(env.Get("server.RemotePort")); + Assert.Equal("::1", env.Get("server.LocalIpAddress")); + Assert.Equal("8080", env.Get("server.LocalPort")); + Assert.True(env.Get("server.IsLocal")); + + // Note: Response keys are validated in the ResponseTests + } + catch (Exception ex) + { + byte[] body = Encoding.ASCII.GetBytes(ex.ToString()); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + } + return Task.FromResult(0); + }, "http", "localhost", "8080", "/basepath")) + { + string response = await SendRequestAsync(Address + "/basepath/SomePath?SomeQuery"); + Assert.Equal(string.Empty, response); + } + } + + [Theory] + [InlineData("http", "localhost", "8080", "/", "http://localhost:8080/", "", "/")] + [InlineData("http", "localhost", "8080", "/basepath/", "http://localhost:8080/basepath", "/basepath", "")] + [InlineData("http", "localhost", "8080", "/basepath/", "http://localhost:8080/basepath/", "/basepath", "/")] + [InlineData("http", "localhost", "8080", "/basepath/", "http://localhost:8080/basepath/subpath", "/basepath", "/subpath")] + [InlineData("http", "localhost", "8080", "/base path/", "http://localhost:8080/base%20path/sub path", "/base path", "/sub path")] + [InlineData("http", "localhost", "8080", "/base葉path/", "http://localhost:8080/base%E8%91%89path/sub%E8%91%89path", "/base葉path", "/sub葉path")] + public async Task Request_PathSplitting(string scheme, string host, string port, string pathBase, string requestUri, + string expectedPathBase, string expectedPath) + { + using (CreateServer(env => + { + try + { + Uri uri = new Uri(requestUri); + string expectedQuery = uri.Query.Length > 0 ? uri.Query.Substring(1) : string.Empty; + // Request Keys + Assert.Equal(scheme, env.Get("owin.RequestScheme")); + Assert.Equal(expectedPath, env.Get("owin.RequestPath")); + Assert.Equal(expectedPathBase, env.Get("owin.RequestPathBase")); + Assert.Equal(expectedQuery, env.Get("owin.RequestQueryString")); + Assert.Equal(port, env.Get("server.LocalPort")); + } + catch (Exception ex) + { + byte[] body = Encoding.ASCII.GetBytes(ex.ToString()); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + } + return Task.FromResult(0); + }, scheme, host, port, pathBase)) + { + string response = await SendRequestAsync(requestUri); + Assert.Equal(string.Empty, response); + } + } + + [Theory] + // The test server defines these prefixes: "/", "/11", "/2/3", "/2", "/11/2" + [InlineData("/", "", "/")] + [InlineData("/random", "", "/random")] + [InlineData("/11", "/11", "")] + [InlineData("/11/", "/11", "/")] + [InlineData("/11/random", "/11", "/random")] + [InlineData("/2", "/2", "")] + [InlineData("/2/", "/2", "/")] + [InlineData("/2/random", "/2", "/random")] + [InlineData("/2/3", "/2/3", "")] + [InlineData("/2/3/", "/2/3", "/")] + [InlineData("/2/3/random", "/2/3", "/random")] + public async Task Request_MultiplePrefixes(string requestUri, string expectedPathBase, string expectedPath) + { + using (CreateServer(env => + { + try + { + Assert.Equal(expectedPath, env.Get("owin.RequestPath")); + Assert.Equal(expectedPathBase, env.Get("owin.RequestPathBase")); + } + catch (Exception ex) + { + byte[] body = Encoding.ASCII.GetBytes(ex.ToString()); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + } + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address + requestUri); + Assert.Equal(string.Empty, response); + } + } + + private IDisposable CreateServer(AppFunc app, string scheme, string host, string port, string path) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = scheme; + address["host"] = host; + address["port"] = port; + address["path"] = path; + + return OwinServerFactory.Create(app, properties); + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + foreach (string path in new[] { "/", "/11", "/2/3", "/2", "/11/2" }) + { + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = path; + } + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetStringAsync(uri); + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/ResponseBodyTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseBodyTests.cs new file mode 100644 index 0000000000..bc95cb305f --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseBodyTests.cs @@ -0,0 +1,213 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class ResponseBodyTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task ResponseBody_WriteNoHeaders_DefaultsToChunked() + { + using (CreateServer(env => + { + env.Get("owin.ResponseBody").Write(new byte[10], 0, 10); + return env.Get("owin.ResponseBody").WriteAsync(new byte[10], 0, 10); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new Version(1, 1), response.Version); + IEnumerable ignored; + Assert.False(response.Content.Headers.TryGetValues("content-length", out ignored), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value, "Chunked"); + Assert.Equal(new byte[20], await response.Content.ReadAsByteArrayAsync()); + } + } + + [Fact] + public async Task ResponseBody_WriteChunked_Chunked() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["transfeR-Encoding"] = new string[] { " CHunked " }; + Stream stream = env.Get("owin.ResponseBody"); + stream.EndWrite(stream.BeginWrite(new byte[10], 0, 10, null, null)); + stream.Write(new byte[10], 0, 10); + return stream.WriteAsync(new byte[10], 0, 10); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new Version(1, 1), response.Version); + IEnumerable ignored; + Assert.False(response.Content.Headers.TryGetValues("content-length", out ignored), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value, "Chunked"); + Assert.Equal(new byte[30], await response.Content.ReadAsByteArrayAsync()); + } + } + + [Fact] + public async Task ResponseBody_WriteContentLength_PassedThrough() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { " 30 " }; + Stream stream = env.Get("owin.ResponseBody"); + stream.EndWrite(stream.BeginWrite(new byte[10], 0, 10, null, null)); + stream.Write(new byte[10], 0, 10); + return stream.WriteAsync(new byte[10], 0, 10); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new Version(1, 1), response.Version); + IEnumerable contentLength; + Assert.True(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.Equal("30", contentLength.First()); + Assert.Null(response.Headers.TransferEncodingChunked); + Assert.Equal(new byte[30], await response.Content.ReadAsByteArrayAsync()); + } + } + + [Fact] + public async Task ResponseBody_Http10WriteNoHeaders_DefaultsConnectionClose() + { + using (CreateServer(env => + { + env["owin.ResponseProtocol"] = "HTTP/1.0"; + env.Get("owin.ResponseBody").Write(new byte[10], 0, 10); + return env.Get("owin.ResponseBody").WriteAsync(new byte[10], 0, 10); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new Version(1, 1), response.Version); // Http.Sys won't transmit 1.0 + IEnumerable ignored; + Assert.False(response.Content.Headers.TryGetValues("content-length", out ignored), "Content-Length"); + Assert.Null(response.Headers.TransferEncodingChunked); + Assert.Equal(new byte[20], await response.Content.ReadAsByteArrayAsync()); + } + } + + [Fact] + public void ResponseBody_WriteContentLengthNoneWritten_Throws() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { " 20 " }; + return Task.FromResult(0); + })) + { + Assert.Throws(() => SendRequestAsync(Address).Result); + } + } + + [Fact] + public void ResponseBody_WriteContentLengthNotEnoughWritten_Throws() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { " 20 " }; + env.Get("owin.ResponseBody").Write(new byte[5], 0, 5); + return Task.FromResult(0); + })) + { + Assert.Throws(() => SendRequestAsync(Address).Result); + } + } + + [Fact] + public void ResponseBody_WriteContentLengthTooMuchWritten_Throws() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { " 10 " }; + env.Get("owin.ResponseBody").Write(new byte[5], 0, 5); + env.Get("owin.ResponseBody").Write(new byte[6], 0, 6); + return Task.FromResult(0); + })) + { + Assert.Throws(() => SendRequestAsync(Address).Result); + } + } + + [Fact] + public async Task ResponseBody_WriteContentLengthExtraWritten_Throws() + { + ManualResetEvent waitHandle = new ManualResetEvent(false); + bool? appThrew = null; + using (CreateServer(env => + { + try + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { " 10 " }; + env.Get("owin.ResponseBody").Write(new byte[10], 0, 10); + env.Get("owin.ResponseBody").Write(new byte[9], 0, 9); + appThrew = false; + } + catch (Exception) + { + appThrew = true; + } + waitHandle.Set(); + return Task.FromResult(0); + })) + { + // The full response is received. + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new Version(1, 1), response.Version); + IEnumerable contentLength; + Assert.True(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.Equal("10", contentLength.First()); + Assert.Null(response.Headers.TransferEncodingChunked); + Assert.Equal(new byte[10], await response.Content.ReadAsByteArrayAsync()); + + Assert.True(waitHandle.WaitOne(100)); + Assert.True(appThrew.HasValue, "appThrew.HasValue"); + Assert.True(appThrew.Value, "appThrew.Value"); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/ResponseHeaderTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseHeaderTests.cs new file mode 100644 index 0000000000..56b3bb48f4 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseHeaderTests.cs @@ -0,0 +1,243 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class ResponseHeaderTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task ResponseHeaders_ServerSendsDefaultHeaders_Success() + { + using (CreateServer(env => + { + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.Equal(2, response.Headers.Count()); + Assert.False(response.Headers.TransferEncodingChunked.HasValue); + Assert.True(response.Headers.Date.HasValue); + Assert.Equal("Microsoft-HTTPAPI/2.0", response.Headers.Server.ToString()); + Assert.Equal(1, response.Content.Headers.Count()); + Assert.Equal(0, response.Content.Headers.ContentLength); + } + } + + [Fact] + public async Task ResponseHeaders_ServerSendsCustomHeaders_Success() + { + using (CreateServer(env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Custom-Header1"] = new string[] { "custom1, and custom2", "custom3" }; + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.Equal(3, response.Headers.Count()); + Assert.False(response.Headers.TransferEncodingChunked.HasValue); + Assert.True(response.Headers.Date.HasValue); + Assert.Equal("Microsoft-HTTPAPI/2.0", response.Headers.Server.ToString()); + Assert.Equal(new string[] { "custom1, and custom2", "custom3" }, response.Headers.GetValues("Custom-Header1")); + Assert.Equal(1, response.Content.Headers.Count()); + Assert.Equal(0, response.Content.Headers.ContentLength); + } + } + + [Fact] + public async Task ResponseHeaders_ServerSendsConnectionClose_Closed() + { + using (CreateServer(env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Connection"] = new string[] { "Close" }; + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.True(response.Headers.ConnectionClose.Value); + Assert.Equal(new string[] { "close" }, response.Headers.GetValues("Connection")); + } + } + + [Fact] + public async Task ResponseHeaders_SendsHttp10_Gets11Close() + { + using (CreateServer(env => + { + env["owin.ResponseProtocol"] = "HTTP/1.0"; + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.Equal(new Version(1, 1), response.Version); + Assert.True(response.Headers.ConnectionClose.Value); + Assert.Equal(new string[] { "close" }, response.Headers.GetValues("Connection")); + } + } + + [Fact] + public async Task ResponseHeaders_SendsHttp10WithBody_Gets11Close() + { + using (CreateServer(env => + { + env["owin.ResponseProtocol"] = "HTTP/1.0"; + return env.Get("owin.ResponseBody").WriteAsync(new byte[10], 0, 10); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.Equal(new Version(1, 1), response.Version); + Assert.False(response.Headers.TransferEncodingChunked.HasValue); + Assert.False(response.Content.Headers.Contains("Content-Length")); + Assert.True(response.Headers.ConnectionClose.Value); + Assert.Equal(new string[] { "close" }, response.Headers.GetValues("Connection")); + } + } + + [Fact] + public async Task ResponseHeaders_HTTP10Request_Gets11Close() + { + using (CreateServer(env => + { + return Task.FromResult(0); + })) + { + using (HttpClient client = new HttpClient()) + { + HttpRequestMessage request = new HttpRequestMessage(HttpMethod.Get, Address); + request.Version = new Version(1, 0); + HttpResponseMessage response = await client.SendAsync(request); + response.EnsureSuccessStatusCode(); + Assert.Equal(new Version(1, 1), response.Version); + Assert.True(response.Headers.ConnectionClose.Value); + Assert.Equal(new string[] { "close" }, response.Headers.GetValues("Connection")); + } + } + } + + [Fact] + public async Task ResponseHeaders_HTTP10Request_RemovesChunkedHeader() + { + using (CreateServer(env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Transfer-Encoding"] = new string[] { "chunked" }; + return env.Get("owin.ResponseBody").WriteAsync(new byte[10], 0, 10); + })) + { + using (HttpClient client = new HttpClient()) + { + HttpRequestMessage request = new HttpRequestMessage(HttpMethod.Get, Address); + request.Version = new Version(1, 0); + HttpResponseMessage response = await client.SendAsync(request); + response.EnsureSuccessStatusCode(); + Assert.Equal(new Version(1, 1), response.Version); + Assert.False(response.Headers.TransferEncodingChunked.HasValue); + Assert.False(response.Content.Headers.Contains("Content-Length")); + Assert.True(response.Headers.ConnectionClose.Value); + Assert.Equal(new string[] { "close" }, response.Headers.GetValues("Connection")); + } + } + } + + [Fact] + public async Task Headers_FlushSendsHeaders_Success() + { + using (CreateServer( + env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders.Add("Custom1", new string[] { "value1a", "value1b" }); + responseHeaders.Add("Custom2", new string[] { "value2a, value2b" }); + var body = env.Get("owin.ResponseBody"); + body.Flush(); + env["owin.ResponseStatusCode"] = 404; // Ignored + responseHeaders.Add("Custom3", new string[] { "value3a, value3b", "value3c" }); // Ignored + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.Equal(5, response.Headers.Count()); // Date, Server, Chunked + + Assert.Equal(2, response.Headers.GetValues("Custom1").Count()); + Assert.Equal("value1a", response.Headers.GetValues("Custom1").First()); + Assert.Equal("value1b", response.Headers.GetValues("Custom1").Skip(1).First()); + Assert.Equal(1, response.Headers.GetValues("Custom2").Count()); + Assert.Equal("value2a, value2b", response.Headers.GetValues("Custom2").First()); + } + } + + [Fact] + public async Task Headers_FlushAsyncSendsHeaders_Success() + { + using (CreateServer( + async env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders.Add("Custom1", new string[] { "value1a", "value1b" }); + responseHeaders.Add("Custom2", new string[] { "value2a, value2b" }); + var body = env.Get("owin.ResponseBody"); + await body.FlushAsync(); + env["owin.ResponseStatusCode"] = 404; // Ignored + responseHeaders.Add("Custom3", new string[] { "value3a, value3b", "value3c" }); // Ignored + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.Equal(5, response.Headers.Count()); // Date, Server, Chunked + + Assert.Equal(2, response.Headers.GetValues("Custom1").Count()); + Assert.Equal("value1a", response.Headers.GetValues("Custom1").First()); + Assert.Equal("value1b", response.Headers.GetValues("Custom1").Skip(1).First()); + Assert.Equal(1, response.Headers.GetValues("Custom2").Count()); + Assert.Equal("value2a, value2b", response.Headers.GetValues("Custom2").First()); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/ResponseSendFileTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseSendFileTests.cs new file mode 100644 index 0000000000..d176dc2a1a --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseSendFileTests.cs @@ -0,0 +1,327 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + using SendFileFunc = Func; + + public class ResponseSendFileTests + { + private const string Address = "http://localhost:8080/"; + private static readonly string AbsoluteFilePath = Environment.CurrentDirectory + "\\Microsoft.AspNet.Server.WebListener.dll"; + private static readonly string RelativeFilePath = "Microsoft.AspNet.Server.WebListener.dll"; + private static readonly long FileLength = new FileInfo(AbsoluteFilePath).Length; + + [Fact] + public async Task ResponseSendFile_SupportKeys_Present() + { + using (CreateServer(env => + { + try + { + IDictionary capabilities = env.Get>("server.Capabilities"); + Assert.NotNull(capabilities); + + Assert.Equal("1.0", capabilities.Get("sendfile.Version")); + + IDictionary support = capabilities.Get>("sendfile.Support"); + Assert.NotNull(support); + + Assert.Equal("Overlapped", support.Get("sendfile.Concurrency")); + + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + Assert.NotNull(sendFileAsync); + } + catch (Exception ex) + { + byte[] body = Encoding.UTF8.GetBytes(ex.ToString()); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + } + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable ignored; + Assert.True(response.Content.Headers.TryGetValues("content-length", out ignored), "Content-Length"); + Assert.False(response.Headers.TransferEncodingChunked.HasValue, "Chunked"); + Assert.Equal(0, response.Content.Headers.ContentLength); + Assert.Equal(string.Empty, await response.Content.ReadAsStringAsync()); + } + } + + [Fact] + public async Task ResponseSendFile_MissingFile_Throws() + { + ManualResetEvent waitHandle = new ManualResetEvent(false); + bool? appThrew = null; + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + try + { + sendFileAsync(string.Empty, 0, null, CancellationToken.None).Wait(); + appThrew = false; + } + catch (Exception) + { + appThrew = true; + throw; + } + finally + { + waitHandle.Set(); + } + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + Assert.True(waitHandle.WaitOne(100)); + Assert.True(appThrew.HasValue, "appThrew.HasValue"); + Assert.True(appThrew.Value, "appThrew.Value"); + } + } + + [Fact] + public async Task ResponseSendFile_NoHeaders_DefaultsToChunked() + { + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable ignored; + Assert.False(response.Content.Headers.TryGetValues("content-length", out ignored), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value, "Chunked"); + Assert.Equal(FileLength, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_RelativeFile_Success() + { + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(RelativeFilePath, 0, null, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable ignored; + Assert.False(response.Content.Headers.TryGetValues("content-length", out ignored), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value, "Chunked"); + Assert.Equal(FileLength, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_Chunked_Chunked() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Transfer-EncodinG"] = new string[] { "CHUNKED" }; + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.False(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value); + Assert.Equal(FileLength, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_MultipleChunks_Chunked() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Transfer-EncodinG"] = new string[] { "CHUNKED" }; + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + sendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None).Wait(); + return sendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.False(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value); + Assert.Equal(FileLength * 2, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_ChunkedHalfOfFile_Chunked() + { + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, FileLength / 2, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.False(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value); + Assert.Equal(FileLength / 2, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_ChunkedOffsetOutOfRange_Throws() + { + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 1234567, null, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(500, (int)response.StatusCode); + } + } + + [Fact] + public async Task ResponseSendFile_ChunkedCountOutOfRange_Throws() + { + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, 1234567, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(500, (int)response.StatusCode); + } + } + + [Fact] + public async Task ResponseSendFile_ChunkedCount0_Chunked() + { + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, 0, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.False(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value); + Assert.Equal(0, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_ContentLength_PassedThrough() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { FileLength.ToString() }; + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.True(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.Equal(FileLength.ToString(), contentLength.First()); + Assert.Null(response.Headers.TransferEncodingChunked); + Assert.Equal(FileLength, response.Content.ReadAsByteArrayAsync().Result.Length); + } + } + + [Fact] + public async Task ResponseSendFile_ContentLengthSpecific_PassedThrough() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { "10" }; + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, 10, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.True(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.Equal("10", contentLength.First()); + Assert.Null(response.Headers.TransferEncodingChunked); + Assert.Equal(10, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_ContentLength0_PassedThrough() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { "0" }; + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, 0, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.True(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.Equal("0", contentLength.First()); + Assert.Null(response.Headers.TransferEncodingChunked); + Assert.Equal(0, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IList> addresses = new List>(); + IDictionary properties = new Dictionary(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + OwinServerFactory.Initialize(properties); + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/ResponseTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseTests.cs new file mode 100644 index 0000000000..537ad8c6c9 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseTests.cs @@ -0,0 +1,142 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class ResponseTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task Response_ServerSendsDefaultResponse_ServerProvidesStatusCodeAndReasonPhrase() + { + using (CreateServer(env => + { + Assert.Equal(200, env["owin.ResponseStatusCode"]); + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal("OK", response.ReasonPhrase); + Assert.Equal(new Version(1, 1), response.Version); + Assert.Equal(string.Empty, await response.Content.ReadAsStringAsync()); + } + } + + [Fact] + public async Task Response_ServerSendsSpecificStatus_ServerProvidesReasonPhrase() + { + using (CreateServer(env => + { + env["owin.ResponseStatusCode"] = 201; + env["owin.ResponseProtocol"] = "HTTP/1.0"; // Http.Sys ignores this value + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(201, (int)response.StatusCode); + Assert.Equal("Created", response.ReasonPhrase); + Assert.Equal(new Version(1, 1), response.Version); + Assert.Equal(string.Empty, await response.Content.ReadAsStringAsync()); + } + } + + [Fact] + public async Task Response_ServerSendsSpecificStatusAndReasonPhrase_PassedThrough() + { + using (CreateServer(env => + { + env["owin.ResponseStatusCode"] = 201; + env["owin.ResponseReasonPhrase"] = "CustomReasonPhrase"; + env["owin.ResponseProtocol"] = "HTTP/1.0"; // Http.Sys ignores this value + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(201, (int)response.StatusCode); + Assert.Equal("CustomReasonPhrase", response.ReasonPhrase); + Assert.Equal(new Version(1, 1), response.Version); + Assert.Equal(string.Empty, await response.Content.ReadAsStringAsync()); + } + } + + [Fact] + public async Task Response_ServerSendsCustomStatus_NoReasonPhrase() + { + using (CreateServer(env => + { + env["owin.ResponseStatusCode"] = 901; + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(901, (int)response.StatusCode); + Assert.Equal(string.Empty, response.ReasonPhrase); + Assert.Equal(string.Empty, await response.Content.ReadAsStringAsync()); + } + } + + [Fact] + public void Response_100_Throws() + { + using (CreateServer(env => + { + env["owin.ResponseStatusCode"] = 100; + return Task.FromResult(0); + })) + { + Assert.Throws(() => SendRequestAsync(Address).Result); + } + } + + [Fact] + public void Response_0_Throws() + { + using (CreateServer(env => + { + env["owin.ResponseStatusCode"] = 0; + return Task.FromResult(0); + })) + { + Assert.Throws(() => SendRequestAsync(Address).Result); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/ServerTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/ServerTests.cs new file mode 100644 index 0000000000..84f76d9a52 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/ServerTests.cs @@ -0,0 +1,287 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Net.Http; +using System.Net.Sockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class ServerTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task Server_200OK_Success() + { + using (CreateServer(env => + { + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal(string.Empty, response); + } + } + + [Fact] + public async Task Server_SendHelloWorld_Success() + { + using (CreateServer(env => + { + byte[] body = Encoding.UTF8.GetBytes("Hello World"); + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { body.Length.ToString() }; + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public async Task Server_EchoHelloWorld_Success() + { + using (CreateServer(env => + { + string input = new StreamReader(env.Get("owin.RequestBody")).ReadToEnd(); + Assert.Equal("Hello World", input); + byte[] body = Encoding.UTF8.GetBytes("Hello World"); + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { body.Length.ToString() }; + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address, "Hello World"); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public void Server_AppException_ClientReset() + { + using (CreateServer(env => + { + throw new InvalidOperationException(); + })) + { + Task requestTask = SendRequestAsync(Address); + Assert.Throws(() => requestTask.Result); + + // Do it again to make sure the server didn't crash + requestTask = SendRequestAsync(Address); + Assert.Throws(() => requestTask.Result); + } + } + + [Fact] + public void Server_MultipleOutstandingSyncRequests_Success() + { + int requestLimit = 10; + int requestCount = 0; + TaskCompletionSource tcs = new TaskCompletionSource(); + + using (CreateServer(env => + { + if (Interlocked.Increment(ref requestCount) == requestLimit) + { + tcs.TrySetResult(null); + } + else + { + tcs.Task.Wait(); + } + + return Task.FromResult(0); + })) + { + List requestTasks = new List(); + for (int i = 0; i < requestLimit; i++) + { + Task requestTask = SendRequestAsync(Address); + requestTasks.Add(requestTask); + } + + bool success = Task.WaitAll(requestTasks.ToArray(), TimeSpan.FromSeconds(5)); + if (!success) + { + Console.WriteLine(); + } + Assert.True(success, "Timed out"); + } + } + + [Fact] + public void Server_MultipleOutstandingAsyncRequests_Success() + { + int requestLimit = 10; + int requestCount = 0; + TaskCompletionSource tcs = new TaskCompletionSource(); + + using (CreateServer(async env => + { + if (Interlocked.Increment(ref requestCount) == requestLimit) + { + tcs.TrySetResult(null); + } + else + { + await tcs.Task; + } + })) + { + List requestTasks = new List(); + for (int i = 0; i < requestLimit; i++) + { + Task requestTask = SendRequestAsync(Address); + requestTasks.Add(requestTask); + } + Assert.True(Task.WaitAll(requestTasks.ToArray(), TimeSpan.FromSeconds(2)), "Timed out"); + } + } + + [Fact] + public async Task Server_ClientDisconnects_CallCancelled() + { + TimeSpan interval = TimeSpan.FromSeconds(1); + ManualResetEvent received = new ManualResetEvent(false); + ManualResetEvent aborted = new ManualResetEvent(false); + ManualResetEvent canceled = new ManualResetEvent(false); + + using (CreateServer(env => + { + CancellationToken ct = env.Get("owin.CallCancelled"); + Assert.True(ct.CanBeCanceled, "CanBeCanceled"); + Assert.False(ct.IsCancellationRequested, "IsCancellationRequested"); + ct.Register(() => canceled.Set()); + received.Set(); + Assert.True(aborted.WaitOne(interval), "Aborted"); + Assert.True(ct.WaitHandle.WaitOne(interval), "CT Wait"); + Assert.True(ct.IsCancellationRequested, "IsCancellationRequested"); + return Task.FromResult(0); + })) + { + // Note: System.Net.Sockets does not RST the connection by default, it just FINs. + // Http.Sys's disconnect notice requires a RST. + using (Socket socket = await SendHungRequestAsync("GET", Address)) + { + Assert.True(received.WaitOne(interval), "Receive Timeout"); + socket.Close(0); // Force a RST + aborted.Set(); + } + Assert.True(canceled.WaitOne(interval), "canceled"); + } + } + + [Fact] + public async Task Server_SetQueueLimit_Success() + { + using (CreateServer(env => + { + // There's no good way to validate this in code. Just execute it to make sure it doesn't crash. + // Run "netsh http show servicestate" to see the current value + var listener = env.Get("Microsoft.AspNet.Server.WebListener.OwinWebListener"); + listener.SetRequestQueueLimit(1001); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal(string.Empty, response); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + ServicePointManager.DefaultConnectionLimit = 100; + using (HttpClient client = new HttpClient()) + { + return await client.GetStringAsync(uri); + } + } + + private async Task SendRequestAsync(string uri, string upload) + { + using (HttpClient client = new HttpClient()) + { + HttpResponseMessage response = await client.PostAsync(uri, new StringContent(upload)); + response.EnsureSuccessStatusCode(); + return await response.Content.ReadAsStringAsync(); + } + } + + private async Task SendHungRequestAsync(string method, string address) + { + // Connect with a socket + Uri uri = new Uri(address); + TcpClient client = new TcpClient(); + try + { + await client.ConnectAsync(uri.Host, uri.Port); + NetworkStream stream = client.GetStream(); + + // Send an HTTP GET request + byte[] requestBytes = BuildGetRequest(method, uri); + await stream.WriteAsync(requestBytes, 0, requestBytes.Length); + + // Return the opaque network stream + return client.Client; + } + catch (Exception) + { + client.Close(); + throw; + } + } + + private byte[] BuildGetRequest(string method, Uri uri) + { + StringBuilder builder = new StringBuilder(); + builder.Append(method); + builder.Append(" "); + builder.Append(uri.PathAndQuery); + builder.Append(" HTTP/1.1"); + builder.AppendLine(); + + builder.Append("Host: "); + builder.Append(uri.Host); + builder.Append(':'); + builder.Append(uri.Port); + builder.AppendLine(); + + builder.AppendLine(); + return Encoding.ASCII.GetBytes(builder.ToString()); + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/project.json b/test/Microsoft.AspNet.Server.WebListener.Test/project.json new file mode 100644 index 0000000000..b11449f836 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/project.json @@ -0,0 +1,16 @@ +{ + "version" : "0.1-alpha-*", + "dependencies": { + "Microsoft.AspNet.Server.WebListener" : "" + }, + "configurations": { + "net45": { + "dependencies": { + "XUnit": "1.9.2", + "XUnit.Extensions": "1.9.2", + "System.Net.Http": "", + "System.Net.Http.WebRequest": "" + } + } + } +}