From 251630590d1e666da57443bf12fdcbba6c8e653f Mon Sep 17 00:00:00 2001 From: Chris Ross Date: Fri, 7 Feb 2014 17:01:08 -0800 Subject: [PATCH] Initial port. --- .gitattributes | 50 + .gitignore | 22 + KatanaInternal.sln | 94 + NuGet.Config | 13 + build.cmd | 23 + global.json | 3 + makefile.shade | 7 + samples/HelloWorld/Program.cs | 55 + samples/HelloWorld/project.json | 10 + samples/SelfHostServer/App.config | 9 + samples/SelfHostServer/Program.cs | 138 + .../SelfHostServer/Properties/AssemblyInfo.cs | 42 + samples/SelfHostServer/Public/1kb.txt | 1 + samples/SelfHostServer/packages.config | 8 + samples/SelfHostServer/project.json | 19 + samples/TestClient/App.config | 6 + samples/TestClient/Program.cs | 78 + samples/TestClient/Properties/AssemblyInfo.cs | 36 + samples/TestClient/TestClient.csproj | 62 + .../AuthTypes.cs | 37 + .../ComNetOS.cs | 38 + .../Constants.cs | 44 + .../DictionaryExtensions.cs | 67 + .../DigestCache.cs | 153 + .../DisconnectAsyncResult.cs | 93 + .../HeaderEncoding.cs | 139 + .../HttpKnownHeaderNames.cs | 75 + .../HttpStatusCode.cs | 314 ++ .../Legacy/CaseInsinsitiveAscii.cs | 133 + .../Legacy/GlobalLog.cs | 689 +++++ .../Legacy/HttpListenerContext.cs | 67 + .../Legacy/Internal.cs | 286 ++ .../Legacy/Logging.cs | 657 +++++ .../Legacy/LoggingObject.cs | 580 ++++ .../Legacy/SR.cs | 630 ++++ .../Legacy/ValidationHelper.cs | 74 + .../NTAuthentication.cs | 721 +++++ .../NativeInterop/AuthIdentity.cs | 40 + .../NativeInterop/ContextFlags.cs | 124 + .../NativeInterop/NativeSSPI.cs | 42 + .../NativeInterop/SSPIAuthType.cs | 302 ++ .../NativeInterop/SSPIHandle.cs | 36 + .../NativeInterop/SSPISessionCache.cs | 49 + .../NativeInterop/SSPIWrapper.cs | 462 +++ .../NativeInterop/SafeCloseHandle.cs | 51 + .../NativeInterop/SafeCredentialReference.cs | 69 + .../NativeInterop/SafeDeleteContext.cs | 707 +++++ .../NativeInterop/SafeFreeCertContext.cs | 35 + .../NativeInterop/SafeFreeContextBuffer.cs | 148 + .../SafeFreeContextBufferChannelBinding.cs | 89 + .../NativeInterop/SafeFreeCredentials.cs | 199 ++ .../NativeInterop/SafeLocalFree.cs | 48 + .../NativeInterop/SafeSspiAuthDataHandle.cs | 32 + .../NativeInterop/SchProtocols.cs | 41 + .../NativeInterop/SecSizes.cs | 44 + .../NativeInterop/SecurityBuffer.cs | 56 + .../NativeInterop/SecurityBufferDescriptor.cs | 33 + .../NativeInterop/SecurityPackageInfoClass.cs | 80 + .../NativeInterop/SslConnectionInfo.cs | 38 + .../NativeInterop/StreamSizes.cs | 43 + .../NativeInterop/UnsafeNativeMethods.cs | 222 ++ .../NegotiationInfoClass.cs | 70 + .../PrefixCollection.cs | 109 + .../PrefixEnumerator.cs | 51 + .../Properties/AssemblyInfo.cs | 42 + .../ServiceNameStore.cs | 368 +++ .../WindowsAuthMiddleware.cs | 1281 +++++++++ .../packages.config | 4 + .../project.json | 14 + .../AsyncAcceptContext.cs | 224 ++ .../AuthenticationManager.cs | 129 + .../AuthenticationTypes.cs | 19 + .../Constants.cs | 63 + .../CustomDictionary.xml | 10 + .../DictionaryExtensions.cs | 51 + .../GlobalSuppressions.cs | Bin 0 -> 1926 bytes .../Helpers.cs | 29 + .../LogHelper.cs | 82 + .../NativeInterop/AddressFamily.cs | 172 ++ .../NativeInterop/ComNetOS.cs | 26 + .../NativeInterop/ContextAttribute.cs | 56 + .../NativeInterop/HttpRequestQueueV2Handle.cs | 25 + .../NativeInterop/HttpServerSessionHandle.cs | 48 + .../NativeInterop/HttpSysRequestHeader.cs | 54 + .../NativeInterop/HttpSysResponseHeader.cs | 43 + .../NativeInterop/HttpSysSettings.cs | 125 + .../NativeInterop/IntPtrHelper.cs | 23 + .../NativeInterop/NclUtilities.cs | 25 + .../NativeInterop/SSPIHandle.cs | 34 + .../NativeInterop/SafeLoadLibrary.cs | 42 + .../NativeInterop/SafeLocalFree.cs | 46 + .../SafeLocalFreeChannelBinding.cs | 42 + .../NativeInterop/SafeLocalMemHandle.cs | 30 + .../NativeInterop/SafeNativeOverlapped.cs | 68 + .../NativeInterop/SchProtocols.cs | 51 + .../NativeInterop/SecurityStatus.cs | 54 + .../NativeInterop/SocketAddress.cs | 342 +++ .../NativeInterop/UnsafeNativeMethods.cs | 1129 ++++++++ .../OwinServerFactory.cs | 108 + .../OwinWebListener.cs | 1107 +++++++ .../Prefix.cs | 102 + .../Properties/AssemblyInfo.cs | 44 + .../PumpLimits.cs | 31 + .../RequestProcessing/BoundaryType.cs | 18 + .../CallEnvironment.Generated.cs | 1913 +++++++++++++ .../CallEnvironment.Generated.tt | 316 ++ .../RequestProcessing/CallEnvironment.cs | 165 ++ .../RequestProcessing/ClientCertLoader.cs | 347 +++ .../RequestProcessing/EntitySendFormat.cs | 17 + .../RequestProcessing/HeaderEncoding.cs | 97 + .../RequestProcessing/HttpKnownHeaderNames.cs | 75 + .../RequestProcessing/HttpReasonPhrase.cs | 109 + .../RequestProcessing/HttpStatusCode.cs | 314 ++ .../RequestProcessing/NativeRequestContext.cs | 170 ++ .../RequestProcessing/NilEnvDictionary.cs | 115 + .../RequestProcessing/OpaqueStream.cs | 168 ++ .../RequestProcessing/Request.cs | 456 +++ .../RequestProcessing/RequestContext.cs | 388 +++ .../RequestHeaders.Generated.cs | 2544 +++++++++++++++++ .../RequestHeaders.Generated.tt | 218 ++ .../RequestProcessing/RequestHeaders.cs | 167 ++ .../RequestProcessing/RequestStream.cs | 614 ++++ .../RequestProcessing/RequestUriBuilder.cs | 569 ++++ .../RequestProcessing/Response.cs | 756 +++++ .../RequestProcessing/ResponseStream.cs | 826 ++++++ .../ResponseStreamAsyncResult.cs | 441 +++ .../RequestProcessing/SslStatus.cs | 15 + .../Resources.Designer.cs | 153 + .../Resources.resx | 150 + .../TimeoutManager.cs | 267 ++ .../ValidationHelper.cs | 67 + .../WebListenerException.cs | 44 + .../CriticalHandleZeroOrMinusOneIsInvalid.cs | 32 + .../SafeHandleZeroOrMinusOneIsInvalid.cs | 31 + .../System/ComponentModel/Win32Exception.cs | 112 + .../fx/System/Diagnostics/TraceEventType.cs | 35 + .../fx/System/ExternDll.cs | 16 + .../InteropServices/ExternalException.cs | 98 + .../fx/System/SafeNativeMethods.cs | 28 + .../ExtendedProtection/ChannelBinding.cs | 33 + .../project.json | 8 + src/Microsoft.AspNet.WebSockets/Constants.cs | 21 + .../HttpKnownHeaderNames.cs | 76 + .../Legacy/HttpListenerContext.cs | 58 + .../Legacy/HttpListenerRequest.cs | 66 + src/Microsoft.AspNet.WebSockets/Legacy/SR.cs | 66 + .../WebSocketHttpListenerDuplexStream.cs | 1262 ++++++++ .../NativeInterop/SafeLoadLibrary.cs | 40 + .../NativeInterop/SafeNativeOverlapped.cs | 80 + .../NativeInterop/SafeWebSocketHandle.cs | 32 + .../NativeInterop/UnsafeNativeMethods.cs | 842 ++++++ .../OwinWebSocketWrapper.cs | 137 + .../Properties/AssemblyInfo.cs | 36 + .../ServerWebSocket.cs | 62 + src/Microsoft.AspNet.WebSockets/WebSocket.cs | 124 + .../WebSocketBase.cs | 2481 ++++++++++++++++ .../WebSocketBuffer.cs | 698 +++++ .../WebSocketCloseStatus.cs | 40 + .../WebSocketError.cs | 22 + .../WebSocketException.cs | 155 + .../WebSocketExtensions.cs | 14 + .../WebSocketHelpers.cs | 522 ++++ .../WebSocketMessageType.cs | 15 + .../WebSocketMiddleware.cs | 167 ++ .../WebSocketReceiveResult.cs | 55 + .../WebSocketState.cs | 19 + src/Microsoft.AspNet.WebSockets/build.cmd | 3 + .../SafeHandleZeroOrMinusOneIsInvalid.cs | 31 + .../fx/System/AccessViolationException.cs | 16 + .../System/ComponentModel/Win32Exception.cs | 112 + .../fx/System/ExternDll.cs | 16 + .../InteropServices/ExternalException.cs | 98 + .../fx/System/SafeNativeMethods.cs | 28 + .../fx/System/SystemException.cs | 16 + .../packages.config | 5 + src/Microsoft.AspNet.WebSockets/project.json | 17 + .../DenyAnonymous.cs | 47 + .../DictionaryExtensions.cs | 51 + .../DigestTests.cs | 244 ++ .../NegotiateTests.cs | 261 ++ .../PassThroughTests.cs | 64 + .../Properties/AssemblyInfo.cs | 42 + .../packages.config | 5 + .../project.json | 17 + .../AuthenticationTests.cs | 136 + .../DictionaryExtensions.cs | 31 + .../HttpsTests.cs | 191 ++ .../OpaqueUpgradeTests.cs | 324 +++ .../Properties/AssemblyInfo.cs | 42 + .../RequestBodyTests.cs | 176 ++ .../RequestHeaderTests.cs | 120 + .../RequestTests.cs | 185 ++ .../ResponseBodyTests.cs | 213 ++ .../ResponseHeaderTests.cs | 243 ++ .../ResponseSendFileTests.cs | 327 +++ .../ResponseTests.cs | 142 + .../ServerTests.cs | 287 ++ .../project.json | 16 + 198 files changed, 37324 insertions(+) create mode 100644 .gitattributes create mode 100644 .gitignore create mode 100644 KatanaInternal.sln create mode 100644 NuGet.Config create mode 100644 build.cmd create mode 100644 global.json create mode 100644 makefile.shade create mode 100644 samples/HelloWorld/Program.cs create mode 100644 samples/HelloWorld/project.json create mode 100644 samples/SelfHostServer/App.config create mode 100644 samples/SelfHostServer/Program.cs create mode 100644 samples/SelfHostServer/Properties/AssemblyInfo.cs create mode 100644 samples/SelfHostServer/Public/1kb.txt create mode 100644 samples/SelfHostServer/packages.config create mode 100644 samples/SelfHostServer/project.json create mode 100644 samples/TestClient/App.config create mode 100644 samples/TestClient/Program.cs create mode 100644 samples/TestClient/Properties/AssemblyInfo.cs create mode 100644 samples/TestClient/TestClient.csproj create mode 100644 src/Microsoft.AspNet.Security.Windows/AuthTypes.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/ComNetOS.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/Constants.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/DictionaryExtensions.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/DigestCache.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/DisconnectAsyncResult.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/HeaderEncoding.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/HttpKnownHeaderNames.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/HttpStatusCode.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/Legacy/CaseInsinsitiveAscii.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/Legacy/GlobalLog.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/Legacy/HttpListenerContext.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/Legacy/Internal.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/Legacy/Logging.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/Legacy/LoggingObject.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/Legacy/SR.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/Legacy/ValidationHelper.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NTAuthentication.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/AuthIdentity.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/ContextFlags.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/NativeSSPI.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIAuthType.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIHandle.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPISessionCache.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIWrapper.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCloseHandle.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCredentialReference.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeDeleteContext.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCertContext.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBuffer.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBufferChannelBinding.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCredentials.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeLocalFree.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeSspiAuthDataHandle.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SchProtocols.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SecSizes.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBuffer.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBufferDescriptor.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityPackageInfoClass.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/SslConnectionInfo.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/StreamSizes.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NativeInterop/UnsafeNativeMethods.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/NegotiationInfoClass.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/PrefixCollection.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/PrefixEnumerator.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/Properties/AssemblyInfo.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/ServiceNameStore.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/WindowsAuthMiddleware.cs create mode 100644 src/Microsoft.AspNet.Security.Windows/packages.config create mode 100644 src/Microsoft.AspNet.Security.Windows/project.json create mode 100644 src/Microsoft.AspNet.Server.WebListener/AsyncAcceptContext.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/AuthenticationManager.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/AuthenticationTypes.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/Constants.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/CustomDictionary.xml create mode 100644 src/Microsoft.AspNet.Server.WebListener/DictionaryExtensions.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/GlobalSuppressions.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/Helpers.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/LogHelper.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/AddressFamily.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/ComNetOS.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/ContextAttribute.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpRequestQueueV2Handle.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpServerSessionHandle.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysRequestHeader.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysResponseHeader.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysSettings.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/IntPtrHelper.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/NclUtilities.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/SSPIHandle.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLoadLibrary.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFree.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFreeChannelBinding.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalMemHandle.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeNativeOverlapped.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/SchProtocols.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/SecurityStatus.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/SocketAddress.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/NativeInterop/UnsafeNativeMethods.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/OwinServerFactory.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/OwinWebListener.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/Prefix.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/Properties/AssemblyInfo.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/PumpLimits.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/BoundaryType.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.tt create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ClientCertLoader.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/EntitySendFormat.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HeaderEncoding.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpKnownHeaderNames.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpReasonPhrase.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpStatusCode.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NativeRequestContext.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NilEnvDictionary.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/OpaqueStream.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Request.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestContext.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.tt create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestStream.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestUriBuilder.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Response.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStream.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStreamAsyncResult.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/RequestProcessing/SslStatus.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/Resources.Designer.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/Resources.resx create mode 100644 src/Microsoft.AspNet.Server.WebListener/TimeoutManager.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/ValidationHelper.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/WebListenerException.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/CriticalHandleZeroOrMinusOneIsInvalid.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/fx/System/ComponentModel/Win32Exception.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/fx/System/Diagnostics/TraceEventType.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/fx/System/ExternDll.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/fx/System/Runtime/InteropServices/ExternalException.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/fx/System/SafeNativeMethods.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/fx/System/Security/Authentication/ExtendedProtection/ChannelBinding.cs create mode 100644 src/Microsoft.AspNet.Server.WebListener/project.json create mode 100644 src/Microsoft.AspNet.WebSockets/Constants.cs create mode 100644 src/Microsoft.AspNet.WebSockets/HttpKnownHeaderNames.cs create mode 100644 src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerContext.cs create mode 100644 src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerRequest.cs create mode 100644 src/Microsoft.AspNet.WebSockets/Legacy/SR.cs create mode 100644 src/Microsoft.AspNet.WebSockets/Legacy/WebSocketHttpListenerDuplexStream.cs create mode 100644 src/Microsoft.AspNet.WebSockets/NativeInterop/SafeLoadLibrary.cs create mode 100644 src/Microsoft.AspNet.WebSockets/NativeInterop/SafeNativeOverlapped.cs create mode 100644 src/Microsoft.AspNet.WebSockets/NativeInterop/SafeWebSocketHandle.cs create mode 100644 src/Microsoft.AspNet.WebSockets/NativeInterop/UnsafeNativeMethods.cs create mode 100644 src/Microsoft.AspNet.WebSockets/OwinWebSocketWrapper.cs create mode 100644 src/Microsoft.AspNet.WebSockets/Properties/AssemblyInfo.cs create mode 100644 src/Microsoft.AspNet.WebSockets/ServerWebSocket.cs create mode 100644 src/Microsoft.AspNet.WebSockets/WebSocket.cs create mode 100644 src/Microsoft.AspNet.WebSockets/WebSocketBase.cs create mode 100644 src/Microsoft.AspNet.WebSockets/WebSocketBuffer.cs create mode 100644 src/Microsoft.AspNet.WebSockets/WebSocketCloseStatus.cs create mode 100644 src/Microsoft.AspNet.WebSockets/WebSocketError.cs create mode 100644 src/Microsoft.AspNet.WebSockets/WebSocketException.cs create mode 100644 src/Microsoft.AspNet.WebSockets/WebSocketExtensions.cs create mode 100644 src/Microsoft.AspNet.WebSockets/WebSocketHelpers.cs create mode 100644 src/Microsoft.AspNet.WebSockets/WebSocketMessageType.cs create mode 100644 src/Microsoft.AspNet.WebSockets/WebSocketMiddleware.cs create mode 100644 src/Microsoft.AspNet.WebSockets/WebSocketReceiveResult.cs create mode 100644 src/Microsoft.AspNet.WebSockets/WebSocketState.cs create mode 100644 src/Microsoft.AspNet.WebSockets/build.cmd create mode 100644 src/Microsoft.AspNet.WebSockets/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs create mode 100644 src/Microsoft.AspNet.WebSockets/fx/System/AccessViolationException.cs create mode 100644 src/Microsoft.AspNet.WebSockets/fx/System/ComponentModel/Win32Exception.cs create mode 100644 src/Microsoft.AspNet.WebSockets/fx/System/ExternDll.cs create mode 100644 src/Microsoft.AspNet.WebSockets/fx/System/Runtime/InteropServices/ExternalException.cs create mode 100644 src/Microsoft.AspNet.WebSockets/fx/System/SafeNativeMethods.cs create mode 100644 src/Microsoft.AspNet.WebSockets/fx/System/SystemException.cs create mode 100644 src/Microsoft.AspNet.WebSockets/packages.config create mode 100644 src/Microsoft.AspNet.WebSockets/project.json create mode 100644 test/Microsoft.AspNet.Security.Windows.Test/DenyAnonymous.cs create mode 100644 test/Microsoft.AspNet.Security.Windows.Test/DictionaryExtensions.cs create mode 100644 test/Microsoft.AspNet.Security.Windows.Test/DigestTests.cs create mode 100644 test/Microsoft.AspNet.Security.Windows.Test/NegotiateTests.cs create mode 100644 test/Microsoft.AspNet.Security.Windows.Test/PassThroughTests.cs create mode 100644 test/Microsoft.AspNet.Security.Windows.Test/Properties/AssemblyInfo.cs create mode 100644 test/Microsoft.AspNet.Security.Windows.Test/packages.config create mode 100644 test/Microsoft.AspNet.Security.Windows.Test/project.json create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/AuthenticationTests.cs create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/DictionaryExtensions.cs create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/HttpsTests.cs create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/OpaqueUpgradeTests.cs create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/Properties/AssemblyInfo.cs create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/RequestBodyTests.cs create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/RequestHeaderTests.cs create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/RequestTests.cs create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/ResponseBodyTests.cs create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/ResponseHeaderTests.cs create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/ResponseSendFileTests.cs create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/ResponseTests.cs create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/ServerTests.cs create mode 100644 test/Microsoft.AspNet.Server.WebListener.Test/project.json diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000..bdaa5ba982 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,50 @@ +*.doc diff=astextplain +*.DOC diff=astextplain +*.docx diff=astextplain +*.DOCX diff=astextplain +*.dot diff=astextplain +*.DOT diff=astextplain +*.pdf diff=astextplain +*.PDF diff=astextplain +*.rtf diff=astextplain +*.RTF diff=astextplain + +*.jpg binary +*.png binary +*.gif binary + +*.cs text=auto diff=csharp +*.vb text=auto +*.resx text=auto +*.c text=auto +*.cpp text=auto +*.cxx text=auto +*.h text=auto +*.hxx text=auto +*.py text=auto +*.rb text=auto +*.java text=auto +*.html text=auto +*.htm text=auto +*.css text=auto +*.scss text=auto +*.sass text=auto +*.less text=auto +*.js text=auto +*.lisp text=auto +*.clj text=auto +*.sql text=auto +*.php text=auto +*.lua text=auto +*.m text=auto +*.asm text=auto +*.erl text=auto +*.fs text=auto +*.fsx text=auto +*.hs text=auto + +*.csproj text=auto +*.vbproj text=auto +*.fsproj text=auto +*.dbproj text=auto +*.sln text=auto eol=crlf diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..2554a1fc23 --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +[Oo]bj/ +[Bb]in/ +TestResults/ +.nuget/ +_ReSharper.*/ +packages/ +artifacts/ +PublishProfiles/ +*.user +*.suo +*.cache +*.docstates +_ReSharper.* +nuget.exe +*net45.csproj +*k10.csproj +*.psess +*.vsp +*.pidb +*.userprefs +*DS_Store +*.ncrunchsolution diff --git a/KatanaInternal.sln b/KatanaInternal.sln new file mode 100644 index 0000000000..9f18ac98c5 --- /dev/null +++ b/KatanaInternal.sln @@ -0,0 +1,94 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio 2013 +VisualStudioVersion = 12.0.30110.0 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TestClient", "samples\TestClient\TestClient.csproj", "{8B828433-B333-4C19-96AE-00BFFF9D8841}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNet.Server.WebListener.k10", "src\Microsoft.AspNet.Server.WebListener\Microsoft.AspNet.Server.WebListener.k10.csproj", "{6D9D3023-3ED7-4C95-80F0-347843ABD759}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNet.Server.WebListener.net45", "src\Microsoft.AspNet.Server.WebListener\Microsoft.AspNet.Server.WebListener.net45.csproj", "{253B9134-B6EB-4E59-8725-D983FD941A21}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNet.WebSockets.net45", "src\Microsoft.AspNet.WebSockets\Microsoft.AspNet.WebSockets.net45.csproj", "{00C6A882-1FE2-4769-901C-023D8DC175C4}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{99D5E5F3-88F5-4CCF-8D8C-717C8925DF09}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{E183C826-1360-4DFF-9994-F33CED5C8525}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "samples", "samples", "{3A1E31E3-2794-4CA3-B8E2-253E96BDE514}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HelloWorld.net45", "samples\HelloWorld\HelloWorld.net45.csproj", "{BF335732-BB09-49A1-8676-F074047E7DB2}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SelfHostServer.net45", "samples\SelfHostServer\SelfHostServer.net45.csproj", "{96C67B2F-9913-4E8D-B2E8-969BE66B71B6}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNet.Server.WebListener.Test.net45", "test\Microsoft.AspNet.Server.WebListener.Test\Microsoft.AspNet.Server.WebListener.Test.net45.csproj", "{485DAC59-A1F1-4D47-98BF-B448C994E05B}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "HelloWorld.k10", "samples\HelloWorld\HelloWorld.k10.csproj", "{A1F2CA12-3F08-4DE2-B3D9-52DBE267936B}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNet.Security.Windows.net45", "src\Microsoft.AspNet.Security.Windows\Microsoft.AspNet.Security.Windows.net45.csproj", "{8B4EF749-251D-4222-AD18-DE5A1E7D321A}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNet.Security.Windows.Test.net45", "test\Microsoft.AspNet.Security.Windows.Test\Microsoft.AspNet.Security.Windows.Test.net45.csproj", "{3EC418D5-C8FD-47AA-BFED-F524358EC3DD}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {8B828433-B333-4C19-96AE-00BFFF9D8841}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8B828433-B333-4C19-96AE-00BFFF9D8841}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8B828433-B333-4C19-96AE-00BFFF9D8841}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8B828433-B333-4C19-96AE-00BFFF9D8841}.Release|Any CPU.Build.0 = Release|Any CPU + {6D9D3023-3ED7-4C95-80F0-347843ABD759}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {6D9D3023-3ED7-4C95-80F0-347843ABD759}.Debug|Any CPU.Build.0 = Debug|Any CPU + {6D9D3023-3ED7-4C95-80F0-347843ABD759}.Release|Any CPU.ActiveCfg = Release|Any CPU + {6D9D3023-3ED7-4C95-80F0-347843ABD759}.Release|Any CPU.Build.0 = Release|Any CPU + {253B9134-B6EB-4E59-8725-D983FD941A21}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {253B9134-B6EB-4E59-8725-D983FD941A21}.Debug|Any CPU.Build.0 = Debug|Any CPU + {253B9134-B6EB-4E59-8725-D983FD941A21}.Release|Any CPU.ActiveCfg = Release|Any CPU + {253B9134-B6EB-4E59-8725-D983FD941A21}.Release|Any CPU.Build.0 = Release|Any CPU + {00C6A882-1FE2-4769-901C-023D8DC175C4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {00C6A882-1FE2-4769-901C-023D8DC175C4}.Debug|Any CPU.Build.0 = Debug|Any CPU + {00C6A882-1FE2-4769-901C-023D8DC175C4}.Release|Any CPU.ActiveCfg = Release|Any CPU + {00C6A882-1FE2-4769-901C-023D8DC175C4}.Release|Any CPU.Build.0 = Release|Any CPU + {BF335732-BB09-49A1-8676-F074047E7DB2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {BF335732-BB09-49A1-8676-F074047E7DB2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {BF335732-BB09-49A1-8676-F074047E7DB2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {BF335732-BB09-49A1-8676-F074047E7DB2}.Release|Any CPU.Build.0 = Release|Any CPU + {96C67B2F-9913-4E8D-B2E8-969BE66B71B6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {96C67B2F-9913-4E8D-B2E8-969BE66B71B6}.Debug|Any CPU.Build.0 = Debug|Any CPU + {96C67B2F-9913-4E8D-B2E8-969BE66B71B6}.Release|Any CPU.ActiveCfg = Release|Any CPU + {96C67B2F-9913-4E8D-B2E8-969BE66B71B6}.Release|Any CPU.Build.0 = Release|Any CPU + {485DAC59-A1F1-4D47-98BF-B448C994E05B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {485DAC59-A1F1-4D47-98BF-B448C994E05B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {485DAC59-A1F1-4D47-98BF-B448C994E05B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {485DAC59-A1F1-4D47-98BF-B448C994E05B}.Release|Any CPU.Build.0 = Release|Any CPU + {A1F2CA12-3F08-4DE2-B3D9-52DBE267936B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A1F2CA12-3F08-4DE2-B3D9-52DBE267936B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A1F2CA12-3F08-4DE2-B3D9-52DBE267936B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A1F2CA12-3F08-4DE2-B3D9-52DBE267936B}.Release|Any CPU.Build.0 = Release|Any CPU + {8B4EF749-251D-4222-AD18-DE5A1E7D321A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8B4EF749-251D-4222-AD18-DE5A1E7D321A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8B4EF749-251D-4222-AD18-DE5A1E7D321A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8B4EF749-251D-4222-AD18-DE5A1E7D321A}.Release|Any CPU.Build.0 = Release|Any CPU + {3EC418D5-C8FD-47AA-BFED-F524358EC3DD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3EC418D5-C8FD-47AA-BFED-F524358EC3DD}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3EC418D5-C8FD-47AA-BFED-F524358EC3DD}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3EC418D5-C8FD-47AA-BFED-F524358EC3DD}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {485DAC59-A1F1-4D47-98BF-B448C994E05B} = {E183C826-1360-4DFF-9994-F33CED5C8525} + {3EC418D5-C8FD-47AA-BFED-F524358EC3DD} = {E183C826-1360-4DFF-9994-F33CED5C8525} + {8B828433-B333-4C19-96AE-00BFFF9D8841} = {3A1E31E3-2794-4CA3-B8E2-253E96BDE514} + {BF335732-BB09-49A1-8676-F074047E7DB2} = {3A1E31E3-2794-4CA3-B8E2-253E96BDE514} + {96C67B2F-9913-4E8D-B2E8-969BE66B71B6} = {3A1E31E3-2794-4CA3-B8E2-253E96BDE514} + {A1F2CA12-3F08-4DE2-B3D9-52DBE267936B} = {3A1E31E3-2794-4CA3-B8E2-253E96BDE514} + {6D9D3023-3ED7-4C95-80F0-347843ABD759} = {99D5E5F3-88F5-4CCF-8D8C-717C8925DF09} + {253B9134-B6EB-4E59-8725-D983FD941A21} = {99D5E5F3-88F5-4CCF-8D8C-717C8925DF09} + {00C6A882-1FE2-4769-901C-023D8DC175C4} = {99D5E5F3-88F5-4CCF-8D8C-717C8925DF09} + {8B4EF749-251D-4222-AD18-DE5A1E7D321A} = {99D5E5F3-88F5-4CCF-8D8C-717C8925DF09} + EndGlobalSection +EndGlobal diff --git a/NuGet.Config b/NuGet.Config new file mode 100644 index 0000000000..ab583b0ff7 --- /dev/null +++ b/NuGet.Config @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/build.cmd b/build.cmd new file mode 100644 index 0000000000..7045ee1f84 --- /dev/null +++ b/build.cmd @@ -0,0 +1,23 @@ +@echo off +cd %~dp0 + +SETLOCAL +SET CACHED_NUGET=%LocalAppData%\NuGet\NuGet.exe + +IF EXIST %CACHED_NUGET% goto copynuget +echo Downloading latest version of NuGet.exe... +IF NOT EXIST %LocalAppData%\NuGet md %LocalAppData%\NuGet +@powershell -NoProfile -ExecutionPolicy unrestricted -Command "$ProgressPreference = 'SilentlyContinue'; Invoke-WebRequest 'https://www.nuget.org/nuget.exe' -OutFile '%CACHED_NUGET%'" + +:copynuget +IF EXIST .nuget\nuget.exe goto restore +md .nuget +copy %CACHED_NUGET% .nuget\nuget.exe > nul + +:restore +IF EXIST packages\KoreBuild goto run +.nuget\NuGet.exe install KoreBuild -ExcludeVersion -o packages -nocache -pre +.nuget\NuGet.exe install Sake -version 0.2 -o packages -ExcludeVersion + +:run +packages\Sake\tools\Sake.exe -I packages\KoreBuild\build -f makefile.shade %* diff --git a/global.json b/global.json new file mode 100644 index 0000000000..840c36f6ad --- /dev/null +++ b/global.json @@ -0,0 +1,3 @@ +{ + "sources": ["src"] +} \ No newline at end of file diff --git a/makefile.shade b/makefile.shade new file mode 100644 index 0000000000..6357ea2841 --- /dev/null +++ b/makefile.shade @@ -0,0 +1,7 @@ + +var VERSION='0.1' +var FULL_VERSION='0.1' +var AUTHORS='Microsoft' + +use-standard-lifecycle +k-standard-goals diff --git a/samples/HelloWorld/Program.cs b/samples/HelloWorld/Program.cs new file mode 100644 index 0000000000..5e176ff7ee --- /dev/null +++ b/samples/HelloWorld/Program.cs @@ -0,0 +1,55 @@ + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Text; +using System.Threading.Tasks; +using Microsoft.AspNet.Server.WebListener; + +using AppFunc = System.Func, System.Threading.Tasks.Task>; + +public class Program +{ + public static void Main(string[] args) + { + using (CreateServer(new AppFunc(HelloWorldApp))) + { + Console.WriteLine("Running, press enter to exit..."); + Console.ReadLine(); + } + } + + private static IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + public static Task HelloWorldApp(IDictionary environment) + { + string responseText = "Hello World"; + byte[] responseBytes = Encoding.UTF8.GetBytes(responseText); + + // See http://owin.org/spec/owin-1.0.0.html for standard environment keys. + Stream responseStream = (Stream)environment["owin.ResponseBody"]; + IDictionary responseHeaders = + (IDictionary)environment["owin.ResponseHeaders"]; + + responseHeaders["Content-Length"] = new string[] { responseBytes.Length.ToString(CultureInfo.InvariantCulture) }; + responseHeaders["Content-Type"] = new string[] { "text/plain" }; + + return responseStream.WriteAsync(responseBytes, 0, responseBytes.Length); + } +} diff --git a/samples/HelloWorld/project.json b/samples/HelloWorld/project.json new file mode 100644 index 0000000000..63b0cd1595 --- /dev/null +++ b/samples/HelloWorld/project.json @@ -0,0 +1,10 @@ +{ + "version" : "0.1-alpha-*", + "dependencies": { + "Microsoft.AspNet.Server.WebListener" : "" + }, + "configurations": { + "net45": { }, + "k10" : { } + } +} diff --git a/samples/SelfHostServer/App.config b/samples/SelfHostServer/App.config new file mode 100644 index 0000000000..73c9881fb6 --- /dev/null +++ b/samples/SelfHostServer/App.config @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/samples/SelfHostServer/Program.cs b/samples/SelfHostServer/Program.cs new file mode 100644 index 0000000000..f01efdd682 --- /dev/null +++ b/samples/SelfHostServer/Program.cs @@ -0,0 +1,138 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Owin; +using Microsoft.AspNet.Server.WebListener; +using Microsoft.Owin.Hosting; +using Owin; + +namespace SelfHostServer +{ + // http://owin.org/extensions/owin-WebSocket-Extension-v0.4.0.htm + using WebSocketAccept = Action, // options + Func, Task>>; // callback + using WebSocketCloseAsync = + Func; + using WebSocketReceiveAsync = + Func /* data */, + CancellationToken /* cancel */, + Task>>; + using WebSocketReceiveResult = Tuple; // count + using WebSocketSendAsync = + Func /* data */, + int /* messageType */, + bool /* endOfMessage */, + CancellationToken /* cancel */, + Task>; + + public class Program + { + private static byte[] Data = new byte[1024]; + + public static void Main(string[] args) + { + using (WebApp.Start(new StartOptions( + // "http://localhost:5000/" + "https://localhost:9090/" + ) + { + ServerFactory = "Microsoft.AspNet.Server.WebListener" + })) + { + Console.WriteLine("Running, press any key to exit"); + // System.Diagnostics.Process.Start("http://localhost:5000/"); + Console.ReadKey(); + } + } + + public void Configuration(IAppBuilder app) + { + OwinWebListener listener = (OwinWebListener)app.Properties["Microsoft.AspNet.Server.WebListener.OwinWebListener"]; + listener.AuthenticationManager.AuthenticationTypes = + AuthenticationType.Basic | + AuthenticationType.Digest | + AuthenticationType.Negotiate | + AuthenticationType.Ntlm | + AuthenticationType.Kerberos; + + app.Use((context, next) => + { + Console.WriteLine("Request: " + context.Request.Uri); + return next(); + }); + app.Use((context, next) => + { + if (context.Request.User == null) + { + context.Response.StatusCode = 401; + return Task.FromResult(0); + } + else + { + Console.WriteLine(context.Request.User.Identity.AuthenticationType); + } + return next(); + }); + app.UseWebSockets(); + app.Use(UpgradeToWebSockets); + app.Run(Invoke); + } + + public Task Invoke(IOwinContext context) + { + context.Response.ContentLength = Data.Length; + return context.Response.WriteAsync(Data); + } + + // Run once per request + private Task UpgradeToWebSockets(IOwinContext context, Func next) + { + WebSocketAccept accept = context.Get("websocket.Accept"); + if (accept == null) + { + // Not a websocket request + return next(); + } + + accept(null, WebSocketEcho); + + return Task.FromResult(null); + } + + private async Task WebSocketEcho(IDictionary websocketContext) + { + var sendAsync = (WebSocketSendAsync)websocketContext["websocket.SendAsync"]; + var receiveAsync = (WebSocketReceiveAsync)websocketContext["websocket.ReceiveAsync"]; + var closeAsync = (WebSocketCloseAsync)websocketContext["websocket.CloseAsync"]; + var callCancelled = (CancellationToken)websocketContext["websocket.CallCancelled"]; + + byte[] buffer = new byte[1024]; + WebSocketReceiveResult received = await receiveAsync(new ArraySegment(buffer), callCancelled); + + object status; + while (!websocketContext.TryGetValue("websocket.ClientCloseStatus", out status) || (int)status == 0) + { + // Echo anything we receive + await sendAsync(new ArraySegment(buffer, 0, received.Item3), received.Item1, received.Item2, callCancelled); + + received = await receiveAsync(new ArraySegment(buffer), callCancelled); + } + + await closeAsync((int)websocketContext["websocket.ClientCloseStatus"], (string)websocketContext["websocket.ClientCloseDescription"], callCancelled); + } + } +} diff --git a/samples/SelfHostServer/Properties/AssemblyInfo.cs b/samples/SelfHostServer/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..86b415a1b3 --- /dev/null +++ b/samples/SelfHostServer/Properties/AssemblyInfo.cs @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("SelfHostServer")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("SelfHostServer")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("ec50ddb4-9ec6-4cbd-96ac-15de948247cc")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] diff --git a/samples/SelfHostServer/Public/1kb.txt b/samples/SelfHostServer/Public/1kb.txt new file mode 100644 index 0000000000..1d43866603 --- /dev/null +++ b/samples/SelfHostServer/Public/1kb.txt @@ -0,0 +1 @@ +asdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfqweruoiasdfnsngdfioenrglknsgilhasdgha;gu;agnaknusgnjkadfgnknjksdfk asdhfhasdf nklasdgnasg njagnjasdfasdfasdfasdfasdfasd \ No newline at end of file diff --git a/samples/SelfHostServer/packages.config b/samples/SelfHostServer/packages.config new file mode 100644 index 0000000000..b452be1749 --- /dev/null +++ b/samples/SelfHostServer/packages.config @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/samples/SelfHostServer/project.json b/samples/SelfHostServer/project.json new file mode 100644 index 0000000000..ed203d2aeb --- /dev/null +++ b/samples/SelfHostServer/project.json @@ -0,0 +1,19 @@ +{ + "version" : "0.1-alpha-*", + "dependencies": { + "Microsoft.AspNet.Server.WebListener" : "", + "Microsoft.AspNet.WebSockets" : "" + }, + "configurations": { + "net45": { + "dependencies": { + "Owin": "1.0", + "Microsoft.Owin": "2.1.0", + "Microsoft.Owin.Diagnostics": "2.1.0", + "Microsoft.Owin.Hosting": "2.1.0", + "Microsoft.Owin.Host.HttpListener": "2.1.0", + "Microsoft.AspNet.AppBuilderSupport": "0.1-alpha-*" + } + } + } +} diff --git a/samples/TestClient/App.config b/samples/TestClient/App.config new file mode 100644 index 0000000000..8e15646352 --- /dev/null +++ b/samples/TestClient/App.config @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/samples/TestClient/Program.cs b/samples/TestClient/Program.cs new file mode 100644 index 0000000000..b1f93759aa --- /dev/null +++ b/samples/TestClient/Program.cs @@ -0,0 +1,78 @@ +using System; +using System.Net; +using System.Net.Http; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace TestClient +{ + public class Program + { + private const string Address = + // "http://localhost:5000/public/1kb.txt"; + "https://localhost:9090/public/1kb.txt"; + + public static void Main(string[] args) + { + WebRequestHandler handler = new WebRequestHandler(); + handler.ServerCertificateValidationCallback = (_, __, ___, ____) => true; + // handler.UseDefaultCredentials = true; + handler.Credentials = new NetworkCredential(@"redmond\chrross", "passwird"); + HttpClient client = new HttpClient(handler); + + /* + int completionCount = 0; + int itterations = 30000; + for (int i = 0; i < itterations; i++) + { + client.GetAsync(Address) + .ContinueWith(t => Interlocked.Increment(ref completionCount)); + } + + while (completionCount < itterations) + { + Thread.Sleep(10); + }*/ + + while (true) + { + Console.WriteLine("Press any key to send request"); + Console.ReadKey(); + var result = client.GetAsync(Address).Result; + Console.WriteLine(result); + } + + // RunWebSocketClient().Wait(); + Console.WriteLine("Done"); + Console.ReadKey(); + } + + public static async Task RunWebSocketClient() + { + ClientWebSocket websocket = new ClientWebSocket(); + + string url = "ws://localhost:5000/"; + Console.WriteLine("Connecting to: " + url); + await websocket.ConnectAsync(new Uri(url), CancellationToken.None); + + string message = "Hello World"; + Console.WriteLine("Sending message: " + message); + byte[] messageBytes = Encoding.UTF8.GetBytes(message); + await websocket.SendAsync(new ArraySegment(messageBytes), WebSocketMessageType.Text, true, CancellationToken.None); + + byte[] incomingData = new byte[1024]; + WebSocketReceiveResult result = await websocket.ReceiveAsync(new ArraySegment(incomingData), CancellationToken.None); + + if (result.CloseStatus.HasValue) + { + Console.WriteLine("Closed; Status: " + result.CloseStatus + ", " + result.CloseStatusDescription); + } + else + { + Console.WriteLine("Received message: " + Encoding.UTF8.GetString(incomingData, 0, result.Count)); + } + } + } +} diff --git a/samples/TestClient/Properties/AssemblyInfo.cs b/samples/TestClient/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..6919b6d3ce --- /dev/null +++ b/samples/TestClient/Properties/AssemblyInfo.cs @@ -0,0 +1,36 @@ +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("TestClient")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("TestClient")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("8db62eb3-48c0-4049-b33e-271c738140a0")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] diff --git a/samples/TestClient/TestClient.csproj b/samples/TestClient/TestClient.csproj new file mode 100644 index 0000000000..eb38a21555 --- /dev/null +++ b/samples/TestClient/TestClient.csproj @@ -0,0 +1,62 @@ + + + + + Debug + AnyCPU + {8B828433-B333-4C19-96AE-00BFFF9D8841} + Exe + Properties + TestClient + TestClient + v4.5 + 512 + ..\..\ + true + + + AnyCPU + true + full + false + bin\Debug\ + DEBUG;TRACE + prompt + 4 + + + AnyCPU + pdbonly + true + bin\Release\ + TRACE + prompt + 4 + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/Microsoft.AspNet.Security.Windows/AuthTypes.cs b/src/Microsoft.AspNet.Security.Windows/AuthTypes.cs new file mode 100644 index 0000000000..90a063b5fa --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/AuthTypes.cs @@ -0,0 +1,37 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; + +namespace Microsoft.AspNet.Security.Windows +{ + /// + /// Types of Windows Authentication supported. + /// + [Flags] + public enum AuthTypes + { + /// + /// Default + /// + None = 0, + + /// + /// Digest authentication using Windows credentials + /// + Digest = 1, + + /// + /// Negotiates Kerberos or NTLM + /// + Negotiate = 2, + + /// + /// NTLM Windows authentication + /// + Ntlm = 4, + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/ComNetOS.cs b/src/Microsoft.AspNet.Security.Windows/ComNetOS.cs new file mode 100644 index 0000000000..ca5902dbd6 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/ComNetOS.cs @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Diagnostics; +using System.Runtime.Versioning; +using System.Security.Permissions; + +namespace Microsoft.AspNet.Security.Windows +{ + internal static class ComNetOS + { + // Minimum support for Windows 2008 is assumed. + internal static readonly bool IsWin7orLater; // Is Windows 7 or later + internal static readonly bool IsWin8orLater; // Is Windows 8 or later + + // We use it safe so assert + [EnvironmentPermission(SecurityAction.Assert, Unrestricted = true)] + [ResourceExposure(ResourceScope.None)] + [ResourceConsumption(ResourceScope.AppDomain, ResourceScope.AppDomain)] + static ComNetOS() + { + OperatingSystem operatingSystem = Environment.OSVersion; + + GlobalLog.Print("ComNetOS::.ctor(): " + operatingSystem.ToString()); + + Debug.Assert(operatingSystem.Platform != PlatformID.Win32Windows, "Windows 9x is not supported"); + + var Win7Version = new Version(6, 1); + var Win8Version = new Version(6, 2); + IsWin7orLater = (operatingSystem.Version >= Win7Version); + IsWin8orLater = (operatingSystem.Version >= Win8Version); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Constants.cs b/src/Microsoft.AspNet.Security.Windows/Constants.cs new file mode 100644 index 0000000000..28e47e83fc --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Constants.cs @@ -0,0 +1,44 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + internal static class Constants + { + internal const string VersionKey = "owin.Version"; + internal const string OwinVersion = "1.0"; + internal const string CallCancelledKey = "owin.CallCancelled"; + + internal const string RequestBodyKey = "owin.RequestBody"; + internal const string RequestHeadersKey = "owin.RequestHeaders"; + internal const string RequestSchemeKey = "owin.RequestScheme"; + internal const string RequestMethodKey = "owin.RequestMethod"; + internal const string RequestPathBaseKey = "owin.RequestPathBase"; + internal const string RequestPathKey = "owin.RequestPath"; + internal const string RequestQueryStringKey = "owin.RequestQueryString"; + internal const string HttpRequestProtocolKey = "owin.RequestProtocol"; + + internal const string HttpResponseProtocolKey = "owin.ResponseProtocol"; + internal const string ResponseStatusCodeKey = "owin.ResponseStatusCode"; + internal const string ResponseReasonPhraseKey = "owin.ResponseReasonPhrase"; + internal const string ResponseHeadersKey = "owin.ResponseHeaders"; + internal const string ResponseBodyKey = "owin.ResponseBody"; + + internal const string ClientCertifiateKey = "ssl.ClientCertificate"; + internal const string SslSpnKey = "ssl.Spn"; + internal const string SslChannelBindingKey = "ssl.ChannelBinding"; + + internal const string RemoteIpAddressKey = "server.RemoteIpAddress"; + internal const string RemotePortKey = "server.RemotePort"; + internal const string LocalIpAddressKey = "server.LocalIpAddress"; + internal const string LocalPortKey = "server.LocalPort"; + internal const string IsLocalKey = "server.IsLocal"; + internal const string ServerOnSendingHeadersKey = "server.OnSendingHeaders"; + internal const string ServerUserKey = "server.User"; + internal const string ServerConnectionIdKey = "server.ConnectionId"; + internal const string ServerConnectionDisconnectKey = "server.ConnectionDisconnect"; + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/DictionaryExtensions.cs b/src/Microsoft.AspNet.Security.Windows/DictionaryExtensions.cs new file mode 100644 index 0000000000..2d1b86ad04 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/DictionaryExtensions.cs @@ -0,0 +1,67 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Linq; +using System.Text; + +namespace System.Collections.Generic +{ + internal static class DictionaryExtensions + { + internal static void Append(this IDictionary dictionary, string key, string value) + { + string[] orriginalValues; + if (dictionary.TryGetValue(key, out orriginalValues)) + { + string[] newValues = new string[orriginalValues.Length + 1]; + orriginalValues.CopyTo(newValues, 0); + newValues[newValues.Length - 1] = value; + dictionary[key] = newValues; + } + else + { + dictionary[key] = new string[] { value }; + } + } + + internal static void Append(this IDictionary dictionary, string key, IList values) + { + string[] orriginalValues; + if (dictionary.TryGetValue(key, out orriginalValues)) + { + string[] newValues = new string[orriginalValues.Length + values.Count]; + orriginalValues.CopyTo(newValues, 0); + values.CopyTo(newValues, orriginalValues.Length); + dictionary[key] = newValues; + } + else + { + dictionary[key] = values.ToArray(); + } + } + + internal static string Get(this IDictionary dictionary, string key) + { + string[] values; + if (dictionary.TryGetValue(key, out values)) + { + return string.Join(", ", values); + } + return null; + } + + internal static T Get(this IDictionary dictionary, string key, T fallback = default(T)) + { + object values; + if (dictionary.TryGetValue(key, out values)) + { + return (T)values; + } + return fallback; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/DigestCache.cs b/src/Microsoft.AspNet.Security.Windows/DigestCache.cs new file mode 100644 index 0000000000..0de3d84475 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/DigestCache.cs @@ -0,0 +1,153 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections; +using System.Threading; + +namespace Microsoft.AspNet.Security.Windows +{ + // Saves generated digest challenges so that they are still valid when the authenticated request arrives. + internal class DigestCache : IDisposable + { + private const int DigestLifetimeSeconds = 300; + private const int MaximumDigests = 1024; // Must be a power of two. + private const int MinimumDigestLifetimeSeconds = 10; + + private DigestContext[] _savedDigests; + private ArrayList _extraSavedDigests; + private ArrayList _extraSavedDigestsBaking; + private int _extraSavedDigestsTimestamp; + private int _newestContext; + private int _oldestContext; + + internal DigestCache() + { + } + + internal void SaveDigestContext(NTAuthentication digestContext) + { + if (_savedDigests == null) + { + Interlocked.CompareExchange(ref _savedDigests, new DigestContext[MaximumDigests], null); + } + + // We want to actually close the contexts outside the lock. + NTAuthentication oldContext = null; + ArrayList digestsToClose = null; + lock (_savedDigests) + { + int now = ((now = Environment.TickCount) == 0 ? 1 : now); + + _newestContext = (_newestContext + 1) & (MaximumDigests - 1); + + int oldTimestamp = _savedDigests[_newestContext].timestamp; + oldContext = _savedDigests[_newestContext].context; + _savedDigests[_newestContext].timestamp = now; + _savedDigests[_newestContext].context = digestContext; + + // May need to move this up. + if (_oldestContext == _newestContext) + { + _oldestContext = (_newestContext + 1) & (MaximumDigests - 1); + } + + // Delete additional contexts older than five minutes. + while (unchecked(now - _savedDigests[_oldestContext].timestamp) >= DigestLifetimeSeconds && _savedDigests[_oldestContext].context != null) + { + if (digestsToClose == null) + { + digestsToClose = new ArrayList(); + } + digestsToClose.Add(_savedDigests[_oldestContext].context); + _savedDigests[_oldestContext].context = null; + _oldestContext = (_oldestContext + 1) & (MaximumDigests - 1); + } + + // If the old context is younger than 10 seconds, put it in the backup pile. + if (oldContext != null && unchecked(now - oldTimestamp) <= MinimumDigestLifetimeSeconds * 1000) + { + // Use a two-tier ArrayList system to guarantee each entry lives at least 10 seconds. + if (_extraSavedDigests == null || + unchecked(now - _extraSavedDigestsTimestamp) > MinimumDigestLifetimeSeconds * 1000) + { + digestsToClose = _extraSavedDigestsBaking; + _extraSavedDigestsBaking = _extraSavedDigests; + _extraSavedDigestsTimestamp = now; + _extraSavedDigests = new ArrayList(); + } + _extraSavedDigests.Add(oldContext); + oldContext = null; + } + } + + if (oldContext != null) + { + oldContext.CloseContext(); + } + if (digestsToClose != null) + { + for (int i = 0; i < digestsToClose.Count; i++) + { + ((NTAuthentication)digestsToClose[i]).CloseContext(); + } + } + } + + private void ClearDigestCache() + { + if (_savedDigests == null) + { + return; + } + + ArrayList[] toClose = new ArrayList[3]; + lock (_savedDigests) + { + toClose[0] = _extraSavedDigestsBaking; + _extraSavedDigestsBaking = null; + toClose[1] = _extraSavedDigests; + _extraSavedDigests = null; + + _newestContext = 0; + _oldestContext = 0; + + toClose[2] = new ArrayList(); + for (int i = 0; i < MaximumDigests; i++) + { + if (_savedDigests[i].context != null) + { + toClose[2].Add(_savedDigests[i].context); + _savedDigests[i].context = null; + } + _savedDigests[i].timestamp = 0; + } + } + + for (int j = 0; j < toClose.Length; j++) + { + if (toClose[j] != null) + { + for (int k = 0; k < toClose[j].Count; k++) + { + ((NTAuthentication)toClose[j][k]).CloseContext(); + } + } + } + } + + public void Dispose() + { + ClearDigestCache(); + } + + private struct DigestContext + { + internal NTAuthentication context; + internal int timestamp; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/DisconnectAsyncResult.cs b/src/Microsoft.AspNet.Security.Windows/DisconnectAsyncResult.cs new file mode 100644 index 0000000000..6064872060 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/DisconnectAsyncResult.cs @@ -0,0 +1,93 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Security.Principal; +using System.Threading; + +namespace Microsoft.AspNet.Security.Windows +{ + // Keeps NTLM/Negotiate auth contexts alive until the connection is broken. + internal class DisconnectAsyncResult + { + private const string NTLM = "NTLM"; + + private object _connectionId; + private WindowsAuthMiddleware _winAuth; + private CancellationTokenRegistration _disconnectRegistration; + + private WindowsPrincipal _authenticatedUser; + private NTAuthentication _session; + + internal DisconnectAsyncResult(WindowsAuthMiddleware winAuth, object connectionId, CancellationToken connectionDisconnect) + { + GlobalLog.Print("DisconnectAsyncResult#" + ValidationHelper.HashString(this) + "::.ctor() httpListener#" + ValidationHelper.HashString(winAuth) + " connectionId:" + connectionId); + _winAuth = winAuth; + _connectionId = connectionId; + _winAuth.DisconnectResults[_connectionId] = this; + + // Register with a connection specific CancellationToken. Without this notification, the contexts will leak indefinitely. + // Alternatively we could attempt some kind of LRU storage, but this will either have to be larger than your expected connection limit, + // or will fail at unexpected moments under stress. + try + { + _disconnectRegistration = connectionDisconnect.Register(HandleDisconnect); + } + catch (ObjectDisposedException) + { + _winAuth.DisconnectResults.Remove(_connectionId); + } + } + + internal WindowsPrincipal AuthenticatedUser + { + get + { + return _authenticatedUser; + } + set + { + // The previous value can't be disposed because it may be in use by the app. + _authenticatedUser = value; + } + } + + internal NTAuthentication Session + { + get + { + return _session; + } + set + { + _session = value; + } + } + + private void HandleDisconnect() + { + GlobalLog.Print("DisconnectAsyncResult#" + ValidationHelper.HashString(this) + "::HandleDisconnect() DisconnectResults#" + ValidationHelper.HashString(_winAuth.DisconnectResults) + " removing for m_ConnectionId:" + _connectionId); + _winAuth.DisconnectResults.Remove(_connectionId); + if (_session != null) + { + _session.CloseContext(); + } + + // Clean up the identity. This is for scenarios where identity was not cleaned up before due to + // identity caching for unsafe ntlm authentication + + IDisposable identity = _authenticatedUser == null ? null : _authenticatedUser.Identity as IDisposable; + if ((identity != null) && + (NTLM.Equals(_authenticatedUser.Identity.AuthenticationType, StringComparison.OrdinalIgnoreCase)) && + (_winAuth.UnsafeConnectionNtlmAuthentication)) + { + identity.Dispose(); + } + + _disconnectRegistration.Dispose(); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/HeaderEncoding.cs b/src/Microsoft.AspNet.Security.Windows/HeaderEncoding.cs new file mode 100644 index 0000000000..7c41535d5e --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/HeaderEncoding.cs @@ -0,0 +1,139 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Text; + +namespace Microsoft.AspNet.Security.Windows +{ + // we use this static class as a helper class to encode/decode HTTP headers. + // what we need is a 1-1 correspondence between a char in the range U+0000-U+00FF + // and a byte in the range 0x00-0xFF (which is the range that can hit the network). + // The Latin-1 encoding (ISO-88591-1) (GetEncoding(28591)) works for byte[] to string, but is a little slow. + // It doesn't work for string -> byte[] because of best-fit-mapping problems. + internal static class HeaderEncoding + { + internal static unsafe string GetString(byte[] bytes, int byteIndex, int byteCount) + { + fixed (byte* pBytes = bytes) + return GetString(pBytes + byteIndex, byteCount); + } + + internal static unsafe string GetString(byte* pBytes, int byteCount) + { + if (byteCount < 1) + { + return string.Empty; + } + + string s = new String('\0', byteCount); + + fixed (char* pStr = s) + { + char* pString = pStr; + while (byteCount >= 8) + { + pString[0] = (char)pBytes[0]; + pString[1] = (char)pBytes[1]; + pString[2] = (char)pBytes[2]; + pString[3] = (char)pBytes[3]; + pString[4] = (char)pBytes[4]; + pString[5] = (char)pBytes[5]; + pString[6] = (char)pBytes[6]; + pString[7] = (char)pBytes[7]; + pString += 8; + pBytes += 8; + byteCount -= 8; + } + for (int i = 0; i < byteCount; i++) + { + pString[i] = (char)pBytes[i]; + } + } + + return s; + } + + internal static int GetByteCount(string myString) + { + return myString.Length; + } + internal static unsafe void GetBytes(string myString, int charIndex, int charCount, byte[] bytes, int byteIndex) + { + if (myString.Length == 0) + { + return; + } + fixed (byte* bufferPointer = bytes) + { + byte* newBufferPointer = bufferPointer + byteIndex; + int finalIndex = charIndex + charCount; + while (charIndex < finalIndex) + { + *newBufferPointer++ = (byte)myString[charIndex++]; + } + } + } + internal static unsafe byte[] GetBytes(string myString) + { + byte[] bytes = new byte[myString.Length]; + if (myString.Length != 0) + { + GetBytes(myString, 0, myString.Length, bytes, 0); + } + return bytes; + } + + // The normal client header parser just casts bytes to chars (see GetString). + // Check if those bytes were actually utf-8 instead of ASCII. + // If not, just return the input value. + + internal static string DecodeUtf8FromString(string input) + { + if (string.IsNullOrWhiteSpace(input)) + { + return input; + } + + bool possibleUtf8 = false; + for (int i = 0; i < input.Length; i++) + { + if (input[i] > (char)255) + { + return input; // This couldn't have come from the wire, someone assigned it directly. + } + else if (input[i] > (char)127) + { + possibleUtf8 = true; + break; + } + } + if (possibleUtf8) + { + byte[] rawBytes = new byte[input.Length]; + for (int i = 0; i < input.Length; i++) + { + if (input[i] > (char)255) + { + return input; // This couldn't have come from the wire, someone assigned it directly. + } + rawBytes[i] = (byte)input[i]; + } + try + { + // We don't want '?' replacement characters, just fail. + Encoding decoder = Encoding.GetEncoding("utf-8", EncoderFallback.ExceptionFallback, + DecoderFallback.ExceptionFallback); + return decoder.GetString(rawBytes); + } + catch (ArgumentException) + { + } // Not actually Utf-8 + } + return input; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/HttpKnownHeaderNames.cs b/src/Microsoft.AspNet.Security.Windows/HttpKnownHeaderNames.cs new file mode 100644 index 0000000000..29ab395d03 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/HttpKnownHeaderNames.cs @@ -0,0 +1,75 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + internal static class HttpKnownHeaderNames + { + public const string CacheControl = "Cache-Control"; + public const string Connection = "Connection"; + public const string Date = "Date"; + public const string KeepAlive = "Keep-Alive"; + public const string Pragma = "Pragma"; + public const string ProxyConnection = "Proxy-Connection"; + public const string Trailer = "Trailer"; + public const string TransferEncoding = "Transfer-Encoding"; + public const string Upgrade = "Upgrade"; + public const string Via = "Via"; + public const string Warning = "Warning"; + public const string ContentLength = "Content-Length"; + public const string ContentType = "Content-Type"; + public const string ContentDisposition = "Content-Disposition"; + public const string ContentEncoding = "Content-Encoding"; + public const string ContentLanguage = "Content-Language"; + public const string ContentLocation = "Content-Location"; + public const string ContentRange = "Content-Range"; + public const string Expires = "Expires"; + public const string LastModified = "Last-Modified"; + public const string Age = "Age"; + public const string Location = "Location"; + public const string ProxyAuthenticate = "Proxy-Authenticate"; + public const string RetryAfter = "Retry-After"; + public const string Server = "Server"; + public const string SetCookie = "Set-Cookie"; + public const string SetCookie2 = "Set-Cookie2"; + public const string Vary = "Vary"; + public const string WWWAuthenticate = "WWW-Authenticate"; + public const string Accept = "Accept"; + public const string AcceptCharset = "Accept-Charset"; + public const string AcceptEncoding = "Accept-Encoding"; + public const string AcceptLanguage = "Accept-Language"; + public const string Authorization = "Authorization"; + public const string Cookie = "Cookie"; + public const string Cookie2 = "Cookie2"; + public const string Expect = "Expect"; + public const string From = "From"; + public const string Host = "Host"; + public const string IfMatch = "If-Match"; + public const string IfModifiedSince = "If-Modified-Since"; + public const string IfNoneMatch = "If-None-Match"; + public const string IfRange = "If-Range"; + public const string IfUnmodifiedSince = "If-Unmodified-Since"; + public const string MaxForwards = "Max-Forwards"; + public const string ProxyAuthorization = "Proxy-Authorization"; + public const string Referer = "Referer"; + public const string Range = "Range"; + public const string UserAgent = "User-Agent"; + public const string ContentMD5 = "Content-MD5"; + public const string ETag = "ETag"; + public const string TE = "TE"; + public const string Allow = "Allow"; + public const string AcceptRanges = "Accept-Ranges"; + public const string P3P = "P3P"; + public const string XPoweredBy = "X-Powered-By"; + public const string XAspNetVersion = "X-AspNet-Version"; + public const string SecWebSocketKey = "Sec-WebSocket-Key"; + public const string SecWebSocketExtensions = "Sec-WebSocket-Extensions"; + public const string SecWebSocketAccept = "Sec-WebSocket-Accept"; + public const string Origin = "Origin"; + public const string SecWebSocketProtocol = "Sec-WebSocket-Protocol"; + public const string SecWebSocketVersion = "Sec-WebSocket-Version"; + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/HttpStatusCode.cs b/src/Microsoft.AspNet.Security.Windows/HttpStatusCode.cs new file mode 100644 index 0000000000..d60d640f2d --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/HttpStatusCode.cs @@ -0,0 +1,314 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Security.Windows +{ + // Redirect Status code numbers that need to be defined. + + /// + /// Contains the values of status + /// codes defined for the HTTP protocol. + /// + // UEUE : Any int can be cast to a HttpStatusCode to allow checking for non http1.1 codes. + internal enum HttpStatusCode + { + // Informational 1xx + + /// + /// [To be supplied.] + /// + Continue = 100, + + /// + /// [To be supplied.] + /// + SwitchingProtocols = 101, + + // Successful 2xx + + /// + /// [To be supplied.] + /// + OK = 200, + + /// + /// [To be supplied.] + /// + Created = 201, + + /// + /// [To be supplied.] + /// + Accepted = 202, + + /// + /// [To be supplied.] + /// + NonAuthoritativeInformation = 203, + + /// + /// [To be supplied.] + /// + NoContent = 204, + + /// + /// [To be supplied.] + /// + ResetContent = 205, + + /// + /// [To be supplied.] + /// + PartialContent = 206, + + // Redirection 3xx + + /// + /// [To be supplied.] + /// + MultipleChoices = 300, + + /// + /// [To be supplied.] + /// + Ambiguous = 300, + + /// + /// [To be supplied.] + /// + MovedPermanently = 301, + + /// + /// [To be supplied.] + /// + Moved = 301, + + /// + /// [To be supplied.] + /// + Found = 302, + + /// + /// [To be supplied.] + /// + Redirect = 302, + + /// + /// [To be supplied.] + /// + SeeOther = 303, + + /// + /// [To be supplied.] + /// + RedirectMethod = 303, + + /// + /// [To be supplied.] + /// + NotModified = 304, + + /// + /// [To be supplied.] + /// + UseProxy = 305, + + /// + /// [To be supplied.] + /// + Unused = 306, + + /// + /// [To be supplied.] + /// + TemporaryRedirect = 307, + + /// + /// [To be supplied.] + /// + RedirectKeepVerb = 307, + + // Client Error 4xx + + /// + /// [To be supplied.] + /// + BadRequest = 400, + + /// + /// [To be supplied.] + /// + Unauthorized = 401, + + /// + /// [To be supplied.] + /// + PaymentRequired = 402, + + /// + /// [To be supplied.] + /// + Forbidden = 403, + + /// + /// [To be supplied.] + /// + NotFound = 404, + + /// + /// [To be supplied.] + /// + MethodNotAllowed = 405, + + /// + /// [To be supplied.] + /// + NotAcceptable = 406, + + /// + /// [To be supplied.] + /// + ProxyAuthenticationRequired = 407, + + /// + /// [To be supplied.] + /// + RequestTimeout = 408, + + /// + /// [To be supplied.] + /// + Conflict = 409, + + /// + /// [To be supplied.] + /// + Gone = 410, + + /// + /// [To be supplied.] + /// + LengthRequired = 411, + + /// + /// [To be supplied.] + /// + PreconditionFailed = 412, + + /// + /// [To be supplied.] + /// + RequestEntityTooLarge = 413, + + /// + /// [To be supplied.] + /// + RequestUriTooLong = 414, + + /// + /// [To be supplied.] + /// + UnsupportedMediaType = 415, + + /// + /// [To be supplied.] + /// + RequestedRangeNotSatisfiable = 416, + + /// + /// [To be supplied.] + /// + ExpectationFailed = 417, + + UpgradeRequired = 426, + + // Server Error 5xx + + /// + /// [To be supplied.] + /// + InternalServerError = 500, + + /// + /// [To be supplied.] + /// + NotImplemented = 501, + + /// + /// [To be supplied.] + /// + BadGateway = 502, + + /// + /// [To be supplied.] + /// + ServiceUnavailable = 503, + + /// + /// [To be supplied.] + /// + GatewayTimeout = 504, + + /// + /// [To be supplied.] + /// + HttpVersionNotSupported = 505, + } // enum HttpStatusCode + +/* +Fielding, et al. Standards Track [Page 3] + +RFC 2616 HTTP/1.1 June 1999 + + + 10.1 Informational 1xx ...........................................57 + 10.1.1 100 Continue .............................................58 + 10.1.2 101 Switching Protocols ..................................58 + 10.2 Successful 2xx ..............................................58 + 10.2.1 200 OK ...................................................58 + 10.2.2 201 Created ..............................................59 + 10.2.3 202 Accepted .............................................59 + 10.2.4 203 Non-Authoritative Information ........................59 + 10.2.5 204 No Content ...........................................60 + 10.2.6 205 Reset Content ........................................60 + 10.2.7 206 Partial Content ......................................60 + 10.3 Redirection 3xx .............................................61 + 10.3.1 300 Multiple Choices .....................................61 + 10.3.2 301 Moved Permanently ....................................62 + 10.3.3 302 Found ................................................62 + 10.3.4 303 See Other ............................................63 + 10.3.5 304 Not Modified .........................................63 + 10.3.6 305 Use Proxy ............................................64 + 10.3.7 306 (Unused) .............................................64 + 10.3.8 307 Temporary Redirect ...................................65 + 10.4 Client Error 4xx ............................................65 + 10.4.1 400 Bad Request .........................................65 + 10.4.2 401 Unauthorized ........................................66 + 10.4.3 402 Payment Required ....................................66 + 10.4.4 403 Forbidden ...........................................66 + 10.4.5 404 Not Found ...........................................66 + 10.4.6 405 Method Not Allowed ..................................66 + 10.4.7 406 Not Acceptable ......................................67 + 10.4.8 407 Proxy Authentication Required .......................67 + 10.4.9 408 Request Timeout .....................................67 + 10.4.10 409 Conflict ............................................67 + 10.4.11 410 Gone ................................................68 + 10.4.12 411 Length Required .....................................68 + 10.4.13 412 Precondition Failed .................................68 + 10.4.14 413 Request Entity Too Large ............................69 + 10.4.15 414 Request-URI Too Long ................................69 + 10.4.16 415 Unsupported Media Type ..............................69 + 10.4.17 416 Requested Range Not Satisfiable .....................69 + 10.4.18 417 Expectation Failed ..................................70 + 10.5 Server Error 5xx ............................................70 + 10.5.1 500 Internal Server Error ................................70 + 10.5.2 501 Not Implemented ......................................70 + 10.5.3 502 Bad Gateway ..........................................70 + 10.5.4 503 Service Unavailable ..................................70 + 10.5.5 504 Gateway Timeout ......................................71 + 10.5.6 505 HTTP Version Not Supported ...........................71 +*/ +} // namespace System.Net diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/CaseInsinsitiveAscii.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/CaseInsinsitiveAscii.cs new file mode 100644 index 0000000000..8818694ff8 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/CaseInsinsitiveAscii.cs @@ -0,0 +1,133 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Collections; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class CaseInsensitiveAscii : IEqualityComparer, IComparer + { + // ASCII char ToLower table + internal static readonly byte[] AsciiToLower = new byte[] + { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, + 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 97, 98, 99, 100, 101, // 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, + 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, // 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, // 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, // 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, + 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, + 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, + 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, + 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, + 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, + 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, + 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, + 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, + 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, + 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, + 250, 251, 252, 253, 254, 255 + }; + + // ASCII string case insensitive hash function + public int GetHashCode(object myObject) + { + string myString = myObject as string; + if (myObject == null) + { + return 0; + } + int myHashCode = myString.Length; + if (myHashCode == 0) + { + return 0; + } + myHashCode ^= AsciiToLower[(byte)myString[0]] << 24 ^ AsciiToLower[(byte)myString[myHashCode - 1]] << 16; + return myHashCode; + } + + // ASCII string case insensitive comparer + public int Compare(object firstObject, object secondObject) + { + string firstString = firstObject as string; + string secondString = secondObject as string; + if (firstString == null) + { + return secondString == null ? 0 : -1; + } + if (secondString == null) + { + return 1; + } + int result = firstString.Length - secondString.Length; + int comparisons = result > 0 ? secondString.Length : firstString.Length; + int difference, index = 0; + while (index < comparisons) + { + difference = (int)(AsciiToLower[firstString[index]] - AsciiToLower[secondString[index]]); + if (difference != 0) + { + result = difference; + break; + } + index++; + } + return result; + } + + // ASCII string case insensitive hash function + private int FastGetHashCode(string myString) + { + int myHashCode = myString.Length; + if (myHashCode != 0) + { + myHashCode ^= AsciiToLower[(byte)myString[0]] << 24 ^ AsciiToLower[(byte)myString[myHashCode - 1]] << 16; + } + return myHashCode; + } + + // ASCII string case insensitive comparer + public new bool Equals(object firstObject, object secondObject) + { + string firstString = firstObject as string; + string secondString = secondObject as string; + if (firstString == null) + { + return secondString == null; + } + if (secondString != null) + { + int index = firstString.Length; + if (index == secondString.Length) + { + if (FastGetHashCode(firstString) == FastGetHashCode(secondString)) + { + int comparisons = firstString.Length; + while (index > 0) + { + index--; + if (AsciiToLower[firstString[index]] != AsciiToLower[secondString[index]]) + { + return false; + } + } + return true; + } + } + } + return false; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/GlobalLog.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/GlobalLog.cs new file mode 100644 index 0000000000..259dc13382 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/GlobalLog.cs @@ -0,0 +1,689 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Configuration; + using System.Diagnostics; + using System.Globalization; + using System.Net; + using System.Runtime.ConstrainedExecution; + using System.Security.Permissions; + using System.Threading; + + /// + /// + /// + internal static class GlobalLog + { + // Logging Initalization - I need to disable Logging code, and limit + // the effect it has when it is dissabled, so I use a bool here. + // + // This can only be set when the logging code is built and enabled. + // By specifing the "CSC_DEFINES=/D:TRAVE" in the build environment, + // this code will be built and then checks against an enviroment variable + // and a BooleanSwitch to see if any of the two have enabled logging. + + private static BaseLoggingObject Logobject = GlobalLog.LoggingInitialize(); +#if TRAVE + internal static LocalDataStoreSlot s_ThreadIdSlot; + internal static bool s_UseThreadId; + internal static bool s_UseTimeSpan; + internal static bool s_DumpWebData; + internal static bool s_UsePerfCounter; + internal static bool s_DebugCallNesting; + internal static bool s_DumpToConsole; + internal static int s_MaxDumpSize; + internal static string s_RootDirectory; + + // + // Logging Config Variables - below are list of consts that can be used to config + // the logging, + // + + // Max number of lines written into a buffer, before a save is invoked + // s_DumpToConsole disables. + public const int MaxLinesBeforeSave = 0; + +#endif + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + private static BaseLoggingObject LoggingInitialize() + { +#if DEBUG + if (GetSwitchValue("SystemNetLogging", "System.Net logging module", false) && + GetSwitchValue("SystemNetLog_ConnectionMonitor", "System.Net connection monitor thread", false)) + { + InitConnectionMonitor(); + } +#endif // DEBUG +#if TRAVE + // by default we'll log to c:\temp\ so that non interactive services (like w3wp.exe) that don't have environment + // variables can easily be debugged, note that the ACLs of the directory might need to be adjusted + if (!GetSwitchValue("SystemNetLog_OverrideDefaults", "System.Net log override default settings", false)) { + s_ThreadIdSlot = Thread.AllocateDataSlot(); + s_UseThreadId = true; + s_UseTimeSpan = true; + s_DumpWebData = true; + s_MaxDumpSize = 256; + s_UsePerfCounter = true; + s_DebugCallNesting = true; + s_DumpToConsole = false; + s_RootDirectory = "C:\\Temp\\"; + return new LoggingObject(); + } + if (GetSwitchValue("SystemNetLogging", "System.Net logging module", false)) { + s_ThreadIdSlot = Thread.AllocateDataSlot(); + s_UseThreadId = GetSwitchValue("SystemNetLog_UseThreadId", "System.Net log display system thread id", false); + s_UseTimeSpan = GetSwitchValue("SystemNetLog_UseTimeSpan", "System.Net log display ticks as TimeSpan", false); + s_DumpWebData = GetSwitchValue("SystemNetLog_DumpWebData", "System.Net log display HTTP send/receive data", false); + s_MaxDumpSize = GetSwitchValue("SystemNetLog_MaxDumpSize", "System.Net log max size of display data", 256); + s_UsePerfCounter = GetSwitchValue("SystemNetLog_UsePerfCounter", "System.Net log use QueryPerformanceCounter() to get ticks ", false); + s_DebugCallNesting = GetSwitchValue("SystemNetLog_DebugCallNesting", "System.Net used to debug call nesting", false); + s_DumpToConsole = GetSwitchValue("SystemNetLog_DumpToConsole", "System.Net log to console", false); + s_RootDirectory = GetSwitchValue("SystemNetLog_RootDirectory", "System.Net root directory of log file", string.Empty); + return new LoggingObject(); + } +#endif // TRAVE + return new BaseLoggingObject(); + } + +#if TRAVE + private static string GetSwitchValue(string switchName, string switchDescription, string defaultValue) { + new EnvironmentPermission(PermissionState.Unrestricted).Assert(); + try { + defaultValue = Environment.GetEnvironmentVariable(switchName); + } + finally { + EnvironmentPermission.RevertAssert(); + } + return defaultValue; + } + + private static int GetSwitchValue(string switchName, string switchDescription, int defaultValue) { + IntegerSwitch theSwitch = new IntegerSwitch(switchName, switchDescription); + if (theSwitch.Enabled) { + return theSwitch.Value; + } + new EnvironmentPermission(PermissionState.Unrestricted).Assert(); + try { + string environmentVar = Environment.GetEnvironmentVariable(switchName); + if (environmentVar!=null) { + defaultValue = Int32.Parse(environmentVar.Trim(), CultureInfo.InvariantCulture); + } + } + finally { + EnvironmentPermission.RevertAssert(); + } + return defaultValue; + } + +#endif + +#if TRAVE || DEBUG + private static bool GetSwitchValue(string switchName, string switchDescription, bool defaultValue) + { + BooleanSwitch theSwitch = new BooleanSwitch(switchName, switchDescription); + new EnvironmentPermission(PermissionState.Unrestricted).Assert(); + try + { + if (theSwitch.Enabled) + { + return true; + } + string environmentVar = Environment.GetEnvironmentVariable(switchName); + defaultValue = environmentVar != null && environmentVar.Trim() == "1"; + } + catch (ConfigurationException) + { + } + finally + { + EnvironmentPermission.RevertAssert(); + } + return defaultValue; + } +#endif // TRAVE || DEBUG + + // Enables thread tracing, detects mis-use of threads. +#if DEBUG + [ThreadStatic] + private static Stack t_ThreadKindStack; + + private static Stack ThreadKindStack + { + get + { + if (t_ThreadKindStack == null) + { + t_ThreadKindStack = new Stack(); + } + return t_ThreadKindStack; + } + } +#endif + + internal static ThreadKinds CurrentThreadKind + { + get + { +#if DEBUG + return ThreadKindStack.Count > 0 ? ThreadKindStack.Peek() : ThreadKinds.Other; +#else + return ThreadKinds.Unknown; +#endif + } + } + + private static bool HasShutdownStarted + { + get + { + return Environment.HasShutdownStarted || AppDomain.CurrentDomain.IsFinalizingForUnload(); + } + } + +#if DEBUG + // ifdef'd instead of conditional since people are forced to handle the return value. + // [Conditional("DEBUG")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + internal static IDisposable SetThreadKind(ThreadKinds kind) + { + if ((kind & ThreadKinds.SourceMask) != ThreadKinds.Unknown) + { + throw new InvalidOperationException(); + } + + // Ignore during shutdown. + if (HasShutdownStarted) + { + return null; + } + + ThreadKinds threadKind = CurrentThreadKind; + ThreadKinds source = threadKind & ThreadKinds.SourceMask; + +#if TRAVE + // Special warnings when doing dangerous things on a thread. + if ((threadKind & ThreadKinds.User) != 0 && (kind & ThreadKinds.System) != 0) + { + Print("WARNING: Thread changed from User to System; user's thread shouldn't be hijacked."); + } + + if ((threadKind & ThreadKinds.Async) != 0 && (kind & ThreadKinds.Sync) != 0) + { + Print("WARNING: Thread changed from Async to Sync, may block an Async thread."); + } + else if ((threadKind & (ThreadKinds.Other | ThreadKinds.CompletionPort)) == 0 && (kind & ThreadKinds.Sync) != 0) + { + Print("WARNING: Thread from a limited resource changed to Sync, may deadlock or bottleneck."); + } +#endif + + ThreadKindStack.Push( + (((kind & ThreadKinds.OwnerMask) == 0 ? threadKind : kind) & ThreadKinds.OwnerMask) | + (((kind & ThreadKinds.SyncMask) == 0 ? threadKind : kind) & ThreadKinds.SyncMask) | + (kind & ~(ThreadKinds.OwnerMask | ThreadKinds.SyncMask)) | + source); + +#if TRAVE + if (CurrentThreadKind != threadKind) + { + Print("Thread becomes:(" + CurrentThreadKind.ToString() + ")"); + } +#endif + + return new ThreadKindFrame(); + } + + private class ThreadKindFrame : IDisposable + { + private int m_FrameNumber; + + internal ThreadKindFrame() + { + m_FrameNumber = ThreadKindStack.Count; + } + + void IDisposable.Dispose() + { + // Ignore during shutdown. + if (GlobalLog.HasShutdownStarted) + { + return; + } + + if (m_FrameNumber != ThreadKindStack.Count) + { + throw new InvalidOperationException(); + } + + ThreadKinds previous = ThreadKindStack.Pop(); + +#if TRAVE + if (CurrentThreadKind != previous) + { + Print("Thread reverts:(" + CurrentThreadKind.ToString() + ")"); + } +#endif + } + } +#endif + + [Conditional("DEBUG")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + internal static void SetThreadSource(ThreadKinds source) + { +#if DEBUG + if ((source & ThreadKinds.SourceMask) != source || source == ThreadKinds.Unknown) + { + throw new ArgumentException("Must specify the thread source.", "source"); + } + + if (ThreadKindStack.Count == 0) + { + ThreadKindStack.Push(source); + return; + } + + if (ThreadKindStack.Count > 1) + { + Print("WARNING: SetThreadSource must be called at the base of the stack, or the stack has been corrupted."); + while (ThreadKindStack.Count > 1) + { + ThreadKindStack.Pop(); + } + } + + if (ThreadKindStack.Peek() != source) + { + // SQL can fail to clean up the stack, leaving the default Other at the bottom. Replace it. + Print("WARNING: The stack has been corrupted."); + ThreadKinds last = ThreadKindStack.Pop() & ThreadKinds.SourceMask; + Assert(last == source || last == ThreadKinds.Other, "Thread source changed.|Was:({0}) Now:({1})", last, source); + ThreadKindStack.Push(source); + } +#endif + } + + [Conditional("DEBUG")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + internal static void ThreadContract(ThreadKinds kind, string errorMsg) + { + ThreadContract(kind, ThreadKinds.SafeSources, errorMsg); + } + + [Conditional("DEBUG")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + internal static void ThreadContract(ThreadKinds kind, ThreadKinds allowedSources, string errorMsg) + { + if ((kind & ThreadKinds.SourceMask) != ThreadKinds.Unknown || (allowedSources & ThreadKinds.SourceMask) != allowedSources) + { + throw new InvalidOperationException(); + } + + ThreadKinds threadKind = CurrentThreadKind; + Assert((threadKind & allowedSources) != 0, errorMsg, "Thread Contract Violation.|Expected source:({0}) Actual source:({1})", allowedSources, threadKind & ThreadKinds.SourceMask); + Assert((threadKind & kind) == kind, errorMsg, "Thread Contract Violation.|Expected kind:({0}) Actual kind:({1})", kind, threadKind & ~ThreadKinds.SourceMask); + } + +#if DEBUG + // Enables auto-hang detection, which will "snap" a log on hang + internal static bool EnableMonitorThread = false; + + // Default value for hang timer +#if FEATURE_PAL // ROTORTODO - after speedups (like real JIT and GC) remove this + public const int DefaultTickValue = 1000*60*5; // 5 minutes +#else + public const int DefaultTickValue = 1000 * 60; // 60 secs +#endif // FEATURE_PAL +#endif // DEBUG + + [System.Diagnostics.Conditional("TRAVE")] + public static void AddToArray(string msg) + { +#if TRAVE + GlobalLog.Logobject.PrintLine(msg); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Ignore(object msg) + { + } + + [System.Diagnostics.Conditional("TRAVE")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + public static void Print(string msg) + { +#if TRAVE + GlobalLog.Logobject.PrintLine(msg); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void PrintHex(string msg, object value) + { +#if TRAVE + GlobalLog.Logobject.PrintLine(msg+TraveHelper.ToHex(value)); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Enter(string func) + { +#if TRAVE + GlobalLog.Logobject.EnterFunc(func + "(*none*)"); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Enter(string func, string parms) + { +#if TRAVE + GlobalLog.Logobject.EnterFunc(func + "(" + parms + ")"); +#endif + } + + [Conditional("DEBUG")] + [Conditional("_FORCE_ASSERTS")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + public static void Assert(bool condition, string messageFormat, params object[] data) + { + if (!condition) + { + string fullMessage = string.Format(CultureInfo.InvariantCulture, messageFormat, data); + int pipeIndex = fullMessage.IndexOf('|'); + if (pipeIndex == -1) + { + Assert(fullMessage); + } + else + { + int detailLength = fullMessage.Length - pipeIndex - 1; + Assert(fullMessage.Substring(0, pipeIndex), detailLength > 0 ? fullMessage.Substring(pipeIndex + 1, detailLength) : null); + } + } + } + + [Conditional("DEBUG")] + [Conditional("_FORCE_ASSERTS")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + public static void Assert(string message) + { + Assert(message, null); + } + + [Conditional("DEBUG")] + [Conditional("_FORCE_ASSERTS")] + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + public static void Assert(string message, string detailMessage) + { + try + { + Print("Assert: " + message + (!string.IsNullOrEmpty(detailMessage) ? ": " + detailMessage : string.Empty)); + Print("*******"); + Logobject.DumpArray(false); + } + finally + { +#if DEBUG && !STRESS + Debug.Assert(false, message, detailMessage); +#endif + } + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void LeaveException(string func, Exception exception) + { +#if TRAVE + GlobalLog.Logobject.LeaveFunc(func + " exception " + ((exception!=null) ? exception.Message : String.Empty)); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Leave(string func) + { +#if TRAVE + GlobalLog.Logobject.LeaveFunc(func + " returns "); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Leave(string func, string result) + { +#if TRAVE + GlobalLog.Logobject.LeaveFunc(func + " returns " + result); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Leave(string func, int returnval) + { +#if TRAVE + GlobalLog.Logobject.LeaveFunc(func + " returns " + returnval.ToString()); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Leave(string func, bool returnval) + { +#if TRAVE + GlobalLog.Logobject.LeaveFunc(func + " returns " + returnval.ToString()); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void DumpArray() + { +#if TRAVE + GlobalLog.Logobject.DumpArray(true); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Dump(byte[] buffer) + { +#if TRAVE + Logobject.Dump(buffer, 0, buffer!=null ? buffer.Length : -1); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Dump(byte[] buffer, int length) + { +#if TRAVE + Logobject.Dump(buffer, 0, length); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Dump(byte[] buffer, int offset, int length) + { +#if TRAVE + Logobject.Dump(buffer, offset, length); +#endif + } + + [System.Diagnostics.Conditional("TRAVE")] + public static void Dump(IntPtr buffer, int offset, int length) + { +#if TRAVE + Logobject.Dump(buffer, offset, length); +#endif + } + +#if DEBUG + private class HttpWebRequestComparer : IComparer + { + public int Compare( + object x1, + object y1) + { + HttpWebRequest x = (HttpWebRequest)x1; + HttpWebRequest y = (HttpWebRequest)y1; + + if (x.GetHashCode() == y.GetHashCode()) + { + return 0; + } + else if (x.GetHashCode() < y.GetHashCode()) + { + return -1; + } + else if (x.GetHashCode() > y.GetHashCode()) + { + return 1; + } + + return 0; + } + } + /* + private class ConnectionMonitorEntry { + public HttpWebRequest m_Request; + public int m_Flags; + public DateTime m_TimeAdded; + public Connection m_Connection; + + public ConnectionMonitorEntry(HttpWebRequest request, Connection connection, int flags) { + m_Request = request; + m_Connection = connection; + m_Flags = flags; + m_TimeAdded = DateTime.Now; + } + } + */ + private static volatile ManualResetEvent s_ShutdownEvent; + private static volatile SortedList s_RequestList; + + internal const int WaitingForReadDoneFlag = 0x1; +#endif + /* +#if DEBUG + private static void ConnectionMonitor() { + while(! s_ShutdownEvent.WaitOne(DefaultTickValue, false)) { + if (GlobalLog.EnableMonitorThread) { +#if TRAVE + GlobalLog.Logobject.LoggingMonitorTick(); +#endif + } + + int hungCount = 0; + lock (s_RequestList) { + DateTime dateNow = DateTime.Now; + DateTime dateExpired = dateNow.AddSeconds(-DefaultTickValue); + foreach (ConnectionMonitorEntry monitorEntry in s_RequestList.GetValueList() ) { + if (monitorEntry != null && + (dateExpired > monitorEntry.m_TimeAdded)) + { + hungCount++; +#if TRAVE + GlobalLog.Print("delay:" + (dateNow - monitorEntry.m_TimeAdded).TotalSeconds + + " req#" + monitorEntry.m_Request.GetHashCode() + + " cnt#" + monitorEntry.m_Connection.GetHashCode() + + " flags:" + monitorEntry.m_Flags); + +#endif + monitorEntry.m_Connection.Debug(monitorEntry.m_Request.GetHashCode()); + } + } + } + Assert(hungCount == 0, "Warning: Hang Detected on Connection(s) of greater than {0} ms. {1} request(s) hung.|Please Dump System.Net.GlobalLog.s_RequestList for pending requests, make sure your streams are calling Close(), and that your destination server is up.", DefaultTickValue, hungCount); + } + } +#endif // DEBUG + **/ +#if DEBUG + [ReliabilityContract(Consistency.MayCorruptAppDomain, Cer.None)] + internal static void AppDomainUnloadEvent(object sender, EventArgs e) + { + s_ShutdownEvent.Set(); + } +#endif + +#if DEBUG + [System.Diagnostics.Conditional("DEBUG")] + private static void InitConnectionMonitor() + { + s_RequestList = new SortedList(new HttpWebRequestComparer(), 10); + s_ShutdownEvent = new ManualResetEvent(false); + AppDomain.CurrentDomain.DomainUnload += new EventHandler(AppDomainUnloadEvent); + AppDomain.CurrentDomain.ProcessExit += new EventHandler(AppDomainUnloadEvent); + // Thread threadMonitor = new Thread(new ThreadStart(ConnectionMonitor)); + // threadMonitor.IsBackground = true; + // threadMonitor.Start(); + } +#endif + /* + [System.Diagnostics.Conditional("DEBUG")] + internal static void DebugAddRequest(HttpWebRequest request, Connection connection, int flags) { +#if DEBUG + // null if the connection monitor is off + if(s_RequestList == null) + return; + + lock(s_RequestList) { + Assert(!s_RequestList.ContainsKey(request), "s_RequestList.ContainsKey(request)|A HttpWebRequest should not be submitted twice."); + + ConnectionMonitorEntry requestEntry = + new ConnectionMonitorEntry(request, connection, flags); + + try { + s_RequestList.Add(request, requestEntry); + } catch { + } + } +#endif + } +*/ + /* + [System.Diagnostics.Conditional("DEBUG")] + internal static void DebugRemoveRequest(HttpWebRequest request) { + #if DEBUG + // null if the connection monitor is off + if(s_RequestList == null) + return; + + lock(s_RequestList) { + Assert(s_RequestList.ContainsKey(request), "!s_RequestList.ContainsKey(request)|A HttpWebRequest should not be removed twice."); + + try { + s_RequestList.Remove(request); + } catch { + } + } + #endif + } + */ + /* + [System.Diagnostics.Conditional("DEBUG")] + internal static void DebugUpdateRequest(HttpWebRequest request, Connection connection, int flags) { +#if DEBUG + // null if the connection monitor is off + if(s_RequestList == null) + return; + + lock(s_RequestList) { + if(!s_RequestList.ContainsKey(request)) { + return; + } + + ConnectionMonitorEntry requestEntry = + new ConnectionMonitorEntry(request, connection, flags); + + try { + s_RequestList.Remove(request); + s_RequestList.Add(request, requestEntry); + } catch { + } + } +#endif + }*/ + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/HttpListenerContext.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/HttpListenerContext.cs new file mode 100644 index 0000000000..9b9335906f --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/HttpListenerContext.cs @@ -0,0 +1,67 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Security.Principal; + +namespace Microsoft.AspNet.Security.Windows +{ + // TODO: At what point does a user need to be cleaned up? + internal sealed class HttpListenerContext + { + private WindowsAuthMiddleware _winAuth; + private IPrincipal _user = null; + + internal const string NTLM = "NTLM"; + + internal HttpListenerContext(WindowsAuthMiddleware httpListener) + { + _winAuth = httpListener; + } + + internal void Close() + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "Close()", string.Empty); + } + + IDisposable user = _user == null ? null : _user.Identity as IDisposable; + + // TODO: At what point does a user need to be cleaned up? + + // For unsafe connection ntlm auth we dont dispose this identity as yet since its cached + if ((user != null) && + (_user.Identity.AuthenticationType != NTLM) && + (!_winAuth.UnsafeConnectionNtlmAuthentication)) + { + user.Dispose(); + } + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "Close", string.Empty); + } + } + + internal void Abort() + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "Abort", string.Empty); + } + + IDisposable user = _user == null ? null : _user.Identity as IDisposable; + if (user != null) + { + user.Dispose(); + } + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "Abort", string.Empty); + } + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/Internal.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/Internal.cs new file mode 100644 index 0000000000..4d922591dd --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/Internal.cs @@ -0,0 +1,286 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Net.Security; +using System.Runtime.InteropServices; +using System.Runtime.Versioning; +using System.Security.Authentication.ExtendedProtection; +using System.Security.Cryptography.X509Certificates; +using System.Security.Permissions; + +namespace Microsoft.AspNet.Security.Windows +{ + internal enum SecurityStatus + { + // Success / Informational + OK = 0x00000000, + ContinueNeeded = unchecked((int)0x00090312), + CompleteNeeded = unchecked((int)0x00090313), + CompAndContinue = unchecked((int)0x00090314), + ContextExpired = unchecked((int)0x00090317), + CredentialsNeeded = unchecked((int)0x00090320), + Renegotiate = unchecked((int)0x00090321), + + // Errors + OutOfMemory = unchecked((int)0x80090300), + InvalidHandle = unchecked((int)0x80090301), + Unsupported = unchecked((int)0x80090302), + TargetUnknown = unchecked((int)0x80090303), + InternalError = unchecked((int)0x80090304), + PackageNotFound = unchecked((int)0x80090305), + NotOwner = unchecked((int)0x80090306), + CannotInstall = unchecked((int)0x80090307), + InvalidToken = unchecked((int)0x80090308), + CannotPack = unchecked((int)0x80090309), + QopNotSupported = unchecked((int)0x8009030A), + NoImpersonation = unchecked((int)0x8009030B), + LogonDenied = unchecked((int)0x8009030C), + UnknownCredentials = unchecked((int)0x8009030D), + NoCredentials = unchecked((int)0x8009030E), + MessageAltered = unchecked((int)0x8009030F), + OutOfSequence = unchecked((int)0x80090310), + NoAuthenticatingAuthority = unchecked((int)0x80090311), + IncompleteMessage = unchecked((int)0x80090318), + IncompleteCredentials = unchecked((int)0x80090320), + BufferNotEnough = unchecked((int)0x80090321), + WrongPrincipal = unchecked((int)0x80090322), + TimeSkew = unchecked((int)0x80090324), + UntrustedRoot = unchecked((int)0x80090325), + IllegalMessage = unchecked((int)0x80090326), + CertUnknown = unchecked((int)0x80090327), + CertExpired = unchecked((int)0x80090328), + AlgorithmMismatch = unchecked((int)0x80090331), + SecurityQosFailed = unchecked((int)0x80090332), + SmartcardLogonRequired = unchecked((int)0x8009033E), + UnsupportedPreauth = unchecked((int)0x80090343), + BadBinding = unchecked((int)0x80090346) + } + + internal enum ContextAttribute + { + // look into and + Sizes = 0x00, + Names = 0x01, + Lifespan = 0x02, + DceInfo = 0x03, + StreamSizes = 0x04, + // KeyInfo = 0x05, must not be used, see ConnectionInfo instead + Authority = 0x06, + // SECPKG_ATTR_PROTO_INFO = 7, + // SECPKG_ATTR_PASSWORD_EXPIRY = 8, + // SECPKG_ATTR_SESSION_KEY = 9, + PackageInfo = 0x0A, + // SECPKG_ATTR_USER_FLAGS = 11, + NegotiationInfo = 0x0C, + // SECPKG_ATTR_NATIVE_NAMES = 13, + // SECPKG_ATTR_FLAGS = 14, + // SECPKG_ATTR_USE_VALIDATED = 15, + // SECPKG_ATTR_CREDENTIAL_NAME = 16, + // SECPKG_ATTR_TARGET_INFORMATION = 17, + // SECPKG_ATTR_ACCESS_TOKEN = 18, + // SECPKG_ATTR_TARGET = 19, + // SECPKG_ATTR_AUTHENTICATION_ID = 20, + UniqueBindings = 0x19, + EndpointBindings = 0x1A, + ClientSpecifiedSpn = 0x1B, // SECPKG_ATTR_CLIENT_SPECIFIED_TARGET = 27 + RemoteCertificate = 0x53, + LocalCertificate = 0x54, + RootStore = 0x55, + IssuerListInfoEx = 0x59, + ConnectionInfo = 0x5A, + // SECPKG_ATTR_EAP_KEY_BLOCK 0x5b // returns SecPkgContext_EapKeyBlock + // SECPKG_ATTR_MAPPED_CRED_ATTR 0x5c // returns SecPkgContext_MappedCredAttr + // SECPKG_ATTR_SESSION_INFO 0x5d // returns SecPkgContext_SessionInfo + // SECPKG_ATTR_APP_DATA 0x5e // sets/returns SecPkgContext_SessionAppData + // SECPKG_ATTR_REMOTE_CERTIFICATES 0x5F // returns SecPkgContext_Certificates + // SECPKG_ATTR_CLIENT_CERT_POLICY 0x60 // sets SecPkgCred_ClientCertCtlPolicy + // SECPKG_ATTR_CC_POLICY_RESULT 0x61 // returns SecPkgContext_ClientCertPolicyResult + // SECPKG_ATTR_USE_NCRYPT 0x62 // Sets the CRED_FLAG_USE_NCRYPT_PROVIDER FLAG on cred group + // SECPKG_ATTR_LOCAL_CERT_INFO 0x63 // returns SecPkgContext_CertInfo + // SECPKG_ATTR_CIPHER_INFO 0x64 // returns new CNG SecPkgContext_CipherInfo + // SECPKG_ATTR_EAP_PRF_INFO 0x65 // sets SecPkgContext_EapPrfInfo + // SECPKG_ATTR_SUPPORTED_SIGNATURES 0x66 // returns SecPkgContext_SupportedSignatures + // SECPKG_ATTR_REMOTE_CERT_CHAIN 0x67 // returns PCCERT_CONTEXT + UiInfo = 0x68, // sets SEcPkgContext_UiInfo + } + + internal enum Endianness + { + Network = 0x00, + Native = 0x10, + } + + internal enum CredentialUse + { + Inbound = 0x1, + Outbound = 0x2, + Both = 0x3, + } + + internal enum BufferType + { + Empty = 0x00, + Data = 0x01, + Token = 0x02, + Parameters = 0x03, + Missing = 0x04, + Extra = 0x05, + Trailer = 0x06, + Header = 0x07, + Padding = 0x09, // non-data padding + Stream = 0x0A, + ChannelBindings = 0x0E, + TargetHost = 0x10, + ReadOnlyFlag = unchecked((int)0x80000000), + ReadOnlyWithChecksum = 0x10000000 + } + + // SecPkgContext_IssuerListInfoEx + [StructLayout(LayoutKind.Sequential)] + internal struct IssuerListInfoEx + { + public SafeHandle aIssuers; + public uint cIssuers; + + public unsafe IssuerListInfoEx(SafeHandle handle, byte[] nativeBuffer) + { + aIssuers = handle; + fixed (byte* voidPtr = nativeBuffer) + { + // if this breaks on 64 bit, do the sizeof(IntPtr) trick + cIssuers = *((uint*)(voidPtr + IntPtr.Size)); + } + } + } + + [StructLayout(LayoutKind.Sequential)] + internal struct SecureCredential + { + /* + typedef struct _SCHANNEL_CRED + { + DWORD dwVersion; // always SCHANNEL_CRED_VERSION + DWORD cCreds; + PCCERT_CONTEXT *paCred; + HCERTSTORE hRootStore; + + DWORD cMappers; + struct _HMAPPER **aphMappers; + + DWORD cSupportedAlgs; + ALG_ID * palgSupportedAlgs; + + DWORD grbitEnabledProtocols; + DWORD dwMinimumCipherStrength; + DWORD dwMaximumCipherStrength; + DWORD dwSessionLifespan; + DWORD dwFlags; + DWORD reserved; + } SCHANNEL_CRED, *PSCHANNEL_CRED; + */ + + public const int CurrentVersion = 0x4; + + public int version; + public int cCreds; + + // ptr to an array of pointers + // There is a hack done with this field. AcquireCredentialsHandle requires an array of + // certificate handles; we only ever use one. In order to avoid pinning a one element array, + // we copy this value onto the stack, create a pointer on the stack to the copied value, + // and replace this field with the pointer, during the call to AcquireCredentialsHandle. + // Then we fix it up afterwards. Fine as long as all the SSPI credentials are not + // supposed to be threadsafe. + public IntPtr certContextArray; + + private readonly IntPtr rootStore; // == always null, OTHERWISE NOT RELIABLE + public int cMappers; + private readonly IntPtr phMappers; // == always null, OTHERWISE NOT RELIABLE + public int cSupportedAlgs; + private readonly IntPtr palgSupportedAlgs; // == always null, OTHERWISE NOT RELIABLE + public SchProtocols grbitEnabledProtocols; + public int dwMinimumCipherStrength; + public int dwMaximumCipherStrength; + public int dwSessionLifespan; + public SecureCredential.Flags dwFlags; + public int reserved; + + public SecureCredential(int version, X509Certificate certificate, SecureCredential.Flags flags, SchProtocols protocols, EncryptionPolicy policy) + { + // default values required for a struct + rootStore = phMappers = palgSupportedAlgs = certContextArray = IntPtr.Zero; + cCreds = cMappers = cSupportedAlgs = 0; + + if (policy == EncryptionPolicy.RequireEncryption) + { + // Prohibit null encryption cipher + dwMinimumCipherStrength = 0; + dwMaximumCipherStrength = 0; + } + else if (policy == EncryptionPolicy.AllowNoEncryption) + { + // Allow null encryption cipher in addition to other ciphers + dwMinimumCipherStrength = -1; + dwMaximumCipherStrength = 0; + } + else if (policy == EncryptionPolicy.NoEncryption) + { + // Suppress all encryption and require null encryption cipher only + dwMinimumCipherStrength = -1; + dwMaximumCipherStrength = -1; + } + else + { + throw new ArgumentException(SR.GetString(SR.net_invalid_enum, "EncryptionPolicy"), "policy"); + } + + dwSessionLifespan = reserved = 0; + this.version = version; + dwFlags = flags; + grbitEnabledProtocols = protocols; + if (certificate != null) + { + certContextArray = certificate.Handle; + cCreds = 1; + } + } + + [Flags] + public enum Flags + { + Zero = 0, + NoSystemMapper = 0x02, + NoNameCheck = 0x04, + ValidateManual = 0x08, + NoDefaultCred = 0x10, + ValidateAuto = 0x20 + } + } // SecureCredential + + [SuppressMessage("Microsoft.Design", "CA1049:TypesThatOwnNativeResourcesShouldBeDisposable", + Justification = "This structure does not own the native resource.")] + [StructLayout(LayoutKind.Sequential)] + internal unsafe struct SecurityBufferStruct + { + public int count; + public BufferType type; + public IntPtr token; + + public static readonly int Size = sizeof(SecurityBufferStruct); + } + + internal static class IntPtrHelper + { + internal static IntPtr Add(IntPtr a, int b) + { + return (IntPtr)((long)a + (long)b); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/Logging.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/Logging.cs new file mode 100644 index 0000000000..81e66a3c30 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/Logging.cs @@ -0,0 +1,657 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Runtime.InteropServices; +using System.Security; +using System.Threading; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class Logging + { + private const string AttributeNameMaxSize = "maxdatasize"; + private const string AttributeNameTraceMode = "tracemode"; + private const string AttributeValueProtocolOnly = "protocolonly"; + // private const string AttributeValueIncludeHex = "includehex"; + + private const int DefaultMaxDumpSize = 1024; + private const bool DefaultUseProtocolTextOnly = false; + + private const string TraceSourceWebName = "System.Net"; + private const string TraceSourceHttpListenerName = "System.Net.HttpListener"; + + private static readonly string[] SupportedAttributes = new string[] { AttributeNameMaxSize, AttributeNameTraceMode }; + + private static volatile bool s_LoggingEnabled = true; + private static volatile bool s_LoggingInitialized; + private static volatile bool s_AppDomainShutdown; + + private static TraceSource s_WebTraceSource; + private static TraceSource s_HttpListenerTraceSource; + + private static object s_InternalSyncObject; + + private Logging() + { + } + + private static object InternalSyncObject + { + get + { + if (s_InternalSyncObject == null) + { + object o = new Object(); + Interlocked.CompareExchange(ref s_InternalSyncObject, o, null); + } + return s_InternalSyncObject; + } + } + + internal static bool On + { + get + { + if (!s_LoggingInitialized) + { + InitializeLogging(); + } + return s_LoggingEnabled; + } + } + + internal static bool IsVerbose(TraceSource traceSource) + { + return ValidateSettings(traceSource, TraceEventType.Verbose); + } + + internal static TraceSource Web + { + get + { + if (!s_LoggingInitialized) + { + InitializeLogging(); + } + if (!s_LoggingEnabled) + { + return null; + } + return s_WebTraceSource; + } + } + + internal static TraceSource HttpListener + { + get + { + if (!s_LoggingInitialized) + { + InitializeLogging(); + } + if (!s_LoggingEnabled) + { + return null; + } + return s_HttpListenerTraceSource; + } + } + + private static bool GetUseProtocolTextSetting(TraceSource traceSource) + { + bool useProtocolTextOnly = DefaultUseProtocolTextOnly; + if (traceSource.Attributes[AttributeNameTraceMode] == AttributeValueProtocolOnly) + { + useProtocolTextOnly = true; + } + return useProtocolTextOnly; + } + + private static int GetMaxDumpSizeSetting(TraceSource traceSource) + { + int maxDumpSize = DefaultMaxDumpSize; + if (traceSource.Attributes.ContainsKey(AttributeNameMaxSize)) + { + try + { + maxDumpSize = Int32.Parse(traceSource.Attributes[AttributeNameMaxSize], NumberFormatInfo.InvariantInfo); + } + catch (Exception exception) + { + if (exception is ThreadAbortException || exception is StackOverflowException || exception is OutOfMemoryException) + { + throw; + } + traceSource.Attributes[AttributeNameMaxSize] = maxDumpSize.ToString(NumberFormatInfo.InvariantInfo); + } + } + return maxDumpSize; + } + + /// + /// Sets up internal config settings for logging. (MUST be called under critsec) + /// + private static void InitializeLogging() + { + lock (InternalSyncObject) + { + if (!s_LoggingInitialized) + { + bool loggingEnabled = false; + s_WebTraceSource = new NclTraceSource(TraceSourceWebName); + s_HttpListenerTraceSource = new NclTraceSource(TraceSourceHttpListenerName); + + GlobalLog.Print("Initalizating tracing"); + + try + { + loggingEnabled = (s_WebTraceSource.Switch.ShouldTrace(TraceEventType.Critical) || + s_HttpListenerTraceSource.Switch.ShouldTrace(TraceEventType.Critical)); + } + catch (SecurityException) + { + // These may throw if the caller does not have permission to hook up trace listeners. + // We treat this case as though logging were disabled. + Close(); + loggingEnabled = false; + } + if (loggingEnabled) + { + AppDomain currentDomain = AppDomain.CurrentDomain; + currentDomain.UnhandledException += new UnhandledExceptionEventHandler(UnhandledExceptionHandler); + currentDomain.DomainUnload += new EventHandler(AppDomainUnloadEvent); + currentDomain.ProcessExit += new EventHandler(ProcessExitEvent); + } + s_LoggingEnabled = loggingEnabled; + s_LoggingInitialized = true; + } + } + } + + [SuppressMessage("Microsoft.Security", "CA2122:DoNotIndirectlyExposeMethodsWithLinkDemands", Justification = "Logging functions must work in partial trust mode")] + private static void Close() + { + if (s_WebTraceSource != null) + { + s_WebTraceSource.Close(); + } + if (s_HttpListenerTraceSource != null) + { + s_HttpListenerTraceSource.Close(); + } + } + + /// + /// Logs any unhandled exception through this event handler + /// + private static void UnhandledExceptionHandler(object sender, UnhandledExceptionEventArgs args) + { + Exception e = (Exception)args.ExceptionObject; + Exception(Web, sender, "UnhandledExceptionHandler", e); + } + + private static void ProcessExitEvent(object sender, EventArgs e) + { + Close(); + s_AppDomainShutdown = true; + } + + /// + /// Called when the system is shutting down, used to prevent additional logging post-shutdown + /// + private static void AppDomainUnloadEvent(object sender, EventArgs e) + { + Close(); + s_AppDomainShutdown = true; + } + + /// + /// Confirms logging is enabled, given current logging settings + /// + private static bool ValidateSettings(TraceSource traceSource, TraceEventType traceLevel) + { + if (!s_LoggingEnabled) + { + return false; + } + if (!s_LoggingInitialized) + { + InitializeLogging(); + } + if (traceSource == null || !traceSource.Switch.ShouldTrace(traceLevel)) + { + return false; + } + if (s_AppDomainShutdown) + { + return false; + } + return true; + } + + /// + /// Converts an object to a normalized string that can be printed + /// takes System.Net.ObjectNamedFoo and coverts to ObjectNamedFoo, + /// except IPAddress, IPEndPoint, and Uri, which return ToString() + /// + /// + private static string GetObjectName(object obj) + { + if (obj is Uri || obj is System.Net.IPAddress || obj is System.Net.IPEndPoint) + { + return obj.ToString(); + } + else + { + return obj.GetType().Name; + } + } + + internal static uint GetThreadId() + { + uint threadId = UnsafeNclNativeMethods.GetCurrentThreadId(); + if (threadId == 0) + { + threadId = (uint)Thread.CurrentThread.GetHashCode(); + } + return threadId; + } + + internal static void PrintLine(TraceSource traceSource, TraceEventType eventType, int id, string msg) + { + string logHeader = "[" + GetThreadId().ToString("d4", CultureInfo.InvariantCulture) + "] "; + traceSource.TraceEvent(eventType, id, logHeader + msg); + } + + /// + /// Indicates that two objects are getting used with one another + /// + internal static void Associate(TraceSource traceSource, object objA, object objB) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + + string lineA = GetObjectName(objA) + "#" + ValidationHelper.HashString(objA); + string lineB = GetObjectName(objB) + "#" + ValidationHelper.HashString(objB); + + PrintLine(traceSource, TraceEventType.Information, 0, "Associating " + lineA + " with " + lineB); + } + + /// + /// Logs entrance to a function + /// + internal static void Enter(TraceSource traceSource, object obj, string method, string param) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + Enter(traceSource, GetObjectName(obj) + "#" + ValidationHelper.HashString(obj), method, param); + } + + /// + /// Logs entrance to a function + /// + internal static void Enter(TraceSource traceSource, object obj, string method, object paramObject) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + Enter(traceSource, GetObjectName(obj) + "#" + ValidationHelper.HashString(obj), method, paramObject); + } + + /// + /// Logs entrance to a function + /// + internal static void Enter(TraceSource traceSource, string obj, string method, string param) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + Enter(traceSource, obj + "::" + method + "(" + param + ")"); + } + + /// + /// Logs entrance to a function + /// + internal static void Enter(TraceSource traceSource, string obj, string method, object paramObject) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + string paramObjectValue = string.Empty; + if (paramObject != null) + { + paramObjectValue = GetObjectName(paramObject) + "#" + ValidationHelper.HashString(paramObject); + } + Enter(traceSource, obj + "::" + method + "(" + paramObjectValue + ")"); + } + + /// + /// Logs entrance to a function, indents and points that out + /// + internal static void Enter(TraceSource traceSource, string method, string parameters) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + Enter(traceSource, method + "(" + parameters + ")"); + } + + /// + /// Logs entrance to a function, indents and points that out + /// + internal static void Enter(TraceSource traceSource, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + // Trace.CorrelationManager.StartLogicalOperation(); + PrintLine(traceSource, TraceEventType.Verbose, 0, msg); + } + + /// + /// Logs exit from a function + /// + internal static void Exit(TraceSource traceSource, object obj, string method, object retObject) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + string retValue = string.Empty; + if (retObject != null) + { + retValue = GetObjectName(retObject) + "#" + ValidationHelper.HashString(retObject); + } + Exit(traceSource, obj, method, retValue); + } + + /// + /// Logs exit from a function + /// + internal static void Exit(TraceSource traceSource, string obj, string method, object retObject) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + string retValue = string.Empty; + if (retObject != null) + { + retValue = GetObjectName(retObject) + "#" + ValidationHelper.HashString(retObject); + } + Exit(traceSource, obj, method, retValue); + } + + /// + /// Logs exit from a function + /// + internal static void Exit(TraceSource traceSource, object obj, string method, string retValue) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + Exit(traceSource, GetObjectName(obj) + "#" + ValidationHelper.HashString(obj), method, retValue); + } + + /// + /// Logs exit from a function + /// + internal static void Exit(TraceSource traceSource, string obj, string method, string retValue) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + if (!string.IsNullOrEmpty(retValue)) + { + retValue = "\t-> " + retValue; + } + Exit(traceSource, obj + "::" + method + "() " + retValue); + } + + /// + /// Logs exit from a function + /// + internal static void Exit(TraceSource traceSource, string method, string parameters) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + Exit(traceSource, method + "() " + parameters); + } + + /// + /// Logs exit from a function + /// + internal static void Exit(TraceSource traceSource, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + PrintLine(traceSource, TraceEventType.Verbose, 0, "Exiting " + msg); + // Trace.CorrelationManager.StopLogicalOperation(); + } + + /// + /// Logs Exception, restores indenting + /// + internal static void Exception(TraceSource traceSource, object obj, string method, Exception e) + { + if (!ValidateSettings(traceSource, TraceEventType.Error)) + { + return; + } + + string infoLine = SR.GetString(SR.net_log_exception, GetObjectLogHash(obj), method, e.Message); + if (!string.IsNullOrEmpty(e.StackTrace)) + { + infoLine += "\r\n" + e.StackTrace; + } + PrintLine(traceSource, TraceEventType.Error, 0, infoLine); + } + + /// + /// Logs an Info line + /// + internal static void PrintInfo(TraceSource traceSource, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + PrintLine(traceSource, TraceEventType.Information, 0, msg); + } + + /// + /// Logs an Info line + /// + internal static void PrintInfo(TraceSource traceSource, object obj, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + PrintLine(traceSource, TraceEventType.Information, 0, + GetObjectName(obj) + "#" + ValidationHelper.HashString(obj) + + " - " + msg); + } + + /// + /// Logs an Info line + /// + internal static void PrintInfo(TraceSource traceSource, object obj, string method, string param) + { + if (!ValidateSettings(traceSource, TraceEventType.Information)) + { + return; + } + PrintLine(traceSource, TraceEventType.Information, 0, + GetObjectName(obj) + "#" + ValidationHelper.HashString(obj) + + "::" + method + "(" + param + ")"); + } + + /// + /// Logs a Warning line + /// + internal static void PrintWarning(TraceSource traceSource, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Warning)) + { + return; + } + PrintLine(traceSource, TraceEventType.Warning, 0, msg); + } + + /// + /// Logs a Warning line + /// + internal static void PrintWarning(TraceSource traceSource, object obj, string method, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Warning)) + { + return; + } + PrintLine(traceSource, TraceEventType.Warning, 0, + GetObjectName(obj) + "#" + ValidationHelper.HashString(obj) + + "::" + method + "() - " + msg); + } + + /// + /// Logs an Error line + /// + internal static void PrintError(TraceSource traceSource, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Error)) + { + return; + } + PrintLine(traceSource, TraceEventType.Error, 0, msg); + } + + /// + /// Logs an Error line + /// + internal static void PrintError(TraceSource traceSource, object obj, string method, string msg) + { + if (!ValidateSettings(traceSource, TraceEventType.Error)) + { + return; + } + PrintLine(traceSource, TraceEventType.Error, 0, + GetObjectName(obj) + "#" + ValidationHelper.HashString(obj) + + "::" + method + "() - " + msg); + } + + internal static string GetObjectLogHash(object obj) + { + return GetObjectName(obj) + "#" + ValidationHelper.HashString(obj); + } + + /// + /// Marhsalls a buffer ptr to an array and then dumps the byte array to the log + /// + internal static void Dump(TraceSource traceSource, object obj, string method, IntPtr bufferPtr, int length) + { + if (!ValidateSettings(traceSource, TraceEventType.Verbose) || bufferPtr == IntPtr.Zero || length < 0) + { + return; + } + byte[] buffer = new byte[length]; + Marshal.Copy(bufferPtr, buffer, 0, length); + Dump(traceSource, obj, method, buffer, 0, length); + } + + /// + /// Dumps a byte array to the log + /// + internal static void Dump(TraceSource traceSource, object obj, string method, byte[] buffer, int offset, int length) + { + if (!ValidateSettings(traceSource, TraceEventType.Verbose)) + { + return; + } + if (buffer == null) + { + PrintLine(traceSource, TraceEventType.Verbose, 0, "(null)"); + return; + } + if (offset > buffer.Length) + { + PrintLine(traceSource, TraceEventType.Verbose, 0, "(offset out of range)"); + return; + } + PrintLine(traceSource, TraceEventType.Verbose, 0, "Data from " + GetObjectName(obj) + "#" + ValidationHelper.HashString(obj) + "::" + method); + int maxDumpSize = GetMaxDumpSizeSetting(traceSource); + if (length > maxDumpSize) + { + PrintLine(traceSource, TraceEventType.Verbose, 0, "(printing " + maxDumpSize.ToString(NumberFormatInfo.InvariantInfo) + " out of " + length.ToString(NumberFormatInfo.InvariantInfo) + ")"); + length = maxDumpSize; + } + if ((length < 0) || (length > buffer.Length - offset)) + { + length = buffer.Length - offset; + } + if (GetUseProtocolTextSetting(traceSource)) + { + string output = "<<" + HeaderEncoding.GetString(buffer, offset, length) + ">>"; + PrintLine(traceSource, TraceEventType.Verbose, 0, output); + return; + } + do + { + int n = Math.Min(length, 16); + string disp = String.Format(CultureInfo.CurrentCulture, "{0:X8} : ", offset); + for (int i = 0; i < n; ++i) + { + disp += String.Format(CultureInfo.CurrentCulture, "{0:X2}", buffer[offset + i]) + ((i == 7) ? '-' : ' '); + } + for (int i = n; i < 16; ++i) + { + disp += " "; + } + disp += ": "; + for (int i = 0; i < n; ++i) + { + disp += ((buffer[offset + i] < 0x20) || (buffer[offset + i] > 0x7e)) + ? '.' + : (char)(buffer[offset + i]); + } + PrintLine(traceSource, TraceEventType.Verbose, 0, disp); + offset += n; + length -= n; + } + while (length > 0); + } + + private class NclTraceSource : TraceSource + { + internal NclTraceSource(string name) : base(name) + { + } + /* + protected internal override string[] GetSupportedAttributes() + { + return Logging.SupportedAttributes; + }*/ + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/LoggingObject.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/LoggingObject.cs new file mode 100644 index 0000000000..dbbd6d5415 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/LoggingObject.cs @@ -0,0 +1,580 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +// We have function based stack and thread based logging of basic behavior. We +// also now have the ability to run a "watch thread" which does basic hang detection +// and error-event based logging. The logging code buffers the callstack/picture +// of all COMNET threads, and upon error from an assert or a hang, it will open a file +// and dump the snapsnot. Future work will allow this to be configed by registry and +// to use Runtime based logging. We'd also like to have different levels of logging. + +namespace Microsoft.AspNet.Security.Windows +{ + using System; + + // BaseLoggingObject - used to disable logging, + // this is just a base class that does nothing. + + [Flags] + internal enum ThreadKinds + { + Unknown = 0x0000, + + // Mutually exclusive. + User = 0x0001, // Thread has entered via an API. + System = 0x0002, // Thread has entered via a system callback (e.g. completion port) or is our own thread. + + // Mutually exclusive. + Sync = 0x0004, // Thread should block. + Async = 0x0008, // Thread should not block. + + // Mutually exclusive, not always known for a user thread. Never changes. + Timer = 0x0010, // Thread is the timer thread. (Can't call user code.) + CompletionPort = 0x0020, // Thread is a ThreadPool completion-port thread. + Worker = 0x0040, // Thread is a ThreadPool worker thread. + Finalization = 0x0080, // Thread is the finalization thread. + Other = 0x0100, // Unknown source. + + OwnerMask = User | System, + SyncMask = Sync | Async, + SourceMask = Timer | CompletionPort | Worker | Finalization | Other, + + // Useful "macros" + SafeSources = SourceMask & ~(Timer | Finalization), // Methods that "unsafe" sources can call must be explicitly marked. + ThreadPool = CompletionPort | Worker, // Like Thread.CurrentThread.IsThreadPoolThread + } + + internal class BaseLoggingObject + { + internal BaseLoggingObject() + { + } + + internal virtual void EnterFunc(string funcname) + { + } + + internal virtual void LeaveFunc(string funcname) + { + } + + internal virtual void DumpArrayToConsole() + { + } + + internal virtual void PrintLine(string msg) + { + } + + internal virtual void DumpArray(bool shouldClose) + { + } + + internal virtual void DumpArrayToFile(bool shouldClose) + { + } + + internal virtual void Flush() + { + } + + internal virtual void Flush(bool close) + { + } + + internal virtual void LoggingMonitorTick() + { + } + + internal virtual void Dump(byte[] buffer) + { + } + + internal virtual void Dump(byte[] buffer, int length) + { + } + + internal virtual void Dump(byte[] buffer, int offset, int length) + { + } + + internal virtual void Dump(IntPtr pBuffer, int offset, int length) + { + } + } // class BaseLoggingObject + +#if TRAVE + /// + /// + /// + internal class LoggingObject : BaseLoggingObject { + public ArrayList _Logarray; + private Hashtable _ThreadNesting; + private int _AddCount; + private StreamWriter _Stream; + private int _IamAlive; + private int _LastIamAlive; + private bool _Finalized = false; + private double _NanosecondsPerTick; + private int _StartMilliseconds; + private long _StartTicks; + + internal LoggingObject() : base() { + _Logarray = new ArrayList(); + _ThreadNesting = new Hashtable(); + _AddCount = 0; + _IamAlive = 0; + _LastIamAlive = -1; + + if (GlobalLog.s_UsePerfCounter) { + long ticksPerSecond; + SafeNativeMethods.QueryPerformanceFrequency(out ticksPerSecond); + _NanosecondsPerTick = 10000000.0/(double)ticksPerSecond; + SafeNativeMethods.QueryPerformanceCounter(out _StartTicks); + } else { + _StartMilliseconds = Environment.TickCount; + } + } + + // + // LoggingMonitorTick - this function is run from the monitor thread, + // and used to check to see if there any hangs, ie no logging + // activitity + // + + internal override void LoggingMonitorTick() { + if ( _LastIamAlive == _IamAlive ) { + PrintLine("================= Error TIMEOUT - HANG DETECTED ================="); + DumpArray(true); + } + _LastIamAlive = _IamAlive; + } + + internal override void EnterFunc(string funcname) { + if (_Finalized) { + return; + } + IncNestingCount(); + ValidatePush(funcname); + PrintLine(funcname); + } + + internal override void LeaveFunc(string funcname) { + if (_Finalized) { + return; + } + PrintLine(funcname); + DecNestingCount(); + ValidatePop(funcname); + } + + internal override void DumpArrayToConsole() { + for (int i=0; i < _Logarray.Count; i++) { + Console.WriteLine((string) _Logarray[i]); + } + } + + internal override void PrintLine(string msg) { + if (_Finalized) { + return; + } + string spc = ""; + + _IamAlive++; + + spc = GetNestingString(); + + string tickString = ""; + + if (GlobalLog.s_UsePerfCounter) { + long nowTicks; + SafeNativeMethods.QueryPerformanceCounter(out nowTicks); + if (_StartTicks>nowTicks) { // counter reset, restart from 0 + _StartTicks = nowTicks; + } + nowTicks -= _StartTicks; + if (GlobalLog.s_UseTimeSpan) { + tickString = new TimeSpan((long)(nowTicks*_NanosecondsPerTick)).ToString(); + // note: TimeSpan().ToString() doesn't return the uSec part + // if its 0. .ToString() returns [H*]HH:MM:SS:uuuuuuu, hence 16 + if (tickString.Length < 16) { + tickString += ".0000000"; + } + } + else { + tickString = ((double)nowTicks*_NanosecondsPerTick/10000).ToString("f3"); + } + } + else { + int nowMilliseconds = Environment.TickCount; + if (_StartMilliseconds>nowMilliseconds) { + _StartMilliseconds = nowMilliseconds; + } + nowMilliseconds -= _StartMilliseconds; + if (GlobalLog.s_UseTimeSpan) { + tickString = new TimeSpan(nowMilliseconds*10000).ToString(); + // note: TimeSpan().ToString() doesn't return the uSec part + // if its 0. .ToString() returns [H*]HH:MM:SS:uuuuuuu, hence 16 + if (tickString.Length < 16) { + tickString += ".0000000"; + } + } + else { + tickString = nowMilliseconds.ToString(); + } + } + + uint threadId = 0; + + if (GlobalLog.s_UseThreadId) { + try { + object threadData = Thread.GetData(GlobalLog.s_ThreadIdSlot); + if (threadData!= null) { + threadId = (uint)threadData; + } + + } + catch(Exception exception) { + if (exception is ThreadAbortException || exception is StackOverflowException || exception is OutOfMemoryException) { + throw; + } + } + if (threadId == 0) { + threadId = UnsafeNclNativeMethods.GetCurrentThreadId(); + Thread.SetData(GlobalLog.s_ThreadIdSlot, threadId); + } + } + if (threadId == 0) { + threadId = (uint)Thread.CurrentThread.GetHashCode(); + } + + string str = "[" + threadId.ToString("x8") + "]" + " (" +tickString+ ") " + spc + msg; + + lock(this) { + _AddCount++; + _Logarray.Add(str); + int MaxLines = GlobalLog.s_DumpToConsole ? 0 : GlobalLog.MaxLinesBeforeSave; + if (_AddCount > MaxLines) { + _AddCount = 0; + DumpArray(false); + _Logarray = new ArrayList(); + } + } + } + + internal override void DumpArray(bool shouldClose) { + if ( GlobalLog.s_DumpToConsole ) { + DumpArrayToConsole(); + } else { + DumpArrayToFile(shouldClose); + } + } + + internal unsafe override void Dump(byte[] buffer, int offset, int length) { + //if (!GlobalLog.s_DumpWebData) { + // return; + //} + if (buffer==null) { + PrintLine("(null)"); + return; + } + if (offset > buffer.Length) { + PrintLine("(offset out of range)"); + return; + } + if (length > GlobalLog.s_MaxDumpSize) { + PrintLine("(printing " + GlobalLog.s_MaxDumpSize.ToString() + " out of " + length.ToString() + ")"); + length = GlobalLog.s_MaxDumpSize; + } + if ((length < 0) || (length > buffer.Length - offset)) { + length = buffer.Length - offset; + } + fixed (byte* pBuffer = buffer) { + Dump((IntPtr)pBuffer, offset, length); + } + } + + internal unsafe override void Dump(IntPtr pBuffer, int offset, int length) { + //if (!GlobalLog.s_DumpWebData) { + // return; + //} + if (pBuffer==IntPtr.Zero || length<0) { + PrintLine("(null)"); + return; + } + if (length > GlobalLog.s_MaxDumpSize) { + PrintLine("(printing " + GlobalLog.s_MaxDumpSize.ToString() + " out of " + length.ToString() + ")"); + length = GlobalLog.s_MaxDumpSize; + } + byte* buffer = (byte*)pBuffer + offset; + Dump(buffer, length); + } + + unsafe void Dump(byte* buffer, int length) { + do { + int offset = 0; + int n = Math.Min(length, 16); + string disp = ((IntPtr)buffer).ToString("X8") + " : " + offset.ToString("X8") + " : "; + byte current; + for (int i = 0; i < n; ++i) { + current = *(buffer + i); + disp += current.ToString("X2") + ((i == 7) ? '-' : ' '); + } + for (int i = n; i < 16; ++i) { + disp += " "; + } + disp += ": "; + for (int i = 0; i < n; ++i) { + current = *(buffer + i); + disp += ((current < 0x20) || (current > 0x7e)) ? '.' : (char)current; + } + PrintLine(disp); + offset += n; + buffer += n; + length -= n; + } while (length > 0); + } + + // SECURITY: This is dev-debugging class and we need some permissions + // to use it under trust-restricted environment as well. + [PermissionSet(SecurityAction.Assert, Name="FullTrust")] + internal override void DumpArrayToFile(bool shouldClose) { + lock (this) { + if (!shouldClose) { + if (_Stream==null) { + string mainLogFileRoot = GlobalLog.s_RootDirectory + "System.Net"; + string mainLogFile = mainLogFileRoot; + for (int k=0; k<20; k++) { + if (k>0) { + mainLogFile = mainLogFileRoot + "." + k.ToString(); + } + string fileName = mainLogFile + ".log"; + if (!File.Exists(fileName)) { + try { + _Stream = new StreamWriter(fileName); + break; + } + catch (Exception exception) { + if (exception is ThreadAbortException || exception is StackOverflowException || exception is OutOfMemoryException) { + throw; + } + if (exception is SecurityException || exception is UnauthorizedAccessException) { + // can't be CAS (we assert) this is an ACL issue + break; + } + } + } + } + if (_Stream==null) { + _Stream = StreamWriter.Null; + } + // write a header with information about the Process and the AppDomain + _Stream.Write("# MachineName: " + Environment.MachineName + "\r\n"); + _Stream.Write("# ProcessName: " + Process.GetCurrentProcess().ProcessName + " (pid: " + Process.GetCurrentProcess().Id + ")\r\n"); + _Stream.Write("# AppDomainId: " + AppDomain.CurrentDomain.Id + "\r\n"); + _Stream.Write("# CurrentIdentity: " + WindowsIdentity.GetCurrent().Name + "\r\n"); + _Stream.Write("# CommandLine: " + Environment.CommandLine + "\r\n"); + _Stream.Write("# ClrVersion: " + Environment.Version + "\r\n"); + _Stream.Write("# CreationDate: " + DateTime.Now.ToString("g") + "\r\n"); + } + } + try { + if (_Logarray!=null) { + for (int i=0; i<_Logarray.Count; i++) { + _Stream.Write((string)_Logarray[i]); + _Stream.Write("\r\n"); + } + + if (_Logarray.Count > 0 && _Stream != null) + _Stream.Flush(); + } + } + catch (Exception exception) { + if (exception is ThreadAbortException || exception is StackOverflowException || exception is OutOfMemoryException) { + throw; + } + } + if (shouldClose && _Stream!=null) { + try { + _Stream.Close(); + } + catch (ObjectDisposedException) { } + _Stream = null; + } + } + } + + internal override void Flush() { + Flush(false); + } + + internal override void Flush(bool close) { + lock (this) { + if (!GlobalLog.s_DumpToConsole) { + DumpArrayToFile(close); + _AddCount = 0; + } + } + } + + private class ThreadInfoData { + public ThreadInfoData(string indent) { + Indent = indent; + NestingStack = new Stack(); + } + public string Indent; + public Stack NestingStack; + }; + + string IndentString { + get { + string indent = " "; + Object obj = _ThreadNesting[Thread.CurrentThread.GetHashCode()]; + if (!GlobalLog.s_DebugCallNesting) { + if (obj == null) { + _ThreadNesting[Thread.CurrentThread.GetHashCode()] = indent; + } else { + indent = (String) obj; + } + } else { + ThreadInfoData threadInfo = obj as ThreadInfoData; + if (threadInfo == null) { + threadInfo = new ThreadInfoData(indent); + _ThreadNesting[Thread.CurrentThread.GetHashCode()] = threadInfo; + } + indent = threadInfo.Indent; + } + return indent; + } + set { + Object obj = _ThreadNesting[Thread.CurrentThread.GetHashCode()]; + if (obj == null) { + return; + } + if (!GlobalLog.s_DebugCallNesting) { + _ThreadNesting[Thread.CurrentThread.GetHashCode()] = value; + } else { + ThreadInfoData threadInfo = obj as ThreadInfoData; + if (threadInfo == null) { + threadInfo = new ThreadInfoData(value); + _ThreadNesting[Thread.CurrentThread.GetHashCode()] = threadInfo; + } + threadInfo.Indent = value; + } + } + } + + [System.Diagnostics.Conditional("TRAVE")] + private void IncNestingCount() { + IndentString = IndentString + " "; + } + + [System.Diagnostics.Conditional("TRAVE")] + private void DecNestingCount() { + string indent = IndentString; + if (indent.Length>1) { + try { + indent = indent.Substring(1); + } + catch { + indent = string.Empty; + } + } + if (indent.Length==0) { + indent = "< "; + } + IndentString = indent; + } + + private string GetNestingString() { + return IndentString; + } + + [System.Diagnostics.Conditional("TRAVE")] + private void ValidatePush(string name) { + if (GlobalLog.s_DebugCallNesting) { + Object obj = _ThreadNesting[Thread.CurrentThread.GetHashCode()]; + ThreadInfoData threadInfo = obj as ThreadInfoData; + if (threadInfo == null) { + return; + } + threadInfo.NestingStack.Push(name); + } + } + + [System.Diagnostics.Conditional("TRAVE")] + private void ValidatePop(string name) { + if (GlobalLog.s_DebugCallNesting) { + try { + Object obj = _ThreadNesting[Thread.CurrentThread.GetHashCode()]; + ThreadInfoData threadInfo = obj as ThreadInfoData; + if (threadInfo == null) { + return; + } + if (threadInfo.NestingStack.Count == 0) { + PrintLine("++++====" + "Poped Empty Stack for :"+name); + } + string popedName = (string) threadInfo.NestingStack.Pop(); + string [] parsedList = popedName.Split(new char [] {'(',')',' ','.',':',',','#'}); + foreach (string element in parsedList) { + if (element != null && element.Length > 1 && name.IndexOf(element) != -1) { + return; + } + } + PrintLine("++++====" + "Expected:" + popedName + ": got :" + name + ": StackSize:"+threadInfo.NestingStack.Count); + // relevel the stack + while(threadInfo.NestingStack.Count>0) { + string popedName2 = (string) threadInfo.NestingStack.Pop(); + string [] parsedList2 = popedName2.Split(new char [] {'(',')',' ','.',':',',','#'}); + foreach (string element2 in parsedList2) { + if (element2 != null && element2.Length > 1 && name.IndexOf(element2) != -1) { + return; + } + } + } + } + catch { + PrintLine("++++====" + "ValidatePop failed for: "+name); + } + } + } + + + ~LoggingObject() { + if(!_Finalized) { + _Finalized = true; + lock(this) { + DumpArray(true); + } + } + } + + + } // class LoggingObject + + internal static class TraveHelper { + private static readonly string Hexizer = "0x{0:x}"; + internal static string ToHex(object value) { + return String.Format(Hexizer, value); + } + } +#endif // TRAVE + +#if TRAVE + internal class IntegerSwitch : BooleanSwitch { + public IntegerSwitch(string switchName, string switchDescription) : base(switchName, switchDescription) { + } + public new int Value { + get { + return base.SwitchSetting; + } + } + } + +#endif + + // class GlobalLog +} // namespace System.Net diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/SR.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/SR.cs new file mode 100644 index 0000000000..5486b33602 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/SR.cs @@ -0,0 +1,630 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace System +{ + using System; + using System.Reflection; + using System.Globalization; + using System.Resources; + using System.Text; + using System.Threading; + using System.Security.Permissions; + using System.ComponentModel; + + /// + /// AutoGenerated resource class. Usage: + /// string s = SR.GetString(SR.MyIdenfitier); + /// + internal sealed class SR + { + internal const string security_ExtendedProtection_NoOSSupport = "security_ExtendedProtection_NoOSSupport"; + internal const string net_nonClsCompliantException = "net_nonClsCompliantException"; + internal const string net_illegalConfigWith = "net_illegalConfigWith"; + internal const string net_illegalConfigWithout = "net_illegalConfigWithout"; + internal const string net_baddate = "net_baddate"; + internal const string net_writestarted = "net_writestarted"; + internal const string net_clsmall = "net_clsmall"; + internal const string net_reqsubmitted = "net_reqsubmitted"; + internal const string net_rspsubmitted = "net_rspsubmitted"; + internal const string net_ftp_no_http_cmd = "net_ftp_no_http_cmd"; + internal const string net_ftp_invalid_method_name = "net_ftp_invalid_method_name"; + internal const string net_ftp_invalid_renameto = "net_ftp_invalid_renameto"; + internal const string net_ftp_no_defaultcreds = "net_ftp_no_defaultcreds"; + internal const string net_ftpnoresponse = "net_ftpnoresponse"; + internal const string net_ftp_response_invalid_format = "net_ftp_response_invalid_format"; + internal const string net_ftp_no_offsetforhttp = "net_ftp_no_offsetforhttp"; + internal const string net_ftp_invalid_uri = "net_ftp_invalid_uri"; + internal const string net_ftp_invalid_status_response = "net_ftp_invalid_status_response"; + internal const string net_ftp_server_failed_passive = "net_ftp_server_failed_passive"; + internal const string net_ftp_active_address_different = "net_ftp_active_address_different"; + internal const string net_ftp_proxy_does_not_support_ssl = "net_ftp_proxy_does_not_support_ssl"; + internal const string net_ftp_invalid_response_filename = "net_ftp_invalid_response_filename"; + internal const string net_ftp_unsupported_method = "net_ftp_unsupported_method"; + internal const string net_resubmitcanceled = "net_resubmitcanceled"; + internal const string net_redirect_perm = "net_redirect_perm"; + internal const string net_resubmitprotofailed = "net_resubmitprotofailed"; + internal const string net_needchunked = "net_needchunked"; + internal const string net_nochunked = "net_nochunked"; + internal const string net_nochunkuploadonhttp10 = "net_nochunkuploadonhttp10"; + internal const string net_connarg = "net_connarg"; + internal const string net_no100 = "net_no100"; + internal const string net_fromto = "net_fromto"; + internal const string net_rangetoosmall = "net_rangetoosmall"; + internal const string net_entitytoobig = "net_entitytoobig"; + internal const string net_invalidversion = "net_invalidversion"; + internal const string net_invalidstatus = "net_invalidstatus"; + internal const string net_toosmall = "net_toosmall"; + internal const string net_toolong = "net_toolong"; + internal const string net_connclosed = "net_connclosed"; + internal const string net_noseek = "net_noseek"; + internal const string net_servererror = "net_servererror"; + internal const string net_nouploadonget = "net_nouploadonget"; + internal const string net_mutualauthfailed = "net_mutualauthfailed"; + internal const string net_invasync = "net_invasync"; + internal const string net_inasync = "net_inasync"; + internal const string net_mustbeuri = "net_mustbeuri"; + internal const string net_format_shexp = "net_format_shexp"; + internal const string net_cannot_load_proxy_helper = "net_cannot_load_proxy_helper"; + internal const string net_invalid_host = "net_invalid_host"; + internal const string net_repcall = "net_repcall"; + internal const string net_wrongversion = "net_wrongversion"; + internal const string net_badmethod = "net_badmethod"; + internal const string net_io_notenoughbyteswritten = "net_io_notenoughbyteswritten"; + internal const string net_io_timeout_use_ge_zero = "net_io_timeout_use_ge_zero"; + internal const string net_io_timeout_use_gt_zero = "net_io_timeout_use_gt_zero"; + internal const string net_io_no_0timeouts = "net_io_no_0timeouts"; + internal const string net_requestaborted = "net_requestaborted"; + internal const string net_tooManyRedirections = "net_tooManyRedirections"; + internal const string net_authmodulenotregistered = "net_authmodulenotregistered"; + internal const string net_authschemenotregistered = "net_authschemenotregistered"; + internal const string net_proxyschemenotsupported = "net_proxyschemenotsupported"; + internal const string net_maxsrvpoints = "net_maxsrvpoints"; + internal const string net_unknown_prefix = "net_unknown_prefix"; + internal const string net_notconnected = "net_notconnected"; + internal const string net_notstream = "net_notstream"; + internal const string net_timeout = "net_timeout"; + internal const string net_nocontentlengthonget = "net_nocontentlengthonget"; + internal const string net_contentlengthmissing = "net_contentlengthmissing"; + internal const string net_nonhttpproxynotallowed = "net_nonhttpproxynotallowed"; + internal const string net_nottoken = "net_nottoken"; + internal const string net_rangetype = "net_rangetype"; + internal const string net_need_writebuffering = "net_need_writebuffering"; + internal const string net_securitypackagesupport = "net_securitypackagesupport"; + internal const string net_securityprotocolnotsupported = "net_securityprotocolnotsupported"; + internal const string net_nodefaultcreds = "net_nodefaultcreds"; + internal const string net_stopped = "net_stopped"; + internal const string net_udpconnected = "net_udpconnected"; + internal const string net_readonlystream = "net_readonlystream"; + internal const string net_writeonlystream = "net_writeonlystream"; + internal const string net_no_concurrent_io_allowed = "net_no_concurrent_io_allowed"; + internal const string net_needmorethreads = "net_needmorethreads"; + internal const string net_MethodNotImplementedException = "net_MethodNotImplementedException"; + internal const string net_PropertyNotImplementedException = "net_PropertyNotImplementedException"; + internal const string net_MethodNotSupportedException = "net_MethodNotSupportedException"; + internal const string net_PropertyNotSupportedException = "net_PropertyNotSupportedException"; + internal const string net_ProtocolNotSupportedException = "net_ProtocolNotSupportedException"; + internal const string net_SelectModeNotSupportedException = "net_SelectModeNotSupportedException"; + internal const string net_InvalidSocketHandle = "net_InvalidSocketHandle"; + internal const string net_InvalidAddressFamily = "net_InvalidAddressFamily"; + internal const string net_InvalidEndPointAddressFamily = "net_InvalidEndPointAddressFamily"; + internal const string net_InvalidSocketAddressSize = "net_InvalidSocketAddressSize"; + internal const string net_invalidAddressList = "net_invalidAddressList"; + internal const string net_invalidPingBufferSize = "net_invalidPingBufferSize"; + internal const string net_cant_perform_during_shutdown = "net_cant_perform_during_shutdown"; + internal const string net_cant_create_environment = "net_cant_create_environment"; + internal const string net_completed_result = "net_completed_result"; + internal const string net_protocol_invalid_family = "net_protocol_invalid_family"; + internal const string net_protocol_invalid_multicast_family = "net_protocol_invalid_multicast_family"; + internal const string net_empty_osinstalltype = "net_empty_osinstalltype"; + internal const string net_unknown_osinstalltype = "net_unknown_osinstalltype"; + internal const string net_cant_determine_osinstalltype = "net_cant_determine_osinstalltype"; + internal const string net_osinstalltype = "net_osinstalltype"; + internal const string net_entire_body_not_written = "net_entire_body_not_written"; + internal const string net_must_provide_request_body = "net_must_provide_request_body"; + internal const string net_ssp_dont_support_cbt = "net_ssp_dont_support_cbt"; + internal const string net_sockets_zerolist = "net_sockets_zerolist"; + internal const string net_sockets_blocking = "net_sockets_blocking"; + internal const string net_sockets_useblocking = "net_sockets_useblocking"; + internal const string net_sockets_select = "net_sockets_select"; + internal const string net_sockets_toolarge_select = "net_sockets_toolarge_select"; + internal const string net_sockets_empty_select = "net_sockets_empty_select"; + internal const string net_sockets_mustbind = "net_sockets_mustbind"; + internal const string net_sockets_mustlisten = "net_sockets_mustlisten"; + internal const string net_sockets_mustnotlisten = "net_sockets_mustnotlisten"; + internal const string net_sockets_mustnotbebound = "net_sockets_mustnotbebound"; + internal const string net_sockets_namedmustnotbebound = "net_sockets_namedmustnotbebound"; + internal const string net_sockets_invalid_socketinformation = "net_sockets_invalid_socketinformation"; + internal const string net_sockets_invalid_ipaddress_length = "net_sockets_invalid_ipaddress_length"; + internal const string net_sockets_invalid_optionValue = "net_sockets_invalid_optionValue"; + internal const string net_sockets_invalid_optionValue_all = "net_sockets_invalid_optionValue_all"; + internal const string net_sockets_invalid_dnsendpoint = "net_sockets_invalid_dnsendpoint"; + internal const string net_sockets_disconnectedConnect = "net_sockets_disconnectedConnect"; + internal const string net_sockets_disconnectedAccept = "net_sockets_disconnectedAccept"; + internal const string net_tcplistener_mustbestopped = "net_tcplistener_mustbestopped"; + internal const string net_sockets_no_duplicate_async = "net_sockets_no_duplicate_async"; + internal const string net_socketopinprogress = "net_socketopinprogress"; + internal const string net_buffercounttoosmall = "net_buffercounttoosmall"; + internal const string net_multibuffernotsupported = "net_multibuffernotsupported"; + internal const string net_ambiguousbuffers = "net_ambiguousbuffers"; + internal const string net_sockets_ipv6only = "net_sockets_ipv6only"; + internal const string net_perfcounter_initialized_success = "net_perfcounter_initialized_success"; + internal const string net_perfcounter_initialized_error = "net_perfcounter_initialized_error"; + internal const string net_perfcounter_nocategory = "net_perfcounter_nocategory"; + internal const string net_perfcounter_initialization_started = "net_perfcounter_initialization_started"; + internal const string net_perfcounter_cant_queue_workitem = "net_perfcounter_cant_queue_workitem"; + internal const string net_config_proxy = "net_config_proxy"; + internal const string net_config_proxy_module_not_public = "net_config_proxy_module_not_public"; + internal const string net_config_authenticationmodules = "net_config_authenticationmodules"; + internal const string net_config_webrequestmodules = "net_config_webrequestmodules"; + internal const string net_config_requestcaching = "net_config_requestcaching"; + internal const string net_config_section_permission = "net_config_section_permission"; + internal const string net_config_element_permission = "net_config_element_permission"; + internal const string net_config_property_permission = "net_config_property_permission"; + internal const string net_WebResponseParseError_InvalidHeaderName = "net_WebResponseParseError_InvalidHeaderName"; + internal const string net_WebResponseParseError_InvalidContentLength = "net_WebResponseParseError_InvalidContentLength"; + internal const string net_WebResponseParseError_IncompleteHeaderLine = "net_WebResponseParseError_IncompleteHeaderLine"; + internal const string net_WebResponseParseError_CrLfError = "net_WebResponseParseError_CrLfError"; + internal const string net_WebResponseParseError_InvalidChunkFormat = "net_WebResponseParseError_InvalidChunkFormat"; + internal const string net_WebResponseParseError_UnexpectedServerResponse = "net_WebResponseParseError_UnexpectedServerResponse"; + internal const string net_webstatus_Success = "net_webstatus_Success"; + internal const string net_webstatus_NameResolutionFailure = "net_webstatus_NameResolutionFailure"; + internal const string net_webstatus_ConnectFailure = "net_webstatus_ConnectFailure"; + internal const string net_webstatus_ReceiveFailure = "net_webstatus_ReceiveFailure"; + internal const string net_webstatus_SendFailure = "net_webstatus_SendFailure"; + internal const string net_webstatus_PipelineFailure = "net_webstatus_PipelineFailure"; + internal const string net_webstatus_RequestCanceled = "net_webstatus_RequestCanceled"; + internal const string net_webstatus_ConnectionClosed = "net_webstatus_ConnectionClosed"; + internal const string net_webstatus_TrustFailure = "net_webstatus_TrustFailure"; + internal const string net_webstatus_SecureChannelFailure = "net_webstatus_SecureChannelFailure"; + internal const string net_webstatus_ServerProtocolViolation = "net_webstatus_ServerProtocolViolation"; + internal const string net_webstatus_KeepAliveFailure = "net_webstatus_KeepAliveFailure"; + internal const string net_webstatus_ProxyNameResolutionFailure = "net_webstatus_ProxyNameResolutionFailure"; + internal const string net_webstatus_MessageLengthLimitExceeded = "net_webstatus_MessageLengthLimitExceeded"; + internal const string net_webstatus_CacheEntryNotFound = "net_webstatus_CacheEntryNotFound"; + internal const string net_webstatus_RequestProhibitedByCachePolicy = "net_webstatus_RequestProhibitedByCachePolicy"; + internal const string net_webstatus_Timeout = "net_webstatus_Timeout"; + internal const string net_webstatus_RequestProhibitedByProxy = "net_webstatus_RequestProhibitedByProxy"; + internal const string net_InvalidStatusCode = "net_InvalidStatusCode"; + internal const string net_ftpstatuscode_ServiceNotAvailable = "net_ftpstatuscode_ServiceNotAvailable"; + internal const string net_ftpstatuscode_CantOpenData = "net_ftpstatuscode_CantOpenData"; + internal const string net_ftpstatuscode_ConnectionClosed = "net_ftpstatuscode_ConnectionClosed"; + internal const string net_ftpstatuscode_ActionNotTakenFileUnavailableOrBusy = "net_ftpstatuscode_ActionNotTakenFileUnavailableOrBusy"; + internal const string net_ftpstatuscode_ActionAbortedLocalProcessingError = "net_ftpstatuscode_ActionAbortedLocalProcessingError"; + internal const string net_ftpstatuscode_ActionNotTakenInsufficentSpace = "net_ftpstatuscode_ActionNotTakenInsufficentSpace"; + internal const string net_ftpstatuscode_CommandSyntaxError = "net_ftpstatuscode_CommandSyntaxError"; + internal const string net_ftpstatuscode_ArgumentSyntaxError = "net_ftpstatuscode_ArgumentSyntaxError"; + internal const string net_ftpstatuscode_CommandNotImplemented = "net_ftpstatuscode_CommandNotImplemented"; + internal const string net_ftpstatuscode_BadCommandSequence = "net_ftpstatuscode_BadCommandSequence"; + internal const string net_ftpstatuscode_NotLoggedIn = "net_ftpstatuscode_NotLoggedIn"; + internal const string net_ftpstatuscode_AccountNeeded = "net_ftpstatuscode_AccountNeeded"; + internal const string net_ftpstatuscode_ActionNotTakenFileUnavailable = "net_ftpstatuscode_ActionNotTakenFileUnavailable"; + internal const string net_ftpstatuscode_ActionAbortedUnknownPageType = "net_ftpstatuscode_ActionAbortedUnknownPageType"; + internal const string net_ftpstatuscode_FileActionAborted = "net_ftpstatuscode_FileActionAborted"; + internal const string net_ftpstatuscode_ActionNotTakenFilenameNotAllowed = "net_ftpstatuscode_ActionNotTakenFilenameNotAllowed"; + internal const string net_httpstatuscode_NoContent = "net_httpstatuscode_NoContent"; + internal const string net_httpstatuscode_NonAuthoritativeInformation = "net_httpstatuscode_NonAuthoritativeInformation"; + internal const string net_httpstatuscode_ResetContent = "net_httpstatuscode_ResetContent"; + internal const string net_httpstatuscode_PartialContent = "net_httpstatuscode_PartialContent"; + internal const string net_httpstatuscode_MultipleChoices = "net_httpstatuscode_MultipleChoices"; + internal const string net_httpstatuscode_Ambiguous = "net_httpstatuscode_Ambiguous"; + internal const string net_httpstatuscode_MovedPermanently = "net_httpstatuscode_MovedPermanently"; + internal const string net_httpstatuscode_Moved = "net_httpstatuscode_Moved"; + internal const string net_httpstatuscode_Found = "net_httpstatuscode_Found"; + internal const string net_httpstatuscode_Redirect = "net_httpstatuscode_Redirect"; + internal const string net_httpstatuscode_SeeOther = "net_httpstatuscode_SeeOther"; + internal const string net_httpstatuscode_RedirectMethod = "net_httpstatuscode_RedirectMethod"; + internal const string net_httpstatuscode_NotModified = "net_httpstatuscode_NotModified"; + internal const string net_httpstatuscode_UseProxy = "net_httpstatuscode_UseProxy"; + internal const string net_httpstatuscode_TemporaryRedirect = "net_httpstatuscode_TemporaryRedirect"; + internal const string net_httpstatuscode_RedirectKeepVerb = "net_httpstatuscode_RedirectKeepVerb"; + internal const string net_httpstatuscode_BadRequest = "net_httpstatuscode_BadRequest"; + internal const string net_httpstatuscode_Unauthorized = "net_httpstatuscode_Unauthorized"; + internal const string net_httpstatuscode_PaymentRequired = "net_httpstatuscode_PaymentRequired"; + internal const string net_httpstatuscode_Forbidden = "net_httpstatuscode_Forbidden"; + internal const string net_httpstatuscode_NotFound = "net_httpstatuscode_NotFound"; + internal const string net_httpstatuscode_MethodNotAllowed = "net_httpstatuscode_MethodNotAllowed"; + internal const string net_httpstatuscode_NotAcceptable = "net_httpstatuscode_NotAcceptable"; + internal const string net_httpstatuscode_ProxyAuthenticationRequired = "net_httpstatuscode_ProxyAuthenticationRequired"; + internal const string net_httpstatuscode_RequestTimeout = "net_httpstatuscode_RequestTimeout"; + internal const string net_httpstatuscode_Conflict = "net_httpstatuscode_Conflict"; + internal const string net_httpstatuscode_Gone = "net_httpstatuscode_Gone"; + internal const string net_httpstatuscode_LengthRequired = "net_httpstatuscode_LengthRequired"; + internal const string net_httpstatuscode_InternalServerError = "net_httpstatuscode_InternalServerError"; + internal const string net_httpstatuscode_NotImplemented = "net_httpstatuscode_NotImplemented"; + internal const string net_httpstatuscode_BadGateway = "net_httpstatuscode_BadGateway"; + internal const string net_httpstatuscode_ServiceUnavailable = "net_httpstatuscode_ServiceUnavailable"; + internal const string net_httpstatuscode_GatewayTimeout = "net_httpstatuscode_GatewayTimeout"; + internal const string net_httpstatuscode_HttpVersionNotSupported = "net_httpstatuscode_HttpVersionNotSupported"; + internal const string net_uri_BadScheme = "net_uri_BadScheme"; + internal const string net_uri_BadFormat = "net_uri_BadFormat"; + internal const string net_uri_BadUserPassword = "net_uri_BadUserPassword"; + internal const string net_uri_BadHostName = "net_uri_BadHostName"; + internal const string net_uri_BadAuthority = "net_uri_BadAuthority"; + internal const string net_uri_BadAuthorityTerminator = "net_uri_BadAuthorityTerminator"; + internal const string net_uri_EmptyUri = "net_uri_EmptyUri"; + internal const string net_uri_BadString = "net_uri_BadString"; + internal const string net_uri_MustRootedPath = "net_uri_MustRootedPath"; + internal const string net_uri_BadPort = "net_uri_BadPort"; + internal const string net_uri_SizeLimit = "net_uri_SizeLimit"; + internal const string net_uri_SchemeLimit = "net_uri_SchemeLimit"; + internal const string net_uri_NotAbsolute = "net_uri_NotAbsolute"; + internal const string net_uri_PortOutOfRange = "net_uri_PortOutOfRange"; + internal const string net_uri_UserDrivenParsing = "net_uri_UserDrivenParsing"; + internal const string net_uri_AlreadyRegistered = "net_uri_AlreadyRegistered"; + internal const string net_uri_NeedFreshParser = "net_uri_NeedFreshParser"; + internal const string net_uri_CannotCreateRelative = "net_uri_CannotCreateRelative"; + internal const string net_uri_InvalidUriKind = "net_uri_InvalidUriKind"; + internal const string net_uri_BadUnicodeHostForIdn = "net_uri_BadUnicodeHostForIdn"; + internal const string net_uri_GenericAuthorityNotDnsSafe = "net_uri_GenericAuthorityNotDnsSafe"; + internal const string net_uri_NotJustSerialization = "net_uri_NotJustSerialization"; + internal const string net_emptystringset = "net_emptystringset"; + internal const string net_emptystringcall = "net_emptystringcall"; + internal const string net_headers_req = "net_headers_req"; + internal const string net_headers_rsp = "net_headers_rsp"; + internal const string net_headers_toolong = "net_headers_toolong"; + internal const string net_WebHeaderInvalidControlChars = "net_WebHeaderInvalidControlChars"; + internal const string net_WebHeaderInvalidCRLFChars = "net_WebHeaderInvalidCRLFChars"; + internal const string net_WebHeaderInvalidHeaderChars = "net_WebHeaderInvalidHeaderChars"; + internal const string net_WebHeaderInvalidNonAsciiChars = "net_WebHeaderInvalidNonAsciiChars"; + internal const string net_WebHeaderMissingColon = "net_WebHeaderMissingColon"; + internal const string net_headerrestrict = "net_headerrestrict"; + internal const string net_io_completionportwasbound = "net_io_completionportwasbound"; + internal const string net_io_writefailure = "net_io_writefailure"; + internal const string net_io_readfailure = "net_io_readfailure"; + internal const string net_io_connectionclosed = "net_io_connectionclosed"; + internal const string net_io_transportfailure = "net_io_transportfailure"; + internal const string net_io_internal_bind = "net_io_internal_bind"; + internal const string net_io_invalidasyncresult = "net_io_invalidasyncresult"; + internal const string net_io_invalidnestedcall = "net_io_invalidnestedcall"; + internal const string net_io_invalidendcall = "net_io_invalidendcall"; + internal const string net_io_must_be_rw_stream = "net_io_must_be_rw_stream"; + internal const string net_io_header_id = "net_io_header_id"; + internal const string net_io_out_range = "net_io_out_range"; + internal const string net_io_encrypt = "net_io_encrypt"; + internal const string net_io_decrypt = "net_io_decrypt"; + internal const string net_io_read = "net_io_read"; + internal const string net_io_write = "net_io_write"; + internal const string net_io_eof = "net_io_eof"; + internal const string net_io_async_result = "net_io_async_result"; + internal const string net_listener_mustcall = "net_listener_mustcall"; + internal const string net_listener_mustcompletecall = "net_listener_mustcompletecall"; + internal const string net_listener_callinprogress = "net_listener_callinprogress"; + internal const string net_listener_scheme = "net_listener_scheme"; + internal const string net_listener_host = "net_listener_host"; + internal const string net_listener_slash = "net_listener_slash"; + internal const string net_listener_repcall = "net_listener_repcall"; + internal const string net_listener_invalid_cbt_type = "net_listener_invalid_cbt_type"; + internal const string net_listener_no_spns = "net_listener_no_spns"; + internal const string net_listener_cannot_set_custom_cbt = "net_listener_cannot_set_custom_cbt"; + internal const string net_listener_cbt_not_supported = "net_listener_cbt_not_supported"; + internal const string net_listener_detach_error = "net_listener_detach_error"; + internal const string net_listener_close_urlgroup_error = "net_listener_close_urlgroup_error"; + internal const string net_tls_version = "net_tls_version"; + internal const string net_perm_target = "net_perm_target"; + internal const string net_perm_both_regex = "net_perm_both_regex"; + internal const string net_perm_none = "net_perm_none"; + internal const string net_perm_attrib_count = "net_perm_attrib_count"; + internal const string net_perm_invalid_val = "net_perm_invalid_val"; + internal const string net_perm_attrib_multi = "net_perm_attrib_multi"; + internal const string net_perm_epname = "net_perm_epname"; + internal const string net_perm_invalid_val_in_element = "net_perm_invalid_val_in_element"; + internal const string net_invalid_ip_addr = "net_invalid_ip_addr"; + internal const string dns_bad_ip_address = "dns_bad_ip_address"; + internal const string net_bad_mac_address = "net_bad_mac_address"; + internal const string net_ping = "net_ping"; + internal const string net_bad_ip_address_prefix = "net_bad_ip_address_prefix"; + internal const string net_max_ip_address_list_length_exceeded = "net_max_ip_address_list_length_exceeded"; + internal const string net_ipv4_not_installed = "net_ipv4_not_installed"; + internal const string net_ipv6_not_installed = "net_ipv6_not_installed"; + internal const string net_webclient = "net_webclient"; + internal const string net_webclient_ContentType = "net_webclient_ContentType"; + internal const string net_webclient_Multipart = "net_webclient_Multipart"; + internal const string net_webclient_no_concurrent_io_allowed = "net_webclient_no_concurrent_io_allowed"; + internal const string net_webclient_invalid_baseaddress = "net_webclient_invalid_baseaddress"; + internal const string net_container_add_cookie = "net_container_add_cookie"; + internal const string net_cookie_invalid = "net_cookie_invalid"; + internal const string net_cookie_size = "net_cookie_size"; + internal const string net_cookie_parse_header = "net_cookie_parse_header"; + internal const string net_cookie_attribute = "net_cookie_attribute"; + internal const string net_cookie_format = "net_cookie_format"; + internal const string net_cookie_exists = "net_cookie_exists"; + internal const string net_cookie_capacity_range = "net_cookie_capacity_range"; + internal const string net_set_token = "net_set_token"; + internal const string net_revert_token = "net_revert_token"; + internal const string net_ssl_io_async_context = "net_ssl_io_async_context"; + internal const string net_ssl_io_encrypt = "net_ssl_io_encrypt"; + internal const string net_ssl_io_decrypt = "net_ssl_io_decrypt"; + internal const string net_ssl_io_context_expired = "net_ssl_io_context_expired"; + internal const string net_ssl_io_handshake_start = "net_ssl_io_handshake_start"; + internal const string net_ssl_io_handshake = "net_ssl_io_handshake"; + internal const string net_ssl_io_frame = "net_ssl_io_frame"; + internal const string net_ssl_io_corrupted = "net_ssl_io_corrupted"; + internal const string net_ssl_io_cert_validation = "net_ssl_io_cert_validation"; + internal const string net_ssl_io_invalid_end_call = "net_ssl_io_invalid_end_call"; + internal const string net_ssl_io_invalid_begin_call = "net_ssl_io_invalid_begin_call"; + internal const string net_ssl_io_no_server_cert = "net_ssl_io_no_server_cert"; + internal const string net_auth_bad_client_creds = "net_auth_bad_client_creds"; + internal const string net_auth_bad_client_creds_or_target_mismatch = "net_auth_bad_client_creds_or_target_mismatch"; + internal const string net_auth_context_expectation = "net_auth_context_expectation"; + internal const string net_auth_context_expectation_remote = "net_auth_context_expectation_remote"; + internal const string net_auth_supported_impl_levels = "net_auth_supported_impl_levels"; + internal const string net_auth_no_anonymous_support = "net_auth_no_anonymous_support"; + internal const string net_auth_reauth = "net_auth_reauth"; + internal const string net_auth_noauth = "net_auth_noauth"; + internal const string net_auth_client_server = "net_auth_client_server"; + internal const string net_auth_noencryption = "net_auth_noencryption"; + internal const string net_auth_SSPI = "net_auth_SSPI"; + internal const string net_auth_failure = "net_auth_failure"; + internal const string net_auth_eof = "net_auth_eof"; + internal const string net_auth_alert = "net_auth_alert"; + internal const string net_auth_ignored_reauth = "net_auth_ignored_reauth"; + internal const string net_auth_empty_read = "net_auth_empty_read"; + internal const string net_auth_message_not_encrypted = "net_auth_message_not_encrypted"; + internal const string net_auth_must_specify_extended_protection_scheme = "net_auth_must_specify_extended_protection_scheme"; + internal const string net_frame_size = "net_frame_size"; + internal const string net_frame_read_io = "net_frame_read_io"; + internal const string net_frame_read_size = "net_frame_read_size"; + internal const string net_frame_max_size = "net_frame_max_size"; + internal const string net_jscript_load = "net_jscript_load"; + internal const string net_proxy_not_gmt = "net_proxy_not_gmt"; + internal const string net_proxy_invalid_dayofweek = "net_proxy_invalid_dayofweek"; + internal const string net_proxy_invalid_url_format = "net_proxy_invalid_url_format"; + internal const string net_param_not_string = "net_param_not_string"; + internal const string net_value_cannot_be_negative = "net_value_cannot_be_negative"; + internal const string net_invalid_offset = "net_invalid_offset"; + internal const string net_offset_plus_count = "net_offset_plus_count"; + internal const string net_cannot_be_false = "net_cannot_be_false"; + internal const string net_invalid_enum = "net_invalid_enum"; + internal const string net_listener_already = "net_listener_already"; + internal const string net_cache_shadowstream_not_writable = "net_cache_shadowstream_not_writable"; + internal const string net_cache_validator_fail = "net_cache_validator_fail"; + internal const string net_cache_access_denied = "net_cache_access_denied"; + internal const string net_cache_validator_result = "net_cache_validator_result"; + internal const string net_cache_retrieve_failure = "net_cache_retrieve_failure"; + internal const string net_cache_not_supported_body = "net_cache_not_supported_body"; + internal const string net_cache_not_supported_command = "net_cache_not_supported_command"; + internal const string net_cache_not_accept_response = "net_cache_not_accept_response"; + internal const string net_cache_method_failed = "net_cache_method_failed"; + internal const string net_cache_key_failed = "net_cache_key_failed"; + internal const string net_cache_no_stream = "net_cache_no_stream"; + internal const string net_cache_unsupported_partial_stream = "net_cache_unsupported_partial_stream"; + internal const string net_cache_not_configured = "net_cache_not_configured"; + internal const string net_cache_non_seekable_stream_not_supported = "net_cache_non_seekable_stream_not_supported"; + internal const string net_invalid_cast = "net_invalid_cast"; + internal const string net_collection_readonly = "net_collection_readonly"; + internal const string net_not_ipermission = "net_not_ipermission"; + internal const string net_no_classname = "net_no_classname"; + internal const string net_no_typename = "net_no_typename"; + internal const string net_array_too_small = "net_array_too_small"; + internal const string net_servicePointAddressNotSupportedInHostMode = "net_servicePointAddressNotSupportedInHostMode"; + internal const string net_Websockets_AlreadyOneOutstandingOperation = "net_Websockets_AlreadyOneOutstandingOperation"; + internal const string net_Websockets_WebSocketBaseFaulted = "net_Websockets_WebSocketBaseFaulted"; + internal const string net_WebSockets_NativeSendResponseHeaders = "net_WebSockets_NativeSendResponseHeaders"; + internal const string net_WebSockets_Generic = "net_WebSockets_Generic"; + internal const string net_WebSockets_NotAWebSocket_Generic = "net_WebSockets_NotAWebSocket_Generic"; + internal const string net_WebSockets_UnsupportedWebSocketVersion_Generic = "net_WebSockets_UnsupportedWebSocketVersion_Generic"; + internal const string net_WebSockets_HeaderError_Generic = "net_WebSockets_HeaderError_Generic"; + internal const string net_WebSockets_UnsupportedProtocol_Generic = "net_WebSockets_UnsupportedProtocol_Generic"; + internal const string net_WebSockets_UnsupportedPlatform = "net_WebSockets_UnsupportedPlatform"; + internal const string net_WebSockets_AcceptNotAWebSocket = "net_WebSockets_AcceptNotAWebSocket"; + internal const string net_WebSockets_AcceptUnsupportedWebSocketVersion = "net_WebSockets_AcceptUnsupportedWebSocketVersion"; + internal const string net_WebSockets_AcceptHeaderNotFound = "net_WebSockets_AcceptHeaderNotFound"; + internal const string net_WebSockets_AcceptUnsupportedProtocol = "net_WebSockets_AcceptUnsupportedProtocol"; + internal const string net_WebSockets_ClientAcceptingNoProtocols = "net_WebSockets_ClientAcceptingNoProtocols"; + internal const string net_WebSockets_ClientSecWebSocketProtocolsBlank = "net_WebSockets_ClientSecWebSocketProtocolsBlank"; + internal const string net_WebSockets_ArgumentOutOfRange_TooSmall = "net_WebSockets_ArgumentOutOfRange_TooSmall"; + internal const string net_WebSockets_ArgumentOutOfRange_InternalBuffer = "net_WebSockets_ArgumentOutOfRange_InternalBuffer"; + internal const string net_WebSockets_ArgumentOutOfRange_TooBig = "net_WebSockets_ArgumentOutOfRange_TooBig"; + internal const string net_WebSockets_InvalidState_Generic = "net_WebSockets_InvalidState_Generic"; + internal const string net_WebSockets_InvalidState_ClosedOrAborted = "net_WebSockets_InvalidState_ClosedOrAborted"; + internal const string net_WebSockets_InvalidState = "net_WebSockets_InvalidState"; + internal const string net_WebSockets_ReceiveAsyncDisallowedAfterCloseAsync = "net_WebSockets_ReceiveAsyncDisallowedAfterCloseAsync"; + internal const string net_WebSockets_InvalidMessageType = "net_WebSockets_InvalidMessageType"; + internal const string net_WebSockets_InvalidBufferType = "net_WebSockets_InvalidBufferType"; + internal const string net_WebSockets_InvalidMessageType_Generic = "net_WebSockets_InvalidMessageType_Generic"; + internal const string net_WebSockets_Argument_InvalidMessageType = "net_WebSockets_Argument_InvalidMessageType"; + internal const string net_WebSockets_ConnectionClosedPrematurely_Generic = "net_WebSockets_ConnectionClosedPrematurely_Generic"; + internal const string net_WebSockets_InvalidCharInProtocolString = "net_WebSockets_InvalidCharInProtocolString"; + internal const string net_WebSockets_InvalidEmptySubProtocol = "net_WebSockets_InvalidEmptySubProtocol"; + internal const string net_WebSockets_ReasonNotNull = "net_WebSockets_ReasonNotNull"; + internal const string net_WebSockets_InvalidCloseStatusCode = "net_WebSockets_InvalidCloseStatusCode"; + internal const string net_WebSockets_InvalidCloseStatusDescription = "net_WebSockets_InvalidCloseStatusDescription"; + internal const string net_WebSockets_Scheme = "net_WebSockets_Scheme"; + internal const string net_WebSockets_AlreadyStarted = "net_WebSockets_AlreadyStarted"; + internal const string net_WebSockets_Connect101Expected = "net_WebSockets_Connect101Expected"; + internal const string net_WebSockets_InvalidResponseHeader = "net_WebSockets_InvalidResponseHeader"; + internal const string net_WebSockets_NotConnected = "net_WebSockets_NotConnected"; + internal const string net_WebSockets_InvalidRegistration = "net_WebSockets_InvalidRegistration"; + internal const string net_WebSockets_NoDuplicateProtocol = "net_WebSockets_NoDuplicateProtocol"; + internal const string net_log_exception = "net_log_exception"; + internal const string net_log_listener_delegate_exception = "net_log_listener_delegate_exception"; + internal const string net_log_listener_unsupported_authentication_scheme = "net_log_listener_unsupported_authentication_scheme"; + internal const string net_log_listener_unmatched_authentication_scheme = "net_log_listener_unmatched_authentication_scheme"; + internal const string net_log_listener_create_valid_identity_failed = "net_log_listener_create_valid_identity_failed"; + internal const string net_log_listener_httpsys_registry_null = "net_log_listener_httpsys_registry_null"; + internal const string net_log_listener_httpsys_registry_error = "net_log_listener_httpsys_registry_error"; + internal const string net_log_listener_cant_convert_raw_path = "net_log_listener_cant_convert_raw_path"; + internal const string net_log_listener_cant_convert_percent_value = "net_log_listener_cant_convert_percent_value"; + internal const string net_log_listener_cant_convert_bytes = "net_log_listener_cant_convert_bytes"; + internal const string net_log_listener_cant_convert_to_utf8 = "net_log_listener_cant_convert_to_utf8"; + internal const string net_log_listener_cant_create_uri = "net_log_listener_cant_create_uri"; + internal const string net_log_listener_no_cbt_disabled = "net_log_listener_no_cbt_disabled"; + internal const string net_log_listener_no_cbt_http = "net_log_listener_no_cbt_http"; + internal const string net_log_listener_no_cbt_platform = "net_log_listener_no_cbt_platform"; + internal const string net_log_listener_no_cbt_trustedproxy = "net_log_listener_no_cbt_trustedproxy"; + internal const string net_log_listener_cbt = "net_log_listener_cbt"; + internal const string net_log_listener_no_spn_kerberos = "net_log_listener_no_spn_kerberos"; + internal const string net_log_listener_no_spn_disabled = "net_log_listener_no_spn_disabled"; + internal const string net_log_listener_no_spn_cbt = "net_log_listener_no_spn_cbt"; + internal const string net_log_listener_no_spn_platform = "net_log_listener_no_spn_platform"; + internal const string net_log_listener_no_spn_whensupported = "net_log_listener_no_spn_whensupported"; + internal const string net_log_listener_no_spn_loopback = "net_log_listener_no_spn_loopback"; + internal const string net_log_listener_spn = "net_log_listener_spn"; + internal const string net_log_listener_spn_passed = "net_log_listener_spn_passed"; + internal const string net_log_listener_spn_failed = "net_log_listener_spn_failed"; + internal const string net_log_listener_spn_failed_always = "net_log_listener_spn_failed_always"; + internal const string net_log_listener_spn_failed_empty = "net_log_listener_spn_failed_empty"; + internal const string net_log_listener_spn_failed_dump = "net_log_listener_spn_failed_dump"; + internal const string net_log_listener_spn_add = "net_log_listener_spn_add"; + internal const string net_log_listener_spn_not_add = "net_log_listener_spn_not_add"; + internal const string net_log_listener_spn_remove = "net_log_listener_spn_remove"; + internal const string net_log_listener_spn_not_remove = "net_log_listener_spn_not_remove"; + internal const string net_log_sspi_enumerating_security_packages = "net_log_sspi_enumerating_security_packages"; + internal const string net_log_sspi_security_package_not_found = "net_log_sspi_security_package_not_found"; + internal const string net_log_sspi_security_context_input_buffer = "net_log_sspi_security_context_input_buffer"; + internal const string net_log_sspi_security_context_input_buffers = "net_log_sspi_security_context_input_buffers"; + internal const string net_log_sspi_selected_cipher_suite = "net_log_sspi_selected_cipher_suite"; + internal const string net_log_remote_certificate = "net_log_remote_certificate"; + internal const string net_log_locating_private_key_for_certificate = "net_log_locating_private_key_for_certificate"; + internal const string net_log_cert_is_of_type_2 = "net_log_cert_is_of_type_2"; + internal const string net_log_found_cert_in_store = "net_log_found_cert_in_store"; + internal const string net_log_did_not_find_cert_in_store = "net_log_did_not_find_cert_in_store"; + internal const string net_log_open_store_failed = "net_log_open_store_failed"; + internal const string net_log_got_certificate_from_delegate = "net_log_got_certificate_from_delegate"; + internal const string net_log_no_delegate_and_have_no_client_cert = "net_log_no_delegate_and_have_no_client_cert"; + internal const string net_log_no_delegate_but_have_client_cert = "net_log_no_delegate_but_have_client_cert"; + internal const string net_log_attempting_restart_using_cert = "net_log_attempting_restart_using_cert"; + internal const string net_log_no_issuers_try_all_certs = "net_log_no_issuers_try_all_certs"; + internal const string net_log_server_issuers_look_for_matching_certs = "net_log_server_issuers_look_for_matching_certs"; + internal const string net_log_selected_cert = "net_log_selected_cert"; + internal const string net_log_n_certs_after_filtering = "net_log_n_certs_after_filtering"; + internal const string net_log_finding_matching_certs = "net_log_finding_matching_certs"; + internal const string net_log_using_cached_credential = "net_log_using_cached_credential"; + internal const string net_log_remote_cert_user_declared_valid = "net_log_remote_cert_user_declared_valid"; + internal const string net_log_remote_cert_user_declared_invalid = "net_log_remote_cert_user_declared_invalid"; + internal const string net_log_remote_cert_has_no_errors = "net_log_remote_cert_has_no_errors"; + internal const string net_log_remote_cert_has_errors = "net_log_remote_cert_has_errors"; + internal const string net_log_remote_cert_not_available = "net_log_remote_cert_not_available"; + internal const string net_log_remote_cert_name_mismatch = "net_log_remote_cert_name_mismatch"; + internal const string net_log_proxy_autodetect_script_location_parse_error = "net_log_proxy_autodetect_script_location_parse_error"; + internal const string net_log_proxy_autodetect_failed = "net_log_proxy_autodetect_failed"; + internal const string net_log_proxy_script_execution_error = "net_log_proxy_script_execution_error"; + internal const string net_log_proxy_script_download_compile_error = "net_log_proxy_script_download_compile_error"; + internal const string net_log_proxy_system_setting_update = "net_log_proxy_system_setting_update"; + internal const string net_log_proxy_update_due_to_ip_config_change = "net_log_proxy_update_due_to_ip_config_change"; + internal const string net_log_proxy_called_with_null_parameter = "net_log_proxy_called_with_null_parameter"; + internal const string net_log_proxy_called_with_invalid_parameter = "net_log_proxy_called_with_invalid_parameter"; + internal const string net_log_proxy_ras_supported = "net_log_proxy_ras_supported"; + internal const string net_log_proxy_ras_notsupported_exception = "net_log_proxy_ras_notsupported_exception"; + internal const string net_log_proxy_winhttp_cant_open_session = "net_log_proxy_winhttp_cant_open_session"; + internal const string net_log_proxy_winhttp_getproxy_failed = "net_log_proxy_winhttp_getproxy_failed"; + internal const string net_log_proxy_winhttp_timeout_error = "net_log_proxy_winhttp_timeout_error"; + internal const string net_log_digest_hash_algorithm_not_supported = "net_log_digest_hash_algorithm_not_supported"; + internal const string net_log_digest_qop_not_supported = "net_log_digest_qop_not_supported"; + internal const string net_log_digest_requires_nonce = "net_log_digest_requires_nonce"; + internal const string net_log_auth_invalid_challenge = "net_log_auth_invalid_challenge"; + internal const string net_log_unknown = "net_log_unknown"; + internal const string net_log_operation_returned_something = "net_log_operation_returned_something"; + internal const string net_log_operation_failed_with_error = "net_log_operation_failed_with_error"; + internal const string net_log_buffered_n_bytes = "net_log_buffered_n_bytes"; + internal const string net_log_method_equal = "net_log_method_equal"; + internal const string net_log_releasing_connection = "net_log_releasing_connection"; + internal const string net_log_unexpected_exception = "net_log_unexpected_exception"; + internal const string net_log_server_response_error_code = "net_log_server_response_error_code"; + internal const string net_log_resubmitting_request = "net_log_resubmitting_request"; + internal const string net_log_retrieving_localhost_exception = "net_log_retrieving_localhost_exception"; + internal const string net_log_resolved_servicepoint_may_not_be_remote_server = "net_log_resolved_servicepoint_may_not_be_remote_server"; + internal const string net_log_closed_idle = "net_log_closed_idle"; + internal const string net_log_received_status_line = "net_log_received_status_line"; + internal const string net_log_sending_headers = "net_log_sending_headers"; + internal const string net_log_received_headers = "net_log_received_headers"; + internal const string net_log_shell_expression_pattern_format_warning = "net_log_shell_expression_pattern_format_warning"; + internal const string net_log_exception_in_callback = "net_log_exception_in_callback"; + internal const string net_log_sending_command = "net_log_sending_command"; + internal const string net_log_received_response = "net_log_received_response"; + internal const string net_log_socket_connected = "net_log_socket_connected"; + internal const string net_log_socket_accepted = "net_log_socket_accepted"; + internal const string net_log_socket_not_logged_file = "net_log_socket_not_logged_file"; + internal const string net_log_socket_connect_dnsendpoint = "net_log_socket_connect_dnsendpoint"; + internal const string SSPIInvalidHandleType = "SSPIInvalidHandleType"; + + private static SR loader = null; + private ResourceManager resources; + + internal SR() + { + resources = new System.Resources.ResourceManager("System", this.GetType().Assembly); + } + + private static SR GetLoader() + { + if (loader == null) + { + SR sr = new SR(); + Interlocked.CompareExchange(ref loader, sr, null); + } + return loader; + } + + private static CultureInfo Culture + { + get { return null/*use ResourceManager default, CultureInfo.CurrentUICulture*/; } + } + + public static ResourceManager Resources + { + get + { + return GetLoader().resources; + } + } + + public static string GetString(string name, params object[] args) + { + SR sys = GetLoader(); + if (sys == null) + { + return null; + } + string res = sys.resources.GetString(name, SR.Culture); + + if (args != null && args.Length > 0) + { + for (int i = 0; i < args.Length; i++) + { + String value = args[i] as String; + if (value != null && value.Length > 1024) + { + args[i] = value.Substring(0, 1024 - 3) + "..."; + } + } + return String.Format(CultureInfo.CurrentCulture, res, args); + } + else + { + return res; + } + } + + public static string GetString(string name) + { + SR sys = GetLoader(); + if (sys == null) + { + return null; + } + return sys.resources.GetString(name, SR.Culture); + } + + public static string GetString(string name, out bool usedFallback) + { + // always false for this version of gensr + usedFallback = false; + return GetString(name); + } + + public static object GetObject(string name) + { + SR sys = GetLoader(); + if (sys == null) + { + return null; + } + return sys.resources.GetObject(name, SR.Culture); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Legacy/ValidationHelper.cs b/src/Microsoft.AspNet.Security.Windows/Legacy/ValidationHelper.cs new file mode 100644 index 0000000000..ff7a09000b --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Legacy/ValidationHelper.cs @@ -0,0 +1,74 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + using System; + using System.Diagnostics; + using System.Diagnostics.CodeAnalysis; + using System.Globalization; + using System.Net.Security; + using System.Runtime.InteropServices; + using System.Runtime.Versioning; + using System.Security.Authentication.ExtendedProtection; + using System.Security.Cryptography.X509Certificates; + using System.Security.Permissions; + + internal static class ValidationHelper + { + public static string ExceptionMessage(Exception exception) + { + if (exception == null) + { + return string.Empty; + } + if (exception.InnerException == null) + { + return exception.Message; + } + return exception.Message + " (" + ExceptionMessage(exception.InnerException) + ")"; + } + + public static string ToString(object objectValue) + { + if (objectValue == null) + { + return "(null)"; + } + else if (objectValue is string && ((string)objectValue).Length == 0) + { + return "(string.empty)"; + } + else if (objectValue is Exception) + { + return ExceptionMessage(objectValue as Exception); + } + else if (objectValue is IntPtr) + { + return "0x" + ((IntPtr)objectValue).ToString("x"); + } + else + { + return objectValue.ToString(); + } + } + public static string HashString(object objectValue) + { + if (objectValue == null) + { + return "(null)"; + } + else if (objectValue is string && ((string)objectValue).Length == 0) + { + return "(string.empty)"; + } + else + { + return objectValue.GetHashCode().ToString(NumberFormatInfo.InvariantInfo); + } + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NTAuthentication.cs b/src/Microsoft.AspNet.Security.Windows/NTAuthentication.cs new file mode 100644 index 0000000000..25cae8a500 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NTAuthentication.cs @@ -0,0 +1,721 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Globalization; +using System.Net; +using System.Security.Authentication.ExtendedProtection; +using System.Security.Permissions; +using System.Security.Principal; +using System.Text; +using System.Threading; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class NTAuthentication + { + private static readonly ContextCallback Callback = new ContextCallback(InitializeCallback); + private static ISSPIInterface SSPIAuth = new SSPIAuthType(); + + private bool _isServer; + + private SafeFreeCredentials _credentialsHandle; + private SafeDeleteContext _securityContext; + private string _spn; + private string _clientSpecifiedSpn; + + private int _tokenSize; + private ContextFlags _requestedContextFlags; + private ContextFlags _contextFlags; + private string _uniqueUserId; + + private bool _isCompleted; + private string _protocolName; + private SecSizes _sizes; + private string _lastProtocolName; + private string _package; + + private ChannelBinding _channelBinding; + + // This overload does not attmept to impersonate because the caller either did it already or the original thread context is still preserved + + internal NTAuthentication(bool isServer, string package, NetworkCredential credential, string spn, ContextFlags requestedContextFlags, ChannelBinding channelBinding) + { + Initialize(isServer, package, credential, spn, requestedContextFlags, channelBinding); + } + + // This overload always uses the default credentials for the process. + + [SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.ControlPrincipal)] + internal NTAuthentication(bool isServer, string package, string spn, ContextFlags requestedContextFlags, ChannelBinding channelBinding) + { + try + { + using (WindowsIdentity.Impersonate(IntPtr.Zero)) + { + Initialize(isServer, package, CredentialCache.DefaultNetworkCredentials, spn, requestedContextFlags, channelBinding); + } + } + catch + { + // Avoid exception filter attacks. + throw; + } + } + + // The semantic of this propoerty is "Don't call me again". + // It can be completed either with success or error + // The latest case is signalled by IsValidContext==false + internal bool IsCompleted + { + get + { + return _isCompleted; + } + } + + internal bool IsValidContext + { + get + { + return !(_securityContext == null || _securityContext.IsInvalid); + } + } + + internal string AssociatedName + { + get + { + if (!(IsValidContext && IsCompleted)) + { + throw new Win32Exception((int)SecurityStatus.InvalidHandle); + } + + string name = SSPIWrapper.QueryContextAttributes(SSPIAuth, _securityContext, ContextAttribute.Names) as string; + GlobalLog.Print("NTAuthentication: The context is associated with [" + name + "]"); + return name; + } + } + + internal bool IsConfidentialityFlag + { + get + { + return (_contextFlags & ContextFlags.Confidentiality) != 0; + } + } + + internal bool IsIntegrityFlag + { + get + { + return (_contextFlags & (_isServer ? ContextFlags.AcceptIntegrity : ContextFlags.InitIntegrity)) != 0; + } + } + + internal bool IsMutualAuthFlag + { + get + { + return (_contextFlags & ContextFlags.MutualAuth) != 0; + } + } + + internal bool IsDelegationFlag + { + get + { + return (_contextFlags & ContextFlags.Delegate) != 0; + } + } + + internal bool IsIdentifyFlag + { + get + { + return (_contextFlags & (_isServer ? ContextFlags.AcceptIdentify : ContextFlags.InitIdentify)) != 0; + } + } + + internal string Spn + { + get + { + return _spn; + } + } + + internal string ClientSpecifiedSpn + { + get + { + if (_clientSpecifiedSpn == null) + { + _clientSpecifiedSpn = GetClientSpecifiedSpn(); + } + return _clientSpecifiedSpn; + } + } + + // True indicates this instance is for Server and will use AcceptSecurityContext SSPI API + + internal bool IsServer + { + get + { + return _isServer; + } + } + + internal bool IsKerberos + { + get + { + if (_lastProtocolName == null) + { + _lastProtocolName = ProtocolName; + } + + return (object)_lastProtocolName == (object)NegotiationInfoClass.Kerberos; + } + } + internal bool IsNTLM + { + get + { + if (_lastProtocolName == null) + { + _lastProtocolName = ProtocolName; + } + + return (object)_lastProtocolName == (object)NegotiationInfoClass.NTLM; + } + } + + internal string Package + { + get + { + return _package; + } + } + + internal string ProtocolName + { + get + { + // NB: May return string.Empty if the auth is not done yet or failed + if (_protocolName == null) + { + NegotiationInfoClass negotiationInfo = null; + + if (IsValidContext) + { + negotiationInfo = SSPIWrapper.QueryContextAttributes(SSPIAuth, _securityContext, ContextAttribute.NegotiationInfo) as NegotiationInfoClass; + if (IsCompleted) + { + if (negotiationInfo != null) + { + // cache it only when it's completed + _protocolName = negotiationInfo.AuthenticationPackage; + } + } + } + return negotiationInfo == null ? string.Empty : negotiationInfo.AuthenticationPackage; + } + return _protocolName; + } + } + + internal SecSizes Sizes + { + get + { + GlobalLog.Assert(IsCompleted && IsValidContext, "NTAuthentication#{0}::MaxDataSize|The context is not completed or invalid.", ValidationHelper.HashString(this)); + if (_sizes == null) + { + _sizes = SSPIWrapper.QueryContextAttributes( + SSPIAuth, + _securityContext, + ContextAttribute.Sizes) as SecSizes; + } + return _sizes; + } + } + + internal ChannelBinding ChannelBinding + { + get { return _channelBinding; } + } + + private static void InitializeCallback(object state) + { + InitializeCallbackContext context = (InitializeCallbackContext)state; + context.ThisPtr.Initialize(context.IsServer, context.Package, context.Credential, context.Spn, context.RequestedContextFlags, context.ChannelBinding); + } + + private void Initialize(bool isServer, string package, NetworkCredential credential, string spn, ContextFlags requestedContextFlags, ChannelBinding channelBinding) + { + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::.ctor() package:" + ValidationHelper.ToString(package) + " spn:" + ValidationHelper.ToString(spn) + " flags :" + requestedContextFlags.ToString()); + _tokenSize = SSPIWrapper.GetVerifyPackageInfo(SSPIAuth, package, true).MaxToken; + _isServer = isServer; + _spn = spn; + _securityContext = null; + _requestedContextFlags = requestedContextFlags; + _package = package; + _channelBinding = channelBinding; + + GlobalLog.Print("Peer SPN-> '" + _spn + "'"); + + // check if we're using DefaultCredentials + + if (credential == CredentialCache.DefaultNetworkCredentials) + { + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::.ctor(): using DefaultCredentials"); + _credentialsHandle = SSPIWrapper.AcquireDefaultCredential( + SSPIAuth, + package, + (_isServer ? CredentialUse.Inbound : CredentialUse.Outbound)); + _uniqueUserId = "/S"; // save off for unique connection marking ONLY used by HTTP client + } + else if (ComNetOS.IsWin7orLater) + { + unsafe + { + SafeSspiAuthDataHandle authData = null; + try + { + SecurityStatus result = UnsafeNclNativeMethods.SspiHelper.SspiEncodeStringsAsAuthIdentity( + credential.UserName/*InternalGetUserName()*/, credential.Domain/*InternalGetDomain()*/, + credential.Password/*InternalGetPassword()*/, out authData); + + if (result != SecurityStatus.OK) + { + if (Logging.On) + { + Logging.PrintError(Logging.Web, SR.GetString(SR.net_log_operation_failed_with_error, "SspiEncodeStringsAsAuthIdentity()", String.Format(CultureInfo.CurrentCulture, "0x{0:X}", (int)result))); + } + throw new Win32Exception((int)result); + } + + _credentialsHandle = SSPIWrapper.AcquireCredentialsHandle(SSPIAuth, + package, (_isServer ? CredentialUse.Inbound : CredentialUse.Outbound), ref authData); + } + finally + { + if (authData != null) + { + authData.Dispose(); + } + } + } + } + else + { + // we're not using DefaultCredentials, we need a + // AuthIdentity struct to contain credentials + // SECREVIEW: + // we'll save username/domain in temp strings, to avoid decrypting multiple times. + // password is only used once + + string username = credential.UserName; // InternalGetUserName(); + + string domain = credential.Domain; // InternalGetDomain(); + // ATTN: + // NetworkCredential class does not differentiate between null and "" but SSPI packages treat these cases differently + // For NTLM we want to keep "" for Wdigest.Dll we should use null. + AuthIdentity authIdentity = new AuthIdentity(username, credential.Password/*InternalGetPassword()*/, (object)package == (object)NegotiationInfoClass.WDigest && (domain == null || domain.Length == 0) ? null : domain); + + _uniqueUserId = domain + "/" + username + "/U"; // save off for unique connection marking ONLY used by HTTP client + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::.ctor(): using authIdentity:" + authIdentity.ToString()); + + _credentialsHandle = SSPIWrapper.AcquireCredentialsHandle( + SSPIAuth, + package, + (_isServer ? CredentialUse.Inbound : CredentialUse.Outbound), + ref authIdentity); + } + } + + // This will return an client token when conducted authentication on server side' + // This token can be used ofr impersanation + // We use it to create a WindowsIdentity and hand it out to the server app. + internal SafeCloseHandle GetContextToken(out SecurityStatus status) + { + GlobalLog.Assert(IsCompleted && IsValidContext, "NTAuthentication#{0}::GetContextToken|Should be called only when completed with success, currently is not!", ValidationHelper.HashString(this)); + GlobalLog.Assert(IsServer, "NTAuthentication#{0}::GetContextToken|The method must not be called by the client side!", ValidationHelper.HashString(this)); + + if (!IsValidContext) + { + throw new Win32Exception((int)SecurityStatus.InvalidHandle); + } + + SafeCloseHandle token = null; + status = (SecurityStatus)SSPIWrapper.QuerySecurityContextToken( + SSPIAuth, + _securityContext, + out token); + + return token; + } + + internal SafeCloseHandle GetContextToken() + { + SecurityStatus status; + SafeCloseHandle token = GetContextToken(out status); + if (status != SecurityStatus.OK) + { + throw new Win32Exception((int)status); + } + return token; + } + + internal void CloseContext() + { + if (_securityContext != null && !_securityContext.IsClosed) + { + _securityContext.Dispose(); + } + } + + // NTAuth::GetOutgoingBlob() + // Created: 12-01-1999: L.M. + // Description: + // Accepts an incoming binary security blob and returns + // an outgoing binary security blob + internal byte[] GetOutgoingBlob(byte[] incomingBlob, bool throwOnError, out SecurityStatus statusCode) + { + GlobalLog.Enter("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob", ((incomingBlob == null) ? "0" : incomingBlob.Length.ToString(NumberFormatInfo.InvariantInfo)) + " bytes"); + + List list = new List(2); + + if (incomingBlob != null) + { + list.Add(new SecurityBuffer(incomingBlob, BufferType.Token)); + } + if (_channelBinding != null) + { + list.Add(new SecurityBuffer(_channelBinding)); + } + + SecurityBuffer[] inSecurityBufferArray = null; + if (list.Count > 0) + { + inSecurityBufferArray = list.ToArray(); + } + + SecurityBuffer outSecurityBuffer = new SecurityBuffer(_tokenSize, BufferType.Token); + + bool firstTime = _securityContext == null; + try + { + if (!_isServer) + { + // client session + statusCode = (SecurityStatus)SSPIWrapper.InitializeSecurityContext( + SSPIAuth, + _credentialsHandle, + ref _securityContext, + _spn, + _requestedContextFlags, + Endianness.Network, + inSecurityBufferArray, + outSecurityBuffer, + ref _contextFlags); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob() SSPIWrapper.InitializeSecurityContext() returns statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + + if (statusCode == SecurityStatus.CompleteNeeded) + { + SecurityBuffer[] inSecurityBuffers = new SecurityBuffer[1]; + inSecurityBuffers[0] = outSecurityBuffer; + + statusCode = (SecurityStatus)SSPIWrapper.CompleteAuthToken( + SSPIAuth, + ref _securityContext, + inSecurityBuffers); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() SSPIWrapper.CompleteAuthToken() returns statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + outSecurityBuffer.token = null; + } + } + else + { + // server session + statusCode = (SecurityStatus)SSPIWrapper.AcceptSecurityContext( + SSPIAuth, + _credentialsHandle, + ref _securityContext, + _requestedContextFlags, + Endianness.Network, + inSecurityBufferArray, + outSecurityBuffer, + ref _contextFlags); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob() SSPIWrapper.AcceptSecurityContext() returns statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + } + } + finally + { + // Assuming the ISC or ASC has referenced the credential on the first successful call, + // we want to decrement the effective ref count by "disposing" it. + // The real dispose will happen when the security context is closed. + // Note if the first call was not successfull the handle is physically destroyed here + + if (firstTime && _credentialsHandle != null) + { + _credentialsHandle.Dispose(); + } + } + + if (((int)statusCode & unchecked((int)0x80000000)) != 0) + { + CloseContext(); + _isCompleted = true; + if (throwOnError) + { + Win32Exception exception = new Win32Exception((int)statusCode); + GlobalLog.Leave("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob", "Win32Exception:" + exception); + throw exception; + } + GlobalLog.Leave("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob", "null statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + return null; + } + else if (firstTime && _credentialsHandle != null) + { + // cache until it is pushed out by newly incoming handles + SSPIHandleCache.CacheCredential(_credentialsHandle); + } + + // the return value from SSPI will tell us correctly if the + // handshake is over or not: http://msdn.microsoft.com/library/psdk/secspi/sspiref_67p0.htm + // we also have to consider the case in which SSPI formed a new context, in this case we're done as well. + if (statusCode == SecurityStatus.OK) + { + // we're sucessfully done + GlobalLog.Assert(statusCode == SecurityStatus.OK, "NTAuthentication#{0}::GetOutgoingBlob()|statusCode:[0x{1:x8}] ({2}) m_SecurityContext#{3}::Handle:[{4}] [STATUS != OK]", ValidationHelper.HashString(this), (int)statusCode, statusCode, ValidationHelper.HashString(_securityContext), ValidationHelper.ToString(_securityContext)); + _isCompleted = true; + } + else + { + // we need to continue + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob() need continue statusCode:[0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + "] (" + statusCode.ToString() + ") m_SecurityContext#" + ValidationHelper.HashString(_securityContext) + "::Handle:" + ValidationHelper.ToString(_securityContext) + "]"); + } + // GlobalLog.Print("out token = " + outSecurityBuffer.ToString()); + // GlobalLog.Dump(outSecurityBuffer.token); + GlobalLog.Leave("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingBlob", "IsCompleted:" + IsCompleted.ToString()); + return outSecurityBuffer.token; + } + + // for Server side (IIS 6.0) see: \\netindex\Sources\inetsrv\iis\iisrearc\iisplus\ulw3\digestprovider.cxx + // for Client side (HTTP.SYS) see: \\netindex\Sources\net\http\sys\ucauth.c + internal string GetOutgoingDigestBlob(string incomingBlob, string requestMethod, string requestedUri, string realm, bool isClientPreAuth, bool throwOnError, out SecurityStatus statusCode) + { + GlobalLog.Enter("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob", incomingBlob); + + // second time call with 3 incoming buffers to select HTTP client. + // we should get back a SecurityStatus.OK and a non null outgoingBlob. + SecurityBuffer[] inSecurityBuffers = null; + SecurityBuffer outSecurityBuffer = new SecurityBuffer(_tokenSize, isClientPreAuth ? BufferType.Parameters : BufferType.Token); + + bool firstTime = _securityContext == null; + try + { + if (!_isServer) + { + // client session + + if (!isClientPreAuth) + { + if (incomingBlob != null) + { + List list = new List(5); + + list.Add(new SecurityBuffer(HeaderEncoding.GetBytes(incomingBlob), BufferType.Token)); + list.Add(new SecurityBuffer(HeaderEncoding.GetBytes(requestMethod), BufferType.Parameters)); + list.Add(new SecurityBuffer(null, BufferType.Parameters)); + list.Add(new SecurityBuffer(Encoding.Unicode.GetBytes(_spn), BufferType.TargetHost)); + + if (_channelBinding != null) + { + list.Add(new SecurityBuffer(_channelBinding)); + } + + inSecurityBuffers = list.ToArray(); + } + + statusCode = (SecurityStatus)SSPIWrapper.InitializeSecurityContext( + SSPIAuth, + _credentialsHandle, + ref _securityContext, + requestedUri, // this must match the Uri in the HTTP status line for the current request + _requestedContextFlags, + Endianness.Network, + inSecurityBuffers, + outSecurityBuffer, + ref _contextFlags); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() SSPIWrapper.InitializeSecurityContext() returns statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + } + else + { +#if WDIGEST_PREAUTH + inSecurityBuffers = new SecurityBuffer[] { + new SecurityBuffer(null, BufferType.Token), + new SecurityBuffer(WebHeaderCollection.HeaderEncoding.GetBytes(requestMethod), BufferType.Parameters), + new SecurityBuffer(WebHeaderCollection.HeaderEncoding.GetBytes(requestedUri), BufferType.Parameters), + new SecurityBuffer(null, BufferType.Parameters), + outSecurityBuffer, + }; + + statusCode = (SecurityStatus) SSPIWrapper.MakeSignature(GlobalSSPI.SSPIAuth, m_SecurityContext, inSecurityBuffers, 0); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() SSPIWrapper.MakeSignature() returns statusCode:0x" + ((int) statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); +#else + statusCode = SecurityStatus.OK; + GlobalLog.Assert("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob()", "Invalid code path."); +#endif + } + } + else + { + // server session + List list = new List(6); + + list.Add(incomingBlob == null ? new SecurityBuffer(0, BufferType.Token) : new SecurityBuffer(HeaderEncoding.GetBytes(incomingBlob), BufferType.Token)); + list.Add(requestMethod == null ? new SecurityBuffer(0, BufferType.Parameters) : new SecurityBuffer(HeaderEncoding.GetBytes(requestMethod), BufferType.Parameters)); + list.Add(requestedUri == null ? new SecurityBuffer(0, BufferType.Parameters) : new SecurityBuffer(HeaderEncoding.GetBytes(requestedUri), BufferType.Parameters)); + list.Add(new SecurityBuffer(0, BufferType.Parameters)); + list.Add(realm == null ? new SecurityBuffer(0, BufferType.Parameters) : new SecurityBuffer(Encoding.Unicode.GetBytes(realm), BufferType.Parameters)); + + if (_channelBinding != null) + { + list.Add(new SecurityBuffer(_channelBinding)); + } + + inSecurityBuffers = list.ToArray(); + + statusCode = (SecurityStatus)SSPIWrapper.AcceptSecurityContext( + SSPIAuth, + _credentialsHandle, + ref _securityContext, + _requestedContextFlags, + Endianness.Network, + inSecurityBuffers, + outSecurityBuffer, + ref _contextFlags); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() SSPIWrapper.AcceptSecurityContext() returns statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + + if (statusCode == SecurityStatus.CompleteNeeded) + { + inSecurityBuffers[4] = outSecurityBuffer; + + statusCode = (SecurityStatus)SSPIWrapper.CompleteAuthToken( + SSPIAuth, + ref _securityContext, + inSecurityBuffers); + + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() SSPIWrapper.CompleteAuthToken() returns statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + + outSecurityBuffer.token = null; + } + } + } + finally + { + // Assuming the ISC or ASC has referenced the credential on the first successful call, + // we want to decrement the effective ref count by "disposing" it. + // The real dispose will happen when the security context is closed. + // Note if the first call was not successfull the handle is physically destroyed here + + if (firstTime && _credentialsHandle != null) + { + _credentialsHandle.Dispose(); + } + } + + if (((int)statusCode & unchecked((int)0x80000000)) != 0) + { + CloseContext(); + if (throwOnError) + { + Win32Exception exception = new Win32Exception((int)statusCode); + GlobalLog.Leave("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob", "Win32Exception:" + exception); + throw exception; + } + GlobalLog.Leave("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob", "null statusCode:0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + " (" + statusCode.ToString() + ")"); + return null; + } + else if (firstTime && _credentialsHandle != null) + { + // cache until it is pushed out by newly incoming handles + SSPIHandleCache.CacheCredential(_credentialsHandle); + } + + // the return value from SSPI will tell us correctly if the + // handshake is over or not: http://msdn.microsoft.com/library/psdk/secspi/sspiref_67p0.htm + if (statusCode == SecurityStatus.OK) + { + // we're done, cleanup + _isCompleted = true; + } + else + { + // we need to continue + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() need continue statusCode:[0x" + ((int)statusCode).ToString("x8", NumberFormatInfo.InvariantInfo) + "] (" + statusCode.ToString() + ") m_SecurityContext#" + ValidationHelper.HashString(_securityContext) + "::Handle:" + ValidationHelper.ToString(_securityContext) + "]"); + } + GlobalLog.Print("out token = " + outSecurityBuffer.ToString()); + GlobalLog.Dump(outSecurityBuffer.token); + GlobalLog.Print("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob() IsCompleted:" + IsCompleted.ToString()); + + byte[] decodedOutgoingBlob = outSecurityBuffer.token; + string outgoingBlob = null; + if (decodedOutgoingBlob != null && decodedOutgoingBlob.Length > 0) + { + outgoingBlob = HeaderEncoding.GetString(decodedOutgoingBlob, 0, outSecurityBuffer.size); + } + GlobalLog.Leave("NTAuthentication#" + ValidationHelper.HashString(this) + "::GetOutgoingDigestBlob", outgoingBlob); + return outgoingBlob; + } + + private string GetClientSpecifiedSpn() + { + GlobalLog.Assert(IsValidContext && IsCompleted, "NTAuthentication: Trying to get the client SPN before handshaking is done!"); + + string spn = SSPIWrapper.QueryContextAttributes(SSPIAuth, _securityContext, + ContextAttribute.ClientSpecifiedSpn) as string; + + GlobalLog.Print("NTAuthentication: The client specified SPN is [" + spn + "]"); + return spn; + } + + private class InitializeCallbackContext + { + internal readonly NTAuthentication ThisPtr; + internal readonly bool IsServer; + internal readonly string Package; + internal readonly NetworkCredential Credential; + internal readonly string Spn; + internal readonly ContextFlags RequestedContextFlags; + internal readonly ChannelBinding ChannelBinding; + + internal InitializeCallbackContext(NTAuthentication thisPtr, bool isServer, string package, NetworkCredential credential, string spn, ContextFlags requestedContextFlags, ChannelBinding channelBinding) + { + this.ThisPtr = thisPtr; + this.IsServer = isServer; + this.Package = package; + this.Credential = credential; + this.Spn = spn; + this.RequestedContextFlags = requestedContextFlags; + this.ChannelBinding = channelBinding; + } + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/AuthIdentity.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/AuthIdentity.cs new file mode 100644 index 0000000000..dcd73516a9 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/AuthIdentity.cs @@ -0,0 +1,40 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Auto)] + internal struct AuthIdentity + { + // see SEC_WINNT_AUTH_IDENTITY_W + internal string UserName; + internal int UserNameLength; + internal string Domain; + internal int DomainLength; + internal string Password; + internal int PasswordLength; + internal int Flags; + + internal AuthIdentity(string userName, string password, string domain) + { + UserName = userName; + UserNameLength = userName == null ? 0 : userName.Length; + Password = password; + PasswordLength = password == null ? 0 : password.Length; + Domain = domain; + DomainLength = domain == null ? 0 : domain.Length; + // Flags are 2 for Unicode and 1 for ANSI. We use 2 on NT and 1 on Win9x. + Flags = 2; + } + + public override string ToString() + { + return ValidationHelper.ToString(Domain) + "\\" + ValidationHelper.ToString(UserName); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/ContextFlags.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/ContextFlags.cs new file mode 100644 index 0000000000..21ef3b735d --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/ContextFlags.cs @@ -0,0 +1,124 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; + +namespace Microsoft.AspNet.Security.Windows +{ + // #define ISC_REQ_DELEGATE 0x00000001 + // #define ISC_REQ_MUTUAL_AUTH 0x00000002 + // #define ISC_REQ_REPLAY_DETECT 0x00000004 + // #define ISC_REQ_SEQUENCE_DETECT 0x00000008 + // #define ISC_REQ_CONFIDENTIALITY 0x00000010 + // #define ISC_REQ_USE_SESSION_KEY 0x00000020 + // #define ISC_REQ_PROMPT_FOR_CREDS 0x00000040 + // #define ISC_REQ_USE_SUPPLIED_CREDS 0x00000080 + // #define ISC_REQ_ALLOCATE_MEMORY 0x00000100 + // #define ISC_REQ_USE_DCE_STYLE 0x00000200 + // #define ISC_REQ_DATAGRAM 0x00000400 + // #define ISC_REQ_CONNECTION 0x00000800 + // #define ISC_REQ_CALL_LEVEL 0x00001000 + // #define ISC_REQ_FRAGMENT_SUPPLIED 0x00002000 + // #define ISC_REQ_EXTENDED_ERROR 0x00004000 + // #define ISC_REQ_STREAM 0x00008000 + // #define ISC_REQ_INTEGRITY 0x00010000 + // #define ISC_REQ_IDENTIFY 0x00020000 + // #define ISC_REQ_NULL_SESSION 0x00040000 + // #define ISC_REQ_MANUAL_CRED_VALIDATION 0x00080000 + // #define ISC_REQ_RESERVED1 0x00100000 + // #define ISC_REQ_FRAGMENT_TO_FIT 0x00200000 + // #define ISC_REQ_HTTP 0x10000000 + // Win7 SP1 + + // #define ISC_REQ_UNVERIFIED_TARGET_NAME 0x20000000 + + // #define ASC_REQ_DELEGATE 0x00000001 + // #define ASC_REQ_MUTUAL_AUTH 0x00000002 + // #define ASC_REQ_REPLAY_DETECT 0x00000004 + // #define ASC_REQ_SEQUENCE_DETECT 0x00000008 + // #define ASC_REQ_CONFIDENTIALITY 0x00000010 + // #define ASC_REQ_USE_SESSION_KEY 0x00000020 + // #define ASC_REQ_ALLOCATE_MEMORY 0x00000100 + // #define ASC_REQ_USE_DCE_STYLE 0x00000200 + // #define ASC_REQ_DATAGRAM 0x00000400 + // #define ASC_REQ_CONNECTION 0x00000800 + // #define ASC_REQ_CALL_LEVEL 0x00001000 + // #define ASC_REQ_EXTENDED_ERROR 0x00008000 + // #define ASC_REQ_STREAM 0x00010000 + // #define ASC_REQ_INTEGRITY 0x00020000 + // #define ASC_REQ_LICENSING 0x00040000 + // #define ASC_REQ_IDENTIFY 0x00080000 + // #define ASC_REQ_ALLOW_NULL_SESSION 0x00100000 + // #define ASC_REQ_ALLOW_NON_USER_LOGONS 0x00200000 + // #define ASC_REQ_ALLOW_CONTEXT_REPLAY 0x00400000 + // #define ASC_REQ_FRAGMENT_TO_FIT 0x00800000 + // #define ASC_REQ_FRAGMENT_SUPPLIED 0x00002000 + // #define ASC_REQ_NO_TOKEN 0x01000000 + // #define ASC_REQ_HTTP 0x10000000 + + [Flags] + internal enum ContextFlags + { + Zero = 0, + // The server in the transport application can + // build new security contexts impersonating the + // client that will be accepted by other servers + // as the client's contexts. + Delegate = 0x00000001, + // The communicating parties must authenticate + // their identities to each other. Without MutualAuth, + // the client authenticates its identity to the server. + // With MutualAuth, the server also must authenticate + // its identity to the client. + MutualAuth = 0x00000002, + // The security package detects replayed packets and + // notifies the caller if a packet has been replayed. + // The use of this flag implies all of the conditions + // specified by the Integrity flag. + ReplayDetect = 0x00000004, + // The context must be allowed to detect out-of-order + // delivery of packets later through the message support + // functions. Use of this flag implies all of the + // conditions specified by the Integrity flag. + SequenceDetect = 0x00000008, + // The context must protect data while in transit. + // Confidentiality is supported for NTLM with Microsoft + // Windows NT version 4.0, SP4 and later and with the + // Kerberos protocol in Microsoft Windows 2000 and later. + Confidentiality = 0x00000010, + UseSessionKey = 0x00000020, + AllocateMemory = 0x00000100, + + // Connection semantics must be used. + Connection = 0x00000800, + + // Client applications requiring extended error messages specify the + // ISC_REQ_EXTENDED_ERROR flag when calling the InitializeSecurityContext + // Server applications requiring extended error messages set + // the ASC_REQ_EXTENDED_ERROR flag when calling AcceptSecurityContext. + InitExtendedError = 0x00004000, + AcceptExtendedError = 0x00008000, + // A transport application requests stream semantics + // by setting the ISC_REQ_STREAM and ASC_REQ_STREAM + // flags in the calls to the InitializeSecurityContext + // and AcceptSecurityContext functions + InitStream = 0x00008000, + AcceptStream = 0x00010000, + // Buffer integrity can be verified; however, replayed + // and out-of-sequence messages will not be detected + InitIntegrity = 0x00010000, // ISC_REQ_INTEGRITY + AcceptIntegrity = 0x00020000, // ASC_REQ_INTEGRITY + + InitManualCredValidation = 0x00080000, // ISC_REQ_MANUAL_CRED_VALIDATION + InitUseSuppliedCreds = 0x00000080, // ISC_REQ_USE_SUPPLIED_CREDS + InitIdentify = 0x00020000, // ISC_REQ_IDENTIFY + AcceptIdentify = 0x00080000, // ASC_REQ_IDENTIFY + + ProxyBindings = 0x04000000, // ASC_REQ_PROXY_BINDINGS + AllowMissingBindings = 0x10000000, // ASC_REQ_ALLOW_MISSING_BINDINGS + + UnverifiedTargetName = 0x20000000, // ISC_REQ_UNVERIFIED_TARGET_NAME + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/NativeSSPI.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/NativeSSPI.cs new file mode 100644 index 0000000000..1ac0a7d055 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/NativeSSPI.cs @@ -0,0 +1,42 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Globalization; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + // used to define the interface for security to use. + internal interface ISSPIInterface + { + SecurityPackageInfoClass[] SecurityPackages { get; set; } + int EnumerateSecurityPackages(out int pkgnum, out SafeFreeContextBuffer pkgArray); + int AcquireCredentialsHandle(string moduleName, CredentialUse usage, ref AuthIdentity authdata, out SafeFreeCredentials outCredential); + int AcquireCredentialsHandle(string moduleName, CredentialUse usage, ref SafeSspiAuthDataHandle authdata, out SafeFreeCredentials outCredential); + int AcquireDefaultCredential(string moduleName, CredentialUse usage, out SafeFreeCredentials outCredential); + int AcquireCredentialsHandle(string moduleName, CredentialUse usage, ref SecureCredential authdata, out SafeFreeCredentials outCredential); + int AcceptSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteContext context, SecurityBuffer inputBuffer, ContextFlags inFlags, + Endianness endianness, SecurityBuffer outputBuffer, ref ContextFlags outFlags); + int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteContext context, SecurityBuffer[] inputBuffers, ContextFlags inFlags, + Endianness endianness, SecurityBuffer outputBuffer, ref ContextFlags outFlags); + int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteContext context, string targetName, ContextFlags inFlags, + Endianness endianness, SecurityBuffer inputBuffer, SecurityBuffer outputBuffer, ref ContextFlags outFlags); + int InitializeSecurityContext(SafeFreeCredentials credential, ref SafeDeleteContext context, string targetName, ContextFlags inFlags, + Endianness endianness, SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer, ref ContextFlags outFlags); + int EncryptMessage(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber); + int DecryptMessage(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber); + int MakeSignature(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber); + int VerifySignature(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber); + + int QueryContextChannelBinding(SafeDeleteContext phContext, ContextAttribute attribute, out SafeFreeContextBufferChannelBinding refHandle); + int QueryContextAttributes(SafeDeleteContext phContext, ContextAttribute attribute, byte[] buffer, Type handleType, out SafeHandle refHandle); + int SetContextAttributes(SafeDeleteContext phContext, ContextAttribute attribute, byte[] buffer); + int QuerySecurityContextToken(SafeDeleteContext phContext, out SafeCloseHandle phToken); + int CompleteAuthToken(ref SafeDeleteContext refContext, SecurityBuffer[] inputBuffers); + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIAuthType.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIAuthType.cs new file mode 100644 index 0000000000..e7951e2969 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIAuthType.cs @@ -0,0 +1,302 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Globalization; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class SSPIAuthType : ISSPIInterface + { + private static volatile SecurityPackageInfoClass[] _securityPackages; + + public SecurityPackageInfoClass[] SecurityPackages + { + get + { + return _securityPackages; + } + set + { + _securityPackages = value; + } + } + + public int EnumerateSecurityPackages(out int pkgnum, out SafeFreeContextBuffer pkgArray) + { + GlobalLog.Print("SSPIAuthType::EnumerateSecurityPackages()"); + return SafeFreeContextBuffer.EnumeratePackages(out pkgnum, out pkgArray); + } + + public int AcquireCredentialsHandle(string moduleName, CredentialUse usage, ref AuthIdentity authdata, out SafeFreeCredentials outCredential) + { + return SafeFreeCredentials.AcquireCredentialsHandle(moduleName, usage, ref authdata, out outCredential); + } + + public int AcquireCredentialsHandle(string moduleName, CredentialUse usage, ref SafeSspiAuthDataHandle authdata, out SafeFreeCredentials outCredential) + { + return SafeFreeCredentials.AcquireCredentialsHandle(moduleName, usage, ref authdata, out outCredential); + } + + public int AcquireDefaultCredential(string moduleName, CredentialUse usage, out SafeFreeCredentials outCredential) + { + return SafeFreeCredentials.AcquireDefaultCredential(moduleName, usage, out outCredential); + } + + public int AcquireCredentialsHandle(string moduleName, CredentialUse usage, ref SecureCredential authdata, out SafeFreeCredentials outCredential) + { + return SafeFreeCredentials.AcquireCredentialsHandle(moduleName, usage, ref authdata, out outCredential); + } + + public int AcceptSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteContext context, SecurityBuffer inputBuffer, ContextFlags inFlags, Endianness endianness, SecurityBuffer outputBuffer, ref ContextFlags outFlags) + { + return SafeDeleteContext.AcceptSecurityContext(ref credential, ref context, inFlags, endianness, inputBuffer, null, outputBuffer, ref outFlags); + } + + public int AcceptSecurityContext(SafeFreeCredentials credential, ref SafeDeleteContext context, SecurityBuffer[] inputBuffers, ContextFlags inFlags, Endianness endianness, SecurityBuffer outputBuffer, ref ContextFlags outFlags) + { + return SafeDeleteContext.AcceptSecurityContext(ref credential, ref context, inFlags, endianness, null, inputBuffers, outputBuffer, ref outFlags); + } + + public int InitializeSecurityContext(ref SafeFreeCredentials credential, ref SafeDeleteContext context, string targetName, ContextFlags inFlags, Endianness endianness, SecurityBuffer inputBuffer, SecurityBuffer outputBuffer, ref ContextFlags outFlags) + { + return SafeDeleteContext.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, endianness, inputBuffer, null, outputBuffer, ref outFlags); + } + + public int InitializeSecurityContext(SafeFreeCredentials credential, ref SafeDeleteContext context, string targetName, ContextFlags inFlags, Endianness endianness, SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer, ref ContextFlags outFlags) + { + return SafeDeleteContext.InitializeSecurityContext(ref credential, ref context, targetName, inFlags, endianness, null, inputBuffers, outputBuffer, ref outFlags); + } + + public int EncryptMessage(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + context.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + context.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.NativeNTSSPI.EncryptMessage(ref context._handle, 0, inputOutput, sequenceNumber); + context.DangerousRelease(); + } + } + return status; + } + + public unsafe int DecryptMessage(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + uint qop = 0; + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + context.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + context.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.NativeNTSSPI.DecryptMessage(ref context._handle, inputOutput, sequenceNumber, &qop); + context.DangerousRelease(); + } + } + + const uint SECQOP_WRAP_NO_ENCRYPT = 0x80000001; + if (status == 0 && qop == SECQOP_WRAP_NO_ENCRYPT) + { + GlobalLog.Assert("NativeNTSSPI.DecryptMessage", "Expected qop = 0, returned value = " + qop.ToString("x", CultureInfo.InvariantCulture)); + throw new InvalidOperationException(SR.GetString(SR.net_auth_message_not_encrypted)); + } + + return status; + } + + public int MakeSignature(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + context.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + context.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + const uint SECQOP_WRAP_NO_ENCRYPT = 0x80000001; + status = UnsafeNclNativeMethods.NativeNTSSPI.EncryptMessage(ref context._handle, SECQOP_WRAP_NO_ENCRYPT, inputOutput, sequenceNumber); + context.DangerousRelease(); + } + } + return status; + } + + public unsafe int VerifySignature(SafeDeleteContext context, SecurityBufferDescriptor inputOutput, uint sequenceNumber) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + + uint qop = 0; + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + context.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + context.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.NativeNTSSPI.DecryptMessage(ref context._handle, inputOutput, sequenceNumber, &qop); + context.DangerousRelease(); + } + } + + return status; + } + + public int QueryContextChannelBinding(SafeDeleteContext context, ContextAttribute attribute, out SafeFreeContextBufferChannelBinding binding) + { + // Querying an auth SSP for a CBT doesn't make sense + binding = null; + throw new NotSupportedException(); + } + + public unsafe int QueryContextAttributes(SafeDeleteContext context, ContextAttribute attribute, byte[] buffer, Type handleType, out SafeHandle refHandle) + { + refHandle = null; + if (handleType != null) + { + if (handleType == typeof(SafeFreeContextBuffer)) + { + refHandle = SafeFreeContextBuffer.CreateEmptyHandle(); + } + else if (handleType == typeof(SafeFreeCertContext)) + { + refHandle = new SafeFreeCertContext(); + } + else + { + throw new ArgumentException(SR.GetString(SR.SSPIInvalidHandleType, handleType.FullName), "handleType"); + } + } + + fixed (byte* bufferPtr = buffer) + { + return SafeFreeContextBuffer.QueryContextAttributes(context, attribute, bufferPtr, refHandle); + } + } + + public int SetContextAttributes(SafeDeleteContext context, ContextAttribute attribute, byte[] buffer) + { + throw new NotImplementedException(); + } + + public int QuerySecurityContextToken(SafeDeleteContext phContext, out SafeCloseHandle phToken) + { + return GetSecurityContextToken(phContext, out phToken); + } + + public int CompleteAuthToken(ref SafeDeleteContext refContext, SecurityBuffer[] inputBuffers) + { + return SafeDeleteContext.CompleteAuthToken(ref refContext, inputBuffers); + } + + private static int GetSecurityContextToken(SafeDeleteContext phContext, out SafeCloseHandle safeHandle) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + safeHandle = null; + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + phContext.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + phContext.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.SafeNetHandles.QuerySecurityContextToken(ref phContext._handle, out safeHandle); + phContext.DangerousRelease(); + } + } + + return status; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIHandle.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIHandle.cs new file mode 100644 index 0000000000..52728d3825 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIHandle.cs @@ -0,0 +1,36 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + [StructLayout(LayoutKind.Sequential, Pack = 1)] + internal struct SSPIHandle + { + private IntPtr HandleHi; + private IntPtr HandleLo; + + public bool IsZero + { + get { return HandleHi == IntPtr.Zero && HandleLo == IntPtr.Zero; } + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal void SetToInvalid() + { + HandleHi = IntPtr.Zero; + HandleLo = IntPtr.Zero; + } + + public override string ToString() + { + return HandleHi.ToString("x") + ":" + HandleLo.ToString("x"); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPISessionCache.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPISessionCache.cs new file mode 100644 index 0000000000..e30c1ab434 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPISessionCache.cs @@ -0,0 +1,49 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +/* +Abstract: + The file implements trivial SSPI credential caching mechanism based on lru list +*/ +using System; +using System.Threading; + +namespace Microsoft.AspNet.Security.Windows +{ + // Implements delayed SSPI handle release, like a finalizable object though the handles are kept alive until being pushed out + // by the newly incoming ones. + internal static class SSPIHandleCache + { + private const int MaxCacheSize = 0x1F; // must a (power of 2) - 1 + private static SafeCredentialReference[] _cacheSlots = new SafeCredentialReference[MaxCacheSize + 1]; + private static int _current = -1; + + internal static void CacheCredential(SafeFreeCredentials newHandle) + { + try + { + SafeCredentialReference newRef = SafeCredentialReference.CreateReference(newHandle); + if (newRef == null) + { + return; + } + unchecked + { + int index = Interlocked.Increment(ref _current) & MaxCacheSize; + newRef = Interlocked.Exchange(ref _cacheSlots[index], newRef); + } + if (newRef != null) + { + newRef.Dispose(); + } + } + catch (Exception e) + { + GlobalLog.Assert("SSPIHandlCache", "Attempted to throw: " + e.ToString()); + } + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIWrapper.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIWrapper.cs new file mode 100644 index 0000000000..4c04e704b2 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SSPIWrapper.cs @@ -0,0 +1,462 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.ComponentModel; +using System.Diagnostics; +using System.Globalization; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + // From Schannel.h + + [StructLayout(LayoutKind.Sequential)] + internal struct NegotiationInfo + { + // see SecPkgContext_NegotiationInfoW in + + // [MarshalAs(UnmanagedType.LPStruct)] internal SecurityPackageInfo PackageInfo; + internal IntPtr PackageInfo; + internal uint NegotiationState; + internal static readonly int Size = Marshal.SizeOf(typeof(NegotiationInfo)); + internal static readonly int NegotiationStateOffest = (int)Marshal.OffsetOf(typeof(NegotiationInfo), "NegotiationState"); + } + + // we keep it simple since we use this only to know if NTLM or + // Kerberos are used in the context of a Negotiate handshake + + [StructLayout(LayoutKind.Sequential)] + internal struct SecurityPackageInfo + { + // see SecPkgInfoW in + internal int Capabilities; + internal short Version; + internal short RPCID; + internal int MaxToken; + internal IntPtr Name; + internal IntPtr Comment; + + internal static readonly int Size = Marshal.SizeOf(typeof(SecurityPackageInfo)); + internal static readonly int NameOffest = (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "Name"); + } + + [StructLayout(LayoutKind.Sequential)] + internal struct Bindings + { + // see SecPkgContext_Bindings in + internal int BindingsLength; + internal IntPtr pBindings; + } + + internal static class SSPIWrapper + { + internal static SecurityPackageInfoClass[] EnumerateSecurityPackages(ISSPIInterface secModule) + { + GlobalLog.Enter("EnumerateSecurityPackages"); + if (secModule.SecurityPackages == null) + { + lock (secModule) + { + if (secModule.SecurityPackages == null) + { + int moduleCount = 0; + SafeFreeContextBuffer arrayBaseHandle = null; + try + { + int errorCode = secModule.EnumerateSecurityPackages(out moduleCount, out arrayBaseHandle); + GlobalLog.Print("SSPIWrapper::arrayBase: " + (arrayBaseHandle.DangerousGetHandle().ToString("x"))); + if (errorCode != 0) + { + throw new Win32Exception(errorCode); + } + SecurityPackageInfoClass[] securityPackages = new SecurityPackageInfoClass[moduleCount]; + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, SR.GetString(SR.net_log_sspi_enumerating_security_packages)); + } + int i; + for (i = 0; i < moduleCount; i++) + { + securityPackages[i] = new SecurityPackageInfoClass(arrayBaseHandle, i); + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, " " + securityPackages[i].Name); + } + } + secModule.SecurityPackages = securityPackages; + } + finally + { + if (arrayBaseHandle != null) + { + arrayBaseHandle.Dispose(); + } + } + } + } + } + GlobalLog.Leave("EnumerateSecurityPackages"); + return secModule.SecurityPackages; + } + + internal static SecurityPackageInfoClass GetVerifyPackageInfo(ISSPIInterface secModule, string packageName, bool throwIfMissing) + { + SecurityPackageInfoClass[] supportedSecurityPackages = EnumerateSecurityPackages(secModule); + if (supportedSecurityPackages != null) + { + for (int i = 0; i < supportedSecurityPackages.Length; i++) + { + if (string.Compare(supportedSecurityPackages[i].Name, packageName, StringComparison.OrdinalIgnoreCase) == 0) + { + return supportedSecurityPackages[i]; + } + } + } + + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, SR.GetString(SR.net_log_sspi_security_package_not_found, packageName)); + } + + // error + if (throwIfMissing) + { + throw new NotSupportedException(SR.GetString(SR.net_securitypackagesupport)); + } + + return null; + } + + public static SafeFreeCredentials AcquireDefaultCredential(ISSPIInterface secModule, string package, CredentialUse intent) + { + GlobalLog.Print("SSPIWrapper::AcquireDefaultCredential(): using " + package); + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, + "AcquireDefaultCredential(" + + "package = " + package + ", " + + "intent = " + intent + ")"); + } + + SafeFreeCredentials outCredential = null; + int errorCode = secModule.AcquireDefaultCredential(package, intent, out outCredential); + + if (errorCode != 0) + { +#if TRAVE + GlobalLog.Print("SSPIWrapper::AcquireDefaultCredential(): error " + SecureChannel.MapSecurityStatus((uint)errorCode)); +#endif + if (Logging.On) + { + Logging.PrintError(Logging.Web, SR.GetString(SR.net_log_operation_failed_with_error, "AcquireDefaultCredential()", String.Format(CultureInfo.CurrentCulture, "0X{0:X}", errorCode))); + } + throw new Win32Exception(errorCode); + } + return outCredential; + } + + public static SafeFreeCredentials AcquireCredentialsHandle(ISSPIInterface secModule, string package, CredentialUse intent, ref AuthIdentity authdata) + { + GlobalLog.Print("SSPIWrapper::AcquireCredentialsHandle#2(): using " + package); + + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, + "AcquireCredentialsHandle(" + + "package = " + package + ", " + + "intent = " + intent + ", " + + "authdata = " + authdata + ")"); + } + + SafeFreeCredentials credentialsHandle = null; + int errorCode = secModule.AcquireCredentialsHandle(package, + intent, + ref authdata, + out credentialsHandle); + + if (errorCode != 0) + { + if (Logging.On) + { + Logging.PrintError(Logging.Web, SR.GetString(SR.net_log_operation_failed_with_error, "AcquireCredentialsHandle()", String.Format(CultureInfo.CurrentCulture, "0X{0:X}", errorCode))); + } + throw new Win32Exception(errorCode); + } + return credentialsHandle; + } + + public static SafeFreeCredentials AcquireCredentialsHandle(ISSPIInterface secModule, string package, CredentialUse intent, ref SafeSspiAuthDataHandle authdata) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, + "AcquireCredentialsHandle(" + + "package = " + package + ", " + + "intent = " + intent + ", " + + "authdata = " + authdata + ")"); + } + + SafeFreeCredentials credentialsHandle = null; + int errorCode = secModule.AcquireCredentialsHandle(package, intent, ref authdata, out credentialsHandle); + + if (errorCode != 0) + { + if (Logging.On) + { + Logging.PrintError(Logging.Web, SR.GetString(SR.net_log_operation_failed_with_error, "AcquireCredentialsHandle()", String.Format(CultureInfo.CurrentCulture, "0X{0:X}", errorCode))); + } + throw new Win32Exception(errorCode); + } + return credentialsHandle; + } + + internal static int InitializeSecurityContext(ISSPIInterface secModule, SafeFreeCredentials credential, ref SafeDeleteContext context, string targetName, ContextFlags inFlags, Endianness datarep, SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer, ref ContextFlags outFlags) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, + "InitializeSecurityContext(" + + "credential = " + credential.ToString() + ", " + + "context = " + ValidationHelper.ToString(context) + ", " + + "targetName = " + targetName + ", " + + "inFlags = " + inFlags + ")"); + } + + int errorCode = secModule.InitializeSecurityContext(credential, ref context, targetName, inFlags, datarep, inputBuffers, outputBuffer, ref outFlags); + + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, SR.GetString(SR.net_log_sspi_security_context_input_buffers, "InitializeSecurityContext", (inputBuffers == null ? 0 : inputBuffers.Length), outputBuffer.size, (SecurityStatus)errorCode)); + } + + return errorCode; + } + + internal static int AcceptSecurityContext(ISSPIInterface secModule, SafeFreeCredentials credential, ref SafeDeleteContext context, ContextFlags inFlags, Endianness datarep, SecurityBuffer[] inputBuffers, SecurityBuffer outputBuffer, ref ContextFlags outFlags) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, + "AcceptSecurityContext(" + + "credential = " + credential.ToString() + ", " + + "context = " + ValidationHelper.ToString(context) + ", " + + "inFlags = " + inFlags + ")"); + } + + int errorCode = secModule.AcceptSecurityContext(credential, ref context, inputBuffers, inFlags, datarep, outputBuffer, ref outFlags); + + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, SR.GetString(SR.net_log_sspi_security_context_input_buffers, "AcceptSecurityContext", (inputBuffers == null ? 0 : inputBuffers.Length), outputBuffer.size, (SecurityStatus)errorCode)); + } + + return errorCode; + } + + internal static int CompleteAuthToken(ISSPIInterface secModule, ref SafeDeleteContext context, SecurityBuffer[] inputBuffers) + { + int errorCode = secModule.CompleteAuthToken(ref context, inputBuffers); + + if (Logging.On) + { + Logging.PrintInfo(Logging.Web, SR.GetString(SR.net_log_operation_returned_something, "CompleteAuthToken()", (SecurityStatus)errorCode)); + } + + return errorCode; + } + + public static int QuerySecurityContextToken(ISSPIInterface secModule, SafeDeleteContext context, out SafeCloseHandle token) + { + return secModule.QuerySecurityContextToken(context, out token); + } + + public static SafeFreeContextBufferChannelBinding QueryContextChannelBinding(ISSPIInterface secModule, SafeDeleteContext securityContext, ContextAttribute contextAttribute) + { + GlobalLog.Enter("QueryContextChannelBinding", contextAttribute.ToString()); + + SafeFreeContextBufferChannelBinding result; + int errorCode = secModule.QueryContextChannelBinding(securityContext, contextAttribute, out result); + if (errorCode != 0) + { + GlobalLog.Leave("QueryContextChannelBinding", "ERROR = " + ErrorDescription(errorCode)); + return null; + } + + GlobalLog.Leave("QueryContextChannelBinding", ValidationHelper.HashString(result)); + return result; + } + + public static object QueryContextAttributes(ISSPIInterface secModule, SafeDeleteContext securityContext, ContextAttribute contextAttribute) + { + int errorCode; + return QueryContextAttributes(secModule, securityContext, contextAttribute, out errorCode); + } + + public static object QueryContextAttributes(ISSPIInterface secModule, SafeDeleteContext securityContext, ContextAttribute contextAttribute, out int errorCode) + { + GlobalLog.Enter("QueryContextAttributes", contextAttribute.ToString()); + + int nativeBlockSize = IntPtr.Size; + Type handleType = null; + + switch (contextAttribute) + { + case ContextAttribute.Sizes: + nativeBlockSize = SecSizes.SizeOf; + break; + case ContextAttribute.StreamSizes: + nativeBlockSize = StreamSizes.SizeOf; + break; + + case ContextAttribute.Names: + handleType = typeof(SafeFreeContextBuffer); + break; + + case ContextAttribute.PackageInfo: + handleType = typeof(SafeFreeContextBuffer); + break; + + case ContextAttribute.NegotiationInfo: + handleType = typeof(SafeFreeContextBuffer); + nativeBlockSize = Marshal.SizeOf(typeof(NegotiationInfo)); + break; + + case ContextAttribute.ClientSpecifiedSpn: + handleType = typeof(SafeFreeContextBuffer); + break; + + case ContextAttribute.RemoteCertificate: + handleType = typeof(SafeFreeCertContext); + break; + + case ContextAttribute.LocalCertificate: + handleType = typeof(SafeFreeCertContext); + break; + + case ContextAttribute.IssuerListInfoEx: + nativeBlockSize = Marshal.SizeOf(typeof(IssuerListInfoEx)); + handleType = typeof(SafeFreeContextBuffer); + break; + + case ContextAttribute.ConnectionInfo: + nativeBlockSize = Marshal.SizeOf(typeof(SslConnectionInfo)); + break; + + default: + throw new ArgumentException(SR.GetString(SR.net_invalid_enum, "ContextAttribute"), "contextAttribute"); + } + + SafeHandle SspiHandle = null; + object attribute = null; + + try + { + byte[] nativeBuffer = new byte[nativeBlockSize]; + errorCode = secModule.QueryContextAttributes(securityContext, contextAttribute, nativeBuffer, handleType, out SspiHandle); + if (errorCode != 0) + { + GlobalLog.Leave("Win32:QueryContextAttributes", "ERROR = " + ErrorDescription(errorCode)); + return null; + } + + switch (contextAttribute) + { + case ContextAttribute.Sizes: + attribute = new SecSizes(nativeBuffer); + break; + + case ContextAttribute.StreamSizes: + attribute = new StreamSizes(nativeBuffer); + break; + + case ContextAttribute.Names: + attribute = Marshal.PtrToStringUni(SspiHandle.DangerousGetHandle()); + break; + + case ContextAttribute.PackageInfo: + attribute = new SecurityPackageInfoClass(SspiHandle, 0); + break; + + case ContextAttribute.NegotiationInfo: + unsafe + { + fixed (void* ptr = nativeBuffer) + { + attribute = new NegotiationInfoClass(SspiHandle, Marshal.ReadInt32(new IntPtr(ptr), NegotiationInfo.NegotiationStateOffest)); + } + } + break; + + case ContextAttribute.ClientSpecifiedSpn: + attribute = Marshal.PtrToStringUni(SspiHandle.DangerousGetHandle()); + break; + + case ContextAttribute.LocalCertificate: + goto case ContextAttribute.RemoteCertificate; + case ContextAttribute.RemoteCertificate: + attribute = SspiHandle; + SspiHandle = null; + break; + + case ContextAttribute.IssuerListInfoEx: + attribute = new IssuerListInfoEx(SspiHandle, nativeBuffer); + SspiHandle = null; + break; + + case ContextAttribute.ConnectionInfo: + attribute = new SslConnectionInfo(nativeBuffer); + break; + default: + // will return null + break; + } + } + finally + { + if (SspiHandle != null) + { + SspiHandle.Dispose(); + } + } + GlobalLog.Leave("QueryContextAttributes", ValidationHelper.ToString(attribute)); + return attribute; + } + + public static string ErrorDescription(int errorCode) + { + if (errorCode == -1) + { + return "An exception when invoking Win32 API"; + } + switch ((SecurityStatus)errorCode) + { + case SecurityStatus.InvalidHandle: + return "Invalid handle"; + case SecurityStatus.InvalidToken: + return "Invalid token"; + case SecurityStatus.ContinueNeeded: + return "Continue needed"; + case SecurityStatus.IncompleteMessage: + return "Message incomplete"; + case SecurityStatus.WrongPrincipal: + return "Wrong principal"; + case SecurityStatus.TargetUnknown: + return "Target unknown"; + case SecurityStatus.PackageNotFound: + return "Package not found"; + case SecurityStatus.BufferNotEnough: + return "Buffer not enough"; + case SecurityStatus.MessageAltered: + return "Message altered"; + case SecurityStatus.UntrustedRoot: + return "Untrusted root"; + default: + return "0x" + errorCode.ToString("x", NumberFormatInfo.InvariantInfo); + } + } + } // class SSPIWrapper +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCloseHandle.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCloseHandle.cs new file mode 100644 index 0000000000..1b720b1ad6 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCloseHandle.cs @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.ConstrainedExecution; +using System.Security; +using System.Threading; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Security.Windows +{ + [SuppressUnmanagedCodeSecurity] + internal sealed class SafeCloseHandle : CriticalHandleZeroOrMinusOneIsInvalid + { + private int _disposed; + + private SafeCloseHandle() + : base() + { + } + + internal IntPtr DangerousGetHandle() + { + return handle; + } + + protected override bool ReleaseHandle() + { + if (!IsInvalid) + { + if (Interlocked.Increment(ref _disposed) == 1) + { + return UnsafeNclNativeMethods.SafeNetHandles.CloseHandle(handle); + } + } + return true; + } + + // This method will bypass refCount check done by VM + // Means it will force handle release if has a valid value + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal void Abort() + { + ReleaseHandle(); + SetHandleAsInvalid(); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCredentialReference.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCredentialReference.cs new file mode 100644 index 0000000000..a0d29a1b61 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeCredentialReference.cs @@ -0,0 +1,69 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.CompilerServices; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Security.Windows +{ + internal sealed class SafeCredentialReference : CriticalHandleMinusOneIsInvalid + { + // Static cache will return the target handle if found the reference in the table. + internal SafeFreeCredentials _Target; + + private SafeCredentialReference(SafeFreeCredentials target) + : base() + { + // Bumps up the refcount on Target to signify that target handle is statically cached so + // its dispose should be postponed + bool b = false; + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + target.DangerousAddRef(ref b); + } + catch + { + if (b) + { + target.DangerousRelease(); + b = false; + } + } + finally + { + if (b) + { + _Target = target; + SetHandle(new IntPtr(0)); // make this handle valid + } + } + } + + internal static SafeCredentialReference CreateReference(SafeFreeCredentials target) + { + SafeCredentialReference result = new SafeCredentialReference(target); + if (result.IsInvalid) + { + return null; + } + + return result; + } + + protected override bool ReleaseHandle() + { + SafeFreeCredentials target = _Target; + if (target != null) + { + target.DangerousRelease(); + } + _Target = null; + return true; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeDeleteContext.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeDeleteContext.cs new file mode 100644 index 0000000000..a356f2b3ef --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeDeleteContext.cs @@ -0,0 +1,707 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + internal sealed class SafeDeleteContext : SafeHandle + { + private const string DummyStr = " "; + private static readonly byte[] DummyBytes = new byte[] { 0 }; + + // ATN: _handle is internal since it is used on PInvokes by other wrapper methods. + // However all such wrappers MUST manually and reliably adjust refCounter of SafeDeleteContext handle. + + internal SSPIHandle _handle; + + private SafeFreeCredentials _effectiveCredential; + + private SafeDeleteContext() + : base(IntPtr.Zero, true) + { + _handle = new SSPIHandle(); + } + + public override bool IsInvalid + { + get + { + return IsClosed || _handle.IsZero; + } + } + + public override string ToString() + { + return _handle.ToString(); + } + + //------------------------------------------------------------------- + internal static unsafe int InitializeSecurityContext(ref SafeFreeCredentials inCredentials, ref SafeDeleteContext refContext, + string targetName, ContextFlags inFlags, Endianness endianness, SecurityBuffer inSecBuffer, SecurityBuffer[] inSecBuffers, + SecurityBuffer outSecBuffer, ref ContextFlags outFlags) + { + GlobalLog.Assert(outSecBuffer != null, "SafeDeleteContext::InitializeSecurityContext()|outSecBuffer != null"); + GlobalLog.Assert(inSecBuffer == null || inSecBuffers == null, "SafeDeleteContext::InitializeSecurityContext()|inSecBuffer == null || inSecBuffers == null"); + + if (inCredentials == null) + { + throw new ArgumentNullException("inCredentials"); + } + + SecurityBufferDescriptor inSecurityBufferDescriptor = null; + if (inSecBuffer != null) + { + inSecurityBufferDescriptor = new SecurityBufferDescriptor(1); + } + else if (inSecBuffers != null) + { + inSecurityBufferDescriptor = new SecurityBufferDescriptor(inSecBuffers.Length); + } + SecurityBufferDescriptor outSecurityBufferDescriptor = new SecurityBufferDescriptor(1); + + // actually this is returned in outFlags + bool isSspiAllocated = (inFlags & ContextFlags.AllocateMemory) != 0 ? true : false; + + int errorCode = -1; + + SSPIHandle contextHandle = new SSPIHandle(); + if (refContext != null) + { + contextHandle = refContext._handle; + } + + // these are pinned user byte arrays passed along with SecurityBuffers + GCHandle[] pinnedInBytes = null; + GCHandle pinnedOutBytes = new GCHandle(); + // optional output buffer that may need to be freed + SafeFreeContextBuffer outFreeContextBuffer = null; + try + { + pinnedOutBytes = GCHandle.Alloc(outSecBuffer.token, GCHandleType.Pinned); + SecurityBufferStruct[] inUnmanagedBuffer = new SecurityBufferStruct[inSecurityBufferDescriptor == null ? 1 : inSecurityBufferDescriptor.Count]; + fixed (void* inUnmanagedBufferPtr = inUnmanagedBuffer) + { + if (inSecurityBufferDescriptor != null) + { + // Fix Descriptor pointer that points to unmanaged SecurityBuffers + inSecurityBufferDescriptor.UnmanagedPointer = inUnmanagedBufferPtr; + pinnedInBytes = new GCHandle[inSecurityBufferDescriptor.Count]; + SecurityBuffer securityBuffer; + for (int index = 0; index < inSecurityBufferDescriptor.Count; ++index) + { + securityBuffer = inSecBuffer != null ? inSecBuffer : inSecBuffers[index]; + if (securityBuffer != null) + { + // Copy the SecurityBuffer content into unmanaged place holder + inUnmanagedBuffer[index].count = securityBuffer.size; + inUnmanagedBuffer[index].type = securityBuffer.type; + + // use the unmanaged token if it's not null; otherwise use the managed buffer + if (securityBuffer.unmanagedToken != null) + { + inUnmanagedBuffer[index].token = securityBuffer.unmanagedToken.DangerousGetHandle(); + } + else if (securityBuffer.token == null || securityBuffer.token.Length == 0) + { + inUnmanagedBuffer[index].token = IntPtr.Zero; + } + else + { + pinnedInBytes[index] = GCHandle.Alloc(securityBuffer.token, GCHandleType.Pinned); + inUnmanagedBuffer[index].token = Marshal.UnsafeAddrOfPinnedArrayElement(securityBuffer.token, securityBuffer.offset); + } + } + } + } + + SecurityBufferStruct[] outUnmanagedBuffer = new SecurityBufferStruct[1]; + fixed (void* outUnmanagedBufferPtr = outUnmanagedBuffer) + { + // Fix Descriptor pointer that points to unmanaged SecurityBuffers + outSecurityBufferDescriptor.UnmanagedPointer = outUnmanagedBufferPtr; + outUnmanagedBuffer[0].count = outSecBuffer.size; + outUnmanagedBuffer[0].type = outSecBuffer.type; + if (outSecBuffer.token == null || outSecBuffer.token.Length == 0) + { + outUnmanagedBuffer[0].token = IntPtr.Zero; + } + else + { + outUnmanagedBuffer[0].token = Marshal.UnsafeAddrOfPinnedArrayElement(outSecBuffer.token, outSecBuffer.offset); + } + if (isSspiAllocated) + { + outFreeContextBuffer = SafeFreeContextBuffer.CreateEmptyHandle(); + } + + if (refContext == null || refContext.IsInvalid) + { + refContext = new SafeDeleteContext(); + } + + if (targetName == null || targetName.Length == 0) + { + targetName = DummyStr; + } + + fixed (char* namePtr = targetName) + { + errorCode = MustRunInitializeSecurityContext( + ref inCredentials, + contextHandle.IsZero ? null : &contextHandle, + (byte*)(((object)targetName == (object)DummyStr) ? null : namePtr), + inFlags, + endianness, + inSecurityBufferDescriptor, + refContext, + outSecurityBufferDescriptor, + ref outFlags, + outFreeContextBuffer); + } + + GlobalLog.Print("SafeDeleteContext:InitializeSecurityContext Marshalling OUT buffer"); + // Get unmanaged buffer with index 0 as the only one passed into PInvoke + outSecBuffer.size = outUnmanagedBuffer[0].count; + outSecBuffer.type = outUnmanagedBuffer[0].type; + if (outSecBuffer.size > 0) + { + outSecBuffer.token = new byte[outSecBuffer.size]; + Marshal.Copy(outUnmanagedBuffer[0].token, outSecBuffer.token, 0, outSecBuffer.size); + } + else + { + outSecBuffer.token = null; + } + } + } + } + finally + { + if (pinnedInBytes != null) + { + for (int index = 0; index < pinnedInBytes.Length; index++) + { + if (pinnedInBytes[index].IsAllocated) + { + pinnedInBytes[index].Free(); + } + } + } + if (pinnedOutBytes.IsAllocated) + { + pinnedOutBytes.Free(); + } + + if (outFreeContextBuffer != null) + { + outFreeContextBuffer.Dispose(); + } + } + + GlobalLog.Leave("SafeDeleteContext::InitializeSecurityContext() unmanaged InitializeSecurityContext()", "errorCode:0x" + errorCode.ToString("x8") + " refContext:" + ValidationHelper.ToString(refContext)); + + return errorCode; + } + + // After PINvoke call the method will fix the handleTemplate.handle with the returned value. + // The caller is responsible for creating a correct SafeFreeContextBuffer_XXX flavour or null can be passed if no handle is returned. + // + // Since it has a CER, this method can't have any references to imports from DLLs that may not exist on the system. + + private static unsafe int MustRunInitializeSecurityContext( + ref SafeFreeCredentials inCredentials, + void* inContextPtr, + byte* targetName, + ContextFlags inFlags, + Endianness endianness, + SecurityBufferDescriptor inputBuffer, + SafeDeleteContext outContext, + SecurityBufferDescriptor outputBuffer, + ref ContextFlags attributes, + SafeFreeContextBuffer handleTemplate) + { + int errorCode = (int)SecurityStatus.InvalidHandle; + bool b1 = false; + bool b2 = false; + + // Run the body of this method as a non-interruptible block. + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + inCredentials.DangerousAddRef(ref b1); + outContext.DangerousAddRef(ref b2); + } + catch (Exception e) + { + if (b1) + { + inCredentials.DangerousRelease(); + b1 = false; + } + if (b2) + { + outContext.DangerousRelease(); + b2 = false; + } + + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + SSPIHandle credentialHandle = inCredentials._handle; + long timeStamp; + + if (!b1) + { + // caller should retry + inCredentials = null; + } + else if (b1 && b2) + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.InitializeSecurityContextW( + ref credentialHandle, + inContextPtr, + targetName, + inFlags, + 0, + endianness, + inputBuffer, + 0, + ref outContext._handle, + outputBuffer, + ref attributes, + out timeStamp); + + // When a credential handle is first associated with the context we keep credential + // ref count bumped up to ensure ordered finalization. + // If the credential handle has been changed we de-ref the old one and associate the + // context with the new cred handle but only if the call was successful. + if (outContext._effectiveCredential != inCredentials && (errorCode & 0x80000000) == 0) + { + // Disassociate the previous credential handle + if (outContext._effectiveCredential != null) + { + outContext._effectiveCredential.DangerousRelease(); + } + outContext._effectiveCredential = inCredentials; + } + else + { + inCredentials.DangerousRelease(); + } + + outContext.DangerousRelease(); + + // The idea is that SSPI has allocated a block and filled up outUnmanagedBuffer+8 slot with the pointer. + if (handleTemplate != null) + { + handleTemplate.Set(((SecurityBufferStruct*)outputBuffer.UnmanagedPointer)->token); // ATTN: on 64 BIT that is still +8 cause of 2* c++ unsigned long == 8 bytes + if (handleTemplate.IsInvalid) + { + handleTemplate.SetHandleAsInvalid(); + } + } + } + + if (inContextPtr == null && (errorCode & 0x80000000) != 0) + { + // an error on the first call, need to set the out handle to invalid value + outContext._handle.SetToInvalid(); + } + } + + return errorCode; + } + + //------------------------------------------------------------------- + internal static unsafe int AcceptSecurityContext(ref SafeFreeCredentials inCredentials, ref SafeDeleteContext refContext, + ContextFlags inFlags, Endianness endianness, SecurityBuffer inSecBuffer, SecurityBuffer[] inSecBuffers, SecurityBuffer outSecBuffer, + ref ContextFlags outFlags) + { + GlobalLog.Assert(outSecBuffer != null, "SafeDeleteContext::AcceptSecurityContext()|outSecBuffer != null"); + GlobalLog.Assert(inSecBuffer == null || inSecBuffers == null, "SafeDeleteContext::AcceptSecurityContext()|inSecBuffer == null || inSecBuffers == null"); + + if (inCredentials == null) + { + throw new ArgumentNullException("inCredentials"); + } + + SecurityBufferDescriptor inSecurityBufferDescriptor = null; + if (inSecBuffer != null) + { + inSecurityBufferDescriptor = new SecurityBufferDescriptor(1); + } + else if (inSecBuffers != null) + { + inSecurityBufferDescriptor = new SecurityBufferDescriptor(inSecBuffers.Length); + } + SecurityBufferDescriptor outSecurityBufferDescriptor = new SecurityBufferDescriptor(1); + + // actually this is returned in outFlags + bool isSspiAllocated = (inFlags & ContextFlags.AllocateMemory) != 0 ? true : false; + + int errorCode = -1; + + SSPIHandle contextHandle = new SSPIHandle(); + if (refContext != null) + { + contextHandle = refContext._handle; + } + + // these are pinned user byte arrays passed along with SecurityBuffers + GCHandle[] pinnedInBytes = null; + GCHandle pinnedOutBytes = new GCHandle(); + // optional output buffer that may need to be freed + SafeFreeContextBuffer outFreeContextBuffer = null; + try + { + pinnedOutBytes = GCHandle.Alloc(outSecBuffer.token, GCHandleType.Pinned); + SecurityBufferStruct[] inUnmanagedBuffer = new SecurityBufferStruct[inSecurityBufferDescriptor == null ? 1 : inSecurityBufferDescriptor.Count]; + fixed (void* inUnmanagedBufferPtr = inUnmanagedBuffer) + { + if (inSecurityBufferDescriptor != null) + { + // Fix Descriptor pointer that points to unmanaged SecurityBuffers + inSecurityBufferDescriptor.UnmanagedPointer = inUnmanagedBufferPtr; + pinnedInBytes = new GCHandle[inSecurityBufferDescriptor.Count]; + SecurityBuffer securityBuffer; + for (int index = 0; index < inSecurityBufferDescriptor.Count; ++index) + { + securityBuffer = inSecBuffer != null ? inSecBuffer : inSecBuffers[index]; + if (securityBuffer != null) + { + // Copy the SecurityBuffer content into unmanaged place holder + inUnmanagedBuffer[index].count = securityBuffer.size; + inUnmanagedBuffer[index].type = securityBuffer.type; + + // use the unmanaged token if it's not null; otherwise use the managed buffer + if (securityBuffer.unmanagedToken != null) + { + inUnmanagedBuffer[index].token = securityBuffer.unmanagedToken.DangerousGetHandle(); + } + else if (securityBuffer.token == null || securityBuffer.token.Length == 0) + { + inUnmanagedBuffer[index].token = IntPtr.Zero; + } + else + { + pinnedInBytes[index] = GCHandle.Alloc(securityBuffer.token, GCHandleType.Pinned); + inUnmanagedBuffer[index].token = Marshal.UnsafeAddrOfPinnedArrayElement(securityBuffer.token, securityBuffer.offset); + } + } + } + } + SecurityBufferStruct[] outUnmanagedBuffer = new SecurityBufferStruct[1]; + fixed (void* outUnmanagedBufferPtr = outUnmanagedBuffer) + { + // Fix Descriptor pointer that points to unmanaged SecurityBuffers + outSecurityBufferDescriptor.UnmanagedPointer = outUnmanagedBufferPtr; + // Copy the SecurityBuffer content into unmanaged place holder + outUnmanagedBuffer[0].count = outSecBuffer.size; + outUnmanagedBuffer[0].type = outSecBuffer.type; + + if (outSecBuffer.token == null || outSecBuffer.token.Length == 0) + { + outUnmanagedBuffer[0].token = IntPtr.Zero; + } + else + { + outUnmanagedBuffer[0].token = Marshal.UnsafeAddrOfPinnedArrayElement(outSecBuffer.token, outSecBuffer.offset); + } + if (isSspiAllocated) + { + outFreeContextBuffer = SafeFreeContextBuffer.CreateEmptyHandle(); + } + + if (refContext == null || refContext.IsInvalid) + { + refContext = new SafeDeleteContext(); + } + + errorCode = MustRunAcceptSecurityContext( + ref inCredentials, + contextHandle.IsZero ? null : &contextHandle, + inSecurityBufferDescriptor, + inFlags, + endianness, + refContext, + outSecurityBufferDescriptor, + ref outFlags, + outFreeContextBuffer); + + GlobalLog.Print("SafeDeleteContext:AcceptSecurityContext Marshalling OUT buffer"); + // Get unmanaged buffer with index 0 as the only one passed into PInvoke + outSecBuffer.size = outUnmanagedBuffer[0].count; + outSecBuffer.type = outUnmanagedBuffer[0].type; + if (outSecBuffer.size > 0) + { + outSecBuffer.token = new byte[outSecBuffer.size]; + Marshal.Copy(outUnmanagedBuffer[0].token, outSecBuffer.token, 0, outSecBuffer.size); + } + else + { + outSecBuffer.token = null; + } + } + } + } + finally + { + if (pinnedInBytes != null) + { + for (int index = 0; index < pinnedInBytes.Length; index++) + { + if (pinnedInBytes[index].IsAllocated) + { + pinnedInBytes[index].Free(); + } + } + } + + if (pinnedOutBytes.IsAllocated) + { + pinnedOutBytes.Free(); + } + + if (outFreeContextBuffer != null) + { + outFreeContextBuffer.Dispose(); + } + } + + GlobalLog.Leave("SafeDeleteContext::AcceptSecurityContex() unmanaged AcceptSecurityContex()", "errorCode:0x" + errorCode.ToString("x8") + " refContext:" + ValidationHelper.ToString(refContext)); + + return errorCode; + } + + // After PINvoke call the method will fix the handleTemplate.handle with the returned value. + // The caller is responsible for creating a correct SafeFreeContextBuffer_XXX flavour or null can be passed if no handle is returned. + // + // Since it has a CER, this method can't have any references to imports from DLLs that may not exist on the system. + + private static unsafe int MustRunAcceptSecurityContext( + ref SafeFreeCredentials inCredentials, + void* inContextPtr, + SecurityBufferDescriptor inputBuffer, + ContextFlags inFlags, + Endianness endianness, + SafeDeleteContext outContext, + SecurityBufferDescriptor outputBuffer, + ref ContextFlags outFlags, + SafeFreeContextBuffer handleTemplate) + { + int errorCode = (int)SecurityStatus.InvalidHandle; + bool b1 = false; + bool b2 = false; + + // Run the body of this method as a non-interruptible block. + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + inCredentials.DangerousAddRef(ref b1); + outContext.DangerousAddRef(ref b2); + } + catch (Exception e) + { + if (b1) + { + inCredentials.DangerousRelease(); + b1 = false; + } + if (b2) + { + outContext.DangerousRelease(); + b2 = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + SSPIHandle credentialHandle = inCredentials._handle; + long timeStamp; + + if (!b1) + { + // caller should retry + inCredentials = null; + } + else if (b1 && b2) + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.AcceptSecurityContext( + ref credentialHandle, + inContextPtr, + inputBuffer, + inFlags, + endianness, + ref outContext._handle, + outputBuffer, + ref outFlags, + out timeStamp); + + // When a credential handle is first associated with the context we keep credential + // ref count bumped up to ensure ordered finalization. + // If the credential handle has been changed we de-ref the old one and associate the + // context with the new cred handle but only if the call was successful. + if (outContext._effectiveCredential != inCredentials && (errorCode & 0x80000000) == 0) + { + // Disassociate the previous credential handle + if (outContext._effectiveCredential != null) + { + outContext._effectiveCredential.DangerousRelease(); + } + outContext._effectiveCredential = inCredentials; + } + else + { + inCredentials.DangerousRelease(); + } + + outContext.DangerousRelease(); + + // The idea is that SSPI has allocated a block and filled up outUnmanagedBuffer+8 slot with the pointer. + if (handleTemplate != null) + { + handleTemplate.Set(((SecurityBufferStruct*)outputBuffer.UnmanagedPointer)->token); // ATTN: on 64 BIT that is still +8 cause of 2* c++ unsigned long == 8 bytes + if (handleTemplate.IsInvalid) + { + handleTemplate.SetHandleAsInvalid(); + } + } + } + + if (inContextPtr == null && (errorCode & 0x80000000) != 0) + { + // an error on the first call, need to set the out handle to invalid value + outContext._handle.SetToInvalid(); + } + } + + return errorCode; + } + + internal static unsafe int CompleteAuthToken(ref SafeDeleteContext refContext, SecurityBuffer[] inSecBuffers) + { + GlobalLog.Enter("SafeDeleteContext::CompleteAuthToken"); + GlobalLog.Print(" refContext = " + ValidationHelper.ToString(refContext)); + GlobalLog.Assert(inSecBuffers != null, "SafeDeleteContext::CompleteAuthToken()|inSecBuffers == null"); + SecurityBufferDescriptor inSecurityBufferDescriptor = new SecurityBufferDescriptor(inSecBuffers.Length); + + int errorCode = (int)SecurityStatus.InvalidHandle; + + // these are pinned user byte arrays passed along with SecurityBuffers + GCHandle[] pinnedInBytes = null; + + SecurityBufferStruct[] inUnmanagedBuffer = new SecurityBufferStruct[inSecurityBufferDescriptor.Count]; + fixed (void* inUnmanagedBufferPtr = inUnmanagedBuffer) + { + // Fix Descriptor pointer that points to unmanaged SecurityBuffers + inSecurityBufferDescriptor.UnmanagedPointer = inUnmanagedBufferPtr; + pinnedInBytes = new GCHandle[inSecurityBufferDescriptor.Count]; + SecurityBuffer securityBuffer; + for (int index = 0; index < inSecurityBufferDescriptor.Count; ++index) + { + securityBuffer = inSecBuffers[index]; + if (securityBuffer != null) + { + inUnmanagedBuffer[index].count = securityBuffer.size; + inUnmanagedBuffer[index].type = securityBuffer.type; + + // use the unmanaged token if it's not null; otherwise use the managed buffer + if (securityBuffer.unmanagedToken != null) + { + inUnmanagedBuffer[index].token = securityBuffer.unmanagedToken.DangerousGetHandle(); + } + else if (securityBuffer.token == null || securityBuffer.token.Length == 0) + { + inUnmanagedBuffer[index].token = IntPtr.Zero; + } + else + { + pinnedInBytes[index] = GCHandle.Alloc(securityBuffer.token, GCHandleType.Pinned); + inUnmanagedBuffer[index].token = Marshal.UnsafeAddrOfPinnedArrayElement(securityBuffer.token, securityBuffer.offset); + } + } + } + + SSPIHandle contextHandle = new SSPIHandle(); + if (refContext != null) + { + contextHandle = refContext._handle; + } + try + { + if (refContext == null || refContext.IsInvalid) + { + refContext = new SafeDeleteContext(); + } + + bool b = false; + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + refContext.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + refContext.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.CompleteAuthToken(contextHandle.IsZero ? null : &contextHandle, inSecurityBufferDescriptor); + refContext.DangerousRelease(); + } + } + } + finally + { + if (pinnedInBytes != null) + { + for (int index = 0; index < pinnedInBytes.Length; index++) + { + if (pinnedInBytes[index].IsAllocated) + { + pinnedInBytes[index].Free(); + } + } + } + } + } + + GlobalLog.Leave("SafeDeleteContext::CompleteAuthToken() unmanaged CompleteAuthToken()", "errorCode:0x" + errorCode.ToString("x8") + " refContext:" + ValidationHelper.ToString(refContext)); + + return errorCode; + } + + protected override bool ReleaseHandle() + { + if (this._effectiveCredential != null) + { + this._effectiveCredential.DangerousRelease(); + } + + return UnsafeNclNativeMethods.SafeNetHandles.DeleteSecurityContext(ref _handle) == 0; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCertContext.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCertContext.cs new file mode 100644 index 0000000000..20f85f5010 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCertContext.cs @@ -0,0 +1,35 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.ConstrainedExecution; +using System.Security; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Security.Windows +{ + [SuppressUnmanagedCodeSecurity] + internal sealed class SafeFreeCertContext : SafeHandleZeroOrMinusOneIsInvalid + { + internal SafeFreeCertContext() + : base(true) + { + } + + // This must be ONLY called from this file within a CER. + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal unsafe void Set(IntPtr value) + { + this.handle = value; + } + + protected override bool ReleaseHandle() + { + UnsafeNclNativeMethods.SafeNetHandles.CertFreeCertificateContext(handle); + return true; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBuffer.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBuffer.cs new file mode 100644 index 0000000000..db31e3018a --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBuffer.cs @@ -0,0 +1,148 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Security; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Security.Windows +{ + [SuppressUnmanagedCodeSecurity] + internal sealed class SafeFreeContextBuffer : SafeHandleZeroOrMinusOneIsInvalid + { + private SafeFreeContextBuffer() + : base(true) + { + } + + // This must be ONLY called from this file and in the context of a CER + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal unsafe void Set(IntPtr value) + { + this.handle = value; + } + + internal static int EnumeratePackages(out int pkgnum, out SafeFreeContextBuffer pkgArray) + { + int res = -1; + res = UnsafeNclNativeMethods.SafeNetHandles.EnumerateSecurityPackagesW(out pkgnum, out pkgArray); + if (res != 0 && pkgArray != null) + { + pkgArray.SetHandleAsInvalid(); + } + return res; + } + + internal static SafeFreeContextBuffer CreateEmptyHandle() + { + return new SafeFreeContextBuffer(); + } + + // After PINvoke call the method will fix the refHandle.handle with the returned value. + // The caller is responsible for creating a correct SafeHandle template or null can be passed if no handle is returned. + // + // This method switches between three non-interruptible helper methods. (This method can't be both non-interruptible and + // reference imports from all three DLLs - doing so would cause all three DLLs to try to be bound to.) + + public static unsafe int QueryContextAttributes(SafeDeleteContext phContext, ContextAttribute contextAttribute, byte* buffer, SafeHandle refHandle) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + + // We don't want to be interrupted by thread abort exceptions or unexpected out-of-memory errors failing to jit + // one of the following methods. So run within a CER non-interruptible block. + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + phContext.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + phContext.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.SafeNetHandles.QueryContextAttributesW(ref phContext._handle, contextAttribute, buffer); + phContext.DangerousRelease(); + } + + if (status == 0 && refHandle != null) + { + if (refHandle is SafeFreeContextBuffer) + { + ((SafeFreeContextBuffer)refHandle).Set(*(IntPtr*)buffer); + } + else + { + ((SafeFreeCertContext)refHandle).Set(*(IntPtr*)buffer); + } + } + + if (status != 0 && refHandle != null) + { + refHandle.SetHandleAsInvalid(); + } + } + + return status; + } + + public static int SetContextAttributes(SafeDeleteContext phContext, ContextAttribute contextAttribute, byte[] buffer) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + + // We don't want to be interrupted by thread abort exceptions or unexpected out-of-memory errors failing + // to jit one of the following methods. So run within a CER non-interruptible block. + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + phContext.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + phContext.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.SafeNetHandles.SetContextAttributesW( + ref phContext._handle, contextAttribute, buffer, buffer.Length); + phContext.DangerousRelease(); + } + } + + return status; + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.FreeContextBuffer(handle) == 0; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBufferChannelBinding.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBufferChannelBinding.cs new file mode 100644 index 0000000000..344ef38a59 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeContextBufferChannelBinding.cs @@ -0,0 +1,89 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.ConstrainedExecution; +using System.Security; +using System.Security.Authentication.ExtendedProtection; + +namespace Microsoft.AspNet.Security.Windows +{ + [SuppressUnmanagedCodeSecurity] + internal sealed class SafeFreeContextBufferChannelBinding : ChannelBinding + { + private int size; + + public override int Size + { + get { return size; } + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal unsafe void Set(IntPtr value) + { + this.handle = value; + } + + internal static SafeFreeContextBufferChannelBinding CreateEmptyHandle() + { + return new SafeFreeContextBufferChannelBinding(); + } + + public static unsafe int QueryContextChannelBinding(SafeDeleteContext phContext, ContextAttribute contextAttribute, Bindings* buffer, + SafeFreeContextBufferChannelBinding refHandle) + { + int status = (int)SecurityStatus.InvalidHandle; + bool b = false; + + // We don't want to be interrupted by thread abort exceptions or unexpected out-of-memory errors failing to jit + // one of the following methods. So run within a CER non-interruptible block. + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + phContext.DangerousAddRef(ref b); + } + catch (Exception e) + { + if (b) + { + phContext.DangerousRelease(); + b = false; + } + if (!(e is ObjectDisposedException)) + { + throw; + } + } + finally + { + if (b) + { + status = UnsafeNclNativeMethods.SafeNetHandles.QueryContextAttributesW(ref phContext._handle, contextAttribute, buffer); + phContext.DangerousRelease(); + } + + if (status == 0 && refHandle != null) + { + refHandle.Set((*buffer).pBindings); + refHandle.size = (*buffer).BindingsLength; + } + + if (status != 0 && refHandle != null) + { + refHandle.SetHandleAsInvalid(); + } + } + + return status; + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.FreeContextBuffer(handle) == 0; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCredentials.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCredentials.cs new file mode 100644 index 0000000000..0b42913028 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeFreeCredentials.cs @@ -0,0 +1,199 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + internal sealed class SafeFreeCredentials : SafeHandle + { + internal SSPIHandle _handle; // should be always used as by ref in PINvokes parameters + + private SafeFreeCredentials() + : base(IntPtr.Zero, true) + { + _handle = new SSPIHandle(); + } + + public override bool IsInvalid + { + get { return IsClosed || _handle.IsZero; } + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.FreeCredentialsHandle(ref _handle) == 0; + } + + public static unsafe int AcquireCredentialsHandle(string package, CredentialUse intent, ref AuthIdentity authdata, + out SafeFreeCredentials outCredential) + { + GlobalLog.Print("SafeFreeCredentials::AcquireCredentialsHandle#1(" + + package + ", " + + intent + ", " + + authdata + ")"); + + int errorCode = -1; + long timeStamp; + + outCredential = new SafeFreeCredentials(); + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + } + finally + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.AcquireCredentialsHandleW( + null, + package, + (int)intent, + null, + ref authdata, + null, + null, + ref outCredential._handle, + out timeStamp); + } + + if (errorCode != 0) + { + outCredential.SetHandleAsInvalid(); + } + return errorCode; + } + + public static unsafe int AcquireDefaultCredential(string package, CredentialUse intent, out SafeFreeCredentials outCredential) + { + GlobalLog.Print("SafeFreeCredentials::AcquireDefaultCredential(" + + package + ", " + + intent + ")"); + + int errorCode = -1; + long timeStamp; + + outCredential = new SafeFreeCredentials(); + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + } + finally + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.AcquireCredentialsHandleW( + null, + package, + (int)intent, + null, + IntPtr.Zero, + null, + null, + ref outCredential._handle, + out timeStamp); + } + + if (errorCode != 0) + { + outCredential.SetHandleAsInvalid(); + } + return errorCode; + } + + // This overload is only called on Win7+ where SspiEncodeStringsAsAuthIdentity() was used to + // create the authData blob. + public static unsafe int AcquireCredentialsHandle( + string package, + CredentialUse intent, + ref SafeSspiAuthDataHandle authdata, + out SafeFreeCredentials outCredential) + { + int errorCode = -1; + long timeStamp; + + outCredential = new SafeFreeCredentials(); + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + } + finally + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.AcquireCredentialsHandleW( + null, + package, + (int)intent, + null, + authdata, + null, + null, + ref outCredential._handle, + out timeStamp); + } + + if (errorCode != 0) + { + outCredential.SetHandleAsInvalid(); + } + return errorCode; + } + + public static unsafe int AcquireCredentialsHandle(string package, CredentialUse intent, ref SecureCredential authdata, + out SafeFreeCredentials outCredential) + { + GlobalLog.Print("SafeFreeCredentials::AcquireCredentialsHandle#2(" + + package + ", " + + intent + ", " + + authdata + ")"); + + int errorCode = -1; + long timeStamp; + + // If there is a certificate, wrap it into an array. + // Not threadsafe. + IntPtr copiedPtr = authdata.certContextArray; + try + { + IntPtr certArrayPtr = new IntPtr(&copiedPtr); + if (copiedPtr != IntPtr.Zero) + { + authdata.certContextArray = certArrayPtr; + } + + outCredential = new SafeFreeCredentials(); + + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + } + finally + { + errorCode = UnsafeNclNativeMethods.SafeNetHandles.AcquireCredentialsHandleW( + null, + package, + (int)intent, + null, + ref authdata, + null, + null, + ref outCredential._handle, + out timeStamp); + } + } + finally + { + authdata.certContextArray = copiedPtr; + } + + if (errorCode != 0) + { + outCredential.SetHandleAsInvalid(); + } + return errorCode; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeLocalFree.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeLocalFree.cs new file mode 100644 index 0000000000..30437854d7 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeLocalFree.cs @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Security; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Security.Windows +{ + [SuppressUnmanagedCodeSecurity] + internal sealed class SafeLocalFree : SafeHandleZeroOrMinusOneIsInvalid + { + private const int LmemFixed = 0; + private const int NULL = 0; + + // This returned handle cannot be modified by the application. + public static SafeLocalFree Zero = new SafeLocalFree(false); + + private SafeLocalFree() + : base(true) + { + } + + private SafeLocalFree(bool ownsHandle) + : base(ownsHandle) + { + } + + public static SafeLocalFree LocalAlloc(int cb) + { + SafeLocalFree result = UnsafeNclNativeMethods.SafeNetHandles.LocalAlloc(LmemFixed, (UIntPtr)cb); + if (result.IsInvalid) + { + result.SetHandleAsInvalid(); + throw new OutOfMemoryException(); + } + return result; + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.LocalFree(handle) == IntPtr.Zero; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeSspiAuthDataHandle.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeSspiAuthDataHandle.cs new file mode 100644 index 0000000000..e02c75aaba --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SafeSspiAuthDataHandle.cs @@ -0,0 +1,32 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + using System; + using System.Runtime.CompilerServices; + using System.Runtime.ConstrainedExecution; + using System.Runtime.InteropServices; + using System.Security; + using System.Security.Authentication.ExtendedProtection; + using System.Threading; + using Microsoft.Win32.SafeHandles; + + [SuppressUnmanagedCodeSecurity] + internal sealed class SafeSspiAuthDataHandle : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeSspiAuthDataHandle() + : base(true) + { + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SspiHelper.SspiFreeAuthIdentity(handle) == SecurityStatus.OK; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SchProtocols.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SchProtocols.cs new file mode 100644 index 0000000000..eb48aa1d3f --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SchProtocols.cs @@ -0,0 +1,41 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; + +namespace Microsoft.AspNet.Security.Windows +{ + // From Schannel.h + [Flags] + internal enum SchProtocols + { + Zero = 0, + PctClient = 0x00000002, + PctServer = 0x00000001, + Pct = (PctClient | PctServer), + Ssl2Client = 0x00000008, + Ssl2Server = 0x00000004, + Ssl2 = (Ssl2Client | Ssl2Server), + Ssl3Client = 0x00000020, + Ssl3Server = 0x00000010, + Ssl3 = (Ssl3Client | Ssl3Server), + Tls10Client = 0x00000080, + Tls10Server = 0x00000040, + Tls10 = (Tls10Client | Tls10Server), + Tls11Client = 0x00000200, + Tls11Server = 0x00000100, + Tls11 = (Tls11Client | Tls11Server), + Tls12Client = 0x00000800, + Tls12Server = 0x00000400, + Tls12 = (Tls12Client | Tls12Server), + Ssl3Tls = (Ssl3 | Tls10), + UniClient = unchecked((int)0x80000000), + UniServer = 0x40000000, + Unified = (UniClient | UniServer), + ClientMask = (PctClient | Ssl2Client | Ssl3Client | Tls10Client | Tls11Client | Tls12Client | UniClient), + ServerMask = (PctServer | Ssl2Server | Ssl3Server | Tls10Server | Tls11Server | Tls12Server | UniServer) + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecSizes.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecSizes.cs new file mode 100644 index 0000000000..0f1bb3e7a3 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecSizes.cs @@ -0,0 +1,44 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + using System; + using System.ComponentModel; + using System.Diagnostics; + using System.Globalization; + using System.Runtime.InteropServices; + + [StructLayout(LayoutKind.Sequential)] + internal class SecSizes + { + public readonly int MaxToken; + public readonly int MaxSignature; + public readonly int BlockSize; + public readonly int SecurityTrailer; + + internal unsafe SecSizes(byte[] memory) + { + fixed (void* voidPtr = memory) + { + IntPtr unmanagedAddress = new IntPtr(voidPtr); + try + { + MaxToken = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress)); + MaxSignature = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 4)); + BlockSize = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 8)); + SecurityTrailer = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 12)); + } + catch (OverflowException) + { + GlobalLog.Assert(false, "SecSizes::.ctor", "Negative size."); + throw; + } + } + } + public static readonly int SizeOf = Marshal.SizeOf(typeof(SecSizes)); + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBuffer.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBuffer.cs new file mode 100644 index 0000000000..069223dd4f --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBuffer.cs @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Globalization; +using System.Runtime.InteropServices; +using System.Security.Authentication.ExtendedProtection; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class SecurityBuffer + { + public int size; + public BufferType type; + public byte[] token; + public SafeHandle unmanagedToken; + public int offset; + + public SecurityBuffer(byte[] data, int offset, int size, BufferType tokentype) + { + GlobalLog.Assert(offset >= 0 && offset <= (data == null ? 0 : data.Length), "SecurityBuffer::.ctor", "'offset' out of range. [" + offset + "]"); + GlobalLog.Assert(size >= 0 && size <= (data == null ? 0 : data.Length - offset), "SecurityBuffer::.ctor", "'size' out of range. [" + size + "]"); + + this.offset = data == null || offset < 0 ? 0 : Math.Min(offset, data.Length); + this.size = data == null || size < 0 ? 0 : Math.Min(size, data.Length - this.offset); + this.type = tokentype; + this.token = size == 0 ? null : data; + } + + public SecurityBuffer(byte[] data, BufferType tokentype) + { + this.size = data == null ? 0 : data.Length; + this.type = tokentype; + this.token = size == 0 ? null : data; + } + + public SecurityBuffer(int size, BufferType tokentype) + { + GlobalLog.Assert(size >= 0, "SecurityBuffer::.ctor", "'size' out of range. [" + size.ToString(NumberFormatInfo.InvariantInfo) + "]"); + + this.size = size; + this.type = tokentype; + this.token = size == 0 ? null : new byte[size]; + } + + public SecurityBuffer(ChannelBinding binding) + { + this.size = (binding == null ? 0 : binding.Size); + this.type = BufferType.ChannelBindings; + this.unmanagedToken = binding; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBufferDescriptor.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBufferDescriptor.cs new file mode 100644 index 0000000000..138f12b2bd --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityBufferDescriptor.cs @@ -0,0 +1,33 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + [StructLayout(LayoutKind.Sequential)] + internal unsafe class SecurityBufferDescriptor + { + /* + typedef struct _SecBufferDesc { + ULONG ulVersion; + ULONG cBuffers; + PSecBuffer pBuffers; + } SecBufferDesc, * PSecBufferDesc; + */ + public readonly int Version; + public readonly int Count; + public void* UnmanagedPointer; + + public SecurityBufferDescriptor(int count) + { + Version = 0; + Count = count; + UnmanagedPointer = null; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityPackageInfoClass.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityPackageInfoClass.cs new file mode 100644 index 0000000000..df19cf7e37 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SecurityPackageInfoClass.cs @@ -0,0 +1,80 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Globalization; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class SecurityPackageInfoClass + { + private int _capabilities = 0; + private short _version = 0; + private short _rpcid = 0; + private int _maxToken = 0; + private string _name = null; + private string _comment = null; + + /* + * This is to support SSL under semi trusted environment. + * Note that it is only for SSL with no client cert + */ + internal SecurityPackageInfoClass(SafeHandle safeHandle, int index) + { + if (safeHandle.IsInvalid) + { + GlobalLog.Print("SecurityPackageInfoClass::.ctor() the pointer is invalid: " + (safeHandle.DangerousGetHandle()).ToString("x")); + return; + } + IntPtr unmanagedAddress = IntPtrHelper.Add(safeHandle.DangerousGetHandle(), SecurityPackageInfo.Size * index); + GlobalLog.Print("SecurityPackageInfoClass::.ctor() unmanagedPointer: " + ((long)unmanagedAddress).ToString("x")); + + _capabilities = Marshal.ReadInt32(unmanagedAddress, (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "Capabilities")); + _version = Marshal.ReadInt16(unmanagedAddress, (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "Version")); + _rpcid = Marshal.ReadInt16(unmanagedAddress, (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "RPCID")); + _maxToken = Marshal.ReadInt32(unmanagedAddress, (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "MaxToken")); + + IntPtr unmanagedString; + + unmanagedString = Marshal.ReadIntPtr(unmanagedAddress, (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "Name")); + if (unmanagedString != IntPtr.Zero) + { + _name = Marshal.PtrToStringUni(unmanagedString); + GlobalLog.Print("Name: " + Name); + } + + unmanagedString = Marshal.ReadIntPtr(unmanagedAddress, (int)Marshal.OffsetOf(typeof(SecurityPackageInfo), "Comment")); + if (unmanagedString != IntPtr.Zero) + { + _comment = Marshal.PtrToStringUni(unmanagedString); + GlobalLog.Print("Comment: " + _comment); + } + + GlobalLog.Print("SecurityPackageInfoClass::.ctor(): " + ToString()); + } + + internal int MaxToken + { + get { return _maxToken; } + } + + internal string Name + { + get { return _name; } + } + + public override string ToString() + { + return "Capabilities:" + String.Format(CultureInfo.InvariantCulture, "0x{0:x}", _capabilities) + + " Version:" + _version.ToString(NumberFormatInfo.InvariantInfo) + + " RPCID:" + _rpcid.ToString(NumberFormatInfo.InvariantInfo) + + " MaxToken:" + MaxToken.ToString(NumberFormatInfo.InvariantInfo) + + " Name:" + ((Name == null) ? "(null)" : Name) + + " Comment:" + ((_comment == null) ? "(null)" : _comment); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/SslConnectionInfo.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SslConnectionInfo.cs new file mode 100644 index 0000000000..2430b73b8d --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/SslConnectionInfo.cs @@ -0,0 +1,38 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + [StructLayout(LayoutKind.Sequential)] + internal class SslConnectionInfo + { + public readonly int Protocol; + public readonly int DataCipherAlg; + public readonly int DataKeySize; + public readonly int DataHashAlg; + public readonly int DataHashKeySize; + public readonly int KeyExchangeAlg; + public readonly int KeyExchKeySize; + + internal unsafe SslConnectionInfo(byte[] nativeBuffer) + { + fixed (void* voidPtr = nativeBuffer) + { + IntPtr unmanagedAddress = new IntPtr(voidPtr); + Protocol = Marshal.ReadInt32(unmanagedAddress); + DataCipherAlg = Marshal.ReadInt32(unmanagedAddress, 4); + DataKeySize = Marshal.ReadInt32(unmanagedAddress, 8); + DataHashAlg = Marshal.ReadInt32(unmanagedAddress, 12); + DataHashKeySize = Marshal.ReadInt32(unmanagedAddress, 16); + KeyExchangeAlg = Marshal.ReadInt32(unmanagedAddress, 20); + KeyExchKeySize = Marshal.ReadInt32(unmanagedAddress, 24); + } + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/StreamSizes.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/StreamSizes.cs new file mode 100644 index 0000000000..6be68aa96b --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/StreamSizes.cs @@ -0,0 +1,43 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Security.Windows +{ + [StructLayout(LayoutKind.Sequential)] + internal class StreamSizes + { + public int header; + public int trailer; + public int maximumMessage; + public int buffersCount; + public int blockSize; + + internal unsafe StreamSizes(byte[] memory) + { + fixed (void* voidPtr = memory) + { + IntPtr unmanagedAddress = new IntPtr(voidPtr); + try + { + header = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress)); + trailer = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 4)); + maximumMessage = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 8)); + buffersCount = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 12)); + blockSize = (int)checked((uint)Marshal.ReadInt32(unmanagedAddress, 16)); + } + catch (OverflowException) + { + GlobalLog.Assert(false, "StreamSizes::.ctor", "Negative size."); + throw; + } + } + } + public static readonly int SizeOf = Marshal.SizeOf(typeof(StreamSizes)); + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NativeInterop/UnsafeNativeMethods.cs b/src/Microsoft.AspNet.Security.Windows/NativeInterop/UnsafeNativeMethods.cs new file mode 100644 index 0000000000..3e8a40d344 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NativeInterop/UnsafeNativeMethods.cs @@ -0,0 +1,222 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Security; + +namespace Microsoft.AspNet.Security.Windows +{ + [System.Security.SuppressUnmanagedCodeSecurityAttribute] + internal static class UnsafeNclNativeMethods + { + private const string KERNEL32 = "kernel32.dll"; + private const string SECUR32 = "secur32.dll"; + private const string CRYPT32 = "crypt32.dll"; + + [DllImport(KERNEL32, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint GetCurrentThreadId(); + + [System.Security.SuppressUnmanagedCodeSecurityAttribute] + internal static class SafeNetHandles + { + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern int FreeContextBuffer( + [In] IntPtr contextBuffer); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern int FreeCredentialsHandle( + ref SSPIHandle handlePtr); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern int DeleteSecurityContext( + ref SSPIHandle handlePtr); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + internal static unsafe extern int AcceptSecurityContext( + ref SSPIHandle credentialHandle, + [In] void* inContextPtr, + [In] SecurityBufferDescriptor inputBuffer, + [In] ContextFlags inFlags, + [In] Endianness endianness, + ref SSPIHandle outContextPtr, + [In, Out] SecurityBufferDescriptor outputBuffer, + [In, Out] ref ContextFlags attributes, + out long timeStamp); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + internal static unsafe extern int QueryContextAttributesW( + ref SSPIHandle contextHandle, + [In] ContextAttribute attribute, + [In] void* buffer); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + internal static unsafe extern int SetContextAttributesW( + ref SSPIHandle contextHandle, + [In] ContextAttribute attribute, + [In] byte[] buffer, + [In] int bufferSize); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + internal static extern int EnumerateSecurityPackagesW( + [Out] out int pkgnum, + [Out] out SafeFreeContextBuffer handle); + + [DllImport(SECUR32, ExactSpelling = true, CharSet = CharSet.Unicode, SetLastError = true)] + internal static unsafe extern int AcquireCredentialsHandleW( + [In] string principal, + [In] string moduleName, + [In] int usage, + [In] void* logonID, + [In] ref AuthIdentity authdata, + [In] void* keyCallback, + [In] void* keyArgument, + ref SSPIHandle handlePtr, + [Out] out long timeStamp); + + [DllImport(SECUR32, ExactSpelling = true, CharSet = CharSet.Unicode, SetLastError = true)] + internal static unsafe extern int AcquireCredentialsHandleW( + [In] string principal, + [In] string moduleName, + [In] int usage, + [In] void* logonID, + [In] IntPtr zero, + [In] void* keyCallback, + [In] void* keyArgument, + ref SSPIHandle handlePtr, + [Out] out long timeStamp); + + // Win7+ + [DllImport(SECUR32, ExactSpelling = true, CharSet = CharSet.Unicode, SetLastError = true)] + internal static unsafe extern int AcquireCredentialsHandleW( + [In] string principal, + [In] string moduleName, + [In] int usage, + [In] void* logonID, + [In] SafeSspiAuthDataHandle authdata, + [In] void* keyCallback, + [In] void* keyArgument, + ref SSPIHandle handlePtr, + [Out] out long timeStamp); + + [DllImport(SECUR32, ExactSpelling = true, CharSet = CharSet.Unicode, SetLastError = true)] + internal static unsafe extern int AcquireCredentialsHandleW( + [In] string principal, + [In] string moduleName, + [In] int usage, + [In] void* logonID, + [In] ref SecureCredential authData, + [In] void* keyCallback, + [In] void* keyArgument, + ref SSPIHandle handlePtr, + [Out] out long timeStamp); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + internal static unsafe extern int InitializeSecurityContextW( + ref SSPIHandle credentialHandle, + [In] void* inContextPtr, + [In] byte* targetName, + [In] ContextFlags inFlags, + [In] int reservedI, + [In] Endianness endianness, + [In] SecurityBufferDescriptor inputBuffer, + [In] int reservedII, + ref SSPIHandle outContextPtr, + [In, Out] SecurityBufferDescriptor outputBuffer, + [In, Out] ref ContextFlags attributes, + out long timeStamp); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + internal static unsafe extern int CompleteAuthToken( + [In] void* inContextPtr, + [In, Out] SecurityBufferDescriptor inputBuffers); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + internal static extern int QuerySecurityContextToken(ref SSPIHandle phContext, [Out] out SafeCloseHandle handle); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern bool CloseHandle(IntPtr handle); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + internal static extern SafeLocalFree LocalAlloc(int uFlags, UIntPtr sizetdwBytes); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern IntPtr LocalFree(IntPtr handle); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern unsafe bool FreeLibrary([In] IntPtr hModule); + + [DllImport(CRYPT32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern void CertFreeCertificateChain( + [In] IntPtr pChainContext); + + [DllImport(CRYPT32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern void CertFreeCertificateChainList( + [In] IntPtr ppChainContext); + + [DllImport(CRYPT32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern bool CertFreeCertificateContext( // Suppressing returned status check, it's always==TRUE, + [In] IntPtr certContext); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)] + internal static extern IntPtr GlobalFree(IntPtr handle); + } + + [System.Security.SuppressUnmanagedCodeSecurityAttribute] + internal static class NativeNTSSPI + { + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + internal static extern int EncryptMessage( + ref SSPIHandle contextHandle, + [In] uint qualityOfProtection, + [In, Out] SecurityBufferDescriptor inputOutput, + [In] uint sequenceNumber); + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + internal static unsafe extern int DecryptMessage( + [In] ref SSPIHandle contextHandle, + [In, Out] SecurityBufferDescriptor inputOutput, + [In] uint sequenceNumber, + uint* qualityOfProtection); + } // class UnsafeNclNativeMethods.NativeNTSSPI + + [SuppressUnmanagedCodeSecurityAttribute] + internal static class SspiHelper + { + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + internal static unsafe extern SecurityStatus SspiFreeAuthIdentity( + [In] IntPtr authData); + + [SuppressMessage("Microsoft.Security", "CA2118:ReviewSuppressUnmanagedCodeSecurityUsage", Justification = "Implementation requires unmanaged code usage")] + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true, CharSet = CharSet.Unicode)] + internal static unsafe extern SecurityStatus SspiEncodeStringsAsAuthIdentity( + [In] string userName, + [In] string domainName, + [In] string password, + [Out] out SafeSspiAuthDataHandle authData); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/NegotiationInfoClass.cs b/src/Microsoft.AspNet.Security.Windows/NegotiationInfoClass.cs new file mode 100644 index 0000000000..ab7367c456 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/NegotiationInfoClass.cs @@ -0,0 +1,70 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Security.Windows +{ + using System; + using System.ComponentModel; + using System.Diagnostics; + using System.Globalization; + using System.Runtime.InteropServices; + + internal class NegotiationInfoClass + { + internal const string NTLM = "NTLM"; + internal const string Kerberos = "Kerberos"; + internal const string WDigest = "WDigest"; + internal const string Digest = "Digest"; + internal const string Negotiate = "Negotiate"; + internal string AuthenticationPackage; + + internal NegotiationInfoClass(SafeHandle safeHandle, int negotiationState) + { + if (safeHandle.IsInvalid) + { + GlobalLog.Print("NegotiationInfoClass::.ctor() the handle is invalid:" + (safeHandle.DangerousGetHandle()).ToString("x")); + return; + } + IntPtr packageInfo = safeHandle.DangerousGetHandle(); + GlobalLog.Print("NegotiationInfoClass::.ctor() packageInfo:" + packageInfo.ToString("x8") + " negotiationState:" + negotiationState.ToString("x8")); + + const int SECPKG_NEGOTIATION_COMPLETE = 0; + const int SECPKG_NEGOTIATION_OPTIMISTIC = 1; + // const int SECPKG_NEGOTIATION_IN_PROGRESS = 2; + // const int SECPKG_NEGOTIATION_DIRECT = 3; + // const int SECPKG_NEGOTIATION_TRY_MULTICRED = 4; + + if (negotiationState == SECPKG_NEGOTIATION_COMPLETE || negotiationState == SECPKG_NEGOTIATION_OPTIMISTIC) + { + IntPtr unmanagedString = Marshal.ReadIntPtr(packageInfo, SecurityPackageInfo.NameOffest); + string name = null; + if (unmanagedString != IntPtr.Zero) + { + name = Marshal.PtrToStringUni(unmanagedString); + } + GlobalLog.Print("NegotiationInfoClass::.ctor() packageInfo:" + packageInfo.ToString("x8") + " negotiationState:" + negotiationState.ToString("x8") + " name:" + ValidationHelper.ToString(name)); + + // an optimization for future string comparisons + if (string.Compare(name, Kerberos, StringComparison.OrdinalIgnoreCase) == 0) + { + AuthenticationPackage = Kerberos; + } + else if (string.Compare(name, NTLM, StringComparison.OrdinalIgnoreCase) == 0) + { + AuthenticationPackage = NTLM; + } + else if (string.Compare(name, WDigest, StringComparison.OrdinalIgnoreCase) == 0) + { + AuthenticationPackage = WDigest; + } + else + { + AuthenticationPackage = name; + } + } + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/PrefixCollection.cs b/src/Microsoft.AspNet.Security.Windows/PrefixCollection.cs new file mode 100644 index 0000000000..37a2caa049 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/PrefixCollection.cs @@ -0,0 +1,109 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class PrefixCollection : ICollection + { + private WindowsAuthMiddleware _winAuth; + + internal PrefixCollection(WindowsAuthMiddleware winAuth) + { + _winAuth = winAuth; + } + + public int Count + { + get + { + return _winAuth._uriPrefixes.Count; + } + } + + public bool IsSynchronized + { + get + { + return false; + } + } + + public bool IsReadOnly + { + get + { + return false; + } + } + + public void CopyTo(Array array, int offset) + { + if (Count > array.Length) + { + throw new ArgumentOutOfRangeException("array", SR.GetString(SR.net_array_too_small)); + } + if (offset + Count > array.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + int index = 0; + foreach (string uriPrefix in _winAuth._uriPrefixes.Keys) + { + array.SetValue(uriPrefix, offset + index++); + } + } + + public void CopyTo(string[] array, int offset) + { + if (Count > array.Length) + { + throw new ArgumentOutOfRangeException("array", SR.GetString(SR.net_array_too_small)); + } + if (offset + Count > array.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + int index = 0; + foreach (string uriPrefix in _winAuth._uriPrefixes.Keys) + { + array[offset + index++] = uriPrefix; + } + } + + public void Add(string uriPrefix) + { + _winAuth.AddPrefix(uriPrefix); + } + + public bool Contains(string uriPrefix) + { + return _winAuth._uriPrefixes.Contains(uriPrefix); + } + + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + public IEnumerator GetEnumerator() + { + return new PrefixEnumerator(_winAuth._uriPrefixes.Keys.GetEnumerator()); + } + + public bool Remove(string uriPrefix) + { + return _winAuth.RemovePrefix(uriPrefix); + } + + public void Clear() + { + _winAuth.RemoveAll(true); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/PrefixEnumerator.cs b/src/Microsoft.AspNet.Security.Windows/PrefixEnumerator.cs new file mode 100644 index 0000000000..40de3afd6d --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/PrefixEnumerator.cs @@ -0,0 +1,51 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System.Collections; +using System.Collections.Generic; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class PrefixEnumerator : IEnumerator + { + private IEnumerator enumerator; + + internal PrefixEnumerator(IEnumerator enumerator) + { + this.enumerator = enumerator; + } + + public string Current + { + get + { + return (string)enumerator.Current; + } + } + + object System.Collections.IEnumerator.Current + { + get + { + return enumerator.Current; + } + } + + public bool MoveNext() + { + return enumerator.MoveNext(); + } + + public void Dispose() + { + } + + void System.Collections.IEnumerator.Reset() + { + enumerator.Reset(); + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/Properties/AssemblyInfo.cs b/src/Microsoft.AspNet.Security.Windows/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..501727a48a --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/Properties/AssemblyInfo.cs @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("Microsoft.AspNet.Security.Windows")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.AspNet.Security.Windows")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("1f471909-581f-4060-a147-430891e9c3c1")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] diff --git a/src/Microsoft.AspNet.Security.Windows/ServiceNameStore.cs b/src/Microsoft.AspNet.Security.Windows/ServiceNameStore.cs new file mode 100644 index 0000000000..7d2d33ff00 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/ServiceNameStore.cs @@ -0,0 +1,368 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Net; +using System.Security.Authentication.ExtendedProtection; + +namespace Microsoft.AspNet.Security.Windows +{ + internal class ServiceNameStore + { + private List serviceNames; + private ServiceNameCollection serviceNameCollection; + + public ServiceNameStore() + { + serviceNames = new List(); + serviceNameCollection = null; // set only when needed (due to expensive item-by-item copy) + } + + public ServiceNameCollection ServiceNames + { + get + { + if (serviceNameCollection == null) + { + serviceNameCollection = new ServiceNameCollection(serviceNames); + } + return serviceNameCollection; + } + } + + private bool AddSingleServiceName(string spn) + { + spn = NormalizeServiceName(spn); + if (Contains(spn)) + { + return false; + } + else + { + serviceNames.Add(spn); + return true; + } + } + + public bool Add(string uriPrefix) + { + Debug.Assert(!String.IsNullOrEmpty(uriPrefix)); + + string[] newServiceNames = BuildServiceNames(uriPrefix); + + bool addedAny = false; + foreach (string spn in newServiceNames) + { + if (AddSingleServiceName(spn)) + { + addedAny = true; + + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, "ServiceNameStore#" + + ValidationHelper.HashString(this) + "::Add() " + + SR.GetString(SR.net_log_listener_spn_add, spn, uriPrefix)); + } + } + } + + if (addedAny) + { + serviceNameCollection = null; + } + else if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, "ServiceNameStore#" + + ValidationHelper.HashString(this) + "::Add() " + + SR.GetString(SR.net_log_listener_spn_not_add, uriPrefix)); + } + + return addedAny; + } + + public bool Remove(string uriPrefix) + { + Debug.Assert(!String.IsNullOrEmpty(uriPrefix)); + + string newServiceName = BuildSimpleServiceName(uriPrefix); + newServiceName = NormalizeServiceName(newServiceName); + bool needToRemove = Contains(newServiceName); + + if (needToRemove) + { + serviceNames.Remove(newServiceName); + serviceNameCollection = null; // invalidate (readonly) ServiceNameCollection + } + + if (Logging.On) + { + if (needToRemove) + { + Logging.PrintInfo(Logging.HttpListener, "ServiceNameStore#" + + ValidationHelper.HashString(this) + "::Remove() " + + SR.GetString(SR.net_log_listener_spn_remove, newServiceName, uriPrefix)); + } + else + { + Logging.PrintInfo(Logging.HttpListener, "ServiceNameStore#" + + ValidationHelper.HashString(this) + "::Remove() " + + SR.GetString(SR.net_log_listener_spn_not_remove, uriPrefix)); + } + } + + return needToRemove; + } + + // Assumes already normalized + private bool Contains(string newServiceName) + { + if (newServiceName == null) + { + return false; + } + + return Contains(newServiceName, serviceNames); + } + + // Assumes searchServiceName and serviceNames have already been normalized + internal static bool Contains(string searchServiceName, ICollection serviceNames) + { + Debug.Assert(serviceNames != null); + Debug.Assert(!String.IsNullOrEmpty(searchServiceName)); + + foreach (string serviceName in serviceNames) + { + if (Match(serviceName, searchServiceName)) + { + return true; + } + } + + return false; + } + + // Assumes already normalized + internal static bool Match(string serviceName1, string serviceName2) + { + return (String.Compare(serviceName1, serviceName2, StringComparison.OrdinalIgnoreCase) == 0); + } + + public void Clear() + { + serviceNames.Clear(); + serviceNameCollection = null; // invalidate (readonly) ServiceNameCollection + } + + private string ExtractHostname(string uriPrefix, bool allowInvalidUriStrings) + { + if (Uri.IsWellFormedUriString(uriPrefix, UriKind.Absolute)) + { + Uri hostUri = new Uri(uriPrefix); + return hostUri.Host; + } + else if (allowInvalidUriStrings) + { + int i = uriPrefix.IndexOf("://") + 3; + int j = i; + + bool inSquareBrackets = false; + while (j < uriPrefix.Length && uriPrefix[j] != '/' && (uriPrefix[j] != ':' || inSquareBrackets)) + { + if (uriPrefix[j] == '[') + { + if (inSquareBrackets) + { + j = i; + break; + } + inSquareBrackets = true; + } + if (inSquareBrackets && uriPrefix[j] == ']') + { + inSquareBrackets = false; + } + j++; + } + + return uriPrefix.Substring(i, j - i); + } + + return null; + } + + public string BuildSimpleServiceName(string uriPrefix) + { + string hostname = ExtractHostname(uriPrefix, false); + + if (hostname != null) + { + return "HTTP/" + hostname; + } + else + { + return null; + } + } + + public string[] BuildServiceNames(string uriPrefix) + { + string hostname = ExtractHostname(uriPrefix, true); + + IPAddress ipAddress = null; + if (String.Compare(hostname, "*", StringComparison.InvariantCultureIgnoreCase) == 0 || + String.Compare(hostname, "+", StringComparison.InvariantCultureIgnoreCase) == 0 || + IPAddress.TryParse(hostname, out ipAddress)) + { + // for a wildcard, register the machine name. If the caller doesn't have DNS permission + // or the query fails for some reason, don't add an SPN. + try + { + string machineName = Dns.GetHostEntry(String.Empty).HostName; + return new string[] { "HTTP/" + machineName }; + } + catch (System.Net.Sockets.SocketException) + { + return new string[0]; + } + catch (System.Security.SecurityException) + { + return new string[0]; + } + } + else if (!hostname.Contains(".")) + { + // for a dotless name, try to resolve the FQDN. If the caller doesn't have DNS permission + // or the query fails for some reason, add only the dotless name. + try + { + string fqdn = Dns.GetHostEntry(hostname).HostName; + return new string[] { "HTTP/" + hostname, "HTTP/" + fqdn }; + } + catch (System.Net.Sockets.SocketException) + { + return new string[] { "HTTP/" + hostname }; + } + catch (System.Security.SecurityException) + { + return new string[] { "HTTP/" + hostname }; + } + } + else + { + return new string[] { "HTTP/" + hostname }; + } + } + + // Normalizes any punycode to unicode in an Service Name (SPN) host. + // If the algorithm fails at any point then the original input is returned. + // ServiceName is in one of the following forms: + // prefix/host + // prefix/host:port + // prefix/host/DistinguishedName + // prefix/host:port/DistinguishedName + internal static string NormalizeServiceName(string inputServiceName) + { + if (string.IsNullOrWhiteSpace(inputServiceName)) + { + return inputServiceName; + } + + // Separate out the prefix + int shashIndex = inputServiceName.IndexOf('/'); + if (shashIndex < 0) + { + return inputServiceName; + } + string prefix = inputServiceName.Substring(0, shashIndex + 1); // Includes slash + string hostPortAndDistinguisher = inputServiceName.Substring(shashIndex + 1); // Excludes slash + + if (string.IsNullOrWhiteSpace(hostPortAndDistinguisher)) + { + return inputServiceName; + } + + string host = hostPortAndDistinguisher; + string port = string.Empty; + string distinguisher = string.Empty; + + // Check for the absence of a port or distinguisher. + UriHostNameType hostType = Uri.CheckHostName(hostPortAndDistinguisher); + if (hostType == UriHostNameType.Unknown) + { + string hostAndPort = hostPortAndDistinguisher; + + // Check for distinguisher + int nextSlashIndex = hostPortAndDistinguisher.IndexOf('/'); + if (nextSlashIndex >= 0) + { + // host:port/distinguisher or host/distinguisher + hostAndPort = hostPortAndDistinguisher.Substring(0, nextSlashIndex); // Excludes Slash + distinguisher = hostPortAndDistinguisher.Substring(nextSlashIndex); // Includes Slash + host = hostAndPort; // We don't know if there is a port yet. + + // No need to validate the distinguisher + } + + // Check for port + int colonIndex = hostAndPort.LastIndexOf(':'); // Allow IPv6 addresses + if (colonIndex >= 0) + { + // host:port + host = hostAndPort.Substring(0, colonIndex); // Excludes colon + port = hostAndPort.Substring(colonIndex + 1); // Excludes colon + + // Loosely validate the port just to make sure it was a port and not something else + UInt16 portValue; + if (!UInt16.TryParse(port, NumberStyles.Integer, CultureInfo.InvariantCulture, out portValue)) + { + return inputServiceName; + } + + // Re-include the colon for the final output. Do not change the port format. + port = hostAndPort.Substring(colonIndex); + } + + hostType = Uri.CheckHostName(host); // Revaidate the host + } + + if (hostType != UriHostNameType.Dns) + { + // UriHostNameType.IPv4, UriHostNameType.IPv6: Do not normalize IPv4/6 hosts. + // UriHostNameType.Basic: This is never returned by CheckHostName today + // UriHostNameType.Unknown: Nothing recognizable to normalize + // default Some new UriHostNameType? + return inputServiceName; + } + + // Now we have a valid DNS host, normalize it. + + Uri constructedUri; + // This shouldn't fail, but we need to avoid any unexpected exceptions on this code path. + if (!Uri.TryCreate(Uri.UriSchemeHttp + Uri.SchemeDelimiter + host, UriKind.Absolute, out constructedUri)) + { + return inputServiceName; + } + + string normalizedHost = constructedUri.GetComponents( + UriComponents.NormalizedHost, UriFormat.SafeUnescaped); + + string normalizedServiceName = string.Format(CultureInfo.InvariantCulture, + "{0}{1}{2}{3}", prefix, normalizedHost, port, distinguisher); + + // Don't return the new one unless we absolutely have to. It may have only changed casing. + if (String.Compare(inputServiceName, normalizedServiceName, StringComparison.OrdinalIgnoreCase) == 0) + { + return inputServiceName; + } + + return normalizedServiceName; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/WindowsAuthMiddleware.cs b/src/Microsoft.AspNet.Security.Windows/WindowsAuthMiddleware.cs new file mode 100644 index 0000000000..85f696728f --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/WindowsAuthMiddleware.cs @@ -0,0 +1,1281 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.Contracts; +using System.Net; +using System.Security.Authentication.ExtendedProtection; +using System.Security.Permissions; +using System.Security.Principal; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Security.Windows +{ + using AppFunc = Func, Task>; + + /// + /// A middleware that performs Windows Authentication of the specified types. + /// + public sealed class WindowsAuthMiddleware + { + private Func, AuthTypes> _authenticationDelegate; + private AuthTypes _authenticationScheme = AuthTypes.Negotiate | AuthTypes.Ntlm | AuthTypes.Digest; + private string _realm; + private PrefixCollection _prefixes; + private bool _unsafeConnectionNtlmAuthentication; + private Func, ExtendedProtectionPolicy> _extendedProtectionSelectorDelegate; + private ExtendedProtectionPolicy _extendedProtectionPolicy; + private ServiceNameStore _defaultServiceNames; + + private Hashtable _disconnectResults; // ulong -> DisconnectAsyncResult + private object _internalLock; + + internal Hashtable _uriPrefixes; + private DigestCache _digestCache; + + private AppFunc _nextApp; + + // TODO: Support proxy auth + // private bool _doProxyAuth; + + /// + /// + /// + /// + public WindowsAuthMiddleware(AppFunc nextApp) + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "WindowsAuth", string.Empty); + } + + _internalLock = new object(); + _defaultServiceNames = new ServiceNameStore(); + + // default: no CBT checks on any platform (appcompat reasons); applies also to PolicyEnforcement + // config element + _extendedProtectionPolicy = new ExtendedProtectionPolicy(PolicyEnforcement.Never); + _uriPrefixes = new Hashtable(); + _digestCache = new DigestCache(); + + _nextApp = nextApp; + + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "WindowsAuth", string.Empty); + } + } + + /// + /// Dynamically select the type of authentication to apply per request. + /// + public Func, AuthTypes> AuthenticationSchemeSelectorDelegate + { + get + { + return _authenticationDelegate; + } + set + { + _authenticationDelegate = value; + } + } + + /// + /// Dynamically select the type of extended protection to apply per request. + /// + public Func, ExtendedProtectionPolicy> ExtendedProtectionSelectorDelegate + { + get + { + return _extendedProtectionSelectorDelegate; + } + set + { + if (value == null) + { + throw new ArgumentNullException(); + } + + if (!ExtendedProtectionPolicy.OSSupportsExtendedProtection) + { + throw new PlatformNotSupportedException(SR.GetString(SR.security_ExtendedProtection_NoOSSupport)); + } + + _extendedProtectionSelectorDelegate = value; + } + } + + /// + /// Specifies which types of Windows authentication are enabled. + /// + public AuthTypes AuthenticationSchemes + { + get + { + return _authenticationScheme; + } + set + { + _authenticationScheme = value; + } + } + + /// + /// Configures extended protection. + /// + public ExtendedProtectionPolicy ExtendedProtectionPolicy + { + get + { + return _extendedProtectionPolicy; + } + set + { + if (value == null) + { + throw new ArgumentNullException("value"); + } + if (!ExtendedProtectionPolicy.OSSupportsExtendedProtection && value.PolicyEnforcement == PolicyEnforcement.Always) + { + throw new PlatformNotSupportedException(SR.GetString(SR.security_ExtendedProtection_NoOSSupport)); + } + if (value.CustomChannelBinding != null) + { + throw new ArgumentException(SR.GetString(SR.net_listener_cannot_set_custom_cbt), "CustomChannelBinding"); + } + + _extendedProtectionPolicy = value; + } + } + + /// + /// Configures the service names for extended protection. + /// + public ServiceNameCollection DefaultServiceNames + { + get + { + return _defaultServiceNames.ServiceNames; + } + } + + /// + /// The Realm for use in digest authentication. + /// + public string Realm + { + get + { + return _realm; + } + set + { + _realm = value; + } + } + + /// + /// Enables authenticated connection sharing with NTLM. + /// + public bool UnsafeConnectionNtlmAuthentication + { + get + { + return _unsafeConnectionNtlmAuthentication; + } + + set + { + if (_unsafeConnectionNtlmAuthentication == value) + { + return; + } + lock (DisconnectResults.SyncRoot) + { + if (_unsafeConnectionNtlmAuthentication == value) + { + return; + } + _unsafeConnectionNtlmAuthentication = value; + if (!value) + { + foreach (DisconnectAsyncResult result in DisconnectResults.Values) + { + result.AuthenticatedUser = null; + } + } + } + } + } + + internal Hashtable DisconnectResults + { + get + { + if (_disconnectResults == null) + { + Interlocked.CompareExchange(ref _disconnectResults, Hashtable.Synchronized(new Hashtable()), null); + } + return _disconnectResults; + } + } + + internal unsafe void AddPrefix(string uriPrefix) + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "AddPrefix", "uriPrefix:" + uriPrefix); + } + string registeredPrefix = null; + try + { + if (uriPrefix == null) + { + throw new ArgumentNullException("uriPrefix"); + } + (new WebPermission(NetworkAccess.Accept, uriPrefix)).Demand(); + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::AddPrefix() uriPrefix:" + uriPrefix); + int i; + if (string.Compare(uriPrefix, 0, "http://", 0, 7, StringComparison.OrdinalIgnoreCase) == 0) + { + i = 7; + } + else if (string.Compare(uriPrefix, 0, "https://", 0, 8, StringComparison.OrdinalIgnoreCase) == 0) + { + i = 8; + } + else + { + throw new ArgumentException(SR.GetString(SR.net_listener_scheme), "uriPrefix"); + } + bool inSquareBrakets = false; + int j = i; + while (j < uriPrefix.Length && uriPrefix[j] != '/' && (uriPrefix[j] != ':' || inSquareBrakets)) + { + if (uriPrefix[j] == '[') + { + if (inSquareBrakets) + { + j = i; + break; + } + inSquareBrakets = true; + } + if (inSquareBrakets && uriPrefix[j] == ']') + { + inSquareBrakets = false; + } + j++; + } + if (i == j) + { + throw new ArgumentException(SR.GetString(SR.net_listener_host), "uriPrefix"); + } + if (uriPrefix[uriPrefix.Length - 1] != '/') + { + throw new ArgumentException(SR.GetString(SR.net_listener_slash), "uriPrefix"); + } + registeredPrefix = uriPrefix[j] == ':' ? String.Copy(uriPrefix) : uriPrefix.Substring(0, j) + (i == 7 ? ":80" : ":443") + uriPrefix.Substring(j); + fixed (char* pChar = registeredPrefix) + { + i = 0; + while (pChar[i] != ':') + { + pChar[i] = (char)CaseInsensitiveAscii.AsciiToLower[(byte)pChar[i]]; + i++; + } + } + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::AddPrefix() mapped uriPrefix:" + uriPrefix + " to registeredPrefix:" + registeredPrefix); + + _uriPrefixes[uriPrefix] = registeredPrefix; + _defaultServiceNames.Add(uriPrefix); + } + catch (Exception exception) + { + if (Logging.On) + { + Logging.Exception(Logging.HttpListener, this, "AddPrefix", exception); + } + throw; + } + finally + { + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "AddPrefix", "prefix:" + registeredPrefix); + } + } + } + + internal PrefixCollection Prefixes + { + get + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "Prefixes_get", string.Empty); + } + if (_prefixes == null) + { + _prefixes = new PrefixCollection(this); + } + return _prefixes; + } + } + + internal bool RemovePrefix(string uriPrefix) + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "RemovePrefix", "uriPrefix:" + uriPrefix); + } + try + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::RemovePrefix() uriPrefix:" + uriPrefix); + if (uriPrefix == null) + { + throw new ArgumentNullException("uriPrefix"); + } + + if (!_uriPrefixes.Contains(uriPrefix)) + { + return false; + } + + _uriPrefixes.Remove(uriPrefix); + _defaultServiceNames.Remove(uriPrefix); + } + catch (Exception exception) + { + if (Logging.On) + { + Logging.Exception(Logging.HttpListener, this, "RemovePrefix", exception); + } + throw; + } + finally + { + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "RemovePrefix", "uriPrefix:" + uriPrefix); + } + } + return true; + } + + internal void RemoveAll(bool clear) + { + if (Logging.On) + { + Logging.Enter(Logging.HttpListener, this, "RemoveAll", string.Empty); + } + try + { + // go through the uri list and unregister for each one of them + if (_uriPrefixes.Count > 0) + { + if (clear) + { + _uriPrefixes.Clear(); + _defaultServiceNames.Clear(); + } + } + } + finally + { + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "RemoveAll", string.Empty); + } + } + } + + // old API, now private, and helper methods + private void Dispose(bool disposing) + { + GlobalLog.Assert(disposing, "Dispose(bool) does nothing if called from the finalizer."); + + if (!disposing) + { + return; + } + + try + { + _digestCache.Dispose(); + } + finally + { + if (Logging.On) + { + Logging.Exit(Logging.HttpListener, this, "Dispose", string.Empty); + } + } + } + + /// + /// + /// + /// + /// + public Task Invoke(IDictionary env) + { + // Process the auth header, if any + if (!TryHandleAuthentication(env)) + { + // If failed and a 400/401/500 was sent. + return Task.FromResult(null); + } + + // If passing through, register for OnSendingHeaders. Add an auth header challenge on 401. + var registerOnSendingHeaders = env.Get, object>>(Constants.ServerOnSendingHeadersKey); + if (registerOnSendingHeaders == null) + { + // This module requires OnSendingHeaders support. + throw new PlatformNotSupportedException(); + } + registerOnSendingHeaders(Set401Challenges, env); + + // Invoke the next item in the app chain + return _nextApp(env); + } + + // Returns true if auth completed successfully (or anonymous), false if there was an auth header + // but processing it failed. + private bool TryHandleAuthentication(IDictionary env) + { + DisconnectAsyncResult disconnectResult; + object connectionId = env.Get(Constants.ServerConnectionIdKey, -1); + string authorizationHeader = null; + if (!TryGetIncomingAuthHeader(env, out authorizationHeader)) + { + if (UnsafeConnectionNtlmAuthentication) + { + disconnectResult = (DisconnectAsyncResult)DisconnectResults[connectionId]; + if (disconnectResult != null) + { + WindowsPrincipal principal = disconnectResult.AuthenticatedUser; + if (principal != null) + { + // This connection has already been authenticated; + SetIdentity(env, principal, null); + } + } + } + + return true; // Anonymous or UnsafeConnectionNtlmAuthentication + } + + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() authorizationHeader:" + ValidationHelper.ToString(authorizationHeader)); + + if (UnsafeConnectionNtlmAuthentication) + { + disconnectResult = (DisconnectAsyncResult)DisconnectResults[connectionId]; + // They sent an authorization header - destroy their previous credentials. + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() clearing principal cache"); + if (disconnectResult != null) + { + disconnectResult.AuthenticatedUser = null; + } + } + + try + { + AuthTypes headerScheme; + string inBlob; + if (!TryGetRecognizedAuthScheme(authorizationHeader, out headerScheme, out inBlob)) + { + return true; // Anonymous / pass through + } + Contract.Assert(headerScheme != AuthTypes.None); + Contract.Assert(inBlob != null); + + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() Performing Authentication headerScheme:" + ValidationHelper.ToString(headerScheme)); + switch (headerScheme) + { + case AuthTypes.Digest: + return TryAuthenticateWithDigest(env, inBlob); + + case AuthTypes.Negotiate: + case AuthTypes.Ntlm: + string package = headerScheme == AuthTypes.Ntlm ? NegotiationInfoClass.NTLM : NegotiationInfoClass.Negotiate; + return TryAuthenticateWithNegotiate(env, package, inBlob); + + default: + throw new NotImplementedException(headerScheme.ToString()); + } + } + catch (Exception) + { + SendError(env, HttpStatusCode.InternalServerError, null); + return false; + } + } + + // TODO: Support proxy auth + private bool TryGetIncomingAuthHeader(IDictionary env, out string authorizationHeader) + { + IDictionary headers = env.Get>(Constants.RequestHeadersKey); + authorizationHeader = headers.Get("Authorization"); + return !string.IsNullOrWhiteSpace(authorizationHeader); + } + + private bool TryGetRecognizedAuthScheme(string authorizationHeader, out AuthTypes headerScheme, out string inBlob) + { + headerScheme = AuthTypes.None; + + int index; + // Find the end of the scheme name. Trust that HTTP.SYS parsed out just our header ok. + for (index = 0; index < authorizationHeader.Length; index++) + { + if (authorizationHeader[index] == ' ' || authorizationHeader[index] == '\t' || + authorizationHeader[index] == '\r' || authorizationHeader[index] == '\n') + { + break; + } + } + + // Currently only allow one Authorization scheme/header per request. + if (index < authorizationHeader.Length) + { + if ((AuthenticationSchemes & AuthTypes.Negotiate) != AuthTypes.None && + string.Compare(authorizationHeader, 0, NegotiationInfoClass.Negotiate, 0, index, StringComparison.OrdinalIgnoreCase) == 0) + { + headerScheme = AuthTypes.Negotiate; + } + else if ((AuthenticationSchemes & AuthTypes.Ntlm) != AuthTypes.None && + string.Compare(authorizationHeader, 0, NegotiationInfoClass.NTLM, 0, index, StringComparison.OrdinalIgnoreCase) == 0) + { + headerScheme = AuthTypes.Ntlm; + } + else if ((AuthenticationSchemes & AuthTypes.Digest) != AuthTypes.None && + string.Compare(authorizationHeader, 0, NegotiationInfoClass.Digest, 0, index, StringComparison.OrdinalIgnoreCase) == 0) + { + headerScheme = AuthTypes.Digest; + } + } + + // Find the beginning of the blob. Trust that HTTP.SYS parsed out just our header ok. + for (index++; index < authorizationHeader.Length; index++) + { + if (authorizationHeader[index] != ' ' && authorizationHeader[index] != '\t' && + authorizationHeader[index] != '\r' && authorizationHeader[index] != '\n') + { + break; + } + } + inBlob = index < authorizationHeader.Length ? authorizationHeader.Substring(index) : string.Empty; + + return headerScheme != AuthTypes.None; + } + + // Returns true if successfully authenticated via Digest. Returns false if a 401 was sent. + private bool TryAuthenticateWithDigest(IDictionary env, string inBlob) + { + NTAuthentication context = null; + IPrincipal principal = null; + SecurityStatus statusCodeNew; + ChannelBinding binding; + string outBlob; + HttpStatusCode httpError = HttpStatusCode.OK; + string verb = env.Get(Constants.RequestMethodKey); + bool isSecureConnection = IsSecureConnection(env); + // GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() package:WDigest headerScheme:" + headerScheme); + + // WDigest had some weird behavior. This is what I have discovered: + // Local accounts don't work, only domain accounts. The domain (i.e. REDMOND) is implied. Not sure how it is chosen. + // If the domain is specified and the credentials are correct, it works. If they're not (domain, username or password): + // AcceptSecurityContext (GetOutgoingDigestBlob) returns success but with a bogus 4k challenge, and + // QuerySecurityContextToken (GetContextToken) fails with NoImpersonation. + // If the domain isn't specified, AcceptSecurityContext returns NoAuthenticatingAuthority for a bad username, + // and LogonDenied for a bad password. + + // Also interesting is that WDigest requires us to keep a reference to the previous context, but fails if we + // actually pass it in! (It't ok to pass it in for the first request, but not if nc > 1.) For Whidbey, + // we create a new context and associate it with the connection, just like NTLM, but instead of using it for + // the next request on the connection, we always create a new context and swap the old one out. As long + // as we keep the old one around until after we authenticate with the new one, it works. For this reason, + // we also keep these contexts around past the lifetime of the connection, so that KeepAlive=false works. + binding = GetChannelBinding(env, isSecureConnection, ExtendedProtectionPolicy); + + context = new NTAuthentication(true, NegotiationInfoClass.WDigest, null, + GetContextFlags(ExtendedProtectionPolicy, isSecureConnection), binding); + + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() verb:" + verb + " context.IsValidContext:" + context.IsValidContext.ToString()); + + outBlob = context.GetOutgoingDigestBlob(inBlob, verb, null, Realm, false, false, out statusCodeNew); + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() GetOutgoingDigestBlob() returned IsCompleted:" + context.IsCompleted + " statusCodeNew:" + statusCodeNew + " outBlob:[" + outBlob + "]"); + + // WDigest bug: sometimes when AcceptSecurityContext returns success, it provides a bogus, empty 4k buffer. + // Ignore it. (Should find out what's going on here from WDigest people.) + if (statusCodeNew == SecurityStatus.OK) + { + outBlob = null; + } + + IList challenges = null; + if (outBlob != null) + { + string challenge = NegotiationInfoClass.Digest + " " + outBlob; + AddChallenge(ref challenges, challenge); + } + + if (context.IsValidContext) + { + SafeCloseHandle userContext = null; + try + { + if (!CheckSpn(context, isSecureConnection, ExtendedProtectionPolicy)) + { + httpError = HttpStatusCode.Unauthorized; + } + else + { + SetServiceName(env, context.ClientSpecifiedSpn); + + userContext = context.GetContextToken(out statusCodeNew); + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() GetContextToken() returned:" + statusCodeNew.ToString()); + if (statusCodeNew != SecurityStatus.OK) + { + httpError = HttpStatusFromSecurityStatus(statusCodeNew); + } + else if (userContext == null) + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() error: GetContextToken() returned:null statusCodeNew:" + statusCodeNew.ToString()); + httpError = HttpStatusCode.Unauthorized; + } + else + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() creating new WindowsIdentity() from userContext:" + userContext.DangerousGetHandle().ToString("x8")); + principal = new WindowsPrincipal(CreateWindowsIdentity(userContext.DangerousGetHandle(), "Digest"/*DigestClient.AuthType*/, WindowsAccountType.Normal, true)); + SetIdentity(env, principal, null); + _digestCache.SaveDigestContext(context); + } + } + } + finally + { + if (userContext != null) + { + userContext.Dispose(); + } + } + } + else + { + httpError = HttpStatusFromSecurityStatus(statusCodeNew); + } + + if (httpError != HttpStatusCode.OK) + { + SendError(env, httpError, challenges); + return false; + } + return true; + } + + // Negotiate or NTLM + private bool TryAuthenticateWithNegotiate(IDictionary env, string package, string inBlob) + { + object connectionId = env.Get(Constants.ServerConnectionIdKey, null); + if (connectionId == null) + { + // We need a connection ID from the server to correctly track in-progress auth. + throw new PlatformNotSupportedException(); + } + + NTAuthentication oldContext = null, context; + DisconnectAsyncResult disconnectResult = (DisconnectAsyncResult)DisconnectResults[connectionId]; + if (disconnectResult != null) + { + oldContext = disconnectResult.Session; + } + ChannelBinding binding; + bool isSecureConnection = IsSecureConnection(env); + byte[] bytes = null; + HttpStatusCode httpError = HttpStatusCode.OK; + bool error = false; + string outBlob = null; + + if (oldContext != null && oldContext.Package == package) + { + context = oldContext; + } + else + { + binding = GetChannelBinding(env, isSecureConnection, ExtendedProtectionPolicy); + + context = new NTAuthentication(true, package, null, + GetContextFlags(ExtendedProtectionPolicy, isSecureConnection), binding); + + // Clean up old context + if (oldContext != null) + { + oldContext.CloseContext(); + } + } + + try + { + bytes = Convert.FromBase64String(inBlob); + } + catch (FormatException) + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() FromBase64String threw a FormatException."); + httpError = HttpStatusCode.BadRequest; + error = true; + } + + byte[] decodedOutgoingBlob = null; + SecurityStatus statusCodeNew; + if (!error) + { + decodedOutgoingBlob = context.GetOutgoingBlob(bytes, false, out statusCodeNew); + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() GetOutgoingBlob() returned IsCompleted:" + context.IsCompleted + " statusCodeNew:" + statusCodeNew); + error = !context.IsValidContext; + if (error) + { + // Bug #474228: SSPI Workaround + // If a client sends up a blob on the initial request, Negotiate returns SEC_E_INVALID_HANDLE + // when it should return SEC_E_INVALID_TOKEN. + if (statusCodeNew == SecurityStatus.InvalidHandle && oldContext == null && bytes != null && bytes.Length > 0) + { + statusCodeNew = SecurityStatus.InvalidToken; + } + + httpError = HttpStatusFromSecurityStatus(statusCodeNew); + } + } + + if (decodedOutgoingBlob != null) + { + outBlob = Convert.ToBase64String(decodedOutgoingBlob); + } + + if (!error) + { + if (context.IsCompleted) + { + SafeCloseHandle userContext = null; + try + { + if (!CheckSpn(context, isSecureConnection, ExtendedProtectionPolicy)) + { + httpError = HttpStatusCode.Unauthorized; + } + else + { + SetServiceName(env, context.ClientSpecifiedSpn); + + userContext = context.GetContextToken(out statusCodeNew); + if (statusCodeNew != SecurityStatus.OK) + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() GetContextToken() failed with statusCodeNew:" + statusCodeNew.ToString()); + httpError = HttpStatusFromSecurityStatus(statusCodeNew); + } + else + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::HandleAuthentication() creating new WindowsIdentity() from userContext:" + userContext.DangerousGetHandle().ToString("x8")); + WindowsPrincipal windowsPrincipal = new WindowsPrincipal(CreateWindowsIdentity(userContext.DangerousGetHandle(), context.ProtocolName, WindowsAccountType.Normal, true)); + SetIdentity(env, windowsPrincipal, outBlob); + + // if appropriate, cache this credential on this connection + if (UnsafeConnectionNtlmAuthentication + && context.ProtocolName.Equals(NegotiationInfoClass.NTLM, StringComparison.OrdinalIgnoreCase)) + { + // We may need to call WaitForDisconnect. + if (disconnectResult == null) + { + RegisterForDisconnectNotification(env, out disconnectResult); + } + + if (disconnectResult != null) + { + lock (DisconnectResults.SyncRoot) + { + if (UnsafeConnectionNtlmAuthentication) + { + disconnectResult.AuthenticatedUser = windowsPrincipal; + } + } + } + } + } + } + } + finally + { + if (userContext != null) + { + userContext.Dispose(); + } + } + return true; + } + else + { + // auth incomplete + if (disconnectResult == null) + { + RegisterForDisconnectNotification(env, out disconnectResult); + + // Failed - send 500. + if (disconnectResult == null) + { + context.CloseContext(); + SendError(env, HttpStatusCode.InternalServerError, null); + return false; + } + } + + disconnectResult.Session = context; + + string challenge = package; + if (!String.IsNullOrEmpty(outBlob)) + { + challenge += " " + outBlob; + } + IList challenges = null; + AddChallenge(ref challenges, challenge); + SendError(env, HttpStatusCode.Unauthorized, challenges); + return false; + } + } + + SendError(env, httpError, null); + return false; + } + + private void SetIdentity(IDictionary env, IPrincipal principal, string mutualAuth) + { + env[Constants.ServerUserKey] = principal; + if (!string.IsNullOrWhiteSpace(mutualAuth)) + { + var responseHeaders = env.Get>(Constants.ResponseHeadersKey); + responseHeaders.Append(HttpKnownHeaderNames.WWWAuthenticate, mutualAuth); + } + } + + // For user info only + private void SetServiceName(IDictionary env, string serviceName) + { + if (!string.IsNullOrWhiteSpace(serviceName)) + { + env[Constants.SslSpnKey] = serviceName; + } + } + + [SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.UnmanagedCode)] + [SecurityPermission(SecurityAction.Assert, Flags = SecurityPermissionFlag.ControlPrincipal)] + internal static WindowsIdentity CreateWindowsIdentity(IntPtr userToken, string type, WindowsAccountType acctType, bool isAuthenticated) + { + return new WindowsIdentity(userToken, type, acctType, isAuthenticated); + } + + // On a 401 response, set any appropriate challenges + private void Set401Challenges(object state) + { + var env = (IDictionary)state; + var responseHeaders = env.Get>(Constants.ResponseHeadersKey); + + // We use the cached results from the delegates so that we don't have to call them again here. + NTAuthentication newContext; + IList challenges = BuildChallenge(env, AuthenticationSchemes, out newContext, ExtendedProtectionPolicy); + + // null == Anonymous + if (challenges != null) + { + // Digest challenge, keep it alive for 10s - 5min. + if (newContext != null) + { + _digestCache.SaveDigestContext(newContext); + } + + responseHeaders.Append(HttpKnownHeaderNames.WWWAuthenticate, challenges); + } + } + + private static bool IsSecureConnection(IDictionary env) + { + return "https".Equals(env.Get(Constants.RequestSchemeKey, "http"), StringComparison.OrdinalIgnoreCase); + } + + private static bool ScenarioChecksChannelBinding(bool isSecureConnection, ProtectionScenario scenario) + { + return (isSecureConnection && scenario == ProtectionScenario.TransportSelected); + } + + private ChannelBinding GetChannelBinding(IDictionary env, bool isSecureConnection, ExtendedProtectionPolicy policy) + { + if (policy.PolicyEnforcement == PolicyEnforcement.Never) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_cbt_disabled)); + } + return null; + } + + if (!isSecureConnection) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_cbt_http)); + } + return null; + } + + if (!ExtendedProtectionPolicy.OSSupportsExtendedProtection) + { + GlobalLog.Assert(policy.PolicyEnforcement != PolicyEnforcement.Always, "User managed to set PolicyEnforcement.Always when the OS does not support extended protection!"); + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_cbt_platform)); + } + return null; + } + + if (policy.ProtectionScenario == ProtectionScenario.TrustedProxy) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_cbt_trustedproxy)); + } + return null; + } + + ChannelBinding result = env.Get(Constants.SslChannelBindingKey); + if (result == null) + { + // A channel binding object is required. + throw new InvalidOperationException(); + } + + GlobalLog.Assert(result != null, "GetChannelBindingFromTls returned null even though OS supposedly supports Extended Protection"); + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_cbt)); + } + return result; + } + + private bool CheckSpn(NTAuthentication context, bool isSecureConnection, ExtendedProtectionPolicy policy) + { + // Kerberos does SPN check already in ASC + if (context.IsKerberos) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_spn_kerberos)); + } + return true; + } + + // Don't check the SPN if Extended Protection is off or we already checked the CBT + if (policy.PolicyEnforcement == PolicyEnforcement.Never) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_spn_disabled)); + } + return true; + } + + if (ScenarioChecksChannelBinding(isSecureConnection, policy.ProtectionScenario)) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_spn_cbt)); + } + return true; + } + + if (!ExtendedProtectionPolicy.OSSupportsExtendedProtection) + { + GlobalLog.Assert(policy.PolicyEnforcement != PolicyEnforcement.Always, "User managed to set PolicyEnforcement.Always when the OS does not support extended protection!"); + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_spn_platform)); + } + return true; + } + + string clientSpn = context.ClientSpecifiedSpn; + + // An empty SPN is only allowed in the WhenSupported case + if (String.IsNullOrEmpty(clientSpn)) + { + if (policy.PolicyEnforcement == PolicyEnforcement.WhenSupported) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, + SR.GetString(SR.net_log_listener_no_spn_whensupported)); + } + return true; + } + else + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, + SR.GetString(SR.net_log_listener_spn_failed_always)); + } + return false; + } + } + else if (String.Compare(clientSpn, "http/localhost", StringComparison.OrdinalIgnoreCase) == 0) + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_no_spn_loopback)); + } + + return true; + } + else + { + if (Logging.On) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_spn, clientSpn)); + } + + ServiceNameCollection serviceNames = GetServiceNames(policy); + + bool found = serviceNames.Contains(clientSpn); + + if (Logging.On) + { + if (found) + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_spn_passed)); + } + else + { + Logging.PrintInfo(Logging.HttpListener, this, SR.GetString(SR.net_log_listener_spn_failed)); + + if (serviceNames.Count == 0) + { + Logging.PrintWarning(Logging.HttpListener, this, "CheckSpn", + SR.GetString(SR.net_log_listener_spn_failed_empty)); + } + else + { + Logging.PrintInfo(Logging.HttpListener, this, + SR.GetString(SR.net_log_listener_spn_failed_dump)); + + foreach (string serviceName in serviceNames) + { + Logging.PrintInfo(Logging.HttpListener, this, "\t" + serviceName); + } + } + } + } + + return found; + } + } + + private ServiceNameCollection GetServiceNames(ExtendedProtectionPolicy policy) + { + ServiceNameCollection serviceNames; + + if (policy.CustomServiceNames == null) + { + if (_defaultServiceNames.ServiceNames.Count == 0) + { + throw new InvalidOperationException(SR.GetString(SR.net_listener_no_spns)); + } + serviceNames = _defaultServiceNames.ServiceNames; + } + else + { + serviceNames = policy.CustomServiceNames; + } + return serviceNames; + } + + private ContextFlags GetContextFlags(ExtendedProtectionPolicy policy, bool isSecureConnection) + { + ContextFlags result = ContextFlags.Connection; + + if (policy.PolicyEnforcement != PolicyEnforcement.Never) + { + if (policy.PolicyEnforcement == PolicyEnforcement.WhenSupported) + { + result |= ContextFlags.AllowMissingBindings; + } + + if (policy.ProtectionScenario == ProtectionScenario.TrustedProxy) + { + result |= ContextFlags.ProxyBindings; + } + } + + return result; + } + + private static void AddChallenge(ref IList challenges, string challenge) + { + if (challenge != null) + { + challenge = challenge.Trim(); + if (challenge.Length > 0) + { + GlobalLog.Print("HttpListener:AddChallenge() challenge:" + challenge); + if (challenges == null) + { + challenges = new List(4); + } + challenges.Add(challenge); + } + } + } + + private IList BuildChallenge(IDictionary env, AuthTypes authenticationScheme, out NTAuthentication digestContext, + ExtendedProtectionPolicy policy) + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::BuildChallenge() authenticationScheme:" + authenticationScheme.ToString()); + IList challenges = null; + digestContext = null; + + if ((authenticationScheme & AuthTypes.Negotiate) != 0) + { + AddChallenge(ref challenges, NegotiationInfoClass.Negotiate); + } + + if ((authenticationScheme & AuthTypes.Ntlm) != 0) + { + AddChallenge(ref challenges, NegotiationInfoClass.NTLM); + } + + if ((authenticationScheme & AuthTypes.Digest) != 0) + { + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::BuildChallenge() package:WDigest"); + + NTAuthentication context = null; + try + { + bool isSecureConnection = IsSecureConnection(env); + string outBlob = null; + ChannelBinding binding = GetChannelBinding(env, isSecureConnection, policy); + + context = new NTAuthentication(true, NegotiationInfoClass.WDigest, null, + GetContextFlags(policy, isSecureConnection), binding); + + SecurityStatus statusCode; + outBlob = context.GetOutgoingDigestBlob(null, null, null, Realm, false, false, out statusCode); + GlobalLog.Print("HttpListener#" + ValidationHelper.HashString(this) + "::BuildChallenge() GetOutgoingDigestBlob() returned IsCompleted:" + context.IsCompleted + " statusCode:" + statusCode + " outBlob:[" + outBlob + "]"); + + if (context.IsValidContext) + { + digestContext = context; + _digestCache.SaveDigestContext(digestContext); + } + + AddChallenge(ref challenges, NegotiationInfoClass.Digest + (string.IsNullOrEmpty(outBlob) ? string.Empty : " " + outBlob)); + } + catch (InvalidOperationException) + { + // No CBT available, therefore no digest challenge can be issued. + } + finally + { + if (context != null && digestContext != context) + { + context.CloseContext(); + } + } + } + + return challenges; + } + + private void RegisterForDisconnectNotification(IDictionary env, out DisconnectAsyncResult disconnectResult) + { + object connectionId = env[Constants.ServerConnectionIdKey]; + CancellationToken connectionDisconnect = env.Get(Constants.ServerConnectionDisconnectKey); + if (!connectionDisconnect.CanBeCanceled || connectionDisconnect.IsCancellationRequested) + { + disconnectResult = null; + return; + } + try + { + disconnectResult = new DisconnectAsyncResult(this, connectionId, connectionDisconnect); + } + catch (ObjectDisposedException) + { + // Just disconnected + disconnectResult = null; + return; + } + } + + private void SendError(IDictionary env, HttpStatusCode httpStatusCode, IList challenges) + { + // Send an OWIN HTTP response with the given error status code. + env[Constants.ResponseStatusCodeKey] = (int)httpStatusCode; + + if (challenges != null) + { + var responseHeaders = env.Get>(Constants.ResponseHeadersKey); + responseHeaders.Append(HttpKnownHeaderNames.WWWAuthenticate, challenges); + } + } + + // This only works for context-destroying errors. + private HttpStatusCode HttpStatusFromSecurityStatus(SecurityStatus status) + { + if (IsCredentialFailure(status)) + { + return HttpStatusCode.Unauthorized; + } + if (IsClientFault(status)) + { + return HttpStatusCode.BadRequest; + } + return HttpStatusCode.InternalServerError; + } + + // This only works for context-destroying errors. + private static bool IsCredentialFailure(SecurityStatus error) + { + return error == SecurityStatus.LogonDenied || + error == SecurityStatus.UnknownCredentials || + error == SecurityStatus.NoImpersonation || + error == SecurityStatus.NoAuthenticatingAuthority || + error == SecurityStatus.UntrustedRoot || + error == SecurityStatus.CertExpired || + error == SecurityStatus.SmartcardLogonRequired || + error == SecurityStatus.BadBinding; + } + + // This only works for context-destroying errors. + private static bool IsClientFault(SecurityStatus error) + { + return error == SecurityStatus.InvalidToken || + error == SecurityStatus.CannotPack || + error == SecurityStatus.QopNotSupported || + error == SecurityStatus.NoCredentials || + error == SecurityStatus.MessageAltered || + error == SecurityStatus.OutOfSequence || + error == SecurityStatus.IncompleteMessage || + error == SecurityStatus.IncompleteCredentials || + error == SecurityStatus.WrongPrincipal || + error == SecurityStatus.TimeSkew || + error == SecurityStatus.IllegalMessage || + error == SecurityStatus.CertUnknown || + error == SecurityStatus.AlgorithmMismatch || + error == SecurityStatus.SecurityQosFailed || + error == SecurityStatus.UnsupportedPreauth; + } + } +} diff --git a/src/Microsoft.AspNet.Security.Windows/packages.config b/src/Microsoft.AspNet.Security.Windows/packages.config new file mode 100644 index 0000000000..7432196421 --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/packages.config @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/src/Microsoft.AspNet.Security.Windows/project.json b/src/Microsoft.AspNet.Security.Windows/project.json new file mode 100644 index 0000000000..3fe6a340df --- /dev/null +++ b/src/Microsoft.AspNet.Security.Windows/project.json @@ -0,0 +1,14 @@ +{ + "version": "0.1-alpha-*", + "dependencies": { + }, + "compilationOptions" : { "allowUnsafe": true }, + "configurations": + { + "net45" : { + "dependencies": { + "Owin": "1.0" + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/AsyncAcceptContext.cs b/src/Microsoft.AspNet.Server.WebListener/AsyncAcceptContext.cs new file mode 100644 index 0000000000..a929f8a854 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/AsyncAcceptContext.cs @@ -0,0 +1,224 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal unsafe class AsyncAcceptContext : IAsyncResult, IDisposable + { + internal static readonly IOCompletionCallback IOCallback = new IOCompletionCallback(IOWaitCallback); + + private TaskCompletionSource _tcs; + private OwinWebListener _server; + private NativeRequestContext _nativeRequestContext; + + internal AsyncAcceptContext(OwinWebListener server) + { + _server = server; + _tcs = new TaskCompletionSource(); + _nativeRequestContext = new NativeRequestContext(this); + } + + internal Task Task + { + get + { + return _tcs.Task; + } + } + + private TaskCompletionSource Tcs + { + get + { + return _tcs; + } + } + + private OwinWebListener Server + { + get + { + return _server; + } + } + + [SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Redirecting to callback")] + [SuppressMessage("Microsoft.Reliability", "CA2000:Dispose objects before losing scope", Justification = "Disposed by callback")] + private static void IOCompleted(AsyncAcceptContext asyncResult, uint errorCode, uint numBytes) + { + bool complete = false; + try + { + if (errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_MORE_DATA) + { + asyncResult.Tcs.TrySetException(new WebListenerException((int)errorCode)); + complete = true; + } + else + { + OwinWebListener server = asyncResult.Server; + if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + // at this point we have received an unmanaged HTTP_REQUEST and memoryBlob + // points to it we need to hook up our authentication handling code here. + bool stoleBlob = false; + try + { + if (server.ValidateRequest(asyncResult._nativeRequestContext)) + { + stoleBlob = true; + RequestContext requestContext = new RequestContext(server, asyncResult._nativeRequestContext); + asyncResult.Tcs.TrySetResult(requestContext); + complete = true; + } + } + finally + { + if (stoleBlob) + { + // The request has been handed to the user, which means this code can't reuse the blob. Reset it here. + asyncResult._nativeRequestContext = complete ? null : new NativeRequestContext(asyncResult); + } + else + { + asyncResult._nativeRequestContext.Reset(0, 0); + } + } + } + else + { + asyncResult._nativeRequestContext.Reset(asyncResult._nativeRequestContext.RequestBlob->RequestId, numBytes); + } + + // We need to issue a new request, either because auth failed, or because our buffer was too small the first time. + if (!complete) + { + uint statusCode = asyncResult.QueueBeginGetContext(); + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + // someother bad error, possible(?) return values are: + // ERROR_INVALID_HANDLE, ERROR_INSUFFICIENT_BUFFER, ERROR_OPERATION_ABORTED + asyncResult.Tcs.TrySetException(new WebListenerException((int)statusCode)); + complete = true; + } + } + if (!complete) + { + return; + } + } + + if (complete) + { + asyncResult.Dispose(); + } + } + catch (Exception exception) + { + // Logged by caller + asyncResult.Tcs.TrySetException(exception); + asyncResult.Dispose(); + } + } + + private static unsafe void IOWaitCallback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped) + { + // take the ListenerAsyncResult object from the state + Overlapped callbackOverlapped = Overlapped.Unpack(nativeOverlapped); + AsyncAcceptContext asyncResult = (AsyncAcceptContext)callbackOverlapped.AsyncResult; + + IOCompleted(asyncResult, errorCode, numBytes); + } + + internal uint QueueBeginGetContext() + { + uint statusCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS; + bool retry; + do + { + retry = false; + uint bytesTransferred = 0; + statusCode = UnsafeNclNativeMethods.HttpApi.HttpReceiveHttpRequest( + Server.RequestQueueHandle, + _nativeRequestContext.RequestBlob->RequestId, + (uint)UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY, + _nativeRequestContext.RequestBlob, + _nativeRequestContext.Size, + &bytesTransferred, + _nativeRequestContext.NativeOverlapped); + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_INVALID_PARAMETER && _nativeRequestContext.RequestBlob->RequestId != 0) + { + // we might get this if somebody stole our RequestId, + // set RequestId to 0 and start all over again with the buffer we just allocated + // BUGBUG: how can someone steal our request ID? seems really bad and in need of fix. + _nativeRequestContext.RequestBlob->RequestId = 0; + retry = true; + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_MORE_DATA) + { + // the buffer was not big enough to fit the headers, we need + // to read the RequestId returned, allocate a new buffer of the required size + _nativeRequestContext.Reset(_nativeRequestContext.RequestBlob->RequestId, bytesTransferred); + retry = true; + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS + && OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + IOCompleted(this, statusCode, bytesTransferred); + } + } + while (retry); + return statusCode; + } + + public object AsyncState + { + get { return _tcs.Task.AsyncState; } + } + + public WaitHandle AsyncWaitHandle + { + get { return ((IAsyncResult)_tcs.Task).AsyncWaitHandle; } + } + + public bool CompletedSynchronously + { + get { return ((IAsyncResult)_tcs.Task).CompletedSynchronously; } + } + + public bool IsCompleted + { + get { return _tcs.Task.IsCompleted; } + } + + public void Dispose() + { + Dispose(true); + } + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + if (_nativeRequestContext != null) + { + _nativeRequestContext.ReleasePins(); + _nativeRequestContext.Dispose(); + } + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/AuthenticationManager.cs b/src/Microsoft.AspNet.Server.WebListener/AuthenticationManager.cs new file mode 100644 index 0000000000..9fa737049a --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/AuthenticationManager.cs @@ -0,0 +1,129 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Server.WebListener +{ + // See the native HTTP_SERVER_AUTHENTICATION_INFO structure documentation for additional information. + // http://msdn.microsoft.com/en-us/library/windows/desktop/aa364638(v=vs.85).aspx + + /// + /// Exposes the Http.Sys authentication configurations. + /// + public sealed class AuthenticationManager + { +#if NET45 + private static readonly int AuthInfoSize = + Marshal.SizeOf(typeof(UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_AUTHENTICATION_INFO)); +#else + private static readonly int AuthInfoSize = + Marshal.SizeOf(); +#endif + + private OwinWebListener _server; + AuthenticationType _authTypes; + + internal AuthenticationManager(OwinWebListener context) + { + _server = context; + _authTypes = AuthenticationType.None; + } + + #region Properties + + public AuthenticationType AuthenticationTypes + { + get + { + return _authTypes; + } + set + { + _authTypes = value; + SetServerSecurity(); + } + } + + #endregion Properties + + private unsafe void SetServerSecurity() + { + UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_AUTHENTICATION_INFO authInfo = + new UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_AUTHENTICATION_INFO(); + + authInfo.Flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_PROPERTY_FLAG_PRESENT; + authInfo.AuthSchemes = (UnsafeNclNativeMethods.HttpApi.HTTP_AUTH_TYPES)_authTypes; + + // TODO: + // NTLM auth sharing (on by default?) DisableNTLMCredentialCaching + // Kerberos auth sharing (off by default?) HTTP_AUTH_EX_FLAG_ENABLE_KERBEROS_CREDENTIAL_CACHING + // Mutual Auth - ReceiveMutualAuth + // Digest domain and realm - HTTP_SERVER_AUTHENTICATION_DIGEST_PARAMS + // Basic realm - HTTP_SERVER_AUTHENTICATION_BASIC_PARAMS + + IntPtr infoptr = new IntPtr(&authInfo); + + _server.SetUrlGroupProperty( + UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_PROPERTY.HttpServerAuthenticationProperty, + infoptr, (uint)AuthInfoSize); + } + + internal void SetAuthenticationChallenge(Response response) + { + if (_authTypes == AuthenticationType.None) + { + return; + } + + IList challenges = new List(); + + // Order by strength. + if ((_authTypes & AuthenticationType.Kerberos) == AuthenticationType.Kerberos) + { + challenges.Add("Kerberos"); + } + if ((_authTypes & AuthenticationType.Negotiate) == AuthenticationType.Negotiate) + { + challenges.Add("Negotiate"); + } + if ((_authTypes & AuthenticationType.Ntlm) == AuthenticationType.Ntlm) + { + challenges.Add("NTLM"); + } + if ((_authTypes & AuthenticationType.Digest) == AuthenticationType.Digest) + { + // TODO: + throw new NotImplementedException("Digest challenge generation has not been implemented."); + // challenges.Add("Digest"); + } + if ((_authTypes & AuthenticationType.Basic) == AuthenticationType.Basic) + { + // TODO: Realm + challenges.Add("Basic"); + } + + // Append to the existing header, if any. Some clients (IE, Chrome) require each challenges to be sent on their own line/header. + string[] oldValues; + string[] newValues; + if (response.Headers.TryGetValue(HttpKnownHeaderNames.WWWAuthenticate, out oldValues)) + { + newValues = new string[oldValues.Length + challenges.Count]; + Array.Copy(oldValues, newValues, oldValues.Length); + challenges.CopyTo(newValues, oldValues.Length); + } + else + { + newValues = new string[challenges.Count]; + challenges.CopyTo(newValues, 0); + } + response.Headers[HttpKnownHeaderNames.WWWAuthenticate] = newValues; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/AuthenticationTypes.cs b/src/Microsoft.AspNet.Server.WebListener/AuthenticationTypes.cs new file mode 100644 index 0000000000..2e47212706 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/AuthenticationTypes.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + [Flags] + public enum AuthenticationType + { + None = 0x0, + Basic = 0x1, + Digest = 0x2, + Ntlm = 0x4, + Negotiate = 0x8, + Kerberos = 0x10, + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/Constants.cs b/src/Microsoft.AspNet.Server.WebListener/Constants.cs new file mode 100644 index 0000000000..53e3325471 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/Constants.cs @@ -0,0 +1,63 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class Constants + { + internal const string VersionKey = "owin.Version"; + internal const string OwinVersion = "1.0"; + internal const string CallCancelledKey = "owin.CallCancelled"; + + internal const string ServerCapabilitiesKey = "server.Capabilities"; + + internal const string RequestBodyKey = "owin.RequestBody"; + internal const string RequestHeadersKey = "owin.RequestHeaders"; + internal const string RequestSchemeKey = "owin.RequestScheme"; + internal const string RequestMethodKey = "owin.RequestMethod"; + internal const string RequestPathBaseKey = "owin.RequestPathBase"; + internal const string RequestPathKey = "owin.RequestPath"; + internal const string RequestQueryStringKey = "owin.RequestQueryString"; + internal const string HttpRequestProtocolKey = "owin.RequestProtocol"; + + internal const string HttpResponseProtocolKey = "owin.ResponseProtocol"; + internal const string ResponseStatusCodeKey = "owin.ResponseStatusCode"; + internal const string ResponseReasonPhraseKey = "owin.ResponseReasonPhrase"; + internal const string ResponseHeadersKey = "owin.ResponseHeaders"; + internal const string ResponseBodyKey = "owin.ResponseBody"; + + internal const string ClientCertifiateKey = "ssl.ClientCertificate"; + + internal const string RemoteIpAddressKey = "server.RemoteIpAddress"; + internal const string RemotePortKey = "server.RemotePort"; + internal const string LocalIpAddressKey = "server.LocalIpAddress"; + internal const string LocalPortKey = "server.LocalPort"; + internal const string IsLocalKey = "server.IsLocal"; + internal const string ServerOnSendingHeadersKey = "server.OnSendingHeaders"; + internal const string ServerLoggerFactoryKey = "server.LoggerFactory"; + + internal const string OpaqueVersionKey = "opaque.Version"; + internal const string OpaqueVersion = "1.0"; + internal const string OpaqueFuncKey = "opaque.Upgrade"; + internal const string OpaqueStreamKey = "opaque.Stream"; + internal const string OpaqueCallCancelledKey = "opaque.CallCancelled"; + + internal const string SendFileVersionKey = "sendfile.Version"; + internal const string SendFileVersion = "1.0"; + internal const string SendFileSupportKey = "sendfile.Support"; + internal const string SendFileConcurrencyKey = "sendfile.Concurrency"; + internal const string Overlapped = "Overlapped"; + + internal const string HttpScheme = "http"; + internal const string HttpsScheme = "https"; + internal const string SchemeDelimiter = "://"; + + internal static Version V1_0 = new Version(1, 0); + internal static Version V1_1 = new Version(1, 1); + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/CustomDictionary.xml b/src/Microsoft.AspNet.Server.WebListener/CustomDictionary.xml new file mode 100644 index 0000000000..78a76142f7 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/CustomDictionary.xml @@ -0,0 +1,10 @@ + + + + + Owin + + + + + diff --git a/src/Microsoft.AspNet.Server.WebListener/DictionaryExtensions.cs b/src/Microsoft.AspNet.Server.WebListener/DictionaryExtensions.cs new file mode 100644 index 0000000000..a1a6e3577a --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/DictionaryExtensions.cs @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Linq; +using System.Text; + +namespace System.Collections.Generic +{ + internal static class DictionaryExtensions + { + internal static void Append(this IDictionary dictionary, string key, string value) + { + string[] orriginalValues; + if (dictionary.TryGetValue(key, out orriginalValues)) + { + string[] newValues = new string[orriginalValues.Length + 1]; + orriginalValues.CopyTo(newValues, 0); + newValues[newValues.Length - 1] = value; + dictionary[key] = newValues; + } + else + { + dictionary[key] = new string[] { value }; + } + } + + internal static string Get(this IDictionary dictionary, string key) + { + string[] values; + if (dictionary.TryGetValue(key, out values)) + { + return string.Join(", ", values); + } + return null; + } + + internal static T Get(this IDictionary dictionary, string key, T fallback = default(T)) + { + object values; + if (dictionary.TryGetValue(key, out values)) + { + return (T)values; + } + return fallback; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/GlobalSuppressions.cs b/src/Microsoft.AspNet.Server.WebListener/GlobalSuppressions.cs new file mode 100644 index 0000000000000000000000000000000000000000..2b0ddbdbf4a6b419d14999f431a6e44ef4126bf5 GIT binary patch literal 1926 zcmd6o-HX#u5XI+N@PD}EX$8BhKI*Pe6u)3a=&~TP4{2K4Xw#HUD)gUMe`jvPO`Gof z63EBgxie?ZnYsD*_rfkL*yvGRPwfe)gzM4@t8HUDySLI7wzL-OPu^EHw=2u+BX6}$ zE$2D0ExsFk=hkqQtgLyS6Q#j7c(e+S9Q)4qU*lD>7 z(l6}4w_f#(l_{!3=4LPuY>ZA)c85pY+rnqC3Tw$)yMKm(_SgF!$mtqIN;srN>&#E- zm)89P&!wvrxT}BkR4=W_BQ4~i%<_v!)MvH*Vzt29*XrBB%7_?qS5%5O5AW*SGZJ*CO|!vkS^FM-!-;=oyxYsUa+p2DOeX{DURkqwKeQ|a$7yA_v z6eq`LZqGf&&g1G`s};HlwZ4U~=&_OT+;|;1`IJ%7PjfM3x|?^!PZKjDkJKsbH+JbF zwj$bw_)&LtNm@;4G@TFiSUf+rL~p^UteWE3Zyma9&~Z8i;<*!ZXTGyzxQ{iV$n3M@ zJAHBm`#@H8Pu62J*muxuK{azm)gS7KHGYTsMW#=)tT{aH=#J+&Uu$w`_B4I5yM +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Runtime.CompilerServices; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class Helpers + { + internal static Task CompletedTask() + { + return Task.FromResult(null); + } + + internal static ConfiguredTaskAwaitable SupressContext(this Task task) + { + return task.ConfigureAwait(continueOnCapturedContext: false); + } + + internal static ConfiguredTaskAwaitable SupressContext(this Task task) + { + return task.ConfigureAwait(continueOnCapturedContext: false); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/LogHelper.cs b/src/Microsoft.AspNet.Server.WebListener/LogHelper.cs new file mode 100644 index 0000000000..a6513df571 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/LogHelper.cs @@ -0,0 +1,82 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics; +using System.Globalization; + +namespace Microsoft.AspNet.Server.WebListener +{ + using LoggerFactoryFunc = Func, bool>>; + using LoggerFunc = Func, bool>; + + internal static class LogHelper + { + private static readonly Func LogState = + (state, error) => Convert.ToString(state, CultureInfo.CurrentCulture); + + private static readonly Func LogStateAndError = + (state, error) => string.Format(CultureInfo.CurrentCulture, "{0}\r\n{1}", state, error); + + internal static LoggerFunc CreateLogger(LoggerFactoryFunc factory, Type type) + { + if (factory == null) + { + return null; + } + + return factory(type.FullName); + } + + internal static void LogInfo(LoggerFunc logger, string data) + { + if (logger == null) + { + Debug.WriteLine(data); + } + else + { + logger(TraceEventType.Information, 0, data, null, LogState); + } + } + + internal static void LogVerbose(LoggerFunc logger, string data) + { + if (logger == null) + { + Debug.WriteLine(data); + } + else + { + logger(TraceEventType.Verbose, 0, data, null, LogState); + } + } + + internal static void LogException(LoggerFunc logger, string location, Exception exception) + { + if (logger == null) + { + Debug.WriteLine(exception); + } + else + { + logger(TraceEventType.Error, 0, location, exception, LogStateAndError); + } + } + + internal static void LogError(LoggerFunc logger, string location, string message) + { + if (logger == null) + { + Debug.WriteLine(message); + } + else + { + logger(TraceEventType.Error, 0, location + ": " + message, null, LogState); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/AddressFamily.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/AddressFamily.cs new file mode 100644 index 0000000000..5f78516900 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/AddressFamily.cs @@ -0,0 +1,172 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener +{ + /// + /// + /// Specifies the address families that an instance of the + /// class can use. + /// + /// + internal enum AddressFamily + { + /// + /// [To be supplied.] + /// + Unknown = -1, // Unknown + + /// + /// [To be supplied.] + /// + Unspecified = 0, // unspecified + + /// + /// [To be supplied.] + /// + Unix = 1, // local to host (pipes, portals) + + /// + /// [To be supplied.] + /// + InterNetwork = 2, // internetwork: UDP, TCP, etc. + + /// + /// [To be supplied.] + /// + ImpLink = 3, // arpanet imp addresses + + /// + /// [To be supplied.] + /// + Pup = 4, // pup protocols: e.g. BSP + + /// + /// [To be supplied.] + /// + Chaos = 5, // mit CHAOS protocols + + /// + /// [To be supplied.] + /// + NS = 6, // XEROX NS protocols + + /// + /// [To be supplied.] + /// + Ipx = NS, // IPX and SPX + + /// + /// [To be supplied.] + /// + Iso = 7, // ISO protocols + + /// + /// [To be supplied.] + /// + Osi = Iso, // OSI is ISO + + /// + /// [To be supplied.] + /// + Ecma = 8, // european computer manufacturers + + /// + /// [To be supplied.] + /// + DataKit = 9, // datakit protocols + + /// + /// [To be supplied.] + /// + Ccitt = 10, // CCITT protocols, X.25 etc + + /// + /// [To be supplied.] + /// + Sna = 11, // IBM SNA + + /// + /// [To be supplied.] + /// + DecNet = 12, // DECnet + + /// + /// [To be supplied.] + /// + DataLink = 13, // Direct data link interface + + /// + /// [To be supplied.] + /// + Lat = 14, // LAT + + /// + /// [To be supplied.] + /// + HyperChannel = 15, // NSC Hyperchannel + + /// + /// [To be supplied.] + /// + AppleTalk = 16, // AppleTalk + + /// + /// [To be supplied.] + /// + NetBios = 17, // NetBios-style addresses + + /// + /// [To be supplied.] + /// + VoiceView = 18, // VoiceView + + /// + /// [To be supplied.] + /// + FireFox = 19, // FireFox + + /// + /// [To be supplied.] + /// + Banyan = 21, // Banyan + + /// + /// [To be supplied.] + /// + Atm = 22, // Native ATM Services + + /// + /// [To be supplied.] + /// + InterNetworkV6 = 23, // Internetwork Version 6 + + /// + /// [To be supplied.] + /// + Cluster = 24, // Microsoft Wolfpack + + /// + /// [To be supplied.] + /// + Ieee12844 = 25, // IEEE 1284.4 WG AF + + /// + /// [To be supplied.] + /// + Irda = 26, // IrDA + + /// + /// [To be supplied.] + /// + NetworkDesigners = 28, // Network Designers OSI & gateway enabled protocols + + /// + /// [To be supplied.] + /// + Max = 29, // Max + }; // enum AddressFamily +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/ComNetOS.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/ComNetOS.cs new file mode 100644 index 0000000000..80e2719065 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/ComNetOS.cs @@ -0,0 +1,26 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class ComNetOS + { + // Minimum support for Windows 7 is assumed. + internal static readonly bool IsWin8orLater; + + static ComNetOS() + { +#if NET45 + var win8Version = new Version(6, 2); + IsWin8orLater = (Environment.OSVersion.Version >= win8Version); +#else + IsWin8orLater = true; +#endif + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/ContextAttribute.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/ContextAttribute.cs new file mode 100644 index 0000000000..a3440bc7a3 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/ContextAttribute.cs @@ -0,0 +1,56 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum ContextAttribute + { + // look into and + Sizes = 0x00, + Names = 0x01, + Lifespan = 0x02, + DceInfo = 0x03, + StreamSizes = 0x04, + // KeyInfo = 0x05, must not be used, see ConnectionInfo instead + Authority = 0x06, + // SECPKG_ATTR_PROTO_INFO = 7, + // SECPKG_ATTR_PASSWORD_EXPIRY = 8, + // SECPKG_ATTR_SESSION_KEY = 9, + PackageInfo = 0x0A, + // SECPKG_ATTR_USER_FLAGS = 11, + NegotiationInfo = 0x0C, + // SECPKG_ATTR_NATIVE_NAMES = 13, + // SECPKG_ATTR_FLAGS = 14, + // SECPKG_ATTR_USE_VALIDATED = 15, + // SECPKG_ATTR_CREDENTIAL_NAME = 16, + // SECPKG_ATTR_TARGET_INFORMATION = 17, + // SECPKG_ATTR_ACCESS_TOKEN = 18, + // SECPKG_ATTR_TARGET = 19, + // SECPKG_ATTR_AUTHENTICATION_ID = 20, + UniqueBindings = 0x19, + EndpointBindings = 0x1A, + ClientSpecifiedSpn = 0x1B, // SECPKG_ATTR_CLIENT_SPECIFIED_TARGET = 27 + RemoteCertificate = 0x53, + LocalCertificate = 0x54, + RootStore = 0x55, + IssuerListInfoEx = 0x59, + ConnectionInfo = 0x5A, + // SECPKG_ATTR_EAP_KEY_BLOCK 0x5b // returns SecPkgContext_EapKeyBlock + // SECPKG_ATTR_MAPPED_CRED_ATTR 0x5c // returns SecPkgContext_MappedCredAttr + // SECPKG_ATTR_SESSION_INFO 0x5d // returns SecPkgContext_SessionInfo + // SECPKG_ATTR_APP_DATA 0x5e // sets/returns SecPkgContext_SessionAppData + // SECPKG_ATTR_REMOTE_CERTIFICATES 0x5F // returns SecPkgContext_Certificates + // SECPKG_ATTR_CLIENT_CERT_POLICY 0x60 // sets SecPkgCred_ClientCertCtlPolicy + // SECPKG_ATTR_CC_POLICY_RESULT 0x61 // returns SecPkgContext_ClientCertPolicyResult + // SECPKG_ATTR_USE_NCRYPT 0x62 // Sets the CRED_FLAG_USE_NCRYPT_PROVIDER FLAG on cred group + // SECPKG_ATTR_LOCAL_CERT_INFO 0x63 // returns SecPkgContext_CertInfo + // SECPKG_ATTR_CIPHER_INFO 0x64 // returns new CNG SecPkgContext_CipherInfo + // SECPKG_ATTR_EAP_PRF_INFO 0x65 // sets SecPkgContext_EapPrfInfo + // SECPKG_ATTR_SUPPORTED_SIGNATURES 0x66 // returns SecPkgContext_SupportedSignatures + // SECPKG_ATTR_REMOTE_CERT_CHAIN 0x67 // returns PCCERT_CONTEXT + UiInfo = 0x68, // sets SEcPkgContext_UiInfo + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpRequestQueueV2Handle.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpRequestQueueV2Handle.cs new file mode 100644 index 0000000000..ad9ae1d391 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpRequestQueueV2Handle.cs @@ -0,0 +1,25 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Server.WebListener +{ + // This class is a wrapper for Http.sys V2 request queue handle. + internal sealed class HttpRequestQueueV2Handle : SafeHandleZeroOrMinusOneIsInvalid + { + private HttpRequestQueueV2Handle() + : base(true) + { + } + + protected override bool ReleaseHandle() + { + return (UnsafeNclNativeMethods.SafeNetHandles.HttpCloseRequestQueue(handle) == + UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpServerSessionHandle.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpServerSessionHandle.cs new file mode 100644 index 0000000000..65c0506f5c --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpServerSessionHandle.cs @@ -0,0 +1,48 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Threading; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal sealed class HttpServerSessionHandle : CriticalHandleZeroOrMinusOneIsInvalid + { + private int disposed; + private ulong serverSessionId; + + internal HttpServerSessionHandle(ulong id) + : base() + { + serverSessionId = id; + + // This class uses no real handle so we need to set a dummy handle. Otherwise, IsInvalid always remains + // true. + + SetHandle(new IntPtr(1)); + } + + internal ulong DangerousGetServerSessionId() + { + return serverSessionId; + } + + protected override bool ReleaseHandle() + { + if (!IsInvalid) + { + if (Interlocked.Increment(ref disposed) == 1) + { + // Closing server session also closes all open url groups under that server session. + return (UnsafeNclNativeMethods.HttpApi.HttpCloseServerSession(serverSessionId) == + UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS); + } + } + return true; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysRequestHeader.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysRequestHeader.cs new file mode 100644 index 0000000000..1202ab2c80 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysRequestHeader.cs @@ -0,0 +1,54 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum HttpSysRequestHeader + { + CacheControl = 0, // general-header [section 4.5] + Connection = 1, // general-header [section 4.5] + Date = 2, // general-header [section 4.5] + KeepAlive = 3, // general-header [not in rfc] + Pragma = 4, // general-header [section 4.5] + Trailer = 5, // general-header [section 4.5] + TransferEncoding = 6, // general-header [section 4.5] + Upgrade = 7, // general-header [section 4.5] + Via = 8, // general-header [section 4.5] + Warning = 9, // general-header [section 4.5] + Allow = 10, // entity-header [section 7.1] + ContentLength = 11, // entity-header [section 7.1] + ContentType = 12, // entity-header [section 7.1] + ContentEncoding = 13, // entity-header [section 7.1] + ContentLanguage = 14, // entity-header [section 7.1] + ContentLocation = 15, // entity-header [section 7.1] + ContentMd5 = 16, // entity-header [section 7.1] + ContentRange = 17, // entity-header [section 7.1] + Expires = 18, // entity-header [section 7.1] + LastModified = 19, // entity-header [section 7.1] + + Accept = 20, // request-header [section 5.3] + AcceptCharset = 21, // request-header [section 5.3] + AcceptEncoding = 22, // request-header [section 5.3] + AcceptLanguage = 23, // request-header [section 5.3] + Authorization = 24, // request-header [section 5.3] + Cookie = 25, // request-header [not in rfc] + Expect = 26, // request-header [section 5.3] + From = 27, // request-header [section 5.3] + Host = 28, // request-header [section 5.3] + IfMatch = 29, // request-header [section 5.3] + IfModifiedSince = 30, // request-header [section 5.3] + IfNoneMatch = 31, // request-header [section 5.3] + IfRange = 32, // request-header [section 5.3] + IfUnmodifiedSince = 33, // request-header [section 5.3] + MaxForwards = 34, // request-header [section 5.3] + ProxyAuthorization = 35, // request-header [section 5.3] + Referer = 36, // request-header [section 5.3] + Range = 37, // request-header [section 5.3] + Te = 38, // request-header [section 5.3] + Translate = 39, // request-header [webDAV, not in rfc 2518] + UserAgent = 40, // request-header [section 5.3] + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysResponseHeader.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysResponseHeader.cs new file mode 100644 index 0000000000..0c7707e943 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysResponseHeader.cs @@ -0,0 +1,43 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum HttpSysResponseHeader + { + CacheControl = 0, // general-header [section 4.5] + Connection = 1, // general-header [section 4.5] + Date = 2, // general-header [section 4.5] + KeepAlive = 3, // general-header [not in rfc] + Pragma = 4, // general-header [section 4.5] + Trailer = 5, // general-header [section 4.5] + TransferEncoding = 6, // general-header [section 4.5] + Upgrade = 7, // general-header [section 4.5] + Via = 8, // general-header [section 4.5] + Warning = 9, // general-header [section 4.5] + Allow = 10, // entity-header [section 7.1] + ContentLength = 11, // entity-header [section 7.1] + ContentType = 12, // entity-header [section 7.1] + ContentEncoding = 13, // entity-header [section 7.1] + ContentLanguage = 14, // entity-header [section 7.1] + ContentLocation = 15, // entity-header [section 7.1] + ContentMd5 = 16, // entity-header [section 7.1] + ContentRange = 17, // entity-header [section 7.1] + Expires = 18, // entity-header [section 7.1] + LastModified = 19, // entity-header [section 7.1] + + AcceptRanges = 20, // response-header [section 6.2] + Age = 21, // response-header [section 6.2] + ETag = 22, // response-header [section 6.2] + Location = 23, // response-header [section 6.2] + ProxyAuthenticate = 24, // response-header [section 6.2] + RetryAfter = 25, // response-header [section 6.2] + Server = 26, // response-header [section 6.2] + SetCookie = 27, // response-header [not in rfc] + Vary = 28, // response-header [section 6.2] + WwwAuthenticate = 29, // response-header [section 6.2] + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysSettings.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysSettings.cs new file mode 100644 index 0000000000..98067e6fd6 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/HttpSysSettings.cs @@ -0,0 +1,125 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.Security; +#if NET45 +using Microsoft.Win32; +#endif + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class HttpSysSettings + { +#if NET45 + private const string HttpSysParametersKey = @"System\CurrentControlSet\Services\HTTP\Parameters"; +#endif + private const bool EnableNonUtf8Default = true; + private const bool FavorUtf8Default = true; + private const string EnableNonUtf8Name = "EnableNonUtf8"; + private const string FavorUtf8Name = "FavorUtf8"; + + private static volatile bool enableNonUtf8 = EnableNonUtf8Default; + private static volatile bool favorUtf8 = FavorUtf8Default; + + static HttpSysSettings() + { + ReadHttpSysRegistrySettings(); + } + + internal static bool EnableNonUtf8 + { + get { return enableNonUtf8; } + } + + internal static bool FavorUtf8 + { + get { return favorUtf8; } + } + + private static void ReadHttpSysRegistrySettings() +#if !NET45 + { + } +#else + { + try + { + RegistryKey httpSysParameters = Registry.LocalMachine.OpenSubKey(HttpSysParametersKey); + + if (httpSysParameters == null) + { + LogWarning("ReadHttpSysRegistrySettings", "The Http.Sys registry key is null.", + HttpSysParametersKey); + } + else + { + using (httpSysParameters) + { + enableNonUtf8 = ReadRegistryValue(httpSysParameters, EnableNonUtf8Name, EnableNonUtf8Default); + favorUtf8 = ReadRegistryValue(httpSysParameters, FavorUtf8Name, FavorUtf8Default); + } + } + } + catch (SecurityException e) + { + LogRegistryException("ReadHttpSysRegistrySettings", e); + } + catch (ObjectDisposedException e) + { + LogRegistryException("ReadHttpSysRegistrySettings", e); + } + } + + private static bool ReadRegistryValue(RegistryKey key, string valueName, bool defaultValue) + { + Debug.Assert(key != null, "'key' must not be null"); + + try + { + if (key.GetValue(valueName) != null && key.GetValueKind(valueName) == RegistryValueKind.DWord) + { + // At this point we know the Registry value exists and it must be valid (any DWORD value + // can be converted to a bool). + return Convert.ToBoolean(key.GetValue(valueName), CultureInfo.InvariantCulture); + } + } + catch (UnauthorizedAccessException e) + { + LogRegistryException("ReadRegistryValue", e); + } + catch (IOException e) + { + LogRegistryException("ReadRegistryValue", e); + } + catch (SecurityException e) + { + LogRegistryException("ReadRegistryValue", e); + } + catch (ObjectDisposedException e) + { + LogRegistryException("ReadRegistryValue", e); + } + + return defaultValue; + } + + private static void LogRegistryException(string methodName, Exception e) + { + LogWarning(methodName, "Unable to access the Http.Sys registry value.", HttpSysParametersKey, e); + } + + private static void LogWarning(string methodName, string message, params object[] args) + { + // TODO: log + // Logging.PrintWarning(Logging.HttpListener, typeof(HttpSysSettings), methodName, SR.GetString(message, args)); + } +#endif + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/IntPtrHelper.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/IntPtrHelper.cs new file mode 100644 index 0000000000..87ee72ce34 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/IntPtrHelper.cs @@ -0,0 +1,23 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class IntPtrHelper + { + internal static IntPtr Add(IntPtr a, int b) + { + return (IntPtr)((long)a + (long)b); + } + + internal static long Subtract(IntPtr a, IntPtr b) + { + return ((long)a - (long)b); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/NclUtilities.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/NclUtilities.cs new file mode 100644 index 0000000000..4e74a95560 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/NclUtilities.cs @@ -0,0 +1,25 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class NclUtilities + { + internal static bool HasShutdownStarted + { + get + { + return Environment.HasShutdownStarted +#if NET45 + || AppDomain.CurrentDomain.IsFinalizingForUnload() +#endif + ; + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SSPIHandle.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SSPIHandle.cs new file mode 100644 index 0000000000..6089860776 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SSPIHandle.cs @@ -0,0 +1,34 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Server.WebListener +{ + [StructLayout(LayoutKind.Sequential, Pack = 1)] + internal struct SSPIHandle + { + private IntPtr handleHi; + private IntPtr handleLo; + + public bool IsZero + { + get { return handleHi == IntPtr.Zero && handleLo == IntPtr.Zero; } + } + + internal void SetToInvalid() + { + handleHi = IntPtr.Zero; + handleLo = IntPtr.Zero; + } + + public override string ToString() + { + return handleHi.ToString("x") + ":" + handleLo.ToString("x"); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLoadLibrary.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLoadLibrary.cs new file mode 100644 index 0000000000..998fbe014a --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLoadLibrary.cs @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal sealed class SafeLoadLibrary : SafeHandleZeroOrMinusOneIsInvalid + { + private const string KERNEL32 = "kernel32.dll"; + + public static readonly SafeLoadLibrary Zero = new SafeLoadLibrary(false); + + private SafeLoadLibrary() + : base(true) + { + } + + private SafeLoadLibrary(bool ownsHandle) + : base(ownsHandle) + { + } + + public static unsafe SafeLoadLibrary LoadLibraryEx(string library) + { + SafeLoadLibrary result = UnsafeNclNativeMethods.SafeNetHandles.LoadLibraryExW(library, null, 0); + if (result.IsInvalid) + { + result.SetHandleAsInvalid(); + } + return result; + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.FreeLibrary(handle); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFree.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFree.cs new file mode 100644 index 0000000000..ad35b83ff0 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFree.cs @@ -0,0 +1,46 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal sealed class SafeLocalFree : SafeHandleZeroOrMinusOneIsInvalid + { + private const int LMEM_FIXED = 0; + private const int NULL = 0; + + // This returned handle cannot be modified by the application. + public static SafeLocalFree Zero = new SafeLocalFree(false); + + private SafeLocalFree() + : base(true) + { + } + + private SafeLocalFree(bool ownsHandle) + : base(ownsHandle) + { + } + + public static SafeLocalFree LocalAlloc(int cb) + { + SafeLocalFree result = UnsafeNclNativeMethods.SafeNetHandles.LocalAlloc(LMEM_FIXED, (UIntPtr)cb); + if (result.IsInvalid) + { + result.SetHandleAsInvalid(); + throw new OutOfMemoryException(); + } + return result; + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.LocalFree(handle) == IntPtr.Zero; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFreeChannelBinding.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFreeChannelBinding.cs new file mode 100644 index 0000000000..66124161ca --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalFreeChannelBinding.cs @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Security.Authentication.ExtendedProtection; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class SafeLocalFreeChannelBinding : ChannelBinding + { + private const int LMEM_FIXED = 0; + private int size; + + public override int Size + { + get { return size; } + } + + public static SafeLocalFreeChannelBinding LocalAlloc(int cb) + { + SafeLocalFreeChannelBinding result; + + result = UnsafeNclNativeMethods.SafeNetHandles.LocalAllocChannelBinding(LMEM_FIXED, (UIntPtr)cb); + if (result.IsInvalid) + { + result.SetHandleAsInvalid(); + throw new OutOfMemoryException(); + } + + result.size = cb; + return result; + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.LocalFree(handle) == IntPtr.Zero; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalMemHandle.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalMemHandle.cs new file mode 100644 index 0000000000..6bc76f512f --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeLocalMemHandle.cs @@ -0,0 +1,30 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal sealed class SafeLocalMemHandle : SafeHandleZeroOrMinusOneIsInvalid + { + internal SafeLocalMemHandle() + : base(true) + { + } + + internal SafeLocalMemHandle(IntPtr existingHandle, bool ownsHandle) + : base(ownsHandle) + { + SetHandle(existingHandle); + } + + protected override bool ReleaseHandle() + { + return UnsafeNclNativeMethods.SafeNetHandles.LocalFree(handle) == IntPtr.Zero; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeNativeOverlapped.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeNativeOverlapped.cs new file mode 100644 index 0000000000..c36466ec84 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SafeNativeOverlapped.cs @@ -0,0 +1,68 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class SafeNativeOverlapped : SafeHandle + { + internal static readonly SafeNativeOverlapped Zero = new SafeNativeOverlapped(); + + internal SafeNativeOverlapped() + : this(IntPtr.Zero) + { + } + + internal unsafe SafeNativeOverlapped(NativeOverlapped* handle) + : this((IntPtr)handle) + { + } + + internal SafeNativeOverlapped(IntPtr handle) + : base(IntPtr.Zero, true) + { + SetHandle(handle); + } + + public override bool IsInvalid + { + get { return handle == IntPtr.Zero; } + } + + public void ReinitializeNativeOverlapped() + { + IntPtr handleSnapshot = handle; + + if (handleSnapshot != IntPtr.Zero) + { + unsafe + { + ((NativeOverlapped*)handleSnapshot)->InternalHigh = IntPtr.Zero; + ((NativeOverlapped*)handleSnapshot)->InternalLow = IntPtr.Zero; + ((NativeOverlapped*)handleSnapshot)->EventHandle = IntPtr.Zero; + } + } + } + + protected override bool ReleaseHandle() + { + IntPtr oldHandle = Interlocked.Exchange(ref handle, IntPtr.Zero); + // Do not call free durring AppDomain shutdown, there may be an outstanding operation. + // Overlapped will take care calling free when the native callback completes. + if (oldHandle != IntPtr.Zero && !NclUtilities.HasShutdownStarted) + { + unsafe + { + Overlapped.Free((NativeOverlapped*)oldHandle); + } + } + return true; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SchProtocols.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SchProtocols.cs new file mode 100644 index 0000000000..be1d240cd2 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SchProtocols.cs @@ -0,0 +1,51 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener +{ + using System; + using System.Globalization; + using System.Runtime.InteropServices; + + // From Schannel.h + [Flags] + internal enum SchProtocols + { + Zero = 0, + PctClient = 0x00000002, + PctServer = 0x00000001, + Pct = (PctClient | PctServer), + Ssl2Client = 0x00000008, + Ssl2Server = 0x00000004, + Ssl2 = (Ssl2Client | Ssl2Server), + Ssl3Client = 0x00000020, + Ssl3Server = 0x00000010, + Ssl3 = (Ssl3Client | Ssl3Server), + Tls10Client = 0x00000080, + Tls10Server = 0x00000040, + Tls10 = (Tls10Client | Tls10Server), + Tls11Client = 0x00000200, + Tls11Server = 0x00000100, + Tls11 = (Tls11Client | Tls11Server), + Tls12Client = 0x00000800, + Tls12Server = 0x00000400, + Tls12 = (Tls12Client | Tls12Server), + Ssl3Tls = (Ssl3 | Tls10), + UniClient = unchecked((int)0x80000000), + UniServer = 0x40000000, + Unified = (UniClient | UniServer), + ClientMask = (PctClient | Ssl2Client | Ssl3Client | Tls10Client | Tls11Client | Tls12Client | UniClient), + ServerMask = (PctServer | Ssl2Server | Ssl3Server | Tls10Server | Tls11Server | Tls12Server | UniServer) + } + + [StructLayout(LayoutKind.Sequential)] + internal struct Bindings + { + // see SecPkgContext_Bindings in + internal int BindingsLength; + internal IntPtr bindings; + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SecurityStatus.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SecurityStatus.cs new file mode 100644 index 0000000000..294fa1d863 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SecurityStatus.cs @@ -0,0 +1,54 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum SecurityStatus + { + // Success / Informational + OK = 0x00000000, + ContinueNeeded = unchecked((int)0x00090312), + CompleteNeeded = unchecked((int)0x00090313), + CompAndContinue = unchecked((int)0x00090314), + ContextExpired = unchecked((int)0x00090317), + CredentialsNeeded = unchecked((int)0x00090320), + Renegotiate = unchecked((int)0x00090321), + + // Errors + OutOfMemory = unchecked((int)0x80090300), + InvalidHandle = unchecked((int)0x80090301), + Unsupported = unchecked((int)0x80090302), + TargetUnknown = unchecked((int)0x80090303), + InternalError = unchecked((int)0x80090304), + PackageNotFound = unchecked((int)0x80090305), + NotOwner = unchecked((int)0x80090306), + CannotInstall = unchecked((int)0x80090307), + InvalidToken = unchecked((int)0x80090308), + CannotPack = unchecked((int)0x80090309), + QopNotSupported = unchecked((int)0x8009030A), + NoImpersonation = unchecked((int)0x8009030B), + LogonDenied = unchecked((int)0x8009030C), + UnknownCredentials = unchecked((int)0x8009030D), + NoCredentials = unchecked((int)0x8009030E), + MessageAltered = unchecked((int)0x8009030F), + OutOfSequence = unchecked((int)0x80090310), + NoAuthenticatingAuthority = unchecked((int)0x80090311), + IncompleteMessage = unchecked((int)0x80090318), + IncompleteCredentials = unchecked((int)0x80090320), + BufferNotEnough = unchecked((int)0x80090321), + WrongPrincipal = unchecked((int)0x80090322), + TimeSkew = unchecked((int)0x80090324), + UntrustedRoot = unchecked((int)0x80090325), + IllegalMessage = unchecked((int)0x80090326), + CertUnknown = unchecked((int)0x80090327), + CertExpired = unchecked((int)0x80090328), + AlgorithmMismatch = unchecked((int)0x80090331), + SecurityQosFailed = unchecked((int)0x80090332), + SmartcardLogonRequired = unchecked((int)0x8009033E), + UnsupportedPreauth = unchecked((int)0x80090343), + BadBinding = unchecked((int)0x80090346) + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SocketAddress.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SocketAddress.cs new file mode 100644 index 0000000000..9dfdb69ac3 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/SocketAddress.cs @@ -0,0 +1,342 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.Diagnostics.Contracts; +using System.Globalization; +using System.Text; + +namespace Microsoft.AspNet.Server.WebListener +{ + // a little perf app measured these times when comparing the internal + // buffer implemented as a managed byte[] or unmanaged memory IntPtr + // that's why we use byte[] + // byte[] total ms:19656 + // IntPtr total ms:25671 + + /// + /// + /// This class is used when subclassing EndPoint, and provides indication + /// on how to format the memory buffers that winsock uses for network addresses. + /// + /// + internal class SocketAddress + { + private const int NumberOfIPv6Labels = 8; + // Lower case hex, no leading zeros + private const string IPv6NumberFormat = "{0:x}"; + private const string IPv6StringSeparator = ":"; + private const string IPv4StringFormat = "{0:d}.{1:d}.{2:d}.{3:d}"; + + internal const int IPv6AddressSize = 28; + internal const int IPv4AddressSize = 16; + + private const int WriteableOffset = 2; + + private int _size; + private byte[] _buffer; + private int _hash; + + /// + /// [To be supplied.] + /// + public SocketAddress(AddressFamily family, int size) + { + if (size < WriteableOffset) + { + // it doesn't make sense to create a socket address with less tha + // 2 bytes, that's where we store the address family. + + throw new ArgumentOutOfRangeException("size"); + } + _size = size; + _buffer = new byte[((size / IntPtr.Size) + 2) * IntPtr.Size]; // sizeof DWORD + +#if BIGENDIAN + m_Buffer[0] = unchecked((byte)((int)family>>8)); + m_Buffer[1] = unchecked((byte)((int)family )); +#else + _buffer[0] = unchecked((byte)((int)family)); + _buffer[1] = unchecked((byte)((int)family >> 8)); +#endif + } + + internal byte[] Buffer + { + get { return _buffer; } + } + + internal AddressFamily Family + { + get + { + int family; +#if BIGENDIAN + family = ((int)m_Buffer[0]<<8) | m_Buffer[1]; +#else + family = _buffer[0] | ((int)_buffer[1] << 8); +#endif + return (AddressFamily)family; + } + } + + internal int Size + { + get + { + return _size; + } + } + + // access to unmanaged serialized data. this doesn't + // allow access to the first 2 bytes of unmanaged memory + // that are supposed to contain the address family which + // is readonly. + // + // you can still use negative offsets as a back door in case + // winsock changes the way it uses SOCKADDR. maybe we want to prohibit it? + // maybe we should make the class sealed to avoid potentially dangerous calls + // into winsock with unproperly formatted data? + + /// + /// [To be supplied.] + /// + private byte this[int offset] + { + get + { + // access + if (offset < 0 || offset >= Size) + { + throw new ArgumentOutOfRangeException("offset"); + } + return _buffer[offset]; + } + } + + internal int GetPort() + { + return (int)((_buffer[2] << 8 & 0xFF00) | (_buffer[3])); + } + + public override bool Equals(object comparand) + { + SocketAddress castedComparand = comparand as SocketAddress; + if (castedComparand == null || this.Size != castedComparand.Size) + { + return false; + } + for (int i = 0; i < this.Size; i++) + { + if (this[i] != castedComparand[i]) + { + return false; + } + } + return true; + } + + public override int GetHashCode() + { + if (_hash == 0) + { + int i; + int size = Size & ~3; + + for (i = 0; i < size; i += 4) + { + _hash ^= (int)_buffer[i] + | ((int)_buffer[i + 1] << 8) + | ((int)_buffer[i + 2] << 16) + | ((int)_buffer[i + 3] << 24); + } + if ((Size & 3) != 0) + { + int remnant = 0; + int shift = 0; + + for (; i < Size; ++i) + { + remnant |= ((int)_buffer[i]) << shift; + shift += 8; + } + _hash ^= remnant; + } + } + return _hash; + } + + public override string ToString() + { + StringBuilder bytes = new StringBuilder(); + for (int i = WriteableOffset; i < this.Size; i++) + { + if (i > WriteableOffset) + { + bytes.Append(","); + } + bytes.Append(this[i].ToString(NumberFormatInfo.InvariantInfo)); + } + return Family.ToString() + ":" + Size.ToString(NumberFormatInfo.InvariantInfo) + ":{" + bytes.ToString() + "}"; + } + + internal string GetIPAddressString() + { + if (Family == AddressFamily.InterNetworkV6) + { + return GetIpv6AddressString(); + } + else if (Family == AddressFamily.InterNetwork) + { + return GetIPv4AddressString(); + } + else + { + return null; + } + } + + private string GetIPv4AddressString() + { + Contract.Assert(Size >= IPv4AddressSize); + + return string.Format(CultureInfo.InvariantCulture, IPv4StringFormat, + _buffer[4], _buffer[5], _buffer[6], _buffer[7]); + } + + // TODO: Does scope ID ever matter? + private unsafe string GetIpv6AddressString() + { + Contract.Assert(Size >= IPv6AddressSize); + + fixed (byte* rawBytes = _buffer) + { + // Convert from bytes to shorts. + ushort* rawShorts = stackalloc ushort[NumberOfIPv6Labels]; + int numbersOffset = 0; + // The address doesn't start at the beginning of the buffer. + for (int i = 8; i < ((NumberOfIPv6Labels * 2) + 8); i += 2) + { + rawShorts[numbersOffset++] = (ushort)(rawBytes[i] << 8 | rawBytes[i + 1]); + } + return GetIPv6AddressString(rawShorts); + } + } + + private static unsafe string GetIPv6AddressString(ushort* numbers) + { + // RFC 5952 Sections 4 & 5 - Compressed, lower case, with possible embedded IPv4 addresses. + + // Start to finish, inclusive. <-1, -1> for no compression + KeyValuePair range = FindCompressionRange(numbers); + bool ipv4Embedded = ShouldHaveIpv4Embedded(numbers); + + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < NumberOfIPv6Labels; i++) + { + if (ipv4Embedded && i == (NumberOfIPv6Labels - 2)) + { + // Write the remaining digits as an IPv4 address + builder.Append(IPv6StringSeparator); + builder.Append(string.Format(CultureInfo.InvariantCulture, IPv4StringFormat, + numbers[i] >> 8, numbers[i] & 0xFF, numbers[i + 1] >> 8, numbers[i + 1] & 0xFF)); + break; + } + + // Compression; 1::1, ::1, 1:: + if (range.Key == i) + { + // Start compression, add : + builder.Append(IPv6StringSeparator); + } + if (range.Key <= i && range.Value == (NumberOfIPv6Labels - 1)) + { + // Remainder compressed; 1:: + builder.Append(IPv6StringSeparator); + break; + } + if (range.Key <= i && i <= range.Value) + { + continue; // Compressed + } + + if (i != 0) + { + builder.Append(IPv6StringSeparator); + } + builder.Append(string.Format(CultureInfo.InvariantCulture, IPv6NumberFormat, numbers[i])); + } + + return builder.ToString(); + } + + // RFC 5952 Section 4.2.3 + // Longest consecutive sequence of zero segments, minimum 2. + // On equal, first sequence wins. + // <-1, -1> for no compression. + private static unsafe KeyValuePair FindCompressionRange(ushort* numbers) + { + int longestSequenceLength = 0; + int longestSequenceStart = -1; + + int currentSequenceLength = 0; + for (int i = 0; i < NumberOfIPv6Labels; i++) + { + if (numbers[i] == 0) + { + // In a sequence + currentSequenceLength++; + if (currentSequenceLength > longestSequenceLength) + { + longestSequenceLength = currentSequenceLength; + longestSequenceStart = i - currentSequenceLength + 1; + } + } + else + { + currentSequenceLength = 0; + } + } + + if (longestSequenceLength >= 2) + { + return new KeyValuePair(longestSequenceStart, + longestSequenceStart + longestSequenceLength - 1); + } + + return new KeyValuePair(-1, -1); // No compression + } + + // Returns true if the IPv6 address should be formated with an embedded IPv4 address: + // ::192.168.1.1 + private static unsafe bool ShouldHaveIpv4Embedded(ushort* numbers) + { + // 0:0 : 0:0 : x:x : x.x.x.x + if (numbers[0] == 0 && numbers[1] == 0 && numbers[2] == 0 && numbers[3] == 0 && numbers[6] != 0) + { + // RFC 5952 Section 5 - 0:0 : 0:0 : 0:[0 | FFFF] : x.x.x.x + if (numbers[4] == 0 && (numbers[5] == 0 || numbers[5] == 0xFFFF)) + { + return true; + + // SIIT - 0:0 : 0:0 : FFFF:0 : x.x.x.x + } + else if (numbers[4] == 0xFFFF && numbers[5] == 0) + { + return true; + } + } + // ISATAP + if (numbers[4] == 0 && numbers[5] == 0x5EFE) + { + return true; + } + + return false; + } + } // class SocketAddress +} // namespace System.Net diff --git a/src/Microsoft.AspNet.Server.WebListener/NativeInterop/UnsafeNativeMethods.cs b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/UnsafeNativeMethods.cs new file mode 100644 index 0000000000..0eb890cf8b --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/NativeInterop/UnsafeNativeMethods.cs @@ -0,0 +1,1129 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class UnsafeNclNativeMethods + { + private const string KERNEL32 = "kernel32.dll"; + private const string SECUR32 = "secur32.dll"; + private const string HTTPAPI = "httpapi.dll"; + + // CONSIDER: Make this an enum, requires changing a lot of types from uint to ErrorCodes. + internal static class ErrorCodes + { + internal const uint ERROR_SUCCESS = 0; + internal const uint ERROR_HANDLE_EOF = 38; + internal const uint ERROR_NOT_SUPPORTED = 50; + internal const uint ERROR_INVALID_PARAMETER = 87; + internal const uint ERROR_ALREADY_EXISTS = 183; + internal const uint ERROR_MORE_DATA = 234; + internal const uint ERROR_OPERATION_ABORTED = 995; + internal const uint ERROR_IO_PENDING = 997; + internal const uint ERROR_NOT_FOUND = 1168; + internal const uint ERROR_CONNECTION_INVALID = 1229; + } + + [DllImport(KERNEL32, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint GetCurrentThreadId(); + + [DllImport(KERNEL32, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static unsafe extern uint CancelIoEx(SafeHandle handle, SafeNativeOverlapped overlapped); + + [DllImport(KERNEL32, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static unsafe extern bool SetFileCompletionNotificationModes(SafeHandle handle, FileCompletionNotificationModes modes); + + [Flags] + internal enum FileCompletionNotificationModes : byte + { + None = 0, + SkipCompletionPortOnSuccess = 1, + SkipSetEventOnHandle = 2 + } + + internal static class SafeNetHandles + { + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + internal static extern int FreeContextBuffer( + [In] IntPtr contextBuffer); + + [DllImport(SECUR32, ExactSpelling = true, SetLastError = true)] + internal static unsafe extern int QueryContextAttributesW( + ref SSPIHandle contextHandle, + [In] ContextAttribute attribute, + [In] void* buffer); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern unsafe uint HttpCreateRequestQueue(HttpApi.HTTPAPI_VERSION version, string pName, + Microsoft.AspNet.Server.WebListener.UnsafeNclNativeMethods.SECURITY_ATTRIBUTES pSecurityAttributes, uint flags, out HttpRequestQueueV2Handle pReqQueueHandle); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern unsafe uint HttpCloseRequestQueue(IntPtr pReqQueueHandle); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + internal static extern bool CloseHandle(IntPtr handle); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + internal static extern SafeLocalFree LocalAlloc(int uFlags, UIntPtr sizetdwBytes); + + [DllImport(KERNEL32, EntryPoint = "LocalAlloc", SetLastError = true)] + internal static extern SafeLocalFreeChannelBinding LocalAllocChannelBinding(int uFlags, UIntPtr sizetdwBytes); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + internal static extern IntPtr LocalFree(IntPtr handle); + + [DllImport(KERNEL32, ExactSpelling = true, CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern unsafe SafeLoadLibrary LoadLibraryExW([In] string lpwLibFileName, [In] void* hFile, [In] uint dwFlags); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + internal static extern unsafe bool FreeLibrary([In] IntPtr hModule); + } + + internal static unsafe class HttpApi + { + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpInitialize(HTTPAPI_VERSION version, uint flags, void* pReserved); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpReceiveRequestEntityBody(SafeHandle requestQueueHandle, ulong requestId, uint flags, IntPtr pEntityBuffer, uint entityBufferLength, out uint bytesReturned, SafeNativeOverlapped pOverlapped); + + [DllImport(HTTPAPI, EntryPoint = "HttpReceiveRequestEntityBody", ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpReceiveRequestEntityBody2(SafeHandle requestQueueHandle, ulong requestId, uint flags, void* pEntityBuffer, uint entityBufferLength, out uint bytesReturned, [In] SafeHandle pOverlapped); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpReceiveClientCertificate(SafeHandle requestQueueHandle, ulong connectionId, uint flags, HTTP_SSL_CLIENT_CERT_INFO* pSslClientCertInfo, uint sslClientCertInfoSize, uint* pBytesReceived, SafeNativeOverlapped pOverlapped); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpReceiveClientCertificate(SafeHandle requestQueueHandle, ulong connectionId, uint flags, byte* pSslClientCertInfo, uint sslClientCertInfoSize, uint* pBytesReceived, SafeNativeOverlapped pOverlapped); + + [SuppressMessage("Microsoft.Interoperability", "CA1415:DeclarePInvokesCorrectly", Justification = "NativeOverlapped is now wrapped by SafeNativeOverlapped")] + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpReceiveHttpRequest(SafeHandle requestQueueHandle, ulong requestId, uint flags, HTTP_REQUEST* pRequestBuffer, uint requestBufferLength, uint* pBytesReturned, SafeNativeOverlapped pOverlapped); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpSendHttpResponse(SafeHandle requestQueueHandle, ulong requestId, uint flags, HTTP_RESPONSE* pHttpResponse, void* pCachePolicy, uint* pBytesSent, SafeLocalFree pRequestBuffer, uint requestBufferLength, SafeNativeOverlapped pOverlapped, IntPtr pLogData); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpSendResponseEntityBody(SafeHandle requestQueueHandle, ulong requestId, uint flags, ushort entityChunkCount, HTTP_DATA_CHUNK* pEntityChunks, uint* pBytesSent, SafeLocalFree pRequestBuffer, uint requestBufferLength, SafeNativeOverlapped pOverlapped, IntPtr pLogData); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpCancelHttpRequest(SafeHandle requestQueueHandle, ulong requestId, IntPtr pOverlapped); + + [DllImport(HTTPAPI, EntryPoint = "HttpSendResponseEntityBody", ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpSendResponseEntityBody2(SafeHandle requestQueueHandle, ulong requestId, uint flags, ushort entityChunkCount, IntPtr pEntityChunks, out uint pBytesSent, SafeLocalFree pRequestBuffer, uint requestBufferLength, SafeHandle pOverlapped, IntPtr pLogData); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpWaitForDisconnect(SafeHandle requestQueueHandle, ulong connectionId, SafeNativeOverlapped pOverlapped); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpCreateServerSession(HTTPAPI_VERSION version, ulong* serverSessionId, uint reserved); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpCreateUrlGroup(ulong serverSessionId, ulong* urlGroupId, uint reserved); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern uint HttpAddUrlToUrlGroup(ulong urlGroupId, string pFullyQualifiedUrl, ulong context, uint pReserved); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpSetUrlGroupProperty(ulong urlGroupId, HTTP_SERVER_PROPERTY serverProperty, IntPtr pPropertyInfo, uint propertyInfoLength); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern uint HttpRemoveUrlFromUrlGroup(ulong urlGroupId, string pFullyQualifiedUrl, uint flags); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpCloseServerSession(ulong serverSessionId); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpCloseUrlGroup(ulong urlGroupId); + + [DllImport(HTTPAPI, ExactSpelling = true, CallingConvention = CallingConvention.StdCall, SetLastError = true)] + internal static extern uint HttpSetRequestQueueProperty(SafeHandle requestQueueHandle, HTTP_SERVER_PROPERTY serverProperty, IntPtr pPropertyInfo, uint propertyInfoLength, uint reserved, IntPtr pReserved); + + internal enum HTTP_API_VERSION + { + Invalid, + Version10, + Version20, + } + + // see http.w for definitions + internal enum HTTP_SERVER_PROPERTY + { + HttpServerAuthenticationProperty, + HttpServerLoggingProperty, + HttpServerQosProperty, + HttpServerTimeoutsProperty, + HttpServerQueueLengthProperty, + HttpServerStateProperty, + HttpServer503VerbosityProperty, + HttpServerBindingProperty, + HttpServerExtendedAuthenticationProperty, + HttpServerListenEndpointProperty, + HttpServerChannelBindProperty, + HttpServerProtectionLevelProperty, + } + + // Currently only one request info type is supported but the enum is for future extensibility. + + internal enum HTTP_REQUEST_INFO_TYPE + { + HttpRequestInfoTypeAuth, + } + + internal enum HTTP_RESPONSE_INFO_TYPE + { + HttpResponseInfoTypeMultipleKnownHeaders, + HttpResponseInfoTypeAuthenticationProperty, + HttpResponseInfoTypeQosProperty, + } + + internal enum HTTP_TIMEOUT_TYPE + { + EntityBody, + DrainEntityBody, + RequestQueue, + IdleConnection, + HeaderWait, + MinSendRate, + } + + internal const int MaxTimeout = 6; + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_VERSION + { + internal ushort MajorVersion; + internal ushort MinorVersion; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_KNOWN_HEADER + { + internal ushort RawValueLength; + internal sbyte* pRawValue; + } + + [StructLayout(LayoutKind.Explicit)] + internal struct HTTP_DATA_CHUNK + { + [SuppressMessage("Microsoft.Performance", "CA1823:AvoidUnusedPrivateFields", Justification = "Used natively")] + [FieldOffset(0)] + internal HTTP_DATA_CHUNK_TYPE DataChunkType; + + [FieldOffset(8)] + internal FromMemory fromMemory; + + [FieldOffset(8)] + internal FromFileHandle fromFile; + } + + [SuppressMessage("Microsoft.Design", "CA1049:TypesThatOwnNativeResourcesShouldBeDisposable", + Justification = "This type does not own the native buffer")] + [StructLayout(LayoutKind.Sequential)] + internal struct FromMemory + { + // 4 bytes for 32bit, 8 bytes for 64bit + internal IntPtr pBuffer; + internal uint BufferLength; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct FromFileHandle + { + internal ulong offset; + internal ulong count; + internal IntPtr fileHandle; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTPAPI_VERSION + { + internal ushort HttpApiMajorVersion; + internal ushort HttpApiMinorVersion; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_COOKED_URL + { + internal ushort FullUrlLength; + internal ushort HostLength; + internal ushort AbsPathLength; + internal ushort QueryStringLength; + internal ushort* pFullUrl; + internal ushort* pHost; + internal ushort* pAbsPath; + internal ushort* pQueryString; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct SOCKADDR + { + internal ushort sa_family; + internal byte sa_data; + internal byte sa_data_02; + internal byte sa_data_03; + internal byte sa_data_04; + internal byte sa_data_05; + internal byte sa_data_06; + internal byte sa_data_07; + internal byte sa_data_08; + internal byte sa_data_09; + internal byte sa_data_10; + internal byte sa_data_11; + internal byte sa_data_12; + internal byte sa_data_13; + internal byte sa_data_14; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_TRANSPORT_ADDRESS + { + internal SOCKADDR* pRemoteAddress; + internal SOCKADDR* pLocalAddress; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_SSL_CLIENT_CERT_INFO + { + internal uint CertFlags; + internal uint CertEncodedSize; + internal byte* pCertEncoded; + internal void* Token; + internal byte CertDeniedByMapper; + } + + internal enum HTTP_SERVICE_BINDING_TYPE : uint + { + HttpServiceBindingTypeNone = 0, + HttpServiceBindingTypeW, + HttpServiceBindingTypeA + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_SERVICE_BINDING_BASE + { + internal HTTP_SERVICE_BINDING_TYPE Type; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_REQUEST_CHANNEL_BIND_STATUS + { + internal IntPtr ServiceName; + internal IntPtr ChannelToken; + internal uint ChannelTokenSize; + internal uint Flags; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_UNKNOWN_HEADER + { + internal ushort NameLength; + internal ushort RawValueLength; + internal sbyte* pName; + internal sbyte* pRawValue; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_SSL_INFO + { + internal ushort ServerCertKeySize; + internal ushort ConnectionKeySize; + internal uint ServerCertIssuerSize; + internal uint ServerCertSubjectSize; + internal sbyte* pServerCertIssuer; + internal sbyte* pServerCertSubject; + internal HTTP_SSL_CLIENT_CERT_INFO* pClientCertInfo; + internal uint SslClientCertNegotiated; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_RESPONSE_HEADERS + { + internal ushort UnknownHeaderCount; + internal HTTP_UNKNOWN_HEADER* pUnknownHeaders; + internal ushort TrailerCount; + internal HTTP_UNKNOWN_HEADER* pTrailers; + internal HTTP_KNOWN_HEADER KnownHeaders; + internal HTTP_KNOWN_HEADER KnownHeaders_02; + internal HTTP_KNOWN_HEADER KnownHeaders_03; + internal HTTP_KNOWN_HEADER KnownHeaders_04; + internal HTTP_KNOWN_HEADER KnownHeaders_05; + internal HTTP_KNOWN_HEADER KnownHeaders_06; + internal HTTP_KNOWN_HEADER KnownHeaders_07; + internal HTTP_KNOWN_HEADER KnownHeaders_08; + internal HTTP_KNOWN_HEADER KnownHeaders_09; + internal HTTP_KNOWN_HEADER KnownHeaders_10; + internal HTTP_KNOWN_HEADER KnownHeaders_11; + internal HTTP_KNOWN_HEADER KnownHeaders_12; + internal HTTP_KNOWN_HEADER KnownHeaders_13; + internal HTTP_KNOWN_HEADER KnownHeaders_14; + internal HTTP_KNOWN_HEADER KnownHeaders_15; + internal HTTP_KNOWN_HEADER KnownHeaders_16; + internal HTTP_KNOWN_HEADER KnownHeaders_17; + internal HTTP_KNOWN_HEADER KnownHeaders_18; + internal HTTP_KNOWN_HEADER KnownHeaders_19; + internal HTTP_KNOWN_HEADER KnownHeaders_20; + internal HTTP_KNOWN_HEADER KnownHeaders_21; + internal HTTP_KNOWN_HEADER KnownHeaders_22; + internal HTTP_KNOWN_HEADER KnownHeaders_23; + internal HTTP_KNOWN_HEADER KnownHeaders_24; + internal HTTP_KNOWN_HEADER KnownHeaders_25; + internal HTTP_KNOWN_HEADER KnownHeaders_26; + internal HTTP_KNOWN_HEADER KnownHeaders_27; + internal HTTP_KNOWN_HEADER KnownHeaders_28; + internal HTTP_KNOWN_HEADER KnownHeaders_29; + internal HTTP_KNOWN_HEADER KnownHeaders_30; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_REQUEST_HEADERS + { + internal ushort UnknownHeaderCount; + internal HTTP_UNKNOWN_HEADER* pUnknownHeaders; + internal ushort TrailerCount; + internal HTTP_UNKNOWN_HEADER* pTrailers; + internal HTTP_KNOWN_HEADER KnownHeaders; + internal HTTP_KNOWN_HEADER KnownHeaders_02; + internal HTTP_KNOWN_HEADER KnownHeaders_03; + internal HTTP_KNOWN_HEADER KnownHeaders_04; + internal HTTP_KNOWN_HEADER KnownHeaders_05; + internal HTTP_KNOWN_HEADER KnownHeaders_06; + internal HTTP_KNOWN_HEADER KnownHeaders_07; + internal HTTP_KNOWN_HEADER KnownHeaders_08; + internal HTTP_KNOWN_HEADER KnownHeaders_09; + internal HTTP_KNOWN_HEADER KnownHeaders_10; + internal HTTP_KNOWN_HEADER KnownHeaders_11; + internal HTTP_KNOWN_HEADER KnownHeaders_12; + internal HTTP_KNOWN_HEADER KnownHeaders_13; + internal HTTP_KNOWN_HEADER KnownHeaders_14; + internal HTTP_KNOWN_HEADER KnownHeaders_15; + internal HTTP_KNOWN_HEADER KnownHeaders_16; + internal HTTP_KNOWN_HEADER KnownHeaders_17; + internal HTTP_KNOWN_HEADER KnownHeaders_18; + internal HTTP_KNOWN_HEADER KnownHeaders_19; + internal HTTP_KNOWN_HEADER KnownHeaders_20; + internal HTTP_KNOWN_HEADER KnownHeaders_21; + internal HTTP_KNOWN_HEADER KnownHeaders_22; + internal HTTP_KNOWN_HEADER KnownHeaders_23; + internal HTTP_KNOWN_HEADER KnownHeaders_24; + internal HTTP_KNOWN_HEADER KnownHeaders_25; + internal HTTP_KNOWN_HEADER KnownHeaders_26; + internal HTTP_KNOWN_HEADER KnownHeaders_27; + internal HTTP_KNOWN_HEADER KnownHeaders_28; + internal HTTP_KNOWN_HEADER KnownHeaders_29; + internal HTTP_KNOWN_HEADER KnownHeaders_30; + internal HTTP_KNOWN_HEADER KnownHeaders_31; + internal HTTP_KNOWN_HEADER KnownHeaders_32; + internal HTTP_KNOWN_HEADER KnownHeaders_33; + internal HTTP_KNOWN_HEADER KnownHeaders_34; + internal HTTP_KNOWN_HEADER KnownHeaders_35; + internal HTTP_KNOWN_HEADER KnownHeaders_36; + internal HTTP_KNOWN_HEADER KnownHeaders_37; + internal HTTP_KNOWN_HEADER KnownHeaders_38; + internal HTTP_KNOWN_HEADER KnownHeaders_39; + internal HTTP_KNOWN_HEADER KnownHeaders_40; + internal HTTP_KNOWN_HEADER KnownHeaders_41; + } + + internal enum HTTP_VERB : int + { + HttpVerbUnparsed = 0, + HttpVerbUnknown = 1, + HttpVerbInvalid = 2, + HttpVerbOPTIONS = 3, + HttpVerbGET = 4, + HttpVerbHEAD = 5, + HttpVerbPOST = 6, + HttpVerbPUT = 7, + HttpVerbDELETE = 8, + HttpVerbTRACE = 9, + HttpVerbCONNECT = 10, + HttpVerbTRACK = 11, + HttpVerbMOVE = 12, + HttpVerbCOPY = 13, + HttpVerbPROPFIND = 14, + HttpVerbPROPPATCH = 15, + HttpVerbMKCOL = 16, + HttpVerbLOCK = 17, + HttpVerbUNLOCK = 18, + HttpVerbSEARCH = 19, + HttpVerbMaximum = 20, + } + + internal static readonly string[] HttpVerbs = new string[] + { + null, + "Unknown", + "Invalid", + "OPTIONS", + "GET", + "HEAD", + "POST", + "PUT", + "DELETE", + "TRACE", + "CONNECT", + "TRACK", + "MOVE", + "COPY", + "PROPFIND", + "PROPPATCH", + "MKCOL", + "LOCK", + "UNLOCK", + "SEARCH", + }; + + internal enum HTTP_DATA_CHUNK_TYPE : int + { + HttpDataChunkFromMemory = 0, + HttpDataChunkFromFileHandle = 1, + HttpDataChunkFromFragmentCache = 2, + HttpDataChunkMaximum = 3, + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_RESPONSE_INFO + { + internal HTTP_RESPONSE_INFO_TYPE Type; + internal uint Length; + internal void* pInfo; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_RESPONSE + { + internal uint Flags; + internal HTTP_VERSION Version; + internal ushort StatusCode; + internal ushort ReasonLength; + internal sbyte* pReason; + internal HTTP_RESPONSE_HEADERS Headers; + internal ushort EntityChunkCount; + internal HTTP_DATA_CHUNK* pEntityChunks; + internal ushort ResponseInfoCount; + internal HTTP_RESPONSE_INFO* pResponseInfo; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_REQUEST_AUTH_INFO + { + internal HTTP_AUTH_STATUS AuthStatus; + internal uint SecStatus; + internal uint Flags; + internal HTTP_REQUEST_AUTH_TYPE AuthType; + internal IntPtr AccessToken; + internal uint ContextAttributes; + internal uint PackedContextLength; + internal uint PackedContextType; + internal IntPtr PackedContext; + internal uint MutualAuthDataLength; + internal char* pMutualAuthData; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_REQUEST_INFO + { + internal HTTP_REQUEST_INFO_TYPE InfoType; + internal uint InfoLength; + internal HTTP_REQUEST_AUTH_INFO* pInfo; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_REQUEST + { + internal uint Flags; + internal ulong ConnectionId; + internal ulong RequestId; + internal ulong UrlContext; + internal HTTP_VERSION Version; + internal HTTP_VERB Verb; + internal ushort UnknownVerbLength; + internal ushort RawUrlLength; + internal sbyte* pUnknownVerb; + internal sbyte* pRawUrl; + internal HTTP_COOKED_URL CookedUrl; + internal HTTP_TRANSPORT_ADDRESS Address; + internal HTTP_REQUEST_HEADERS Headers; + internal ulong BytesReceived; + internal ushort EntityChunkCount; + internal HTTP_DATA_CHUNK* pEntityChunks; + internal ulong RawConnectionId; + internal HTTP_SSL_INFO* pSslInfo; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_REQUEST_V2 + { + internal HTTP_REQUEST Request; + internal ushort RequestInfoCount; + internal HTTP_REQUEST_INFO* pRequestInfo; + } + + internal enum HTTP_AUTH_STATUS + { + HttpAuthStatusSuccess, + HttpAuthStatusNotAuthenticated, + HttpAuthStatusFailure, + } + + internal enum HTTP_REQUEST_AUTH_TYPE + { + HttpRequestAuthTypeNone = 0, + HttpRequestAuthTypeBasic, + HttpRequestAuthTypeDigest, + HttpRequestAuthTypeNTLM, + HttpRequestAuthTypeNegotiate, + HttpRequestAuthTypeKerberos + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_SERVER_AUTHENTICATION_INFO + { + internal HTTP_FLAGS Flags; + internal HTTP_AUTH_TYPES AuthSchemes; + internal bool ReceiveMutualAuth; + internal bool ReceiveContextHandle; + internal bool DisableNTLMCredentialCaching; + internal ulong ExFlags; + HTTP_SERVER_AUTHENTICATION_DIGEST_PARAMS DigestParams; + HTTP_SERVER_AUTHENTICATION_BASIC_PARAMS BasicParams; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_SERVER_AUTHENTICATION_DIGEST_PARAMS + { + internal ushort DomainNameLength; + internal char* DomainName; + internal ushort RealmLength; + internal char* Realm; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_SERVER_AUTHENTICATION_BASIC_PARAMS + { + ushort RealmLength; + char* Realm; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_TIMEOUT_LIMIT_INFO + { + internal HTTP_FLAGS Flags; + internal ushort EntityBody; + internal ushort DrainEntityBody; + internal ushort RequestQueue; + internal ushort IdleConnection; + internal ushort HeaderWait; + internal uint MinSendRate; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HTTP_BINDING_INFO + { + internal HTTP_FLAGS Flags; + internal IntPtr RequestQueueHandle; + } + + // see http.w for definitions + [Flags] + internal enum HTTP_FLAGS : uint + { + NONE = 0x00000000, + HTTP_RECEIVE_REQUEST_FLAG_COPY_BODY = 0x00000001, + HTTP_RECEIVE_SECURE_CHANNEL_TOKEN = 0x00000001, + HTTP_SEND_RESPONSE_FLAG_DISCONNECT = 0x00000001, + HTTP_SEND_RESPONSE_FLAG_MORE_DATA = 0x00000002, + HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA = 0x00000004, + HTTP_SEND_RESPONSE_FLAG_RAW_HEADER = 0x00000004, + HTTP_SEND_REQUEST_FLAG_MORE_DATA = 0x00000001, + HTTP_PROPERTY_FLAG_PRESENT = 0x00000001, + HTTP_INITIALIZE_SERVER = 0x00000001, + HTTP_INITIALIZE_CBT = 0x00000004, + HTTP_SEND_RESPONSE_FLAG_OPAQUE = 0x00000040, + } + + [Flags] + internal enum HTTP_AUTH_TYPES : uint + { + NONE = 0x00000000, + HTTP_AUTH_ENABLE_BASIC = 0x00000001, + HTTP_AUTH_ENABLE_DIGEST = 0x00000002, + HTTP_AUTH_ENABLE_NTLM = 0x00000004, + HTTP_AUTH_ENABLE_NEGOTIATE = 0x00000008, + HTTP_AUTH_ENABLE_KERBEROS = 0x00000010, + } + + private const int HttpHeaderRequestMaximum = (int)HttpSysRequestHeader.UserAgent + 1; + private const int HttpHeaderResponseMaximum = (int)HttpSysResponseHeader.WwwAuthenticate + 1; + + internal static class HTTP_REQUEST_HEADER_ID + { + internal static string ToString(int position) + { + return _strings[position]; + } + + private static string[] _strings = + { + "Cache-Control", + "Connection", + "Date", + "Keep-Alive", + "Pragma", + "Trailer", + "Transfer-Encoding", + "Upgrade", + "Via", + "Warning", + + "Allow", + "Content-Length", + "Content-Type", + "Content-Encoding", + "Content-Language", + "Content-Location", + "Content-MD5", + "Content-Range", + "Expires", + "Last-Modified", + + "Accept", + "Accept-Charset", + "Accept-Encoding", + "Accept-Language", + "Authorization", + "Cookie", + "Expect", + "From", + "Host", + "If-Match", + + "If-Modified-Since", + "If-None-Match", + "If-Range", + "If-Unmodified-Since", + "Max-Forwards", + "Proxy-Authorization", + "Referer", + "Range", + "Te", + "Translate", + "User-Agent", + }; + } + + internal static class HTTP_RESPONSE_HEADER_ID + { + private static string[] _strings = + { + "Cache-Control", + "Connection", + "Date", + "Keep-Alive", + "Pragma", + "Trailer", + "Transfer-Encoding", + "Upgrade", + "Via", + "Warning", + + "Allow", + "Content-Length", + "Content-Type", + "Content-Encoding", + "Content-Language", + "Content-Location", + "Content-MD5", + "Content-Range", + "Expires", + "Last-Modified", + + "Accept-Ranges", + "Age", + "ETag", + "Location", + "Proxy-Authenticate", + "Retry-After", + "Server", + "Set-Cookie", + "Vary", + "WWW-Authenticate", + }; + + private static Dictionary _lookupTable = CreateLookupTable(); + + private static Dictionary CreateLookupTable() + { + Dictionary lookupTable = new Dictionary((int)Enum.HttpHeaderResponseMaximum, StringComparer.OrdinalIgnoreCase); + for (int i = 0; i < (int)Enum.HttpHeaderResponseMaximum; i++) + { + lookupTable.Add(_strings[i], i); + } + return lookupTable; + } + + internal static int IndexOfKnownHeader(string HeaderName) + { + int index; + return _lookupTable.TryGetValue(HeaderName, out index) ? index : -1; + } + + internal static string ToString(int position) + { + return _strings[position]; + } + + internal enum Enum + { + HttpHeaderCacheControl = 0, // general-header [section 4.5] + HttpHeaderConnection = 1, // general-header [section 4.5] + HttpHeaderDate = 2, // general-header [section 4.5] + HttpHeaderKeepAlive = 3, // general-header [not in rfc] + HttpHeaderPragma = 4, // general-header [section 4.5] + HttpHeaderTrailer = 5, // general-header [section 4.5] + HttpHeaderTransferEncoding = 6, // general-header [section 4.5] + HttpHeaderUpgrade = 7, // general-header [section 4.5] + HttpHeaderVia = 8, // general-header [section 4.5] + HttpHeaderWarning = 9, // general-header [section 4.5] + + HttpHeaderAllow = 10, // entity-header [section 7.1] + HttpHeaderContentLength = 11, // entity-header [section 7.1] + HttpHeaderContentType = 12, // entity-header [section 7.1] + HttpHeaderContentEncoding = 13, // entity-header [section 7.1] + HttpHeaderContentLanguage = 14, // entity-header [section 7.1] + HttpHeaderContentLocation = 15, // entity-header [section 7.1] + HttpHeaderContentMd5 = 16, // entity-header [section 7.1] + HttpHeaderContentRange = 17, // entity-header [section 7.1] + HttpHeaderExpires = 18, // entity-header [section 7.1] + HttpHeaderLastModified = 19, // entity-header [section 7.1] + + // Response Headers + + HttpHeaderAcceptRanges = 20, // response-header [section 6.2] + HttpHeaderAge = 21, // response-header [section 6.2] + HttpHeaderEtag = 22, // response-header [section 6.2] + HttpHeaderLocation = 23, // response-header [section 6.2] + HttpHeaderProxyAuthenticate = 24, // response-header [section 6.2] + HttpHeaderRetryAfter = 25, // response-header [section 6.2] + HttpHeaderServer = 26, // response-header [section 6.2] + HttpHeaderSetCookie = 27, // response-header [not in rfc] + HttpHeaderVary = 28, // response-header [section 6.2] + HttpHeaderWwwAuthenticate = 29, // response-header [section 6.2] + + HttpHeaderResponseMaximum = 30, + + HttpHeaderMaximum = 41 + } + } + + private static HTTPAPI_VERSION version; + + // This property is used by HttpListener to pass the version structure to the native layer in API + // calls. + + internal static HTTPAPI_VERSION Version + { + get + { + return version; + } + } + + // This property is used by HttpListener to get the Api version in use so that it uses appropriate + // Http APIs. + + internal static HTTP_API_VERSION ApiVersion + { + get + { + if (version.HttpApiMajorVersion == 2 && version.HttpApiMinorVersion == 0) + { + return HTTP_API_VERSION.Version20; + } + else if (version.HttpApiMajorVersion == 1 && version.HttpApiMinorVersion == 0) + { + return HTTP_API_VERSION.Version10; + } + else + { + return HTTP_API_VERSION.Invalid; + } + } + } + + static HttpApi() + { + InitHttpApi(2, 0); + } + + private static void InitHttpApi(ushort majorVersion, ushort minorVersion) + { + version.HttpApiMajorVersion = majorVersion; + version.HttpApiMinorVersion = minorVersion; + + // For pre-Win7 OS versions, we need to check whether http.sys contains the CBT patch. + // We do so by passing HTTP_INITIALIZE_CBT flag to HttpInitialize. If the flag is not + // supported, http.sys is not patched. Note that http.sys will return invalid parameter + // also on Win7, even though it shipped with CBT support. Therefore we must not pass + // the flag on Win7 and later. + uint statusCode = ErrorCodes.ERROR_SUCCESS; + + // on Win7 and later, we don't pass the CBT flag. CBT is always supported. + statusCode = HttpApi.HttpInitialize(version, (uint)HTTP_FLAGS.HTTP_INITIALIZE_SERVER, null); + + supported = statusCode == ErrorCodes.ERROR_SUCCESS; + } + + private static volatile bool supported; + internal static bool Supported + { + get + { + return supported; + } + } + + // Server API + + internal static void GetUnknownHeaders(IDictionary unknownHeaders, byte[] memoryBlob, IntPtr originalAddress) + { + // Return value. + fixed (byte* pMemoryBlob = memoryBlob) + { + HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob; + long fixup = pMemoryBlob - (byte*)originalAddress; + int index; + + // unknown headers + if (request->Headers.UnknownHeaderCount != 0) + { + HTTP_UNKNOWN_HEADER* pUnknownHeader = (HTTP_UNKNOWN_HEADER*)(fixup + (byte*)request->Headers.pUnknownHeaders); + for (index = 0; index < request->Headers.UnknownHeaderCount; index++) + { + // For unknown headers, when header value is empty, RawValueLength will be 0 and + // pRawValue will be null. + if (pUnknownHeader->pName != null && pUnknownHeader->NameLength > 0) + { + string headerName = HeaderEncoding.GetString(pUnknownHeader->pName + fixup, pUnknownHeader->NameLength); + string headerValue; + if (pUnknownHeader->pRawValue != null && pUnknownHeader->RawValueLength > 0) + { + headerValue = HeaderEncoding.GetString(pUnknownHeader->pRawValue + fixup, pUnknownHeader->RawValueLength); + } + else + { + headerValue = string.Empty; + } + // Note that Http.Sys currently collapses all headers of the same name to a single string, so + // append will just set. + unknownHeaders.Append(headerName, headerValue); + } + pUnknownHeader++; + } + } + } + } + + private static string GetKnownHeader(HTTP_REQUEST* request, long fixup, int headerIndex) + { + string header = null; + + HTTP_KNOWN_HEADER* pKnownHeader = (&request->Headers.KnownHeaders) + headerIndex; + // For known headers, when header value is empty, RawValueLength will be 0 and + // pRawValue will point to empty string ("\0") + if (pKnownHeader->pRawValue != null) + { + header = HeaderEncoding.GetString(pKnownHeader->pRawValue + fixup, pKnownHeader->RawValueLength); + } + + return header; + } + + internal static string GetKnownHeader(byte[] memoryBlob, IntPtr originalAddress, int headerIndex) + { + fixed (byte* pMemoryBlob = memoryBlob) + { + return GetKnownHeader((HTTP_REQUEST*)pMemoryBlob, pMemoryBlob - (byte*)originalAddress, headerIndex); + } + } + + private static unsafe string GetVerb(HTTP_REQUEST* request, long fixup) + { + string verb = null; + + if ((int)request->Verb > (int)HTTP_VERB.HttpVerbUnknown && (int)request->Verb < (int)HTTP_VERB.HttpVerbMaximum) + { + verb = HttpVerbs[(int)request->Verb]; + } + else if (request->Verb == HTTP_VERB.HttpVerbUnknown && request->pUnknownVerb != null) + { + verb = HeaderEncoding.GetString(request->pUnknownVerb + fixup, request->UnknownVerbLength); + } + + return verb; + } + + internal static unsafe string GetVerb(byte[] memoryBlob, IntPtr originalAddress) + { + fixed (byte* pMemoryBlob = memoryBlob) + { + return GetVerb((HTTP_REQUEST*)pMemoryBlob, pMemoryBlob - (byte*)originalAddress); + } + } + + internal static HTTP_VERB GetKnownVerb(byte[] memoryBlob, IntPtr originalAddress) + { + // Return value. + HTTP_VERB verb = HTTP_VERB.HttpVerbUnknown; + fixed (byte* pMemoryBlob = memoryBlob) + { + HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob; + if ((int)request->Verb > (int)HTTP_VERB.HttpVerbUnparsed && (int)request->Verb < (int)HTTP_VERB.HttpVerbMaximum) + { + verb = request->Verb; + } + } + + return verb; + } + + internal static uint GetChunks(byte[] memoryBlob, IntPtr originalAddress, ref int dataChunkIndex, ref uint dataChunkOffset, byte[] buffer, int offset, int size) + { + // Return value. + uint dataRead = 0; + fixed (byte* pMemoryBlob = memoryBlob) + { + HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob; + long fixup = pMemoryBlob - (byte*)originalAddress; + + if (request->EntityChunkCount > 0 && dataChunkIndex < request->EntityChunkCount && dataChunkIndex != -1) + { + HTTP_DATA_CHUNK* pDataChunk = (HTTP_DATA_CHUNK*)(fixup + (byte*)&request->pEntityChunks[dataChunkIndex]); + + fixed (byte* pReadBuffer = buffer) + { + byte* pTo = &pReadBuffer[offset]; + + while (dataChunkIndex < request->EntityChunkCount && dataRead < size) + { + if (dataChunkOffset >= pDataChunk->fromMemory.BufferLength) + { + dataChunkOffset = 0; + dataChunkIndex++; + pDataChunk++; + } + else + { + byte* pFrom = (byte*)pDataChunk->fromMemory.pBuffer + dataChunkOffset + fixup; + + uint bytesToRead = pDataChunk->fromMemory.BufferLength - (uint)dataChunkOffset; + if (bytesToRead > (uint)size) + { + bytesToRead = (uint)size; + } + for (uint i = 0; i < bytesToRead; i++) + { + *(pTo++) = *(pFrom++); + } + dataRead += bytesToRead; + dataChunkOffset += bytesToRead; + } + } + } + } + // we're finished. + if (dataChunkIndex == request->EntityChunkCount) + { + dataChunkIndex = -1; + } + } + + return dataRead; + } + + internal static SocketAddress GetRemoteEndPoint(byte[] memoryBlob, IntPtr originalAddress) + { + fixed (byte* pMemoryBlob = memoryBlob) + { + HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob; + return GetEndPoint(memoryBlob, originalAddress, (byte*)request->Address.pRemoteAddress); + } + } + + internal static SocketAddress GetLocalEndPoint(byte[] memoryBlob, IntPtr originalAddress) + { + fixed (byte* pMemoryBlob = memoryBlob) + { + HTTP_REQUEST* request = (HTTP_REQUEST*)pMemoryBlob; + return GetEndPoint(memoryBlob, originalAddress, (byte*)request->Address.pLocalAddress); + } + } + + internal static SocketAddress GetEndPoint(byte[] memoryBlob, IntPtr originalAddress, byte* source) + { + fixed (byte* pMemoryBlob = memoryBlob) + { + IntPtr address = source != null ? + (IntPtr)(pMemoryBlob - (byte*)originalAddress + source) : IntPtr.Zero; + return CopyOutAddress(address); + } + } + + private static SocketAddress CopyOutAddress(IntPtr address) + { + if (address != IntPtr.Zero) + { + ushort addressFamily = *((ushort*)address); + if (addressFamily == (ushort)AddressFamily.InterNetwork) + { + SocketAddress v4address = new SocketAddress(AddressFamily.InterNetwork, SocketAddress.IPv4AddressSize); + fixed (byte* pBuffer = v4address.Buffer) + { + for (int index = 2; index < SocketAddress.IPv4AddressSize; index++) + { + pBuffer[index] = ((byte*)address)[index]; + } + } + return v4address; + } + if (addressFamily == (ushort)AddressFamily.InterNetworkV6) + { + SocketAddress v6address = new SocketAddress(AddressFamily.InterNetworkV6, SocketAddress.IPv6AddressSize); + fixed (byte* pBuffer = v6address.Buffer) + { + for (int index = 2; index < SocketAddress.IPv6AddressSize; index++) + { + pBuffer[index] = ((byte*)address)[index]; + } + } + return v6address; + } + } + + return null; + } + } + + // DACL related stuff + + [SuppressMessage("Microsoft.Performance", "CA1812:AvoidUninstantiatedInternalClasses", Justification = "Instantiated natively")] + [SuppressMessage("Microsoft.Design", "CA1001:TypesThatOwnDisposableFieldsShouldBeDisposable", + Justification = "Does not own the resource.")] + [StructLayout(LayoutKind.Sequential)] + internal class SECURITY_ATTRIBUTES + { + public int nLength = 12; + public SafeLocalMemHandle lpSecurityDescriptor = new SafeLocalMemHandle(IntPtr.Zero, false); + public bool bInheritHandle = false; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/OwinServerFactory.cs b/src/Microsoft.AspNet.Server.WebListener/OwinServerFactory.cs new file mode 100644 index 0000000000..6b0b205db7 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/OwinServerFactory.cs @@ -0,0 +1,108 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- +// Copyright 2011-2012 Katana contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + using AppFunc = Func, Task>; + using LoggerFactoryFunc = Func, bool>>; + + /// + /// Implements the Katana setup pattern for this server. + /// + public static class OwinServerFactory + { + /// + /// Populates the server capabilities. + /// Also included is a configurable instance of the server. + /// + /// + [SuppressMessage("Microsoft.Reliability", "CA2000:Dispose objects before losing scope", Justification = "Disposed by caller")] + public static void Initialize(IDictionary properties) + { + if (properties == null) + { + throw new ArgumentNullException("properties"); + } + + properties[Constants.VersionKey] = Constants.OwinVersion; + + IDictionary capabilities = + properties.Get>(Constants.ServerCapabilitiesKey) + ?? new Dictionary(); + properties[Constants.ServerCapabilitiesKey] = capabilities; + + // SendFile + capabilities[Constants.SendFileVersionKey] = Constants.SendFileVersion; + IDictionary sendfileSupport = new Dictionary(); + sendfileSupport[Constants.SendFileConcurrencyKey] = Constants.Overlapped; + capabilities[Constants.SendFileSupportKey] = sendfileSupport; + + // Opaque + if (ComNetOS.IsWin8orLater) + { + capabilities[Constants.OpaqueVersionKey] = Constants.OpaqueVersion; + } + + // Directly expose the server for advanced configuration. + properties[typeof(OwinWebListener).FullName] = new OwinWebListener(); + } + + /// + /// Creates a server and starts listening on the given addresses. + /// + /// The application entry point. + /// The configuration. + /// The server. Invoke Dispose to shut down. + [SuppressMessage("Microsoft.Design", "CA1006:DoNotNestGenericTypesInMemberSignatures", Justification = "By design")] + [SuppressMessage("Microsoft.Reliability", "CA2000:Dispose objects before losing scope", Justification = "Disposed by caller")] + public static IDisposable Create(AppFunc app, IDictionary properties) + { + if (app == null) + { + throw new ArgumentNullException("app"); + } + if (properties == null) + { + throw new ArgumentNullException("properties"); + } + + var addresses = properties.Get>>("host.Addresses") + ?? new List>(); + + OwinWebListener server = properties.Get(typeof(OwinWebListener).FullName) + ?? new OwinWebListener(); + + var capabilities = + properties.Get>(Constants.ServerCapabilitiesKey) + ?? new Dictionary(); + + var loggerFactory = properties.Get(Constants.ServerLoggerFactoryKey); + + server.Start(app, addresses, capabilities, loggerFactory); + return server; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/OwinWebListener.cs b/src/Microsoft.AspNet.Server.WebListener/OwinWebListener.cs new file mode 100644 index 0000000000..d6c1cd916a --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/OwinWebListener.cs @@ -0,0 +1,1107 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.Runtime.InteropServices; +using System.Security.Authentication.ExtendedProtection; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + using AppFunc = Func, Task>; + using LoggerFactoryFunc = Func, bool>>; + using LoggerFunc = Func, bool>; + + /// + /// An HTTP server wrapping the Http.Sys APIs that accepts requests and passes them on to the given OWIN application. + /// + public sealed class OwinWebListener : IDisposable + { + private const long DefaultRequestQueueLength = 1000; // Http.sys default. +#if NET45 + private static readonly Type ChannelBindingStatusType = typeof(UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_CHANNEL_BIND_STATUS); + private static readonly int RequestChannelBindStatusSize = + Marshal.SizeOf(typeof(UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_CHANNEL_BIND_STATUS)); + private static readonly int BindingInfoSize = + Marshal.SizeOf(typeof(UnsafeNclNativeMethods.HttpApi.HTTP_BINDING_INFO)); +#else + private static readonly int RequestChannelBindStatusSize = + Marshal.SizeOf(); + private static readonly int BindingInfoSize = + Marshal.SizeOf(); +#endif + private static readonly int DefaultMaxAccepts = 5 * Environment.ProcessorCount; + private static readonly int DefaultMaxRequests = Int32.MaxValue; + + // Win8# 559317 fixed a bug in Http.sys's HttpReceiveClientCertificate method. + // Without this fix IOCP callbacks were not being called although ERROR_IO_PENDING was + // returned from HttpReceiveClientCertificate when using the + // FileCompletionNotificationModes.SkipCompletionPortOnSuccess flag. + // This bug was only hit when the buffer passed into HttpReceiveClientCertificate + // (1500 bytes initially) is tool small for the certificate. + // Due to this bug in downlevel operating systems the FileCompletionNotificationModes.SkipCompletionPortOnSuccess + // flag is only used on Win8 and later. + internal static readonly bool SkipIOCPCallbackOnSuccess = ComNetOS.IsWin8orLater; + + // Mitigate potential DOS attacks by limiting the number of unknown headers we accept. Numerous header names + // with hash collisions will cause the server to consume excess CPU. 1000 headers limits CPU time to under + // 0.5 seconds per request. Respond with a 400 Bad Request. + private const int UnknownHeaderLimit = 1000; + + private readonly ConcurrentDictionary _connectionCancellationTokens; + + private AppFunc _appFunc; + private IDictionary _capabilities; + private LoggerFunc _logger; + + private SafeHandle _requestQueueHandle; + private volatile State _state; // m_State is set only within lock blocks, but often read outside locks. + private bool _ignoreWriteExceptions; + private HttpServerSessionHandle _serverSessionHandle; + private ulong _urlGroupId; + private TimeoutManager _timeoutManager; + private AuthenticationManager _authManager; + private bool _v2Initialized; + + private object _internalLock; + + private List _uriPrefixes = new List(); + + private PumpLimits _pumpLimits; + private int _currentOutstandingAccepts; + private int _currentOutstandingRequests; + private Action _offloadListenForNextRequest; + + // The native request queue + private long? _requestQueueLength; + + internal OwinWebListener() + { + if (!UnsafeNclNativeMethods.HttpApi.Supported) + { + throw new PlatformNotSupportedException(); + } + + Debug.Assert(UnsafeNclNativeMethods.HttpApi.ApiVersion == + UnsafeNclNativeMethods.HttpApi.HTTP_API_VERSION.Version20, "Invalid Http api version"); + + _state = State.Stopped; + _internalLock = new object(); + + _timeoutManager = new TimeoutManager(this); + _authManager = new AuthenticationManager(this); + _connectionCancellationTokens = new ConcurrentDictionary(); + + _offloadListenForNextRequest = new Action(ListenForNextRequestAsync); + + _pumpLimits = new PumpLimits(DefaultMaxAccepts, DefaultMaxRequests); + } + + internal enum State + { + Stopped, + Started, + Disposed, + } + + internal LoggerFunc Logger + { + get { return _logger; } + } + + internal List UriPrefixes + { + get { return _uriPrefixes; } + } + + internal IDictionary Capabilities + { + get { return _capabilities; } + } + + internal SafeHandle RequestQueueHandle + { + get + { + return _requestQueueHandle; + } + } + + /// + /// Exposes the Http.Sys timeout configurations. These may also be configured in the registry. + /// + public TimeoutManager TimeoutManager + { + get + { + ValidateV2Property(); + Debug.Assert(_timeoutManager != null, "Timeout manager is not assigned"); + return _timeoutManager; + } + } + + public AuthenticationManager AuthenticationManager + { + get + { + ValidateV2Property(); + Debug.Assert(_authManager != null, "Auth manager is not assigned"); + return _authManager; + } + } + + internal static bool IsSupported + { + get + { + return UnsafeNclNativeMethods.HttpApi.Supported; + } + } + + internal bool IsListening + { + get + { + return _state == State.Started; + } + } + + internal bool IgnoreWriteExceptions + { + get + { + return _ignoreWriteExceptions; + } + set + { + CheckDisposed(); + _ignoreWriteExceptions = value; + } + } + + private bool CanAcceptMoreRequests + { + get + { + PumpLimits limits = _pumpLimits; + return (_currentOutstandingAccepts < limits.MaxOutstandingAccepts + && _currentOutstandingRequests < limits.MaxOutstandingRequests - _currentOutstandingAccepts); + } + } + + /// + /// These are merged as one operation because they should be swapped out atomically. + /// This controls how many requests the server attempts to process concurrently. + /// + /// The maximum number of pending accepts. + /// The maximum number of outstanding requests. + public void SetRequestProcessingLimits(int maxAccepts, int maxRequests) + { + _pumpLimits = new PumpLimits(maxAccepts, maxRequests); + + // Kick the pump in case we went from zero to non-zero limits. + OffloadListenForNextRequestAsync(); + } + + /// + /// Gets the request processing limits. + /// + /// The maximum number of pending accepts. + /// The maximum number of outstanding requests. + [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "0#", Justification = "By design")] + [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "1#", Justification = "By design")] + public void GetRequestProcessingLimits(out int maxAccepts, out int maxRequests) + { + PumpLimits limits = _pumpLimits; + maxAccepts = limits.MaxOutstandingAccepts; + maxRequests = limits.MaxOutstandingRequests; + } + + /// + /// Sets the maximum number of requests that will be queued up in Http.Sys. + /// + /// + public void SetRequestQueueLimit(long limit) + { + if (limit <= 0) + { + throw new ArgumentOutOfRangeException("limit", limit, string.Empty); + } + if ((!_requestQueueLength.HasValue && limit == DefaultRequestQueueLength) + || (_requestQueueLength.HasValue && limit == _requestQueueLength.Value)) + { + return; + } + + _requestQueueLength = limit; + + SetRequestQueueLimit(); + } + + private unsafe void SetRequestQueueLimit() + { + // The listener must be active for this to work. Call from Start after activating. + if (!IsListening || !_requestQueueLength.HasValue) + { + return; + } + + long length = _requestQueueLength.Value; + uint result = UnsafeNclNativeMethods.HttpApi.HttpSetRequestQueueProperty(_requestQueueHandle, + UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_PROPERTY.HttpServerQueueLengthProperty, + new IntPtr((void*)&length), (uint)Marshal.SizeOf(length), 0, IntPtr.Zero); + + if (result != 0) + { + throw new WebListenerException((int)result); + } + } + + private void ValidateV2Property() + { + // Make sure that calling CheckDisposed and SetupV2Config is an atomic operation. This + // avoids race conditions if the listener is aborted/closed after CheckDisposed(), but + // before SetupV2Config(). + lock (_internalLock) + { + CheckDisposed(); + SetupV2Config(); + } + } + + internal void SetUrlGroupProperty(UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_PROPERTY property, IntPtr info, uint infosize) + { + ValidateV2Property(); + uint statusCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS; + + Debug.Assert(_urlGroupId != 0, "SetUrlGroupProperty called with invalid url group id"); + Debug.Assert(info != IntPtr.Zero, "SetUrlGroupProperty called with invalid pointer"); + + // Set the url group property using Http Api. + + statusCode = UnsafeNclNativeMethods.HttpApi.HttpSetUrlGroupProperty( + _urlGroupId, property, info, infosize); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + WebListenerException exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_logger, "SetUrlGroupProperty", exception); + throw exception; + } + } + + internal void RemoveAll(bool clear) + { + CheckDisposed(); + // go through the uri list and unregister for each one of them + if (_uriPrefixes.Count > 0) + { + LogHelper.LogInfo(_logger, "RemoveAll"); + if (_state == State.Started) + { + foreach (Prefix registeredPrefix in _uriPrefixes) + { + // ignore possible failures + InternalRemovePrefix(registeredPrefix.Whole); + } + } + + if (clear) + { + _uriPrefixes.Clear(); + } + } + } + + private IntPtr DangerousGetHandle() + { + return _requestQueueHandle.DangerousGetHandle(); + } + + private unsafe void SetupV2Config() + { + uint statusCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS; + ulong id = 0; + + // If we have already initialized V2 config, then nothing to do. + + if (_v2Initialized) + { + return; + } + + // V2 initialization sequence: + // 1. Create server session + // 2. Create url group + // 3. Create request queue - Done in Start() + // 4. Add urls to url group - Done in Start() + // 5. Attach request queue to url group - Done in Start() + + try + { + statusCode = UnsafeNclNativeMethods.HttpApi.HttpCreateServerSession( + UnsafeNclNativeMethods.HttpApi.Version, &id, 0); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + throw new WebListenerException((int)statusCode); + } + + Debug.Assert(id != 0, "Invalid id returned by HttpCreateServerSession"); + + _serverSessionHandle = new HttpServerSessionHandle(id); + + id = 0; + statusCode = UnsafeNclNativeMethods.HttpApi.HttpCreateUrlGroup( + _serverSessionHandle.DangerousGetServerSessionId(), &id, 0); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + throw new WebListenerException((int)statusCode); + } + + Debug.Assert(id != 0, "Invalid id returned by HttpCreateUrlGroup"); + _urlGroupId = id; + + _v2Initialized = true; + } + catch (Exception exception) + { + // If V2 initialization fails, we mark object as unusable. + _state = State.Disposed; + // If Url group or request queue creation failed, close server session before throwing. + if (_serverSessionHandle != null) + { + _serverSessionHandle.Dispose(); + } + LogHelper.LogException(_logger, "SetupV2Config", exception); + throw; + } + } + + internal void Start(AppFunc app, IList> addresses, IDictionary capabilities, LoggerFactoryFunc loggerFactory) + { + CheckDisposed(); + // Can't call Start twice + Contract.Assert(_appFunc == null); + + Contract.Assert(app != null); + Contract.Assert(addresses != null); + Contract.Assert(capabilities != null); + + _appFunc = app; + _capabilities = capabilities; + _logger = LogHelper.CreateLogger(loggerFactory, typeof(OwinWebListener)); + LogHelper.LogInfo(_logger, "Start"); + + foreach (var address in addresses) + { + // Build addresses from parts + var scheme = address.Get("scheme") ?? Constants.HttpScheme; + var host = address.Get("host") ?? "localhost"; + var port = address.Get("port") ?? "5000"; + var path = address.Get("path") ?? string.Empty; + + Prefix prefix = Prefix.Create(scheme, host, port, path); + _uriPrefixes.Add(prefix); + } + + // Make sure there are no race conditions between Start/Stop/Abort/Close/Dispose and + // calls to SetupV2Config: Start needs to setup all resources (esp. in V2 where besides + // the request handle, there is also a server session and a Url group. Abort/Stop must + // not interfere while Start is allocating those resources. The lock also makes sure + // all methods changing state can read and change the state in an atomic way. + lock (_internalLock) + { + try + { + CheckDisposed(); + if (_state == State.Started) + { + return; + } + + // SetupV2Config() is not called in the ctor, because it may throw. This would + // be a regression since in v1 the ctor never threw. Besides, ctors should do + // minimal work according to the framework design guidelines. + SetupV2Config(); + CreateRequestQueueHandle(); + AttachRequestQueueToUrlGroup(); + + // All resources are set up correctly. Now add all prefixes. + try + { + AddAllPrefixes(); + } + catch (WebListenerException) + { + // If an error occurred while adding prefixes, free all resources allocated by previous steps. + DetachRequestQueueFromUrlGroup(); + throw; + } + + _state = State.Started; + + SetRequestQueueLimit(); + + OffloadListenForNextRequestAsync(); + } + catch (Exception exception) + { + // Make sure the HttpListener instance can't be used if Start() failed. + _state = State.Disposed; + CloseRequestQueueHandle(); + CleanupV2Config(); + LogHelper.LogException(_logger, "Start", exception); + throw; + } + } + } + + // Make sure the next request is processed on another thread as to not recursively + // block the request we just received. + private void OffloadListenForNextRequestAsync() + { + if (IsListening && CanAcceptMoreRequests) + { + Task offloadTask = Task.Run(_offloadListenForNextRequest); + } + } + + // The message pump. + // When we start listening for the next request on one thread, we may need to be sure that the + // completion continues on another thread as to not block the current request processing. + // The awaits will manage stack depth for us. + private async void ListenForNextRequestAsync() + { + while (IsListening && CanAcceptMoreRequests) + { + // Receive a request + RequestContext requestContext; + Interlocked.Increment(ref _currentOutstandingAccepts); + try + { + requestContext = await GetContextAsync().SupressContext(); + Interlocked.Decrement(ref _currentOutstandingAccepts); + } + catch (Exception exception) + { + LogHelper.LogException(_logger, "ListenForNextRequestAsync", exception); + // Assume the server has stopped. + Interlocked.Decrement(ref _currentOutstandingAccepts); + Contract.Assert(!IsListening); + return; + } + + Interlocked.Increment(ref _currentOutstandingRequests); + OffloadListenForNextRequestAsync(); + await ProcessRequestAsync(requestContext).SupressContext(); + Interlocked.Decrement(ref _currentOutstandingRequests); + } + } + + private async Task ProcessRequestAsync(RequestContext requestContext) + { + try + { + try + { + // TODO: Make disconnect registration lazy + RegisterForDisconnectNotification(requestContext); + await _appFunc(requestContext.Environment).SupressContext(); + await requestContext.ProcessResponseAsync().SupressContext(); + } + catch (Exception ex) + { + LogHelper.LogException(_logger, "ProcessRequestAsync", ex); + if (requestContext.Response.SentHeaders) + { + requestContext.Abort(); + } + else + { + // We haven't sent a response yet, try to send a 500 Internal Server Error + requestContext.SetFatalResponse(); + } + } + requestContext.Dispose(); + } + catch (Exception ex) + { + LogHelper.LogException(_logger, "ProcessRequestAsync", ex); + requestContext.Abort(); + requestContext.Dispose(); + } + } + + private void CleanupV2Config() + { + // If we never setup V2, just return. + if (!_v2Initialized) + { + return; + } + + // V2 stopping sequence: + // 1. Detach request queue from url group - Done in Stop()/Abort() + // 2. Remove urls from url group - Done in Stop() + // 3. Close request queue - Done in Stop()/Abort() + // 4. Close Url group. + // 5. Close server session. + + Debug.Assert(_urlGroupId != 0, "HttpCloseUrlGroup called with invalid url group id"); + + uint statusCode = UnsafeNclNativeMethods.HttpApi.HttpCloseUrlGroup(_urlGroupId); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + LogHelper.LogError(_logger, "CleanupV2Config", "Result: " + statusCode); + } + _urlGroupId = 0; + + Debug.Assert(_serverSessionHandle != null, "ServerSessionHandle is null in CloseV2Config"); + Debug.Assert(!_serverSessionHandle.IsInvalid, "ServerSessionHandle is invalid in CloseV2Config"); + + _serverSessionHandle.Dispose(); + } + + private unsafe void AttachRequestQueueToUrlGroup() + { + // Set the association between request queue and url group. After this, requests for registered urls will + // get delivered to this request queue. + + UnsafeNclNativeMethods.HttpApi.HTTP_BINDING_INFO info = new UnsafeNclNativeMethods.HttpApi.HTTP_BINDING_INFO(); + info.Flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_PROPERTY_FLAG_PRESENT; + info.RequestQueueHandle = DangerousGetHandle(); + + IntPtr infoptr = new IntPtr(&info); + + SetUrlGroupProperty(UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_PROPERTY.HttpServerBindingProperty, + infoptr, (uint)BindingInfoSize); + } + + private unsafe void DetachRequestQueueFromUrlGroup() + { + Debug.Assert(_urlGroupId != 0, "DetachRequestQueueFromUrlGroup can't detach using Url group id 0."); + + // Break the association between request queue and url group. After this, requests for registered urls + // will get 503s. + // Note that this method may be called multiple times (Stop() and then Abort()). This + // is fine since http.sys allows to set HttpServerBindingProperty multiple times for valid + // Url groups. + + UnsafeNclNativeMethods.HttpApi.HTTP_BINDING_INFO info = new UnsafeNclNativeMethods.HttpApi.HTTP_BINDING_INFO(); + info.Flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + info.RequestQueueHandle = IntPtr.Zero; + + IntPtr infoptr = new IntPtr(&info); + + uint statusCode = UnsafeNclNativeMethods.HttpApi.HttpSetUrlGroupProperty(_urlGroupId, + UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_PROPERTY.HttpServerBindingProperty, + infoptr, (uint)BindingInfoSize); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + LogHelper.LogError(_logger, "DetachRequestQueueFromUrlGroup", "Result: " + statusCode); + } + } + + internal void Stop() + { + try + { + lock (_internalLock) + { + CheckDisposed(); + if (_state == State.Stopped) + { + return; + } + LogHelper.LogInfo(_logger, "Stop"); + + RemoveAll(false); + + _state = State.Stopped; + + DetachRequestQueueFromUrlGroup(); + + // Even though it would be enough to just detach the request queue in v2, in order to + // keep app compat with earlier versions of the framework, we need to close the request queue. + // This will make sure that pending GetContext() calls will complete and throw an exception. Just + // detaching the url group from the request queue would not cause GetContext() to return. + CloseRequestQueueHandle(); + } + } + catch (Exception exception) + { + LogHelper.LogException(_logger, "Stop", exception); + throw; + } + } + + private unsafe void CreateRequestQueueHandle() + { + uint statusCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS; + + HttpRequestQueueV2Handle requestQueueHandle = null; + statusCode = + UnsafeNclNativeMethods.SafeNetHandles.HttpCreateRequestQueue( + UnsafeNclNativeMethods.HttpApi.Version, null, null, 0, out requestQueueHandle); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + throw new WebListenerException((int)statusCode); + } + + // Disabling callbacks when IO operation completes synchronously (returns ErrorCodes.ERROR_SUCCESS) + if (SkipIOCPCallbackOnSuccess && + !UnsafeNclNativeMethods.SetFileCompletionNotificationModes( + requestQueueHandle, + UnsafeNclNativeMethods.FileCompletionNotificationModes.SkipCompletionPortOnSuccess | + UnsafeNclNativeMethods.FileCompletionNotificationModes.SkipSetEventOnHandle)) + { + throw new WebListenerException(Marshal.GetLastWin32Error()); + } + + _requestQueueHandle = requestQueueHandle; + ThreadPool.BindHandle(_requestQueueHandle); + } + + private unsafe void CloseRequestQueueHandle() + { + if ((_requestQueueHandle != null) && (!_requestQueueHandle.IsInvalid)) + { + _requestQueueHandle.Dispose(); + } + } + + /// + /// Stop the server and clean up. + /// + public void Dispose() + { + Dispose(true); + } + + // old API, now private, and helper methods + private void Dispose(bool disposing) + { + if (!disposing) + { + return; + } + + lock (_internalLock) + { + try + { + if (_state == State.Disposed) + { + return; + } + LogHelper.LogInfo(_logger, "Dispose"); + + Stop(); + CleanupV2Config(); + } + catch (Exception exception) + { + LogHelper.LogException(_logger, "Dispose", exception); + throw; + } + finally + { + _state = State.Disposed; + } + } + } + + private uint InternalAddPrefix(string uriPrefix, int contextId) + { + uint statusCode = 0; + + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpAddUrlToUrlGroup( + _urlGroupId, + uriPrefix, + (ulong)contextId, + 0); + + return statusCode; + } + + private bool InternalRemovePrefix(string uriPrefix) + { + uint statusCode = 0; + + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpRemoveUrlFromUrlGroup( + _urlGroupId, + uriPrefix, + 0); + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_NOT_FOUND) + { + return false; + } + return true; + } + + private void AddAllPrefixes() + { + // go through the uri list and register for each one of them + if (_uriPrefixes.Count > 0) + { + for (int i = 0; i < _uriPrefixes.Count; i++) + { + // We'll get this index back on each request and use it to look up the prefix to calculate PathBase. + Prefix registeredPrefix = _uriPrefixes[i]; + uint statusCode = InternalAddPrefix(registeredPrefix.Whole, i); + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_ALREADY_EXISTS) + { + throw new WebListenerException((int)statusCode, String.Format(Resources.Exception_PrefixAlreadyRegistered, registeredPrefix.Whole)); + } + else + { + throw new WebListenerException((int)statusCode); + } + } + } + } + } + + internal unsafe bool ValidateRequest(NativeRequestContext requestMemory) + { + // Block potential DOS attacks + if (requestMemory.RequestBlob->Headers.UnknownHeaderCount > UnknownHeaderLimit) + { + SendError(requestMemory.RequestBlob->RequestId, HttpStatusCode.BadRequest); + return false; + } + return true; + } + + [SuppressMessage("Microsoft.Reliability", "CA2000:Dispose objects before losing scope", Justification = "Disposed by callback")] + internal Task GetContextAsync() + { + AsyncAcceptContext asyncResult = null; + try + { + CheckDisposed(); + Debug.Assert(_state != State.Stopped, "Listener has been stopped."); + // prepare the ListenerAsyncResult object (this will have it's own + // event that the user can wait on for IO completion - which means we + // need to signal it when IO completes) + asyncResult = new AsyncAcceptContext(this); + uint statusCode = asyncResult.QueueBeginGetContext(); + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + // someother bad error, possible(?) return values are: + // ERROR_INVALID_HANDLE, ERROR_INSUFFICIENT_BUFFER, ERROR_OPERATION_ABORTED + asyncResult.Dispose(); + throw new WebListenerException((int)statusCode); + } + } + catch (Exception exception) + { + LogHelper.LogException(_logger, "GetContextAsync", exception); + throw; + } + + return asyncResult.Task; + } + + private void RegisterForDisconnectNotification(RequestContext requestContext) + { + try + { + // Create exactly one CancellationToken per connection. + ulong connectionId = requestContext.Request.ConnectionId; + CancellationToken ct = GetConnectionCancellation(connectionId); + requestContext.Request.RegisterForDisconnect(ct); + requestContext.Environment.ConnectionDisconnect = ct; + } + catch (Win32Exception exception) + { + LogHelper.LogException(_logger, "RegisterForDisconnectNotification", exception); + } + } + + private CancellationToken GetConnectionCancellation(ulong connectionId) + { + // Read case is performance senstive + ConnectionCancellation cancellation; + if (!_connectionCancellationTokens.TryGetValue(connectionId, out cancellation)) + { + cancellation = GetCreatedConnectionCancellation(connectionId); + } + return cancellation.GetCancellationToken(connectionId); + } + + private ConnectionCancellation GetCreatedConnectionCancellation(ulong connectionId) + { + // Race condition on creation has no side effects + ConnectionCancellation cancellation = new ConnectionCancellation(this); + return _connectionCancellationTokens.GetOrAdd(connectionId, cancellation); + } + + private unsafe CancellationToken CreateDisconnectToken(ulong connectionId) + { + // Debug.WriteLine("Server: Registering connection for disconnect for connection ID: " + connectionId); + + // Create a nativeOverlapped callback so we can register for disconnect callback + var overlapped = new Overlapped(); + var cts = new CancellationTokenSource(); + + SafeNativeOverlapped nativeOverlapped = null; + nativeOverlapped = new SafeNativeOverlapped(overlapped.UnsafePack( + (errorCode, numBytes, overlappedPtr) => + { + // Debug.WriteLine("Server: http.sys disconnect callback fired for connection ID: " + connectionId); + + // Free the overlapped + nativeOverlapped.Dispose(); + + // Pull the token out of the list and Cancel it. + ConnectionCancellation token; + _connectionCancellationTokens.TryRemove(connectionId, out token); + try + { + cts.Cancel(); + } + catch (AggregateException exception) + { + LogHelper.LogException(_logger, "CreateDisconnectToken::Disconnected", exception); + } + + cts.Dispose(); + }, + null)); + + uint statusCode; + try + { + statusCode = UnsafeNclNativeMethods.HttpApi.HttpWaitForDisconnect(_requestQueueHandle, connectionId, nativeOverlapped); + } + catch (Win32Exception exception) + { + statusCode = (uint)exception.NativeErrorCode; + LogHelper.LogException(_logger, "CreateDisconnectToken", exception); + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + // We got an unknown result so return a None + // TODO: return a canceled token? + return CancellationToken.None; + } + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + // TODO: return a canceled token? + return CancellationToken.None; + } + + return cts.Token; + } + + private unsafe void SendError(ulong requestId, HttpStatusCode httpStatusCode) + { + UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE httpResponse = new UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE(); + httpResponse.Version = new UnsafeNclNativeMethods.HttpApi.HTTP_VERSION(); + httpResponse.Version.MajorVersion = (ushort)1; + httpResponse.Version.MinorVersion = (ushort)1; + httpResponse.StatusCode = (ushort)httpStatusCode; + string statusDescription = HttpReasonPhrase.Get(httpStatusCode); + uint dataWritten = 0; + uint statusCode; + byte[] byteReason = HeaderEncoding.GetBytes(statusDescription); + fixed (byte* pReason = byteReason) + { + httpResponse.pReason = (sbyte*)pReason; + httpResponse.ReasonLength = (ushort)byteReason.Length; + + byte[] byteContentLength = new byte[] { (byte)'0' }; + fixed (byte* pContentLength = byteContentLength) + { + (&httpResponse.Headers.KnownHeaders)[(int)HttpSysResponseHeader.ContentLength].pRawValue = (sbyte*)pContentLength; + (&httpResponse.Headers.KnownHeaders)[(int)HttpSysResponseHeader.ContentLength].RawValueLength = (ushort)byteContentLength.Length; + httpResponse.Headers.UnknownHeaderCount = 0; + + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendHttpResponse( + _requestQueueHandle, + requestId, + 0, + &httpResponse, + null, + &dataWritten, + SafeLocalFree.Zero, + 0, + SafeNativeOverlapped.Zero, + IntPtr.Zero); + } + } + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + // if we fail to send a 401 something's seriously wrong, abort the request + RequestContext.CancelRequest(_requestQueueHandle, requestId); + } + } + + private static int GetTokenOffsetFromBlob(IntPtr blob) + { + Debug.Assert(blob != IntPtr.Zero); +#if NET45 + IntPtr tokenPointer = Marshal.ReadIntPtr(blob, (int)Marshal.OffsetOf(ChannelBindingStatusType, "ChannelToken")); +#else + IntPtr tokenPointer = Marshal.ReadIntPtr(blob, (int)Marshal.OffsetOf("ChannelToken")); +#endif + Debug.Assert(tokenPointer != IntPtr.Zero); + return (int)IntPtrHelper.Subtract(tokenPointer, blob); + } + + private static int GetTokenSizeFromBlob(IntPtr blob) + { + Debug.Assert(blob != IntPtr.Zero); +#if NET45 + return Marshal.ReadInt32(blob, (int)Marshal.OffsetOf(ChannelBindingStatusType, "ChannelTokenSize")); +#else + return Marshal.ReadInt32(blob, (int)Marshal.OffsetOf("ChannelTokenSize")); +#endif + } + + internal ChannelBinding GetChannelBinding(ulong connectionId, bool isSecureConnection) + { + if (!isSecureConnection) + { + LogHelper.LogInfo(_logger, "Channel binding is not supported for HTTP."); + return null; + } + + ChannelBinding result = GetChannelBindingFromTls(connectionId); + + Debug.Assert(result != null, "GetChannelBindingFromTls returned null even though OS supposedly supports Extended Protection"); + LogHelper.LogInfo(_logger, "Channel binding retrieved."); + return result; + } + + private unsafe ChannelBinding GetChannelBindingFromTls(ulong connectionId) + { + // +128 since a CBT is usually <128 thus we need to call HRCC just once. If the CBT + // is >128 we will get ERROR_MORE_DATA and call again + int size = RequestChannelBindStatusSize + 128; + + Debug.Assert(size >= 0); + + byte[] blob = null; + SafeLocalFreeChannelBinding token = null; + + uint bytesReceived = 0; + uint statusCode; + + do + { + blob = new byte[size]; + fixed (byte* blobPtr = blob) + { + // Http.sys team: ServiceName will always be null if + // HTTP_RECEIVE_SECURE_CHANNEL_TOKEN flag is set. + statusCode = UnsafeNclNativeMethods.HttpApi.HttpReceiveClientCertificate( + RequestQueueHandle, + connectionId, + (uint)UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_RECEIVE_SECURE_CHANNEL_TOKEN, + blobPtr, + (uint)size, + &bytesReceived, + SafeNativeOverlapped.Zero); + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + int tokenOffset = GetTokenOffsetFromBlob((IntPtr)blobPtr); + int tokenSize = GetTokenSizeFromBlob((IntPtr)blobPtr); + Debug.Assert(tokenSize < Int32.MaxValue); + + token = SafeLocalFreeChannelBinding.LocalAlloc(tokenSize); + + Marshal.Copy(blob, tokenOffset, token.DangerousGetHandle(), tokenSize); + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_MORE_DATA) + { + int tokenSize = GetTokenSizeFromBlob((IntPtr)blobPtr); + Debug.Assert(tokenSize < Int32.MaxValue); + + size = RequestChannelBindStatusSize + tokenSize; + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_INVALID_PARAMETER) + { + LogHelper.LogError(_logger, "GetChannelBindingFromTls", "Channel binding is not supported."); + return null; // old schannel library which doesn't support CBT + } + else + { + // It's up to the consumer to fail if the missing ChannelBinding matters to them. + LogHelper.LogException(_logger, "GetChannelBindingFromTls", new WebListenerException((int)statusCode)); + break; + } + } + } + while (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS); + + return token; + } + + internal void CheckDisposed() + { + if (_state == State.Disposed) + { + throw new ObjectDisposedException(this.GetType().FullName); + } + } + + private class ConnectionCancellation + { + private readonly OwinWebListener _parent; + private volatile bool _initialized; // Must be volatile because initialization is synchronized + private CancellationToken _cancellationToken; + + public ConnectionCancellation(OwinWebListener parent) + { + _parent = parent; + } + + internal CancellationToken GetCancellationToken(ulong connectionId) + { + // Initialized case is performance sensitive + if (_initialized) + { + return _cancellationToken; + } + return InitializeCancellationToken(connectionId); + } + + private CancellationToken InitializeCancellationToken(ulong connectionId) + { + object syncObject = this; +#pragma warning disable 420 // Disable warning about volatile by reference since EnsureInitialized does volatile operations + return LazyInitializer.EnsureInitialized(ref _cancellationToken, ref _initialized, ref syncObject, () => _parent.CreateDisconnectToken(connectionId)); +#pragma warning restore 420 + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/Prefix.cs b/src/Microsoft.AspNet.Server.WebListener/Prefix.cs new file mode 100644 index 0000000000..08a0fe61c5 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/Prefix.cs @@ -0,0 +1,102 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Globalization; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class Prefix + { + private Prefix(bool isHttps, string scheme, string host, string port, int portValue, string path) + { + IsHttps = isHttps; + Scheme = scheme; + Host = host; + Port = port; + PortValue = portValue; + Path = path; + Whole = string.Format(CultureInfo.InvariantCulture, "{0}://{1}:{2}{3}", Scheme, Host, Port, Path); + } + + /// + /// http://msdn.microsoft.com/en-us/library/windows/desktop/aa364698(v=vs.85).aspx + /// + /// http or https. Will be normalized to lower case. + /// +, *, IPv4, [IPv6], or a dns name. Http.Sys does not permit punycode (xn--), use Unicode instead. + /// If empty, the default port for the given scheme will be used (80 or 443). + /// Should start and end with a '/', though a missing trailing slash will be added. This value must be un-escaped. + public static Prefix Create(string scheme, string host, string port, string path) + { + bool isHttps; + if (string.Equals(Constants.HttpScheme, scheme, StringComparison.OrdinalIgnoreCase)) + { + scheme = Constants.HttpScheme; // Always use a lower case scheme + isHttps = false; + } + else if (string.Equals(Constants.HttpsScheme, scheme, StringComparison.OrdinalIgnoreCase)) + { + scheme = Constants.HttpsScheme; // Always use a lower case scheme + isHttps = true; + } + else + { + throw new ArgumentOutOfRangeException("scheme", scheme, Resources.Exception_UnsupportedScheme); + } + + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentNullException("host"); + } + + int portValue; + if (string.IsNullOrWhiteSpace(port)) + { + port = isHttps ? "443" : "80"; + portValue = isHttps ? 443 : 80; + } + else + { + portValue = int.Parse(port, NumberStyles.None, CultureInfo.InvariantCulture); + } + + // Http.Sys requires the path end with a slash. + if (string.IsNullOrWhiteSpace(path)) + { + path = "/"; + } + else if (!path.EndsWith("/", StringComparison.Ordinal)) + { + path += "/"; + } + + return new Prefix(isHttps, scheme, host, port, portValue, path); + } + + public bool IsHttps { get; private set; } + public string Scheme { get; private set; } + public string Host { get; private set; } + public string Port { get; private set; } + public int PortValue { get; private set; } + public string Path { get; private set; } + public string Whole { get; private set; } + + public override bool Equals(object obj) + { + return string.Equals(Whole, obj as string, StringComparison.OrdinalIgnoreCase); + } + + public override int GetHashCode() + { + return StringComparer.OrdinalIgnoreCase.GetHashCode(Whole); + } + + public override string ToString() + { + return Whole; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/Properties/AssemblyInfo.cs b/src/Microsoft.AspNet.Server.WebListener/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..ed361b4f1c --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/Properties/AssemblyInfo.cs @@ -0,0 +1,44 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("Microsoft.AspNet.Server.WebListener")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.AspNet.Server.WebListener")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("1f471909-581f-4060-a147-430891e9c3c1")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] +[assembly: CLSCompliant(true)] diff --git a/src/Microsoft.AspNet.Server.WebListener/PumpLimits.cs b/src/Microsoft.AspNet.Server.WebListener/PumpLimits.cs new file mode 100644 index 0000000000..28b99c11cc --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/PumpLimits.cs @@ -0,0 +1,31 @@ +// +// Copyright 2011-2012 Katana contributors +// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class PumpLimits + { + internal PumpLimits(int maxAccepts, int maxRequests) + { + MaxOutstandingAccepts = maxAccepts; + MaxOutstandingRequests = maxRequests; + } + + internal int MaxOutstandingAccepts { get; private set; } + + internal int MaxOutstandingRequests { get; private set; } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/BoundaryType.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/BoundaryType.cs new file mode 100644 index 0000000000..6d9ca0b4e2 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/BoundaryType.cs @@ -0,0 +1,18 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum BoundaryType + { + ContentLength = 0, // Content-Length: XXX + Chunked = 1, // Transfer-Encoding: chunked + Raw = 2, // the app is responsible for sending the correct headers and body encoding + Multipart = 3, + None = 4, + Invalid = 5, + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.cs new file mode 100644 index 0000000000..d0fe8d05ea --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.cs @@ -0,0 +1,1913 @@ +//----------------------------------------------------------------------- +// +// Copyright (c) Katana Contributors. All rights reserved. +// +//----------------------------------------------------------------------- +// + +using System; +using System.CodeDom.Compiler; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Security.Authentication.ExtendedProtection; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + using OpaqueUpgrade = Action, Func, Task>>; + + [GeneratedCode("TextTemplatingFileGenerator", "")] + internal partial class CallEnvironment + { + // Mark all fields with delay initialization support as set. + private UInt32 _flag0 = 0x43e00200u; + private UInt32 _flag1 = 0x1u; + // Mark all fields with delay initialization support as requiring initialization. + private UInt32 _initFlag0 = 0x43e00200u; + private UInt32 _initFlag1 = 0x1u; + + internal interface IPropertySource + { + Stream GetRequestBody(); + string GetRemoteIpAddress(); + string GetRemotePort(); + string GetLocalIpAddress(); + string GetLocalPort(); + bool GetIsLocal(); + bool TryGetChannelBinding(ref ChannelBinding value); + bool TryGetOpaqueUpgrade(ref OpaqueUpgrade value); + } + + private string _OwinVersion; + private CancellationToken _CallCancelled; + private string _RequestProtocol; + private string _RequestMethod; + private string _RequestScheme; + private string _RequestPathBase; + private string _RequestPath; + private string _RequestQueryString; + private IDictionary _RequestHeaders; + private Stream _RequestBody; + private IDictionary _ResponseHeaders; + private Stream _ResponseBody; + private int? _ResponseStatusCode; + private string _ResponseReasonPhrase; + private TextWriter _HostTraceOutput; + private string _HostAppName; + private string _HostAppMode; + private CancellationToken _OnAppDisposing; + private System.Security.Principal.IPrincipal _User; + private Action, object> _OnSendingHeaders; + private IDictionary _ServerCapabilities; + private string _RemoteIpAddress; + private string _RemotePort; + private string _LocalIpAddress; + private string _LocalPort; + private bool _IsLocal; + private object _ConnectionId; + private CancellationToken _ConnectionDisconnect; + private object _ClientCert; + private Func _LoadClientCert; + private ChannelBinding _ChannelBinding; + private Func _SendFileAsync; + private OpaqueUpgrade _OpaqueUpgrade; + private OwinWebListener _Listener; + + bool InitPropertyChannelBinding() + { + if (!_propertySource.TryGetChannelBinding(ref _ChannelBinding)) + { + _flag0 &= ~0x40000000u; + _initFlag0 &= ~0x40000000u; + return false; + } + _initFlag0 &= ~0x40000000u; + return true; + } + + bool InitPropertyOpaqueUpgrade() + { + if (!_propertySource.TryGetOpaqueUpgrade(ref _OpaqueUpgrade)) + { + _flag1 &= ~0x1u; + _initFlag1 &= ~0x1u; + return false; + } + _initFlag1 &= ~0x1u; + return true; + } + + internal string OwinVersion + { + get + { + return _OwinVersion; + } + set + { + _flag0 |= 0x1u; + _OwinVersion = value; + } + } + + internal CancellationToken CallCancelled + { + get + { + return _CallCancelled; + } + set + { + _flag0 |= 0x2u; + _CallCancelled = value; + } + } + + internal string RequestProtocol + { + get + { + return _RequestProtocol; + } + set + { + _flag0 |= 0x4u; + _RequestProtocol = value; + } + } + + internal string RequestMethod + { + get + { + return _RequestMethod; + } + set + { + _flag0 |= 0x8u; + _RequestMethod = value; + } + } + + internal string RequestScheme + { + get + { + return _RequestScheme; + } + set + { + _flag0 |= 0x10u; + _RequestScheme = value; + } + } + + internal string RequestPathBase + { + get + { + return _RequestPathBase; + } + set + { + _flag0 |= 0x20u; + _RequestPathBase = value; + } + } + + internal string RequestPath + { + get + { + return _RequestPath; + } + set + { + _flag0 |= 0x40u; + _RequestPath = value; + } + } + + internal string RequestQueryString + { + get + { + return _RequestQueryString; + } + set + { + _flag0 |= 0x80u; + _RequestQueryString = value; + } + } + + internal IDictionary RequestHeaders + { + get + { + return _RequestHeaders; + } + set + { + _flag0 |= 0x100u; + _RequestHeaders = value; + } + } + + internal Stream RequestBody + { + get + { + if (((_initFlag0 & 0x200u) != 0)) + { + _RequestBody = _propertySource.GetRequestBody(); + _initFlag0 &= ~0x200u; + } + return _RequestBody; + } + set + { + _initFlag0 &= ~0x200u; + _flag0 |= 0x200u; + _RequestBody = value; + } + } + + internal IDictionary ResponseHeaders + { + get + { + return _ResponseHeaders; + } + set + { + _flag0 |= 0x400u; + _ResponseHeaders = value; + } + } + + internal Stream ResponseBody + { + get + { + return _ResponseBody; + } + set + { + _flag0 |= 0x800u; + _ResponseBody = value; + } + } + + internal int? ResponseStatusCode + { + get + { + return _ResponseStatusCode; + } + set + { + _flag0 |= 0x1000u; + _ResponseStatusCode = value; + } + } + + internal string ResponseReasonPhrase + { + get + { + return _ResponseReasonPhrase; + } + set + { + _flag0 |= 0x2000u; + _ResponseReasonPhrase = value; + } + } + + internal TextWriter HostTraceOutput + { + get + { + return _HostTraceOutput; + } + set + { + _flag0 |= 0x4000u; + _HostTraceOutput = value; + } + } + + internal string HostAppName + { + get + { + return _HostAppName; + } + set + { + _flag0 |= 0x8000u; + _HostAppName = value; + } + } + + internal string HostAppMode + { + get + { + return _HostAppMode; + } + set + { + _flag0 |= 0x10000u; + _HostAppMode = value; + } + } + + internal CancellationToken OnAppDisposing + { + get + { + return _OnAppDisposing; + } + set + { + _flag0 |= 0x20000u; + _OnAppDisposing = value; + } + } + + internal System.Security.Principal.IPrincipal User + { + get + { + return _User; + } + set + { + _flag0 |= 0x40000u; + _User = value; + } + } + + internal Action, object> OnSendingHeaders + { + get + { + return _OnSendingHeaders; + } + set + { + _flag0 |= 0x80000u; + _OnSendingHeaders = value; + } + } + + internal IDictionary ServerCapabilities + { + get + { + return _ServerCapabilities; + } + set + { + _flag0 |= 0x100000u; + _ServerCapabilities = value; + } + } + + internal string RemoteIpAddress + { + get + { + if (((_initFlag0 & 0x200000u) != 0)) + { + _RemoteIpAddress = _propertySource.GetRemoteIpAddress(); + _initFlag0 &= ~0x200000u; + } + return _RemoteIpAddress; + } + set + { + _initFlag0 &= ~0x200000u; + _flag0 |= 0x200000u; + _RemoteIpAddress = value; + } + } + + internal string RemotePort + { + get + { + if (((_initFlag0 & 0x400000u) != 0)) + { + _RemotePort = _propertySource.GetRemotePort(); + _initFlag0 &= ~0x400000u; + } + return _RemotePort; + } + set + { + _initFlag0 &= ~0x400000u; + _flag0 |= 0x400000u; + _RemotePort = value; + } + } + + internal string LocalIpAddress + { + get + { + if (((_initFlag0 & 0x800000u) != 0)) + { + _LocalIpAddress = _propertySource.GetLocalIpAddress(); + _initFlag0 &= ~0x800000u; + } + return _LocalIpAddress; + } + set + { + _initFlag0 &= ~0x800000u; + _flag0 |= 0x800000u; + _LocalIpAddress = value; + } + } + + internal string LocalPort + { + get + { + if (((_initFlag0 & 0x1000000u) != 0)) + { + _LocalPort = _propertySource.GetLocalPort(); + _initFlag0 &= ~0x1000000u; + } + return _LocalPort; + } + set + { + _initFlag0 &= ~0x1000000u; + _flag0 |= 0x1000000u; + _LocalPort = value; + } + } + + internal bool IsLocal + { + get + { + if (((_initFlag0 & 0x2000000u) != 0)) + { + _IsLocal = _propertySource.GetIsLocal(); + _initFlag0 &= ~0x2000000u; + } + return _IsLocal; + } + set + { + _initFlag0 &= ~0x2000000u; + _flag0 |= 0x2000000u; + _IsLocal = value; + } + } + + internal object ConnectionId + { + get + { + return _ConnectionId; + } + set + { + _flag0 |= 0x4000000u; + _ConnectionId = value; + } + } + + internal CancellationToken ConnectionDisconnect + { + get + { + return _ConnectionDisconnect; + } + set + { + _flag0 |= 0x8000000u; + _ConnectionDisconnect = value; + } + } + + internal object ClientCert + { + get + { + return _ClientCert; + } + set + { + _flag0 |= 0x10000000u; + _ClientCert = value; + } + } + + internal Func LoadClientCert + { + get + { + return _LoadClientCert; + } + set + { + _flag0 |= 0x20000000u; + _LoadClientCert = value; + } + } + + internal ChannelBinding ChannelBinding + { + get + { + if (((_initFlag0 & 0x40000000u) != 0)) + { + InitPropertyChannelBinding(); + } + return _ChannelBinding; + } + set + { + _initFlag0 &= ~0x40000000u; + _flag0 |= 0x40000000u; + _ChannelBinding = value; + } + } + + internal Func SendFileAsync + { + get + { + return _SendFileAsync; + } + set + { + _flag0 |= 0x80000000u; + _SendFileAsync = value; + } + } + + internal OpaqueUpgrade OpaqueUpgrade + { + get + { + if (((_initFlag1 & 0x1u) != 0)) + { + InitPropertyOpaqueUpgrade(); + } + return _OpaqueUpgrade; + } + set + { + _initFlag1 &= ~0x1u; + _flag1 |= 0x1u; + _OpaqueUpgrade = value; + } + } + + internal OwinWebListener Listener + { + get + { + return _Listener; + } + set + { + _flag1 |= 0x2u; + _Listener = value; + } + } + + private bool PropertiesContainsKey(string key) + { + switch (key.Length) + { + case 11: + if (((_flag0 & 0x40000u) != 0) && string.Equals(key, "server.User", StringComparison.Ordinal)) + { + return true; + } + break; + case 12: + if (((_flag0 & 0x1u) != 0) && string.Equals(key, "owin.Version", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x8000u) != 0) && string.Equals(key, "host.AppName", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x10000u) != 0) && string.Equals(key, "host.AppMode", StringComparison.Ordinal)) + { + return true; + } + break; + case 14: + if (((_flag0 & 0x2000000u) != 0) && string.Equals(key, "server.IsLocal", StringComparison.Ordinal)) + { + return true; + } + if (((_flag1 & 0x1u) != 0) && string.Equals(key, "opaque.Upgrade", StringComparison.Ordinal)) + { + if (((_initFlag1 & 0x1u) == 0) || InitPropertyOpaqueUpgrade()) + { + return true; + } + } + break; + case 16: + if (((_flag0 & 0x40u) != 0) && string.Equals(key, "owin.RequestPath", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x200u) != 0) && string.Equals(key, "owin.RequestBody", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x4000u) != 0) && string.Equals(key, "host.TraceOutput", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x1000000u) != 0) && string.Equals(key, "server.LocalPort", StringComparison.Ordinal)) + { + return true; + } + break; + case 17: + if (((_flag0 & 0x800u) != 0) && string.Equals(key, "owin.ResponseBody", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x400000u) != 0) && string.Equals(key, "server.RemotePort", StringComparison.Ordinal)) + { + return true; + } + break; + case 18: + if (((_flag0 & 0x2u) != 0) && string.Equals(key, "owin.CallCancelled", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x8u) != 0) && string.Equals(key, "owin.RequestMethod", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x10u) != 0) && string.Equals(key, "owin.RequestScheme", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x40000000u) != 0) && string.Equals(key, "ssl.ChannelBinding", StringComparison.Ordinal)) + { + if (((_initFlag0 & 0x40000000u) == 0) || InitPropertyChannelBinding()) + { + return true; + } + } + if (((_flag0 & 0x80000000u) != 0) && string.Equals(key, "sendfile.SendAsync", StringComparison.Ordinal)) + { + return true; + } + break; + case 19: + if (((_flag0 & 0x100u) != 0) && string.Equals(key, "owin.RequestHeaders", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x20000u) != 0) && string.Equals(key, "host.OnAppDisposing", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x100000u) != 0) && string.Equals(key, "server.Capabilities", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x4000000u) != 0) && string.Equals(key, "server.ConnectionId", StringComparison.Ordinal)) + { + return true; + } + break; + case 20: + if (((_flag0 & 0x4u) != 0) && string.Equals(key, "owin.RequestProtocol", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x20u) != 0) && string.Equals(key, "owin.RequestPathBase", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x400u) != 0) && string.Equals(key, "owin.ResponseHeaders", StringComparison.Ordinal)) + { + return true; + } + break; + case 21: + if (((_flag0 & 0x800000u) != 0) && string.Equals(key, "server.LocalIpAddress", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x10000000u) != 0) && string.Equals(key, "ssl.ClientCertificate", StringComparison.Ordinal)) + { + return true; + } + break; + case 22: + if (((_flag0 & 0x200000u) != 0) && string.Equals(key, "server.RemoteIpAddress", StringComparison.Ordinal)) + { + return true; + } + break; + case 23: + if (((_flag0 & 0x80u) != 0) && string.Equals(key, "owin.RequestQueryString", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x1000u) != 0) && string.Equals(key, "owin.ResponseStatusCode", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x80000u) != 0) && string.Equals(key, "server.OnSendingHeaders", StringComparison.Ordinal)) + { + return true; + } + if (((_flag0 & 0x20000000u) != 0) && string.Equals(key, "ssl.LoadClientCertAsync", StringComparison.Ordinal)) + { + return true; + } + break; + case 25: + if (((_flag0 & 0x2000u) != 0) && string.Equals(key, "owin.ResponseReasonPhrase", StringComparison.Ordinal)) + { + return true; + } + break; + case 27: + if (((_flag0 & 0x8000000u) != 0) && string.Equals(key, "server.ConnectionDisconnect", StringComparison.Ordinal)) + { + return true; + } + break; + case 51: + if (((_flag1 & 0x2u) != 0) && string.Equals(key, "Microsoft.AspNet.Server.WebListener.OwinWebListener", StringComparison.Ordinal)) + { + return true; + } + break; + } + return false; + } + + private bool PropertiesTryGetValue(string key, out object value) + { + switch (key.Length) + { + case 11: + if (((_flag0 & 0x40000u) != 0) && string.Equals(key, "server.User", StringComparison.Ordinal)) + { + value = User; + return true; + } + break; + case 12: + if (((_flag0 & 0x1u) != 0) && string.Equals(key, "owin.Version", StringComparison.Ordinal)) + { + value = OwinVersion; + return true; + } + if (((_flag0 & 0x8000u) != 0) && string.Equals(key, "host.AppName", StringComparison.Ordinal)) + { + value = HostAppName; + return true; + } + if (((_flag0 & 0x10000u) != 0) && string.Equals(key, "host.AppMode", StringComparison.Ordinal)) + { + value = HostAppMode; + return true; + } + break; + case 14: + if (((_flag0 & 0x2000000u) != 0) && string.Equals(key, "server.IsLocal", StringComparison.Ordinal)) + { + value = IsLocal; + return true; + } + if (((_flag1 & 0x1u) != 0) && string.Equals(key, "opaque.Upgrade", StringComparison.Ordinal)) + { + value = OpaqueUpgrade; + // Delayed initialization in the property getter may determine that the element is not actually present + if (!((_flag1 & 0x1u) != 0)) + { + value = default(OpaqueUpgrade); + return false; + } + return true; + } + break; + case 16: + if (((_flag0 & 0x40u) != 0) && string.Equals(key, "owin.RequestPath", StringComparison.Ordinal)) + { + value = RequestPath; + return true; + } + if (((_flag0 & 0x200u) != 0) && string.Equals(key, "owin.RequestBody", StringComparison.Ordinal)) + { + value = RequestBody; + return true; + } + if (((_flag0 & 0x4000u) != 0) && string.Equals(key, "host.TraceOutput", StringComparison.Ordinal)) + { + value = HostTraceOutput; + return true; + } + if (((_flag0 & 0x1000000u) != 0) && string.Equals(key, "server.LocalPort", StringComparison.Ordinal)) + { + value = LocalPort; + return true; + } + break; + case 17: + if (((_flag0 & 0x800u) != 0) && string.Equals(key, "owin.ResponseBody", StringComparison.Ordinal)) + { + value = ResponseBody; + return true; + } + if (((_flag0 & 0x400000u) != 0) && string.Equals(key, "server.RemotePort", StringComparison.Ordinal)) + { + value = RemotePort; + return true; + } + break; + case 18: + if (((_flag0 & 0x2u) != 0) && string.Equals(key, "owin.CallCancelled", StringComparison.Ordinal)) + { + value = CallCancelled; + return true; + } + if (((_flag0 & 0x8u) != 0) && string.Equals(key, "owin.RequestMethod", StringComparison.Ordinal)) + { + value = RequestMethod; + return true; + } + if (((_flag0 & 0x10u) != 0) && string.Equals(key, "owin.RequestScheme", StringComparison.Ordinal)) + { + value = RequestScheme; + return true; + } + if (((_flag0 & 0x40000000u) != 0) && string.Equals(key, "ssl.ChannelBinding", StringComparison.Ordinal)) + { + value = ChannelBinding; + // Delayed initialization in the property getter may determine that the element is not actually present + if (!((_flag0 & 0x40000000u) != 0)) + { + value = default(ChannelBinding); + return false; + } + return true; + } + if (((_flag0 & 0x80000000u) != 0) && string.Equals(key, "sendfile.SendAsync", StringComparison.Ordinal)) + { + value = SendFileAsync; + return true; + } + break; + case 19: + if (((_flag0 & 0x100u) != 0) && string.Equals(key, "owin.RequestHeaders", StringComparison.Ordinal)) + { + value = RequestHeaders; + return true; + } + if (((_flag0 & 0x20000u) != 0) && string.Equals(key, "host.OnAppDisposing", StringComparison.Ordinal)) + { + value = OnAppDisposing; + return true; + } + if (((_flag0 & 0x100000u) != 0) && string.Equals(key, "server.Capabilities", StringComparison.Ordinal)) + { + value = ServerCapabilities; + return true; + } + if (((_flag0 & 0x4000000u) != 0) && string.Equals(key, "server.ConnectionId", StringComparison.Ordinal)) + { + value = ConnectionId; + return true; + } + break; + case 20: + if (((_flag0 & 0x4u) != 0) && string.Equals(key, "owin.RequestProtocol", StringComparison.Ordinal)) + { + value = RequestProtocol; + return true; + } + if (((_flag0 & 0x20u) != 0) && string.Equals(key, "owin.RequestPathBase", StringComparison.Ordinal)) + { + value = RequestPathBase; + return true; + } + if (((_flag0 & 0x400u) != 0) && string.Equals(key, "owin.ResponseHeaders", StringComparison.Ordinal)) + { + value = ResponseHeaders; + return true; + } + break; + case 21: + if (((_flag0 & 0x800000u) != 0) && string.Equals(key, "server.LocalIpAddress", StringComparison.Ordinal)) + { + value = LocalIpAddress; + return true; + } + if (((_flag0 & 0x10000000u) != 0) && string.Equals(key, "ssl.ClientCertificate", StringComparison.Ordinal)) + { + value = ClientCert; + return true; + } + break; + case 22: + if (((_flag0 & 0x200000u) != 0) && string.Equals(key, "server.RemoteIpAddress", StringComparison.Ordinal)) + { + value = RemoteIpAddress; + return true; + } + break; + case 23: + if (((_flag0 & 0x80u) != 0) && string.Equals(key, "owin.RequestQueryString", StringComparison.Ordinal)) + { + value = RequestQueryString; + return true; + } + if (((_flag0 & 0x1000u) != 0) && string.Equals(key, "owin.ResponseStatusCode", StringComparison.Ordinal)) + { + value = ResponseStatusCode; + return true; + } + if (((_flag0 & 0x80000u) != 0) && string.Equals(key, "server.OnSendingHeaders", StringComparison.Ordinal)) + { + value = OnSendingHeaders; + return true; + } + if (((_flag0 & 0x20000000u) != 0) && string.Equals(key, "ssl.LoadClientCertAsync", StringComparison.Ordinal)) + { + value = LoadClientCert; + return true; + } + break; + case 25: + if (((_flag0 & 0x2000u) != 0) && string.Equals(key, "owin.ResponseReasonPhrase", StringComparison.Ordinal)) + { + value = ResponseReasonPhrase; + return true; + } + break; + case 27: + if (((_flag0 & 0x8000000u) != 0) && string.Equals(key, "server.ConnectionDisconnect", StringComparison.Ordinal)) + { + value = ConnectionDisconnect; + return true; + } + break; + case 51: + if (((_flag1 & 0x2u) != 0) && string.Equals(key, "Microsoft.AspNet.Server.WebListener.OwinWebListener", StringComparison.Ordinal)) + { + value = Listener; + return true; + } + break; + } + value = null; + return false; + } + + private bool PropertiesTrySetValue(string key, object value) + { + switch (key.Length) + { + case 11: + if (string.Equals(key, "server.User", StringComparison.Ordinal)) + { + User = (System.Security.Principal.IPrincipal)value; + return true; + } + break; + case 12: + if (string.Equals(key, "owin.Version", StringComparison.Ordinal)) + { + OwinVersion = (string)value; + return true; + } + if (string.Equals(key, "host.AppName", StringComparison.Ordinal)) + { + HostAppName = (string)value; + return true; + } + if (string.Equals(key, "host.AppMode", StringComparison.Ordinal)) + { + HostAppMode = (string)value; + return true; + } + break; + case 14: + if (string.Equals(key, "server.IsLocal", StringComparison.Ordinal)) + { + IsLocal = (bool)value; + return true; + } + if (string.Equals(key, "opaque.Upgrade", StringComparison.Ordinal)) + { + OpaqueUpgrade = (OpaqueUpgrade)value; + return true; + } + break; + case 16: + if (string.Equals(key, "owin.RequestPath", StringComparison.Ordinal)) + { + RequestPath = (string)value; + return true; + } + if (string.Equals(key, "owin.RequestBody", StringComparison.Ordinal)) + { + RequestBody = (Stream)value; + return true; + } + if (string.Equals(key, "host.TraceOutput", StringComparison.Ordinal)) + { + HostTraceOutput = (TextWriter)value; + return true; + } + if (string.Equals(key, "server.LocalPort", StringComparison.Ordinal)) + { + LocalPort = (string)value; + return true; + } + break; + case 17: + if (string.Equals(key, "owin.ResponseBody", StringComparison.Ordinal)) + { + ResponseBody = (Stream)value; + return true; + } + if (string.Equals(key, "server.RemotePort", StringComparison.Ordinal)) + { + RemotePort = (string)value; + return true; + } + break; + case 18: + if (string.Equals(key, "owin.CallCancelled", StringComparison.Ordinal)) + { + CallCancelled = (CancellationToken)value; + return true; + } + if (string.Equals(key, "owin.RequestMethod", StringComparison.Ordinal)) + { + RequestMethod = (string)value; + return true; + } + if (string.Equals(key, "owin.RequestScheme", StringComparison.Ordinal)) + { + RequestScheme = (string)value; + return true; + } + if (string.Equals(key, "ssl.ChannelBinding", StringComparison.Ordinal)) + { + ChannelBinding = (ChannelBinding)value; + return true; + } + if (string.Equals(key, "sendfile.SendAsync", StringComparison.Ordinal)) + { + SendFileAsync = (Func)value; + return true; + } + break; + case 19: + if (string.Equals(key, "owin.RequestHeaders", StringComparison.Ordinal)) + { + RequestHeaders = (IDictionary)value; + return true; + } + if (string.Equals(key, "host.OnAppDisposing", StringComparison.Ordinal)) + { + OnAppDisposing = (CancellationToken)value; + return true; + } + if (string.Equals(key, "server.Capabilities", StringComparison.Ordinal)) + { + ServerCapabilities = (IDictionary)value; + return true; + } + if (string.Equals(key, "server.ConnectionId", StringComparison.Ordinal)) + { + ConnectionId = (object)value; + return true; + } + break; + case 20: + if (string.Equals(key, "owin.RequestProtocol", StringComparison.Ordinal)) + { + RequestProtocol = (string)value; + return true; + } + if (string.Equals(key, "owin.RequestPathBase", StringComparison.Ordinal)) + { + RequestPathBase = (string)value; + return true; + } + if (string.Equals(key, "owin.ResponseHeaders", StringComparison.Ordinal)) + { + ResponseHeaders = (IDictionary)value; + return true; + } + break; + case 21: + if (string.Equals(key, "server.LocalIpAddress", StringComparison.Ordinal)) + { + LocalIpAddress = (string)value; + return true; + } + if (string.Equals(key, "ssl.ClientCertificate", StringComparison.Ordinal)) + { + ClientCert = (object)value; + return true; + } + break; + case 22: + if (string.Equals(key, "server.RemoteIpAddress", StringComparison.Ordinal)) + { + RemoteIpAddress = (string)value; + return true; + } + break; + case 23: + if (string.Equals(key, "owin.RequestQueryString", StringComparison.Ordinal)) + { + RequestQueryString = (string)value; + return true; + } + if (string.Equals(key, "owin.ResponseStatusCode", StringComparison.Ordinal)) + { + ResponseStatusCode = (int?)value; + return true; + } + if (string.Equals(key, "server.OnSendingHeaders", StringComparison.Ordinal)) + { + OnSendingHeaders = (Action, object>)value; + return true; + } + if (string.Equals(key, "ssl.LoadClientCertAsync", StringComparison.Ordinal)) + { + LoadClientCert = (Func)value; + return true; + } + break; + case 25: + if (string.Equals(key, "owin.ResponseReasonPhrase", StringComparison.Ordinal)) + { + ResponseReasonPhrase = (string)value; + return true; + } + break; + case 27: + if (string.Equals(key, "server.ConnectionDisconnect", StringComparison.Ordinal)) + { + ConnectionDisconnect = (CancellationToken)value; + return true; + } + break; + case 51: + if (string.Equals(key, "Microsoft.AspNet.Server.WebListener.OwinWebListener", StringComparison.Ordinal)) + { + Listener = (OwinWebListener)value; + return true; + } + break; + } + return false; + } + + private bool PropertiesTryRemove(string key) + { + switch (key.Length) + { + case 11: + if (((_flag0 & 0x40000u) != 0) && string.Equals(key, "server.User", StringComparison.Ordinal)) + { + _flag0 &= ~0x40000u; + _User = default(System.Security.Principal.IPrincipal); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 12: + if (((_flag0 & 0x1u) != 0) && string.Equals(key, "owin.Version", StringComparison.Ordinal)) + { + _flag0 &= ~0x1u; + _OwinVersion = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x8000u) != 0) && string.Equals(key, "host.AppName", StringComparison.Ordinal)) + { + _flag0 &= ~0x8000u; + _HostAppName = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x10000u) != 0) && string.Equals(key, "host.AppMode", StringComparison.Ordinal)) + { + _flag0 &= ~0x10000u; + _HostAppMode = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 14: + if (((_flag0 & 0x2000000u) != 0) && string.Equals(key, "server.IsLocal", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x2000000u; + _flag0 &= ~0x2000000u; + _IsLocal = default(bool); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag1 & 0x1u) != 0) && string.Equals(key, "opaque.Upgrade", StringComparison.Ordinal)) + { + _initFlag1 &= ~0x1u; + _flag1 &= ~0x1u; + _OpaqueUpgrade = default(OpaqueUpgrade); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 16: + if (((_flag0 & 0x40u) != 0) && string.Equals(key, "owin.RequestPath", StringComparison.Ordinal)) + { + _flag0 &= ~0x40u; + _RequestPath = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x200u) != 0) && string.Equals(key, "owin.RequestBody", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x200u; + _flag0 &= ~0x200u; + _RequestBody = default(Stream); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x4000u) != 0) && string.Equals(key, "host.TraceOutput", StringComparison.Ordinal)) + { + _flag0 &= ~0x4000u; + _HostTraceOutput = default(TextWriter); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x1000000u) != 0) && string.Equals(key, "server.LocalPort", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x1000000u; + _flag0 &= ~0x1000000u; + _LocalPort = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 17: + if (((_flag0 & 0x800u) != 0) && string.Equals(key, "owin.ResponseBody", StringComparison.Ordinal)) + { + _flag0 &= ~0x800u; + _ResponseBody = default(Stream); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x400000u) != 0) && string.Equals(key, "server.RemotePort", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x400000u; + _flag0 &= ~0x400000u; + _RemotePort = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 18: + if (((_flag0 & 0x2u) != 0) && string.Equals(key, "owin.CallCancelled", StringComparison.Ordinal)) + { + _flag0 &= ~0x2u; + _CallCancelled = default(CancellationToken); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x8u) != 0) && string.Equals(key, "owin.RequestMethod", StringComparison.Ordinal)) + { + _flag0 &= ~0x8u; + _RequestMethod = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x10u) != 0) && string.Equals(key, "owin.RequestScheme", StringComparison.Ordinal)) + { + _flag0 &= ~0x10u; + _RequestScheme = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x40000000u) != 0) && string.Equals(key, "ssl.ChannelBinding", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x40000000u; + _flag0 &= ~0x40000000u; + _ChannelBinding = default(ChannelBinding); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x80000000u) != 0) && string.Equals(key, "sendfile.SendAsync", StringComparison.Ordinal)) + { + _flag0 &= ~0x80000000u; + _SendFileAsync = default(Func); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 19: + if (((_flag0 & 0x100u) != 0) && string.Equals(key, "owin.RequestHeaders", StringComparison.Ordinal)) + { + _flag0 &= ~0x100u; + _RequestHeaders = default(IDictionary); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x20000u) != 0) && string.Equals(key, "host.OnAppDisposing", StringComparison.Ordinal)) + { + _flag0 &= ~0x20000u; + _OnAppDisposing = default(CancellationToken); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x100000u) != 0) && string.Equals(key, "server.Capabilities", StringComparison.Ordinal)) + { + _flag0 &= ~0x100000u; + _ServerCapabilities = default(IDictionary); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x4000000u) != 0) && string.Equals(key, "server.ConnectionId", StringComparison.Ordinal)) + { + _flag0 &= ~0x4000000u; + _ConnectionId = default(object); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 20: + if (((_flag0 & 0x4u) != 0) && string.Equals(key, "owin.RequestProtocol", StringComparison.Ordinal)) + { + _flag0 &= ~0x4u; + _RequestProtocol = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x20u) != 0) && string.Equals(key, "owin.RequestPathBase", StringComparison.Ordinal)) + { + _flag0 &= ~0x20u; + _RequestPathBase = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x400u) != 0) && string.Equals(key, "owin.ResponseHeaders", StringComparison.Ordinal)) + { + _flag0 &= ~0x400u; + _ResponseHeaders = default(IDictionary); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 21: + if (((_flag0 & 0x800000u) != 0) && string.Equals(key, "server.LocalIpAddress", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x800000u; + _flag0 &= ~0x800000u; + _LocalIpAddress = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x10000000u) != 0) && string.Equals(key, "ssl.ClientCertificate", StringComparison.Ordinal)) + { + _flag0 &= ~0x10000000u; + _ClientCert = default(object); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 22: + if (((_flag0 & 0x200000u) != 0) && string.Equals(key, "server.RemoteIpAddress", StringComparison.Ordinal)) + { + _initFlag0 &= ~0x200000u; + _flag0 &= ~0x200000u; + _RemoteIpAddress = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 23: + if (((_flag0 & 0x80u) != 0) && string.Equals(key, "owin.RequestQueryString", StringComparison.Ordinal)) + { + _flag0 &= ~0x80u; + _RequestQueryString = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x1000u) != 0) && string.Equals(key, "owin.ResponseStatusCode", StringComparison.Ordinal)) + { + _flag0 &= ~0x1000u; + _ResponseStatusCode = default(int?); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x80000u) != 0) && string.Equals(key, "server.OnSendingHeaders", StringComparison.Ordinal)) + { + _flag0 &= ~0x80000u; + _OnSendingHeaders = default(Action, object>); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + if (((_flag0 & 0x20000000u) != 0) && string.Equals(key, "ssl.LoadClientCertAsync", StringComparison.Ordinal)) + { + _flag0 &= ~0x20000000u; + _LoadClientCert = default(Func); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 25: + if (((_flag0 & 0x2000u) != 0) && string.Equals(key, "owin.ResponseReasonPhrase", StringComparison.Ordinal)) + { + _flag0 &= ~0x2000u; + _ResponseReasonPhrase = default(string); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 27: + if (((_flag0 & 0x8000000u) != 0) && string.Equals(key, "server.ConnectionDisconnect", StringComparison.Ordinal)) + { + _flag0 &= ~0x8000000u; + _ConnectionDisconnect = default(CancellationToken); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + case 51: + if (((_flag1 & 0x2u) != 0) && string.Equals(key, "Microsoft.AspNet.Server.WebListener.OwinWebListener", StringComparison.Ordinal)) + { + _flag1 &= ~0x2u; + _Listener = default(OwinWebListener); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } + break; + } + return false; + } + + private IEnumerable PropertiesKeys() + { + if (((_flag0 & 0x1u) != 0)) + { + yield return "owin.Version"; + } + if (((_flag0 & 0x2u) != 0)) + { + yield return "owin.CallCancelled"; + } + if (((_flag0 & 0x4u) != 0)) + { + yield return "owin.RequestProtocol"; + } + if (((_flag0 & 0x8u) != 0)) + { + yield return "owin.RequestMethod"; + } + if (((_flag0 & 0x10u) != 0)) + { + yield return "owin.RequestScheme"; + } + if (((_flag0 & 0x20u) != 0)) + { + yield return "owin.RequestPathBase"; + } + if (((_flag0 & 0x40u) != 0)) + { + yield return "owin.RequestPath"; + } + if (((_flag0 & 0x80u) != 0)) + { + yield return "owin.RequestQueryString"; + } + if (((_flag0 & 0x100u) != 0)) + { + yield return "owin.RequestHeaders"; + } + if (((_flag0 & 0x200u) != 0)) + { + yield return "owin.RequestBody"; + } + if (((_flag0 & 0x400u) != 0)) + { + yield return "owin.ResponseHeaders"; + } + if (((_flag0 & 0x800u) != 0)) + { + yield return "owin.ResponseBody"; + } + if (((_flag0 & 0x1000u) != 0)) + { + yield return "owin.ResponseStatusCode"; + } + if (((_flag0 & 0x2000u) != 0)) + { + yield return "owin.ResponseReasonPhrase"; + } + if (((_flag0 & 0x4000u) != 0)) + { + yield return "host.TraceOutput"; + } + if (((_flag0 & 0x8000u) != 0)) + { + yield return "host.AppName"; + } + if (((_flag0 & 0x10000u) != 0)) + { + yield return "host.AppMode"; + } + if (((_flag0 & 0x20000u) != 0)) + { + yield return "host.OnAppDisposing"; + } + if (((_flag0 & 0x40000u) != 0)) + { + yield return "server.User"; + } + if (((_flag0 & 0x80000u) != 0)) + { + yield return "server.OnSendingHeaders"; + } + if (((_flag0 & 0x100000u) != 0)) + { + yield return "server.Capabilities"; + } + if (((_flag0 & 0x200000u) != 0)) + { + yield return "server.RemoteIpAddress"; + } + if (((_flag0 & 0x400000u) != 0)) + { + yield return "server.RemotePort"; + } + if (((_flag0 & 0x800000u) != 0)) + { + yield return "server.LocalIpAddress"; + } + if (((_flag0 & 0x1000000u) != 0)) + { + yield return "server.LocalPort"; + } + if (((_flag0 & 0x2000000u) != 0)) + { + yield return "server.IsLocal"; + } + if (((_flag0 & 0x4000000u) != 0)) + { + yield return "server.ConnectionId"; + } + if (((_flag0 & 0x8000000u) != 0)) + { + yield return "server.ConnectionDisconnect"; + } + if (((_flag0 & 0x10000000u) != 0)) + { + yield return "ssl.ClientCertificate"; + } + if (((_flag0 & 0x20000000u) != 0)) + { + yield return "ssl.LoadClientCertAsync"; + } + if (((_flag0 & 0x40000000u) != 0)) + { + if (((_initFlag0 & 0x40000000u) == 0) || InitPropertyChannelBinding()) + { + yield return "ssl.ChannelBinding"; + } + } + if (((_flag0 & 0x80000000u) != 0)) + { + yield return "sendfile.SendAsync"; + } + if (((_flag1 & 0x1u) != 0)) + { + if (((_initFlag1 & 0x1u) == 0) || InitPropertyOpaqueUpgrade()) + { + yield return "opaque.Upgrade"; + } + } + if (((_flag1 & 0x2u) != 0)) + { + yield return "Microsoft.AspNet.Server.WebListener.OwinWebListener"; + } + } + + private IEnumerable PropertiesValues() + { + if (((_flag0 & 0x1u) != 0)) + { + yield return OwinVersion; + } + if (((_flag0 & 0x2u) != 0)) + { + yield return CallCancelled; + } + if (((_flag0 & 0x4u) != 0)) + { + yield return RequestProtocol; + } + if (((_flag0 & 0x8u) != 0)) + { + yield return RequestMethod; + } + if (((_flag0 & 0x10u) != 0)) + { + yield return RequestScheme; + } + if (((_flag0 & 0x20u) != 0)) + { + yield return RequestPathBase; + } + if (((_flag0 & 0x40u) != 0)) + { + yield return RequestPath; + } + if (((_flag0 & 0x80u) != 0)) + { + yield return RequestQueryString; + } + if (((_flag0 & 0x100u) != 0)) + { + yield return RequestHeaders; + } + if (((_flag0 & 0x200u) != 0)) + { + yield return RequestBody; + } + if (((_flag0 & 0x400u) != 0)) + { + yield return ResponseHeaders; + } + if (((_flag0 & 0x800u) != 0)) + { + yield return ResponseBody; + } + if (((_flag0 & 0x1000u) != 0)) + { + yield return ResponseStatusCode; + } + if (((_flag0 & 0x2000u) != 0)) + { + yield return ResponseReasonPhrase; + } + if (((_flag0 & 0x4000u) != 0)) + { + yield return HostTraceOutput; + } + if (((_flag0 & 0x8000u) != 0)) + { + yield return HostAppName; + } + if (((_flag0 & 0x10000u) != 0)) + { + yield return HostAppMode; + } + if (((_flag0 & 0x20000u) != 0)) + { + yield return OnAppDisposing; + } + if (((_flag0 & 0x40000u) != 0)) + { + yield return User; + } + if (((_flag0 & 0x80000u) != 0)) + { + yield return OnSendingHeaders; + } + if (((_flag0 & 0x100000u) != 0)) + { + yield return ServerCapabilities; + } + if (((_flag0 & 0x200000u) != 0)) + { + yield return RemoteIpAddress; + } + if (((_flag0 & 0x400000u) != 0)) + { + yield return RemotePort; + } + if (((_flag0 & 0x800000u) != 0)) + { + yield return LocalIpAddress; + } + if (((_flag0 & 0x1000000u) != 0)) + { + yield return LocalPort; + } + if (((_flag0 & 0x2000000u) != 0)) + { + yield return IsLocal; + } + if (((_flag0 & 0x4000000u) != 0)) + { + yield return ConnectionId; + } + if (((_flag0 & 0x8000000u) != 0)) + { + yield return ConnectionDisconnect; + } + if (((_flag0 & 0x10000000u) != 0)) + { + yield return ClientCert; + } + if (((_flag0 & 0x20000000u) != 0)) + { + yield return LoadClientCert; + } + if (((_flag0 & 0x40000000u) != 0)) + { + if (((_initFlag0 & 0x40000000u) == 0) || InitPropertyChannelBinding()) + { + yield return ChannelBinding; + } + } + if (((_flag0 & 0x80000000u) != 0)) + { + yield return SendFileAsync; + } + if (((_flag1 & 0x1u) != 0)) + { + if (((_initFlag1 & 0x1u) == 0) || InitPropertyOpaqueUpgrade()) + { + yield return OpaqueUpgrade; + } + } + if (((_flag1 & 0x2u) != 0)) + { + yield return Listener; + } + } + + private IEnumerable> PropertiesEnumerable() + { + if (((_flag0 & 0x1u) != 0)) + { + yield return new KeyValuePair("owin.Version", OwinVersion); + } + if (((_flag0 & 0x2u) != 0)) + { + yield return new KeyValuePair("owin.CallCancelled", CallCancelled); + } + if (((_flag0 & 0x4u) != 0)) + { + yield return new KeyValuePair("owin.RequestProtocol", RequestProtocol); + } + if (((_flag0 & 0x8u) != 0)) + { + yield return new KeyValuePair("owin.RequestMethod", RequestMethod); + } + if (((_flag0 & 0x10u) != 0)) + { + yield return new KeyValuePair("owin.RequestScheme", RequestScheme); + } + if (((_flag0 & 0x20u) != 0)) + { + yield return new KeyValuePair("owin.RequestPathBase", RequestPathBase); + } + if (((_flag0 & 0x40u) != 0)) + { + yield return new KeyValuePair("owin.RequestPath", RequestPath); + } + if (((_flag0 & 0x80u) != 0)) + { + yield return new KeyValuePair("owin.RequestQueryString", RequestQueryString); + } + if (((_flag0 & 0x100u) != 0)) + { + yield return new KeyValuePair("owin.RequestHeaders", RequestHeaders); + } + if (((_flag0 & 0x200u) != 0)) + { + yield return new KeyValuePair("owin.RequestBody", RequestBody); + } + if (((_flag0 & 0x400u) != 0)) + { + yield return new KeyValuePair("owin.ResponseHeaders", ResponseHeaders); + } + if (((_flag0 & 0x800u) != 0)) + { + yield return new KeyValuePair("owin.ResponseBody", ResponseBody); + } + if (((_flag0 & 0x1000u) != 0)) + { + yield return new KeyValuePair("owin.ResponseStatusCode", ResponseStatusCode); + } + if (((_flag0 & 0x2000u) != 0)) + { + yield return new KeyValuePair("owin.ResponseReasonPhrase", ResponseReasonPhrase); + } + if (((_flag0 & 0x4000u) != 0)) + { + yield return new KeyValuePair("host.TraceOutput", HostTraceOutput); + } + if (((_flag0 & 0x8000u) != 0)) + { + yield return new KeyValuePair("host.AppName", HostAppName); + } + if (((_flag0 & 0x10000u) != 0)) + { + yield return new KeyValuePair("host.AppMode", HostAppMode); + } + if (((_flag0 & 0x20000u) != 0)) + { + yield return new KeyValuePair("host.OnAppDisposing", OnAppDisposing); + } + if (((_flag0 & 0x40000u) != 0)) + { + yield return new KeyValuePair("server.User", User); + } + if (((_flag0 & 0x80000u) != 0)) + { + yield return new KeyValuePair("server.OnSendingHeaders", OnSendingHeaders); + } + if (((_flag0 & 0x100000u) != 0)) + { + yield return new KeyValuePair("server.Capabilities", ServerCapabilities); + } + if (((_flag0 & 0x200000u) != 0)) + { + yield return new KeyValuePair("server.RemoteIpAddress", RemoteIpAddress); + } + if (((_flag0 & 0x400000u) != 0)) + { + yield return new KeyValuePair("server.RemotePort", RemotePort); + } + if (((_flag0 & 0x800000u) != 0)) + { + yield return new KeyValuePair("server.LocalIpAddress", LocalIpAddress); + } + if (((_flag0 & 0x1000000u) != 0)) + { + yield return new KeyValuePair("server.LocalPort", LocalPort); + } + if (((_flag0 & 0x2000000u) != 0)) + { + yield return new KeyValuePair("server.IsLocal", IsLocal); + } + if (((_flag0 & 0x4000000u) != 0)) + { + yield return new KeyValuePair("server.ConnectionId", ConnectionId); + } + if (((_flag0 & 0x8000000u) != 0)) + { + yield return new KeyValuePair("server.ConnectionDisconnect", ConnectionDisconnect); + } + if (((_flag0 & 0x10000000u) != 0)) + { + yield return new KeyValuePair("ssl.ClientCertificate", ClientCert); + } + if (((_flag0 & 0x20000000u) != 0)) + { + yield return new KeyValuePair("ssl.LoadClientCertAsync", LoadClientCert); + } + if (((_flag0 & 0x40000000u) != 0)) + { + if (((_initFlag0 & 0x40000000u) == 0) || InitPropertyChannelBinding()) + { + yield return new KeyValuePair("ssl.ChannelBinding", ChannelBinding); + } + } + if (((_flag0 & 0x80000000u) != 0)) + { + yield return new KeyValuePair("sendfile.SendAsync", SendFileAsync); + } + if (((_flag1 & 0x1u) != 0)) + { + if (((_initFlag1 & 0x1u) == 0) || InitPropertyOpaqueUpgrade()) + { + yield return new KeyValuePair("opaque.Upgrade", OpaqueUpgrade); + } + } + if (((_flag1 & 0x2u) != 0)) + { + yield return new KeyValuePair("Microsoft.AspNet.Server.WebListener.OwinWebListener", Listener); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.tt b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.tt new file mode 100644 index 0000000000..41a1b3db06 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.Generated.tt @@ -0,0 +1,316 @@ +<#@ template language="C#" #> +<#@ assembly name="System.Core.dll" #> +<#@ import namespace="System.Linq" #> +<# +var Init = new {Yes = new object(), No = new object(), Maybe = new object()}; + +var props = new[] +{ +// owin standard keys + new {Key="owin.Version", Type="string", Name="OwinVersion", Init=Init.No}, + new {Key="owin.CallCancelled", Type="CancellationToken", Name="CallCancelled", Init=Init.No}, + + new {Key="owin.RequestProtocol", Type="string", Name="RequestProtocol", Init=Init.No}, + new {Key="owin.RequestMethod", Type="string", Name="RequestMethod", Init=Init.No}, + new {Key="owin.RequestScheme", Type="string", Name="RequestScheme", Init=Init.No}, + new {Key="owin.RequestPathBase", Type="string", Name="RequestPathBase", Init=Init.No}, + new {Key="owin.RequestPath", Type="string", Name="RequestPath", Init=Init.No}, + new {Key="owin.RequestQueryString", Type="string", Name="RequestQueryString", Init=Init.No}, + new {Key="owin.RequestHeaders", Type="IDictionary", Name="RequestHeaders", Init=Init.No}, + new {Key="owin.RequestBody", Type="Stream", Name="RequestBody", Init=Init.Yes}, + + new {Key="owin.ResponseHeaders", Type="IDictionary", Name="ResponseHeaders", Init=Init.No}, + new {Key="owin.ResponseBody", Type="Stream", Name="ResponseBody", Init=Init.No}, + new {Key="owin.ResponseStatusCode", Type="int?", Name="ResponseStatusCode", Init=Init.No}, + new {Key="owin.ResponseReasonPhrase", Type="string", Name="ResponseReasonPhrase", Init=Init.No}, + +// defacto host keys + new {Key="host.TraceOutput", Type="TextWriter", Name="HostTraceOutput", Init=Init.No}, + new {Key="host.AppName", Type="string", Name="HostAppName", Init=Init.No}, + new {Key="host.AppMode", Type="string", Name="HostAppMode", Init=Init.No}, + new {Key="host.OnAppDisposing", Type="CancellationToken", Name="OnAppDisposing", Init=Init.No}, + new {Key="server.User", Type="System.Security.Principal.IPrincipal", Name="User", Init=Init.No}, + new {Key="server.OnSendingHeaders", Type="Action, object>", Name="OnSendingHeaders", Init=Init.No}, + new {Key="server.Capabilities", Type="IDictionary", Name="ServerCapabilities", Init=Init.No}, + +// ServerVariable keys + new {Key="server.RemoteIpAddress", Type="string", Name="RemoteIpAddress", Init=Init.Yes}, + new {Key="server.RemotePort", Type="string", Name="RemotePort", Init=Init.Yes}, + new {Key="server.LocalIpAddress", Type="string", Name="LocalIpAddress", Init=Init.Yes}, + new {Key="server.LocalPort", Type="string", Name="LocalPort", Init=Init.Yes}, + new {Key="server.IsLocal", Type="bool", Name="IsLocal", Init=Init.Yes}, + new {Key="server.ConnectionId", Type="object", Name="ConnectionId", Init=Init.No}, + new {Key="server.ConnectionDisconnect", Type="CancellationToken", Name="ConnectionDisconnect", Init=Init.No}, + +// SSL + new { Key="ssl.ClientCertificate", Type="object", Name="ClientCert", Init=Init.No}, + new { Key="ssl.LoadClientCertAsync", Type="Func", Name="LoadClientCert", Init=Init.No }, + new { Key="ssl.ChannelBinding", Type="ChannelBinding", Name="ChannelBinding", Init=Init.Maybe }, + +// SendFile keys + new {Key="sendfile.SendAsync", Type="Func", Name="SendFileAsync", Init=Init.No}, + +// Opaque keys + new {Key="opaque.Upgrade", Type="OpaqueUpgrade", Name="OpaqueUpgrade", Init=Init.Maybe}, + +// Server specific keys + new { Key="Microsoft.AspNet.Server.WebListener.OwinWebListener", Type="OwinWebListener", Name="Listener", Init=Init.No}, +}.Select((prop, Index)=>new {prop.Key, prop.Type, prop.Name, prop.Init, Index}); + +var lengths = props.OrderBy(prop=>prop.Key.Length).GroupBy(prop=>prop.Key.Length); + +Func IsSet = Index => "((_flag" + (Index / 32) + " & 0x" + (1<<(Index % 32)).ToString("x") + "u) != 0)"; +Func Set = Index => "_flag" + (Index / 32) + " |= 0x" + (1<<(Index % 32)).ToString("x") + "u"; +Func Clear = Index => "_flag" + (Index / 32) + " &= ~0x" + (1<<(Index % 32)).ToString("x") + "u"; + +Func IsInitRequired = Index => "((_initFlag" + (Index / 32) + " & 0x" + (1<<(Index % 32)).ToString("x") + "u) != 0)"; +Func IsInitCompleted = Index => "((_initFlag" + (Index / 32) + " & 0x" + (1<<(Index % 32)).ToString("x") + "u) == 0)"; +Func CompleteInit = Index => "_initFlag" + (Index / 32) + " &= ~0x" + (1<<(Index % 32)).ToString("x") + "u"; + +#> +//----------------------------------------------------------------------- +// +// Copyright (c) Katana Contributors. All rights reserved. +// +//----------------------------------------------------------------------- +// + +using System; +using System.CodeDom.Compiler; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Security.Authentication.ExtendedProtection; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Owin.Host.WebListener +{ + using OpaqueUpgrade = Action, Func, Task>>; + + [GeneratedCode("TextTemplatingFileGenerator", "")] + internal partial class CallEnvironment + { + // Mark all fields with delay initialization support as set. + private UInt32 _flag0 = 0x<#=props.Aggregate(0, (agg,p) => agg | (((p.Init != Init.No) && (p.Index/32==0) ? 1:0)<u; + private UInt32 _flag1 = 0x<#=props.Aggregate(0, (agg,p) => agg | (((p.Init != Init.No) && (p.Index/32==1) ? 1:0)<u; + // Mark all fields with delay initialization support as requiring initialization. + private UInt32 _initFlag0 = 0x<#=props.Aggregate(0, (agg,p) => agg | (((p.Init != Init.No) && (p.Index/32==0) ? 1:0)<u; + private UInt32 _initFlag1 = 0x<#=props.Aggregate(0, (agg,p) => agg | (((p.Init != Init.No) && (p.Index/32==1) ? 1:0)<u; + + internal interface IPropertySource + { +<# foreach(var prop in props) { #> +<# if (prop.Init == Init.Yes) { #> + <#=prop.Type#> Get<#=prop.Name#>(); +<# } #> +<# if (prop.Init == Init.Maybe) { #> + bool TryGet<#=prop.Name#>(ref <#=prop.Type#> value); +<# } #> +<# } #> + } + +<# foreach(var prop in props) { #> + private <#=prop.Type#> _<#=prop.Name#>; +<# } #> + +<# foreach(var prop in props) { #> +<# // call TryGet once if init flag is set, clear value flag if TryGet returns false +if (prop.Init == Init.Maybe) { #> + bool InitProperty<#=prop.Name#>() + { + if (!_propertySource.TryGet<#=prop.Name#>(ref _<#=prop.Name#>)) + { + <#=Clear(prop.Index)#>; + <#=CompleteInit(prop.Index)#>; + return false; + } + <#=CompleteInit(prop.Index)#>; + return true; + } + +<# } #> +<# } #> +<# foreach(var prop in props) { #> + internal <#=prop.Type#> <#=prop.Name#> + { + get + { +<# // call Get once if init flag is set +if (prop.Init == Init.Yes) { #> + if (<#=IsInitRequired(prop.Index)#>) + { + _<#=prop.Name#> = _propertySource.Get<#=prop.Name#>(); + <#=CompleteInit(prop.Index)#>; + } +<# } #> +<# // call TryGet once if init flag is set, clear value flag if TryGet returns false +if (prop.Init == Init.Maybe) { #> + if (<#=IsInitRequired(prop.Index)#>) + { + InitProperty<#=prop.Name#>(); + } +<# } #> + return _<#=prop.Name#>; + } + set + { +<# // clear init flag - the assigned value is definitive +if (prop.Init != Init.No) { #> + <#=CompleteInit(prop.Index)#>; +<# } #> + <#=Set(prop.Index)#>; + _<#=prop.Name#> = value; + } + } + +<# } #> + private bool PropertiesContainsKey(string key) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (<#=IsSet(prop.Index)#> && string.Equals(key, "<#=prop.Key#>", StringComparison.Ordinal)) + { +<# // variable maybe init might revert +if (prop.Init == Init.Maybe) { #> + if (<#=IsInitCompleted(prop.Index)#> || InitProperty<#=prop.Name#>()) + { + return true; + } +<# } else { #> + return true; +<# } #> + } +<# } #> + break; +<# } #> + } + return false; + } + + private bool PropertiesTryGetValue(string key, out object value) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (<#=IsSet(prop.Index)#> && string.Equals(key, "<#=prop.Key#>", StringComparison.Ordinal)) + { + value = <#=prop.Name#>; +<# if (prop.Init == Init.Maybe) { #> + // Delayed initialization in the property getter may determine that the element is not actually present + if (!<#=IsSet(prop.Index)#>) + { + value = default(<#=prop.Type#>); + return false; + } +<# } #> + return true; + } +<# } #> + break; +<# } #> + } + value = null; + return false; + } + + private bool PropertiesTrySetValue(string key, object value) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (string.Equals(key, "<#=prop.Key#>", StringComparison.Ordinal)) + { + <#=prop.Name#> = (<#=prop.Type#>)value; + return true; + } +<# } #> + break; +<# } #> + } + return false; + } + + private bool PropertiesTryRemove(string key) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (<#=IsSet(prop.Index)#> && string.Equals(key, "<#=prop.Key#>", StringComparison.Ordinal)) + { +<# if (prop.Init != Init.No) { #> + <#=CompleteInit(prop.Index)#>; +<# } #> + <#=Clear(prop.Index)#>; + _<#=prop.Name#> = default(<#=prop.Type#>); + // This can return true incorrectly for values that delayed initialization may determine are not actually present. + return true; + } +<# } #> + break; +<# } #> + } + return false; + } + + private IEnumerable PropertiesKeys() + { +<# foreach(var prop in props) { #> + if (<#=IsSet(prop.Index)#>) + { +<# if (prop.Init == Init.Maybe) { #> + if (<#=IsInitCompleted(prop.Index)#> || InitProperty<#=prop.Name#>()) + { + yield return "<#=prop.Key#>"; + } +<# } else { #> + yield return "<#=prop.Key#>"; +<# } #> + } +<# } #> + } + + private IEnumerable PropertiesValues() + { +<# foreach(var prop in props) { #> + if (<#=IsSet(prop.Index)#>) + { +<# if (prop.Init == Init.Maybe) { #> + if (<#=IsInitCompleted(prop.Index)#> || InitProperty<#=prop.Name#>()) + { + yield return <#=prop.Name#>; + } +<# } else { #> + yield return <#=prop.Name#>; +<# } #> + } +<# } #> + } + + private IEnumerable> PropertiesEnumerable() + { +<# foreach(var prop in props) { #> + if (<#=IsSet(prop.Index)#>) + { +<# if (prop.Init == Init.Maybe) { #> + if (<#=IsInitCompleted(prop.Index)#> || InitProperty<#=prop.Name#>()) + { + yield return new KeyValuePair("<#=prop.Key#>", <#=prop.Name#>); + } +<# } else { #> + yield return new KeyValuePair("<#=prop.Key#>", <#=prop.Name#>); +<# } #> + } +<# } #> + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.cs new file mode 100644 index 0000000000..360a2ea245 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/CallEnvironment.cs @@ -0,0 +1,165 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- +// Copyright 2011-2012 Katana contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Threading; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal partial class CallEnvironment : IDictionary + { + private static readonly IDictionary WeakNilEnvironment = new NilEnvDictionary(); + + private readonly IPropertySource _propertySource; + private IDictionary _extra = WeakNilEnvironment; + + internal CallEnvironment(IPropertySource propertySource) + { + _propertySource = propertySource; + } + + private IDictionary Extra + { + get { return _extra; } + } + + private IDictionary StrongExtra + { + get + { + if (_extra == WeakNilEnvironment) + { + Interlocked.CompareExchange(ref _extra, new Dictionary(), WeakNilEnvironment); + } + return _extra; + } + } + + internal bool IsExtraDictionaryCreated + { + get { return _extra != WeakNilEnvironment; } + } + + public object this[string key] + { + get + { + object value; + return PropertiesTryGetValue(key, out value) ? value : Extra[key]; + } + set + { + if (!PropertiesTrySetValue(key, value)) + { + StrongExtra[key] = value; + } + } + } + + public void Add(string key, object value) + { + if (!PropertiesTrySetValue(key, value)) + { + StrongExtra.Add(key, value); + } + } + + public bool ContainsKey(string key) + { + return PropertiesContainsKey(key) || Extra.ContainsKey(key); + } + + public ICollection Keys + { + get { return PropertiesKeys().Concat(Extra.Keys).ToArray(); } + } + + public bool Remove(string key) + { + // Although this is a mutating operation, Extra is used instead of StrongExtra, + // because if a real dictionary has not been allocated the default behavior of the + // nil dictionary is perfectly fine. + return PropertiesTryRemove(key) || Extra.Remove(key); + } + + public bool TryGetValue(string key, out object value) + { + return PropertiesTryGetValue(key, out value) || Extra.TryGetValue(key, out value); + } + + public ICollection Values + { + get { return PropertiesValues().Concat(Extra.Values).ToArray(); } + } + + public void Add(KeyValuePair item) + { + ((IDictionary)this).Add(item.Key, item.Value); + } + + public void Clear() + { + foreach (var key in PropertiesKeys()) + { + PropertiesTryRemove(key); + } + Extra.Clear(); + } + + public bool Contains(KeyValuePair item) + { + object value; + return ((IDictionary)this).TryGetValue(item.Key, out value) && Object.Equals(value, item.Value); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + PropertiesEnumerable().Concat(Extra).ToArray().CopyTo(array, arrayIndex); + } + + public int Count + { + get { return PropertiesKeys().Count() + Extra.Count; } + } + + public bool IsReadOnly + { + get { return false; } + } + + public bool Remove(KeyValuePair item) + { + return ((IDictionary)this).Contains(item) && + ((IDictionary)this).Remove(item.Key); + } + + public IEnumerator> GetEnumerator() + { + return PropertiesEnumerable().Concat(Extra).GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return ((IDictionary)this).GetEnumerator(); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ClientCertLoader.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ClientCertLoader.cs new file mode 100644 index 0000000000..3cd17db788 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ClientCertLoader.cs @@ -0,0 +1,347 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +#if NET45 + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.Runtime.InteropServices; +using System.Security; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + // This class is used to load the client certificate on-demand. Because client certs are optional, all + // failures are handled internally and reported via ClientCertException or ClientCertError. + internal unsafe sealed class ClientCertLoader : IAsyncResult, IDisposable + { + private const uint CertBoblSize = 1500; + private static readonly IOCompletionCallback IOCallback = new IOCompletionCallback(WaitCallback); + + private SafeNativeOverlapped _overlapped; + private byte[] _backingBuffer; + private UnsafeNclNativeMethods.HttpApi.HTTP_SSL_CLIENT_CERT_INFO* _memoryBlob; + private uint _size; + private TaskCompletionSource _tcs; + private RequestContext _requestContext; + + private int _clientCertError; + private X509Certificate2 _clientCert; + private Exception _clientCertException; + + internal ClientCertLoader(RequestContext requestContext) + { + _requestContext = requestContext; + _tcs = new TaskCompletionSource(); + // we will use this overlapped structure to issue async IO to ul + // the event handle will be put in by the BeginHttpApi2.ERROR_SUCCESS() method + Reset(CertBoblSize); + } + + internal X509Certificate2 ClientCert + { + get + { + Contract.Assert(Task.IsCompleted); + return _clientCert; + } + } + + internal int ClientCertError + { + get + { + Contract.Assert(Task.IsCompleted); + return _clientCertError; + } + } + + internal Exception ClientCertException + { + get + { + Contract.Assert(Task.IsCompleted); + return _clientCertException; + } + } + + private RequestContext RequestContext + { + get + { + return _requestContext; + } + } + + private Task Task + { + get + { + return _tcs.Task; + } + } + + private SafeNativeOverlapped NativeOverlapped + { + get + { + return _overlapped; + } + } + + private UnsafeNclNativeMethods.HttpApi.HTTP_SSL_CLIENT_CERT_INFO* RequestBlob + { + get + { + return _memoryBlob; + } + } + + private void Reset(uint size) + { + if (size == _size) + { + return; + } + if (_size != 0) + { + _overlapped.Dispose(); + } + _size = size; + if (size == 0) + { + _overlapped = null; + _memoryBlob = null; + _backingBuffer = null; + return; + } + _backingBuffer = new byte[checked((int)size)]; + Overlapped overlapped = new Overlapped(); + overlapped.AsyncResult = this; + _overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, _backingBuffer)); + _memoryBlob = (UnsafeNclNativeMethods.HttpApi.HTTP_SSL_CLIENT_CERT_INFO*)Marshal.UnsafeAddrOfPinnedArrayElement(_backingBuffer, 0); + } + + // When you use netsh to configure HTTP.SYS with clientcertnegotiation = enable + // which means negotiate client certificates, when the client makes the + // initial SSL connection, the server (HTTP.SYS) requests the client certificate. + // + // Some apps may not want to negotiate the client cert at the beginning, + // perhaps serving the default.htm. In this case the HTTP.SYS is configured + // with clientcertnegotiation = disabled, which means that the client certificate is + // optional so initially when SSL is established HTTP.SYS won't ask for client + // certificate. This works fine for the default.htm in the case above, + // however, if the app wants to demand a client certificate at a later time + // perhaps showing "YOUR ORDERS" page, then the server wants to negotiate + // Client certs. This will in turn makes HTTP.SYS to do the + // SEC_I_RENOGOTIATE through which the client cert demand is made + // + // NOTE: When calling HttpReceiveClientCertificate you can get + // ERROR_NOT_FOUND - which means the client did not provide the cert + // If this is important, the server should respond with 403 forbidden + // HTTP.SYS will not do this for you automatically + internal Task LoadClientCertificateAsync() + { + uint size = CertBoblSize; + bool retry; + do + { + retry = false; + uint bytesReceived = 0; + + uint statusCode = + UnsafeNclNativeMethods.HttpApi.HttpReceiveClientCertificate( + RequestContext.RequestQueueHandle, + RequestContext.Request.ConnectionId, + (uint)UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE, + RequestBlob, + size, + &bytesReceived, + NativeOverlapped); + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_MORE_DATA) + { + UnsafeNclNativeMethods.HttpApi.HTTP_SSL_CLIENT_CERT_INFO* pClientCertInfo = RequestBlob; + size = bytesReceived + pClientCertInfo->CertEncodedSize; + Reset(size); + retry = true; + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_NOT_FOUND) + { + // The client did not send a cert. + Complete(0, null); + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + OwinWebListener.SkipIOCPCallbackOnSuccess) + { + IOCompleted(statusCode, bytesReceived); + } + else if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + // Some other bad error, possible(?) return values are: + // ERROR_INVALID_HANDLE, ERROR_INSUFFICIENT_BUFFER, ERROR_OPERATION_ABORTED + // Also ERROR_BAD_DATA if we got it twice or it reported smaller size buffer required. + Fail(new WebListenerException((int)statusCode)); + } + } + while (retry); + + return Task; + } + + private void Complete(int certErrors, X509Certificate2 cert) + { + // May be null + _clientCert = cert; + _clientCertError = certErrors; + _tcs.TrySetResult(null); + Dispose(); + } + + private void Fail(Exception ex) + { + // TODO: Log + _clientCertException = ex; + _tcs.TrySetResult(null); + } + + private unsafe void IOCompleted(uint errorCode, uint numBytes) + { + IOCompleted(this, errorCode, numBytes); + } + + [SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Redirected to callback")] + private static unsafe void IOCompleted(ClientCertLoader asyncResult, uint errorCode, uint numBytes) + { + RequestContext requestContext = asyncResult.RequestContext; + try + { + if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_MORE_DATA) + { + // There is a bug that has existed in http.sys since w2k3. Bytesreceived will only + // return the size of the initial cert structure. To get the full size, + // we need to add the certificate encoding size as well. + + UnsafeNclNativeMethods.HttpApi.HTTP_SSL_CLIENT_CERT_INFO* pClientCertInfo = asyncResult.RequestBlob; + asyncResult.Reset(numBytes + pClientCertInfo->CertEncodedSize); + + uint bytesReceived = 0; + errorCode = + UnsafeNclNativeMethods.HttpApi.HttpReceiveClientCertificate( + requestContext.RequestQueueHandle, + requestContext.Request.ConnectionId, + (uint)UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE, + asyncResult._memoryBlob, + asyncResult._size, + &bytesReceived, + asyncResult._overlapped); + + if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING || + (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && !OwinWebListener.SkipIOCPCallbackOnSuccess)) + { + return; + } + } + + if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_NOT_FOUND) + { + // The client did not send a cert. + asyncResult.Complete(0, null); + } + else if (errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + asyncResult.Fail(new WebListenerException((int)errorCode)); + } + else + { + UnsafeNclNativeMethods.HttpApi.HTTP_SSL_CLIENT_CERT_INFO* pClientCertInfo = asyncResult._memoryBlob; + if (pClientCertInfo == null) + { + asyncResult.Complete(0, null); + } + else + { + if (pClientCertInfo->pCertEncoded != null) + { + try + { + byte[] certEncoded = new byte[pClientCertInfo->CertEncodedSize]; + Marshal.Copy((IntPtr)pClientCertInfo->pCertEncoded, certEncoded, 0, certEncoded.Length); + asyncResult.Complete((int)pClientCertInfo->CertFlags, new X509Certificate2(certEncoded)); + } + catch (CryptographicException exception) + { + // TODO: Log + asyncResult.Fail(exception); + } + catch (SecurityException exception) + { + // TODO: Log + asyncResult.Fail(exception); + } + } + } + } + } + catch (Exception exception) + { + asyncResult.Fail(exception); + } + } + + private static unsafe void WaitCallback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped) + { + Overlapped callbackOverlapped = Overlapped.Unpack(nativeOverlapped); + ClientCertLoader asyncResult = (ClientCertLoader)callbackOverlapped.AsyncResult; + + IOCompleted(asyncResult, errorCode, numBytes); + } + + public void Dispose() + { + Dispose(true); + } + + private void Dispose(bool disposing) + { + if (disposing) + { + if (_overlapped != null) + { + _memoryBlob = null; + _overlapped.Dispose(); + } + } + } + + public object AsyncState + { + get { return _tcs.Task.AsyncState; } + } + + public WaitHandle AsyncWaitHandle + { + get { return ((IAsyncResult)_tcs.Task).AsyncWaitHandle; } + } + + public bool CompletedSynchronously + { + get { return ((IAsyncResult)_tcs.Task).CompletedSynchronously; } + } + + public bool IsCompleted + { + get { return _tcs.Task.IsCompleted; } + } + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/EntitySendFormat.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/EntitySendFormat.cs new file mode 100644 index 0000000000..f6861f9bd6 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/EntitySendFormat.cs @@ -0,0 +1,17 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum EntitySendFormat + { + ContentLength = 0, // Content-Length: XXX + Chunked = 1, // Transfer-Encoding: chunked + /* + Raw = 2, // the app is responsible for sending the correct headers and body encoding + */ + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HeaderEncoding.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HeaderEncoding.cs new file mode 100644 index 0000000000..24a8df80ff --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HeaderEncoding.cs @@ -0,0 +1,97 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Microsoft.AspNet.Server.WebListener +{ + // we use this static class as a helper class to encode/decode HTTP headers. + // what we need is a 1-1 correspondence between a char in the range U+0000-U+00FF + // and a byte in the range 0x00-0xFF (which is the range that can hit the network). + // The Latin-1 encoding (ISO-88591-1) (GetEncoding(28591)) works for byte[] to string, but is a little slow. + // It doesn't work for string -> byte[] because of best-fit-mapping problems. + internal static class HeaderEncoding + { + internal static unsafe string GetString(byte[] bytes, int byteIndex, int byteCount) + { + fixed (byte* pBytes = bytes) + return GetString(pBytes + byteIndex, byteCount); + } + + internal static unsafe string GetString(sbyte* pBytes, int byteCount) + { + return GetString((byte*)pBytes, byteCount); + } + + internal static unsafe string GetString(byte* pBytes, int byteCount) + { + if (byteCount < 1) + { + return string.Empty; + } + + string s = new String('\0', byteCount); + + fixed (char* pStr = s) + { + char* pString = pStr; + while (byteCount >= 8) + { + pString[0] = (char)pBytes[0]; + pString[1] = (char)pBytes[1]; + pString[2] = (char)pBytes[2]; + pString[3] = (char)pBytes[3]; + pString[4] = (char)pBytes[4]; + pString[5] = (char)pBytes[5]; + pString[6] = (char)pBytes[6]; + pString[7] = (char)pBytes[7]; + pString += 8; + pBytes += 8; + byteCount -= 8; + } + for (int i = 0; i < byteCount; i++) + { + pString[i] = (char)pBytes[i]; + } + } + + return s; + } + + internal static int GetByteCount(string myString) + { + return myString.Length; + } + internal static unsafe void GetBytes(string myString, int charIndex, int charCount, byte[] bytes, int byteIndex) + { + if (myString.Length == 0) + { + return; + } + fixed (byte* bufferPointer = bytes) + { + byte* newBufferPointer = bufferPointer + byteIndex; + int finalIndex = charIndex + charCount; + while (charIndex < finalIndex) + { + *newBufferPointer++ = (byte)myString[charIndex++]; + } + } + } + internal static unsafe byte[] GetBytes(string myString) + { + byte[] bytes = new byte[myString.Length]; + if (myString.Length != 0) + { + GetBytes(myString, 0, myString.Length, bytes, 0); + } + return bytes; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpKnownHeaderNames.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpKnownHeaderNames.cs new file mode 100644 index 0000000000..9885de2fcd --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpKnownHeaderNames.cs @@ -0,0 +1,75 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class HttpKnownHeaderNames + { + internal const string CacheControl = "Cache-Control"; + internal const string Connection = "Connection"; + internal const string Date = "Date"; + internal const string KeepAlive = "Keep-Alive"; + internal const string Pragma = "Pragma"; + internal const string ProxyConnection = "Proxy-Connection"; + internal const string Trailer = "Trailer"; + internal const string TransferEncoding = "Transfer-Encoding"; + internal const string Upgrade = "Upgrade"; + internal const string Via = "Via"; + internal const string Warning = "Warning"; + internal const string ContentLength = "Content-Length"; + internal const string ContentType = "Content-Type"; + internal const string ContentDisposition = "Content-Disposition"; + internal const string ContentEncoding = "Content-Encoding"; + internal const string ContentLanguage = "Content-Language"; + internal const string ContentLocation = "Content-Location"; + internal const string ContentRange = "Content-Range"; + internal const string Expires = "Expires"; + internal const string LastModified = "Last-Modified"; + internal const string Age = "Age"; + internal const string Location = "Location"; + internal const string ProxyAuthenticate = "Proxy-Authenticate"; + internal const string RetryAfter = "Retry-After"; + internal const string Server = "Server"; + internal const string SetCookie = "Set-Cookie"; + internal const string SetCookie2 = "Set-Cookie2"; + internal const string Vary = "Vary"; + internal const string WWWAuthenticate = "WWW-Authenticate"; + internal const string Accept = "Accept"; + internal const string AcceptCharset = "Accept-Charset"; + internal const string AcceptEncoding = "Accept-Encoding"; + internal const string AcceptLanguage = "Accept-Language"; + internal const string Authorization = "Authorization"; + internal const string Cookie = "Cookie"; + internal const string Cookie2 = "Cookie2"; + internal const string Expect = "Expect"; + internal const string From = "From"; + internal const string Host = "Host"; + internal const string IfMatch = "If-Match"; + internal const string IfModifiedSince = "If-Modified-Since"; + internal const string IfNoneMatch = "If-None-Match"; + internal const string IfRange = "If-Range"; + internal const string IfUnmodifiedSince = "If-Unmodified-Since"; + internal const string MaxForwards = "Max-Forwards"; + internal const string ProxyAuthorization = "Proxy-Authorization"; + internal const string Referer = "Referer"; + internal const string Range = "Range"; + internal const string UserAgent = "User-Agent"; + internal const string ContentMD5 = "Content-MD5"; + internal const string ETag = "ETag"; + internal const string TE = "TE"; + internal const string Allow = "Allow"; + internal const string AcceptRanges = "Accept-Ranges"; + internal const string P3P = "P3P"; + internal const string XPoweredBy = "X-Powered-By"; + internal const string XAspNetVersion = "X-AspNet-Version"; + internal const string SecWebSocketKey = "Sec-WebSocket-Key"; + internal const string SecWebSocketExtensions = "Sec-WebSocket-Extensions"; + internal const string SecWebSocketAccept = "Sec-WebSocket-Accept"; + internal const string Origin = "Origin"; + internal const string SecWebSocketProtocol = "Sec-WebSocket-Protocol"; + internal const string SecWebSocketVersion = "Sec-WebSocket-Version"; + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpReasonPhrase.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpReasonPhrase.cs new file mode 100644 index 0000000000..27370832ff --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpReasonPhrase.cs @@ -0,0 +1,109 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class HttpReasonPhrase + { + private static readonly string[][] HttpReasonPhrases = new string[][] + { + null, + + new string[] + { + /* 100 */ "Continue", + /* 101 */ "Switching Protocols", + /* 102 */ "Processing" + }, + + new string[] + { + /* 200 */ "OK", + /* 201 */ "Created", + /* 202 */ "Accepted", + /* 203 */ "Non-Authoritative Information", + /* 204 */ "No Content", + /* 205 */ "Reset Content", + /* 206 */ "Partial Content", + /* 207 */ "Multi-Status" + }, + + new string[] + { + /* 300 */ "Multiple Choices", + /* 301 */ "Moved Permanently", + /* 302 */ "Found", + /* 303 */ "See Other", + /* 304 */ "Not Modified", + /* 305 */ "Use Proxy", + /* 306 */ null, + /* 307 */ "Temporary Redirect" + }, + + new string[] + { + /* 400 */ "Bad Request", + /* 401 */ "Unauthorized", + /* 402 */ "Payment Required", + /* 403 */ "Forbidden", + /* 404 */ "Not Found", + /* 405 */ "Method Not Allowed", + /* 406 */ "Not Acceptable", + /* 407 */ "Proxy Authentication Required", + /* 408 */ "Request Timeout", + /* 409 */ "Conflict", + /* 410 */ "Gone", + /* 411 */ "Length Required", + /* 412 */ "Precondition Failed", + /* 413 */ "Request Entity Too Large", + /* 414 */ "Request-Uri Too Long", + /* 415 */ "Unsupported Media Type", + /* 416 */ "Requested Range Not Satisfiable", + /* 417 */ "Expectation Failed", + /* 418 */ null, + /* 419 */ null, + /* 420 */ null, + /* 421 */ null, + /* 422 */ "Unprocessable Entity", + /* 423 */ "Locked", + /* 424 */ "Failed Dependency", + /* 425 */ null, + /* 426 */ "Upgrade Required", // RFC 2817 + }, + + new string[] + { + /* 500 */ "Internal Server Error", + /* 501 */ "Not Implemented", + /* 502 */ "Bad Gateway", + /* 503 */ "Service Unavailable", + /* 504 */ "Gateway Timeout", + /* 505 */ "Http Version Not Supported", + /* 506 */ null, + /* 507 */ "Insufficient Storage" + } + }; + + internal static string Get(HttpStatusCode code) + { + return Get((int)code); + } + + internal static string Get(int code) + { + if (code >= 100 && code < 600) + { + int i = code / 100; + int j = code % 100; + if (j < HttpReasonPhrases[i].Length) + { + return HttpReasonPhrases[i][j]; + } + } + return null; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpStatusCode.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpStatusCode.cs new file mode 100644 index 0000000000..88e6fe4b18 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/HttpStatusCode.cs @@ -0,0 +1,314 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener +{ + // Redirect Status code numbers that need to be defined. + + /// + /// Contains the values of status + /// codes defined for the HTTP protocol. + /// + // UEUE : Any int can be cast to a HttpStatusCode to allow checking for non http1.1 codes. + internal enum HttpStatusCode + { + // Informational 1xx + + /// + /// [To be supplied.] + /// + Continue = 100, + + /// + /// [To be supplied.] + /// + SwitchingProtocols = 101, + + // Successful 2xx + + /// + /// [To be supplied.] + /// + OK = 200, + + /// + /// [To be supplied.] + /// + Created = 201, + + /// + /// [To be supplied.] + /// + Accepted = 202, + + /// + /// [To be supplied.] + /// + NonAuthoritativeInformation = 203, + + /// + /// [To be supplied.] + /// + NoContent = 204, + + /// + /// [To be supplied.] + /// + ResetContent = 205, + + /// + /// [To be supplied.] + /// + PartialContent = 206, + + // Redirection 3xx + + /// + /// [To be supplied.] + /// + MultipleChoices = 300, + + /// + /// [To be supplied.] + /// + Ambiguous = 300, + + /// + /// [To be supplied.] + /// + MovedPermanently = 301, + + /// + /// [To be supplied.] + /// + Moved = 301, + + /// + /// [To be supplied.] + /// + Found = 302, + + /// + /// [To be supplied.] + /// + Redirect = 302, + + /// + /// [To be supplied.] + /// + SeeOther = 303, + + /// + /// [To be supplied.] + /// + RedirectMethod = 303, + + /// + /// [To be supplied.] + /// + NotModified = 304, + + /// + /// [To be supplied.] + /// + UseProxy = 305, + + /// + /// [To be supplied.] + /// + Unused = 306, + + /// + /// [To be supplied.] + /// + TemporaryRedirect = 307, + + /// + /// [To be supplied.] + /// + RedirectKeepVerb = 307, + + // Client Error 4xx + + /// + /// [To be supplied.] + /// + BadRequest = 400, + + /// + /// [To be supplied.] + /// + Unauthorized = 401, + + /// + /// [To be supplied.] + /// + PaymentRequired = 402, + + /// + /// [To be supplied.] + /// + Forbidden = 403, + + /// + /// [To be supplied.] + /// + NotFound = 404, + + /// + /// [To be supplied.] + /// + MethodNotAllowed = 405, + + /// + /// [To be supplied.] + /// + NotAcceptable = 406, + + /// + /// [To be supplied.] + /// + ProxyAuthenticationRequired = 407, + + /// + /// [To be supplied.] + /// + RequestTimeout = 408, + + /// + /// [To be supplied.] + /// + Conflict = 409, + + /// + /// [To be supplied.] + /// + Gone = 410, + + /// + /// [To be supplied.] + /// + LengthRequired = 411, + + /// + /// [To be supplied.] + /// + PreconditionFailed = 412, + + /// + /// [To be supplied.] + /// + RequestEntityTooLarge = 413, + + /// + /// [To be supplied.] + /// + RequestUriTooLong = 414, + + /// + /// [To be supplied.] + /// + UnsupportedMediaType = 415, + + /// + /// [To be supplied.] + /// + RequestedRangeNotSatisfiable = 416, + + /// + /// [To be supplied.] + /// + ExpectationFailed = 417, + + UpgradeRequired = 426, + + // Server Error 5xx + + /// + /// [To be supplied.] + /// + InternalServerError = 500, + + /// + /// [To be supplied.] + /// + NotImplemented = 501, + + /// + /// [To be supplied.] + /// + BadGateway = 502, + + /// + /// [To be supplied.] + /// + ServiceUnavailable = 503, + + /// + /// [To be supplied.] + /// + GatewayTimeout = 504, + + /// + /// [To be supplied.] + /// + HttpVersionNotSupported = 505, + } // enum HttpStatusCode + + /* + Fielding, et al. Standards Track [Page 3] + + RFC 2616 HTTP/1.1 June 1999 + + + 10.1 Informational 1xx ...........................................57 + 10.1.1 100 Continue .............................................58 + 10.1.2 101 Switching Protocols ..................................58 + 10.2 Successful 2xx ..............................................58 + 10.2.1 200 OK ...................................................58 + 10.2.2 201 Created ..............................................59 + 10.2.3 202 Accepted .............................................59 + 10.2.4 203 Non-Authoritative Information ........................59 + 10.2.5 204 No Content ...........................................60 + 10.2.6 205 Reset Content ........................................60 + 10.2.7 206 Partial Content ......................................60 + 10.3 Redirection 3xx .............................................61 + 10.3.1 300 Multiple Choices .....................................61 + 10.3.2 301 Moved Permanently ....................................62 + 10.3.3 302 Found ................................................62 + 10.3.4 303 See Other ............................................63 + 10.3.5 304 Not Modified .........................................63 + 10.3.6 305 Use Proxy ............................................64 + 10.3.7 306 (Unused) .............................................64 + 10.3.8 307 Temporary Redirect ...................................65 + 10.4 Client Error 4xx ............................................65 + 10.4.1 400 Bad Request .........................................65 + 10.4.2 401 Unauthorized ........................................66 + 10.4.3 402 Payment Required ....................................66 + 10.4.4 403 Forbidden ...........................................66 + 10.4.5 404 Not Found ...........................................66 + 10.4.6 405 Method Not Allowed ..................................66 + 10.4.7 406 Not Acceptable ......................................67 + 10.4.8 407 Proxy Authentication Required .......................67 + 10.4.9 408 Request Timeout .....................................67 + 10.4.10 409 Conflict ............................................67 + 10.4.11 410 Gone ................................................68 + 10.4.12 411 Length Required .....................................68 + 10.4.13 412 Precondition Failed .................................68 + 10.4.14 413 Request Entity Too Large ............................69 + 10.4.15 414 Request-URI Too Long ................................69 + 10.4.16 415 Unsupported Media Type ..............................69 + 10.4.17 416 Requested Range Not Satisfiable .....................69 + 10.4.18 417 Expectation Failed ..................................70 + 10.5 Server Error 5xx ............................................70 + 10.5.1 500 Internal Server Error ................................70 + 10.5.2 501 Not Implemented ......................................70 + 10.5.3 502 Bad Gateway ..........................................70 + 10.5.4 503 Service Unavailable ..................................70 + 10.5.5 504 Gateway Timeout ......................................71 + 10.5.6 505 HTTP Version Not Supported ...........................71 + */ +} // namespace System.Net diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NativeRequestContext.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NativeRequestContext.cs new file mode 100644 index 0000000000..8ab4f029cb --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NativeRequestContext.cs @@ -0,0 +1,170 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal unsafe class NativeRequestContext : IDisposable + { + private const int DefaultBufferSize = 4096; + private UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST* _memoryBlob; + private IntPtr _originalBlobAddress; + private byte[] _backingBuffer; + private SafeNativeOverlapped _nativeOverlapped; + private AsyncAcceptContext _acceptResult; + + internal NativeRequestContext(AsyncAcceptContext result) + { + _acceptResult = result; + UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST* requestBlob = Allocate(0); + if (requestBlob == null) + { + GC.SuppressFinalize(this); + } + else + { + _memoryBlob = requestBlob; + } + } + + internal SafeNativeOverlapped NativeOverlapped + { + get + { + return _nativeOverlapped; + } + } + + internal UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST* RequestBlob + { + get + { + Debug.Assert(_memoryBlob != null || _backingBuffer == null, "RequestBlob requested after ReleasePins()."); + return _memoryBlob; + } + } + + internal byte[] RequestBuffer + { + get + { + return _backingBuffer; + } + } + + internal uint Size + { + get + { + return (uint)_backingBuffer.Length; + } + } + + internal IntPtr OriginalBlobAddress + { + get + { + UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST* blob = _memoryBlob; + return (blob == null ? _originalBlobAddress : (IntPtr)blob); + } + } + + // ReleasePins() should be called exactly once. It must be called before Dispose() is called, which means it must be called + // before an object (Request) which closes the RequestContext on demand is returned to the application. + internal void ReleasePins() + { + Debug.Assert(_memoryBlob != null || _backingBuffer == null, "RequestContextBase::ReleasePins()|ReleasePins() called twice."); + _originalBlobAddress = (IntPtr)_memoryBlob; + UnsetBlob(); + OnReleasePins(); + } + + private void OnReleasePins() + { + if (_nativeOverlapped != null) + { + SafeNativeOverlapped nativeOverlapped = _nativeOverlapped; + _nativeOverlapped = null; + nativeOverlapped.Dispose(); + } + } + + public void Dispose() + { + Debug.Assert(_memoryBlob == null, "RequestContextBase::Dispose()|Dispose() called before ReleasePins()."); + Dispose(true); + } + + protected void Dispose(bool disposing) + { + if (_nativeOverlapped != null) + { + Debug.Assert(!disposing, "AsyncRequestContext::Dispose()|Must call ReleasePins() before calling Dispose()."); + _nativeOverlapped.Dispose(); + } + } + + private void SetBlob(UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST* requestBlob) + { + Debug.Assert(_memoryBlob != null || _backingBuffer == null, "RequestContextBase::Dispose()|SetBlob() called after ReleasePins()."); + if (requestBlob == null) + { + UnsetBlob(); + return; + } + + if (_memoryBlob == null) + { + GC.ReRegisterForFinalize(this); + } + _memoryBlob = requestBlob; + } + + private void UnsetBlob() + { + if (_memoryBlob != null) + { + GC.SuppressFinalize(this); + } + _memoryBlob = null; + } + + private void SetBuffer(int size) + { + _backingBuffer = size == 0 ? null : new byte[size]; + } + + private UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST* Allocate(uint size) + { + uint newSize = size != 0 ? size : RequestBuffer == null ? DefaultBufferSize : Size; + if (_nativeOverlapped != null && newSize != RequestBuffer.Length) + { + SafeNativeOverlapped nativeOverlapped = _nativeOverlapped; + _nativeOverlapped = null; + nativeOverlapped.Dispose(); + } + if (_nativeOverlapped == null) + { + SetBuffer(checked((int)newSize)); + Overlapped overlapped = new Overlapped(); + overlapped.AsyncResult = _acceptResult; + _nativeOverlapped = new SafeNativeOverlapped(overlapped.Pack(AsyncAcceptContext.IOCallback, RequestBuffer)); + return (UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST*)Marshal.UnsafeAddrOfPinnedArrayElement(RequestBuffer, 0); + } + return RequestBlob; + } + + internal void Reset(ulong requestId, uint size) + { + SetBlob(Allocate(size)); + RequestBlob->RequestId = requestId; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NilEnvDictionary.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NilEnvDictionary.cs new file mode 100644 index 0000000000..af6cc6aff5 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/NilEnvDictionary.cs @@ -0,0 +1,115 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- +// Copyright 2011-2012 Katana contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class NilEnvDictionary : IDictionary + { + private static readonly string[] EmptyKeys = new string[0]; + private static readonly object[] EmptyValues = new object[0]; + private static readonly IEnumerable> EmptyKeyValuePairs = Enumerable.Empty>(); + + public int Count + { + get { return 0; } + } + + public bool IsReadOnly + { + get { return false; } + } + + public ICollection Keys + { + get { return EmptyKeys; } + } + + public ICollection Values + { + get { return EmptyValues; } + } + + [SuppressMessage("Microsoft.Design", "CA1065:DoNotRaiseExceptionsInUnexpectedLocations", Justification = "Not Implemented")] + public object this[string key] + { + get { throw new NotImplementedException(); } + set { throw new NotImplementedException(); } + } + + public IEnumerator> GetEnumerator() + { + return EmptyKeyValuePairs.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return EmptyKeyValuePairs.GetEnumerator(); + } + + public void Add(KeyValuePair item) + { + throw new NotImplementedException(); + } + + public void Clear() + { + } + + public bool Contains(KeyValuePair item) + { + return false; + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + } + + public bool Remove(KeyValuePair item) + { + return false; + } + + public bool ContainsKey(string key) + { + return false; + } + + public void Add(string key, object value) + { + throw new NotImplementedException(); + } + + public bool Remove(string key) + { + return false; + } + + public bool TryGetValue(string key, out object value) + { + value = null; + return false; + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/OpaqueStream.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/OpaqueStream.cs new file mode 100644 index 0000000000..6238d9185d --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/OpaqueStream.cs @@ -0,0 +1,168 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + // A duplex wrapper around RequestStream and ResponseStream. + // TODO: Consider merging RequestStream and ResponseStream instead. + internal class OpaqueStream : Stream + { + private readonly Stream _requestStream; + private readonly Stream _responseStream; + + internal OpaqueStream(Stream requestStream, Stream responseStream) + { + _requestStream = requestStream; + _responseStream = responseStream; + } + +#region Properties + + public override bool CanRead + { + get { return _requestStream.CanRead; } + } + + public override bool CanSeek + { + get { return false; } + } + + public override bool CanTimeout + { + get { return _requestStream.CanTimeout || _responseStream.CanTimeout; } + } + + public override bool CanWrite + { + get { return _responseStream.CanWrite; } + } + + public override long Length + { + get { throw new NotSupportedException(Resources.Exception_NoSeek); } + } + + public override long Position + { + get { throw new NotSupportedException(Resources.Exception_NoSeek); } + set { throw new NotSupportedException(Resources.Exception_NoSeek); } + } + + public override int ReadTimeout + { + get { return _requestStream.ReadTimeout; } + set { _requestStream.ReadTimeout = value; } + } + + public override int WriteTimeout + { + get { return _responseStream.WriteTimeout; } + set { _responseStream.WriteTimeout = value; } + } + +#endregion Properties + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + +#region Read + + public override int Read(byte[] buffer, int offset, int count) + { + return _requestStream.Read(buffer, offset, count); + } + + public override int ReadByte() + { + return _requestStream.ReadByte(); + } +#if NET45 + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _requestStream.BeginRead(buffer, offset, count, callback, state); + } + + public override int EndRead(IAsyncResult asyncResult) + { + return _requestStream.EndRead(asyncResult); + } +#endif + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _requestStream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + return _requestStream.CopyToAsync(destination, bufferSize, cancellationToken); + } + +#endregion Read + +#region Write + + public override void Write(byte[] buffer, int offset, int count) + { + _responseStream.Write(buffer, offset, count); + } + + public override void WriteByte(byte value) + { + _responseStream.WriteByte(value); + } +#if NET45 + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _responseStream.BeginWrite(buffer, offset, count, callback, state); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + _responseStream.EndWrite(asyncResult); + } +#endif + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _responseStream.WriteAsync(buffer, offset, count, cancellationToken); + } + + public override void Flush() + { + _responseStream.Flush(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _responseStream.FlushAsync(cancellationToken); + } + +#endregion Write + + protected override void Dispose(bool disposing) + { + // TODO: Suppress dispose? + if (disposing) + { + _requestStream.Dispose(); + _responseStream.Dispose(); + } + base.Dispose(disposing); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Request.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Request.cs new file mode 100644 index 0000000000..86b25185fb --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Request.cs @@ -0,0 +1,456 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Net; +using System.Runtime.InteropServices; +using System.Security.Authentication.ExtendedProtection; +using System.Security.Principal; +using System.Text; +using System.Threading; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal sealed unsafe class Request : IDisposable + { + private RequestContext _requestContext; + private NativeRequestContext _nativeRequestContext; + + private ulong _requestId; + private ulong _connectionId; + private ulong _contextId; + + private SslStatus _sslStatus; + + private string _httpMethod; + private Version _httpVersion; + + private Uri _requestUri; + private string _rawUrl; + private string _cookedUrlHost; + private string _cookedUrlPath; + private string _cookedUrlQuery; + + private RequestHeaders _headers; + private BoundaryType _contentBoundaryType; + private long _contentLength; + private Stream _requestStream; + private SocketAddress _localEndPoint; + private SocketAddress _remoteEndPoint; + + private IPrincipal _user; + + private bool _isDisposed = false; + private CancellationTokenRegistration _disconnectRegistration; + + internal Request(RequestContext httpContext, NativeRequestContext memoryBlob) + { + // TODO: Verbose log + _requestContext = httpContext; + _nativeRequestContext = memoryBlob; + _contentBoundaryType = BoundaryType.None; + + // Set up some of these now to avoid refcounting on memory blob later. + _requestId = memoryBlob.RequestBlob->RequestId; + _connectionId = memoryBlob.RequestBlob->ConnectionId; + _contextId = memoryBlob.RequestBlob->UrlContext; + _sslStatus = memoryBlob.RequestBlob->pSslInfo == null ? SslStatus.Insecure : + memoryBlob.RequestBlob->pSslInfo->SslClientCertNegotiated == 0 ? SslStatus.NoClientCert : + SslStatus.ClientCert; + if (memoryBlob.RequestBlob->pRawUrl != null && memoryBlob.RequestBlob->RawUrlLength > 0) + { + _rawUrl = Marshal.PtrToStringAnsi((IntPtr)memoryBlob.RequestBlob->pRawUrl, memoryBlob.RequestBlob->RawUrlLength); + } + + UnsafeNclNativeMethods.HttpApi.HTTP_COOKED_URL cookedUrl = memoryBlob.RequestBlob->CookedUrl; + if (cookedUrl.pHost != null && cookedUrl.HostLength > 0) + { + _cookedUrlHost = Marshal.PtrToStringUni((IntPtr)cookedUrl.pHost, cookedUrl.HostLength / 2); + } + if (cookedUrl.pAbsPath != null && cookedUrl.AbsPathLength > 0) + { + _cookedUrlPath = Marshal.PtrToStringUni((IntPtr)cookedUrl.pAbsPath, cookedUrl.AbsPathLength / 2); + } + if (cookedUrl.pQueryString != null && cookedUrl.QueryStringLength > 0) + { + _cookedUrlQuery = Marshal.PtrToStringUni((IntPtr)cookedUrl.pQueryString, cookedUrl.QueryStringLength / 2); + } + + int major = memoryBlob.RequestBlob->Version.MajorVersion; + int minor = memoryBlob.RequestBlob->Version.MinorVersion; + if (major == 1 && minor == 1) + { + _httpVersion = Constants.V1_1; + } + else if (major == 1 && minor == 0) + { + _httpVersion = Constants.V1_0; + } + else + { + _httpVersion = new Version(major, minor); + } + + _httpMethod = UnsafeNclNativeMethods.HttpApi.GetVerb(RequestBuffer, OriginalBlobAddress); + _headers = new RequestHeaders(_nativeRequestContext); + + UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_V2* requestV2 = (UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_V2*)memoryBlob.RequestBlob; + _user = GetUser(requestV2->pRequestInfo); + + // TODO: Verbose log parameters + + // TODO: Verbose log headers + } + + internal SslStatus SslStatus + { + get + { + return _sslStatus; + } + } + + internal ulong ConnectionId + { + get + { + return _connectionId; + } + } + + internal ulong ContextId + { + get { return _contextId; } + } + + internal RequestContext RequestContext + { + get + { + return _requestContext; + } + } + + internal byte[] RequestBuffer + { + get + { + CheckDisposed(); + return _nativeRequestContext.RequestBuffer; + } + } + + internal IntPtr OriginalBlobAddress + { + get + { + CheckDisposed(); + return _nativeRequestContext.OriginalBlobAddress; + } + } + + // Without the leading ? + internal string Query + { + get + { + if (!string.IsNullOrWhiteSpace(_cookedUrlQuery)) + { + return _cookedUrlQuery.Substring(1); + } + return string.Empty; + } + } + + internal ulong RequestId + { + get + { + return _requestId; + } + } + + // TODO: Move this to the constructor, that's where it will be called. + internal long ContentLength64 + { + get + { + if (_contentBoundaryType == BoundaryType.None) + { + string transferEncoding = Headers.Get(HttpKnownHeaderNames.TransferEncoding) ?? string.Empty; + if ("chunked".Equals(transferEncoding.Trim(), StringComparison.OrdinalIgnoreCase)) + { + _contentBoundaryType = BoundaryType.Chunked; + _contentLength = -1; + } + else + { + _contentLength = 0; + _contentBoundaryType = BoundaryType.ContentLength; + string length = Headers.Get(HttpKnownHeaderNames.ContentLength) ?? string.Empty; + if (length != null) + { + if (!long.TryParse(length.Trim(), NumberStyles.None, + CultureInfo.InvariantCulture.NumberFormat, out _contentLength)) + { + _contentLength = 0; + _contentBoundaryType = BoundaryType.Invalid; + } + } + } + } + + return _contentLength; + } + } + + internal IDictionary Headers + { + get + { + return _headers; + } + } + + internal string HttpMethod + { + get + { + return _httpMethod; + } + } + + internal Stream InputStream + { + get + { + if (_requestStream == null) + { + // TODO: Move this to the constructor (or a lazy Env dictionary) + _requestStream = HasEntityBody ? new RequestStream(RequestContext) : Stream.Null; + } + return _requestStream; + } + } + + internal bool IsLocal + { + get + { + return LocalEndPoint.GetIPAddressString().Equals(RemoteEndPoint.GetIPAddressString()); + } + } + + internal bool IsSecureConnection + { + get + { + return _sslStatus != SslStatus.Insecure; + } + } + + internal string RawUrl + { + get + { + return _rawUrl; + } + } + + internal Version ProtocolVersion + { + get + { + return _httpVersion; + } + } + + internal string Protocol + { + get + { + if (_httpVersion.Major == 1) + { + if (_httpVersion.Minor == 1) + { + return "HTTP/1.1"; + } + else if (_httpVersion.Minor == 0) + { + return "HTTP/1.0"; + } + } + return "HTTP/" + _httpVersion.ToString(2); + } + } + + // TODO: Move this to the constructor + internal bool HasEntityBody + { + get + { + // accessing the ContentLength property delay creates m_BoundaryType + return (ContentLength64 > 0 && _contentBoundaryType == BoundaryType.ContentLength) || + _contentBoundaryType == BoundaryType.Chunked || _contentBoundaryType == BoundaryType.Multipart; + } + } + + internal SocketAddress RemoteEndPoint + { + get + { + if (_remoteEndPoint == null) + { + _remoteEndPoint = UnsafeNclNativeMethods.HttpApi.GetRemoteEndPoint(RequestBuffer, OriginalBlobAddress); + } + + return _remoteEndPoint; + } + } + + internal SocketAddress LocalEndPoint + { + get + { + if (_localEndPoint == null) + { + _localEndPoint = UnsafeNclNativeMethods.HttpApi.GetLocalEndPoint(RequestBuffer, OriginalBlobAddress); + } + + return _localEndPoint; + } + } + + internal string RequestScheme + { + get + { + return IsSecureConnection ? Constants.HttpsScheme : Constants.HttpScheme; + } + } + + internal Uri RequestUri + { + get + { + if (_requestUri == null) + { + _requestUri = RequestUriBuilder.GetRequestUri( + _rawUrl, RequestScheme, _cookedUrlHost, _cookedUrlPath, _cookedUrlQuery); + } + + return _requestUri; + } + } + + internal string RequestPath + { + get + { + return RequestUriBuilder.GetRequestPath(_rawUrl, _cookedUrlPath); + } + } + + internal bool IsUpgradable + { + get + { + // HTTP.Sys allows you to upgrade anything to opaque unless content-length > 0 or chunked are specified. + return !HasEntityBody; + } + } + + internal IPrincipal User + { + get { return _user; } + } + + private unsafe IPrincipal GetUser(UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_INFO* requestInfo) + { + if (requestInfo == null + || requestInfo->InfoType != UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_INFO_TYPE.HttpRequestInfoTypeAuth) + { + return null; + } + + if (requestInfo->pInfo->AuthStatus != UnsafeNclNativeMethods.HttpApi.HTTP_AUTH_STATUS.HttpAuthStatusSuccess) + { + return null; + } + +#if NET45 + return new WindowsPrincipal(new WindowsIdentity(requestInfo->pInfo->AccessToken)); +#else + return null; +#endif + } + + // Use this to save the blob from dispose if this object was never used (never given to a user) and is about to be + // disposed. + internal void DetachBlob(NativeRequestContext memoryBlob) + { + if (memoryBlob != null && (object)memoryBlob == (object)_nativeRequestContext) + { + _nativeRequestContext = null; + } + } + + // Finalizes ownership of the memory blob. DetachBlob can't be called after this. + internal void ReleasePins() + { + _nativeRequestContext.ReleasePins(); + } + + // should only be called from RequestContext + public void Dispose() + { + // TODO: Verbose log + _isDisposed = true; + NativeRequestContext memoryBlob = _nativeRequestContext; + if (memoryBlob != null) + { + memoryBlob.Dispose(); + _nativeRequestContext = null; + } + _disconnectRegistration.Dispose(); + if (_requestStream != null) + { + _requestStream.Dispose(); + } + } + + internal void CheckDisposed() + { + if (_isDisposed) + { + throw new ObjectDisposedException(this.GetType().FullName); + } + } + + internal void SwitchToOpaqueMode() + { + if (_requestStream == null || _requestStream == Stream.Null) + { + _requestStream = new RequestStream(RequestContext); + } + } + + internal void RegisterForDisconnect(CancellationToken cancellationToken) + { + _disconnectRegistration = cancellationToken.Register(Cancel, this); + } + + private static void Cancel(object obj) + { + Request request = (Request)obj; + // Cancels owin.CallCanceled + request.RequestContext.Abort(); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestContext.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestContext.cs new file mode 100644 index 0000000000..5cb122006c --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestContext.cs @@ -0,0 +1,388 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.Globalization; +using System.IO; +using System.Runtime.InteropServices; +using System.Security.Authentication.ExtendedProtection; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + using LoggerFunc = Func, bool>; + using OpaqueFunc = Func, Task>; + + internal sealed class RequestContext : IDisposable, CallEnvironment.IPropertySource + { + private static readonly string[] ZeroContentLength = new[] { "0" }; + + private CallEnvironment _environment; + private OwinWebListener _server; + private Request _request; + private Response _response; + private CancellationTokenSource _cts; + private NativeRequestContext _memoryBlob; + private OpaqueFunc _opaqueCallback; + private bool _disposed; + + internal RequestContext(OwinWebListener httpListener, NativeRequestContext memoryBlob) + { + // TODO: Verbose log + _server = httpListener; + _memoryBlob = memoryBlob; + _request = new Request(this, _memoryBlob); + _response = new Response(this); + _environment = new CallEnvironment(this); + _cts = new CancellationTokenSource(); + + PopulateEnvironment(); + + _request.ReleasePins(); + } + + internal CallEnvironment Environment + { + get { return _environment; } + } + + internal Request Request + { + get + { + return _request; + } + } + + internal Response Response + { + get + { + return _response; + } + } + + internal OwinWebListener Server + { + get + { + return _server; + } + } + + internal LoggerFunc Logger + { + get { return Server.Logger; } + } + + internal SafeHandle RequestQueueHandle + { + get + { + return _server.RequestQueueHandle; + } + } + + internal ulong RequestId + { + get + { + return Request.RequestId; + } + } + + private void PopulateEnvironment() + { + // General + _environment.OwinVersion = Constants.OwinVersion; + _environment.CallCancelled = _cts.Token; + + // Server + _environment.ServerCapabilities = _server.Capabilities; + _environment.Listener = _server; + + // Request + _environment.RequestProtocol = _request.Protocol; + _environment.RequestMethod = _request.HttpMethod; + _environment.RequestScheme = _request.RequestScheme; + _environment.RequestQueryString = _request.Query; + _environment.RequestHeaders = _request.Headers; + + SetPaths(); + + _environment.ConnectionId = _request.ConnectionId; + + if (_request.IsSecureConnection) + { + _environment.LoadClientCert = LoadClientCertificateAsync; + } + + if (_request.User != null) + { + _environment.User = _request.User; + } + + // Response + _environment.ResponseStatusCode = 200; + _environment.ResponseHeaders = _response.Headers; + _environment.ResponseBody = _response.OutputStream; + _environment.SendFileAsync = _response.SendFileAsync; + + _environment.OnSendingHeaders = _response.RegisterForOnSendingHeaders; + + Contract.Assert(!_environment.IsExtraDictionaryCreated, + "All server keys should have a reserved slot in the environment."); + } + + // Find the closest matching prefix and use it to separate the request path in to path and base path. + // Scheme and port must match. Path will use a longest match. Host names are more complicated due to + // wildcards, IP addresses, etc. + private void SetPaths() + { + Prefix prefix = _server.UriPrefixes[(int)Request.ContextId]; + string orriginalPath = _request.RequestPath; + + // These paths are both unescaped already. + if (orriginalPath.Length == prefix.Path.Length - 1) + { + // They matched exactly except for the trailing slash. + _environment.RequestPathBase = orriginalPath; + _environment.RequestPath = string.Empty; + } + else + { + // url: /base/path, prefix: /base/, base: /base, path: /path + // url: /, prefix: /, base: , path: / + _environment.RequestPathBase = orriginalPath.Substring(0, prefix.Path.Length - 1); + _environment.RequestPath = orriginalPath.Substring(prefix.Path.Length - 1); + } + } + + // Lazy environment init + + public Stream GetRequestBody() + { + return _request.InputStream; + } + + public string GetRemoteIpAddress() + { + return _request.RemoteEndPoint.GetIPAddressString(); + } + + public string GetRemotePort() + { + return _request.RemoteEndPoint.GetPort().ToString(CultureInfo.InvariantCulture.NumberFormat); + } + + public string GetLocalIpAddress() + { + return _request.LocalEndPoint.GetIPAddressString(); + } + + public string GetLocalPort() + { + return _request.LocalEndPoint.GetPort().ToString(CultureInfo.InvariantCulture.NumberFormat); + } + + public bool GetIsLocal() + { + return _request.IsLocal; + } + + public bool TryGetOpaqueUpgrade(ref Action, OpaqueFunc> value) + { + if (_request.IsUpgradable) + { + value = OpaqueUpgrade; + return true; + } + return false; + } + + public bool TryGetChannelBinding(ref ChannelBinding value) + { + value = Server.GetChannelBinding(Request.ConnectionId, Request.IsSecureConnection); + return value != null; + } + + public void Dispose() + { + if (_disposed) + { + return; + } + _disposed = true; + + // TODO: Verbose log + try + { + _response.Dispose(); + } + finally + { + _request.Dispose(); + _cts.Dispose(); + } + } + + internal void Abort() + { + // TODO: Verbose log + _disposed = true; + try + { + _cts.Cancel(); + } + catch (ObjectDisposedException) + { + } + catch (AggregateException) + { + } + ForceCancelRequest(RequestQueueHandle, _request.RequestId); + _request.Dispose(); + _cts.Dispose(); + } + + internal UnsafeNclNativeMethods.HttpApi.HTTP_VERB GetKnownMethod() + { + return UnsafeNclNativeMethods.HttpApi.GetKnownVerb(Request.RequestBuffer, Request.OriginalBlobAddress); + } + + // This is only called while processing incoming requests. We don't have to worry about cancelling + // any response writes. + [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = + "It is safe to ignore the return value on a cancel operation because the connection is being closed")] + internal static void CancelRequest(SafeHandle requestQueueHandle, ulong requestId) + { + UnsafeNclNativeMethods.HttpApi.HttpCancelHttpRequest(requestQueueHandle, requestId, + IntPtr.Zero); + } + + // The request is being aborted, but large writes may be in progress. Cancel them. + internal void ForceCancelRequest(SafeHandle requestQueueHandle, ulong requestId) + { + try + { + uint statusCode = UnsafeNclNativeMethods.HttpApi.HttpCancelHttpRequest(requestQueueHandle, requestId, + IntPtr.Zero); + + // Either the connection has already dropped, or the last write is in progress. + // The requestId becomes invalid as soon as the last Content-Length write starts. + // The only way to cancel now is with CancelIoEx. + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_CONNECTION_INVALID) + { + _response.CancelLastWrite(requestQueueHandle); + } + } + catch (ObjectDisposedException) + { + // RequestQueueHandle may have been closed + } + } + + // Populates the environment ClicentCertificate. The result may be null if there is no client cert. + // TODO: Does it make sense for this to be invoked multiple times (e.g. renegotiate)? Client and server code appear to + // enable this, but it's unclear what Http.Sys would do. + private async Task LoadClientCertificateAsync() + { + if (Request.SslStatus == SslStatus.Insecure) + { + // Non-SSL + return; + } + // TODO: Verbose log +#if NET45 + ClientCertLoader certLoader = new ClientCertLoader(this); + try + { + await certLoader.LoadClientCertificateAsync().SupressContext(); + // Populate the environment. + if (certLoader.ClientCert != null) + { + Environment.ClientCert = certLoader.ClientCert; + } + // TODO: Expose errors and exceptions? + } + catch (Exception) + { + if (certLoader != null) + { + certLoader.Dispose(); + } + throw; + } +#else + throw new NotImplementedException(); +#endif + } + + internal void OpaqueUpgrade(IDictionary parameters, OpaqueFunc callback) + { + // Parameters are ignored for now + if (Response.SentHeaders) + { + throw new InvalidOperationException(); + } + if (callback == null) + { + throw new ArgumentNullException("callback"); + } + + // Set the status code and reason phrase + Environment.ResponseStatusCode = (int)HttpStatusCode.SwitchingProtocols; + Environment.ResponseReasonPhrase = HttpReasonPhrase.Get(HttpStatusCode.SwitchingProtocols); + + // Store the callback and process it after the stack unwind. + _opaqueCallback = callback; + } + + // Called after the AppFunc completes for any necessary post-processing. + internal unsafe Task ProcessResponseAsync() + { + // If an upgrade was requested, perform it + if (!Response.SentHeaders && _opaqueCallback != null + && Environment.ResponseStatusCode == (int)HttpStatusCode.SwitchingProtocols) + { + Response.SendOpaqueUpgrade(); + + IDictionary opaqueEnv = CreateOpaqueEnvironment(); + return _opaqueCallback(opaqueEnv); + } + + return Helpers.CompletedTask(); + } + + private IDictionary CreateOpaqueEnvironment() + { + IDictionary opaqueEnv = new Dictionary(); + + opaqueEnv[Constants.OpaqueVersionKey] = Constants.OpaqueVersion; + // TODO: Separate CT? + opaqueEnv[Constants.OpaqueCallCancelledKey] = Environment.CallCancelled; + + Request.SwitchToOpaqueMode(); + Response.SwitchToOpaqueMode(); + opaqueEnv[Constants.OpaqueStreamKey] = new OpaqueStream(Request.InputStream, Response.OutputStream); + + return opaqueEnv; + } + + internal void SetFatalResponse() + { + Environment.ResponseStatusCode = 500; + Environment.ResponseReasonPhrase = string.Empty; + Environment.ResponseHeaders.Clear(); + Environment.ResponseHeaders.Add(HttpKnownHeaderNames.ContentLength, ZeroContentLength); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.cs new file mode 100644 index 0000000000..2e2df172c3 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.cs @@ -0,0 +1,2544 @@ +//----------------------------------------------------------------------- +// +// Copyright (c) Katana Contributors. All rights reserved. +// +//----------------------------------------------------------------------- +// + +using System; +using System.CodeDom.Compiler; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + [GeneratedCode("TextTemplatingFileGenerator", "")] + internal partial class RequestHeaders + { + // Tracks if individual fields have been read from native or set directly. + // Once read or set, their presence in the collection is marked by if their string[] is null or not. + private UInt32 _flag0, _flag1; + + private string[] _Accept; + private string[] _AcceptCharset; + private string[] _AcceptEncoding; + private string[] _AcceptLanguage; + private string[] _Allow; + private string[] _Authorization; + private string[] _CacheControl; + private string[] _Connection; + private string[] _ContentEncoding; + private string[] _ContentLanguage; + private string[] _ContentLength; + private string[] _ContentLocation; + private string[] _ContentMd5; + private string[] _ContentRange; + private string[] _ContentType; + private string[] _Cookie; + private string[] _Date; + private string[] _Expect; + private string[] _Expires; + private string[] _From; + private string[] _Host; + private string[] _IfMatch; + private string[] _IfModifiedSince; + private string[] _IfNoneMatch; + private string[] _IfRange; + private string[] _IfUnmodifiedSince; + private string[] _KeepAlive; + private string[] _LastModified; + private string[] _MaxForwards; + private string[] _Pragma; + private string[] _ProxyAuthorization; + private string[] _Range; + private string[] _Referer; + private string[] _Te; + private string[] _Trailer; + private string[] _TransferEncoding; + private string[] _Translate; + private string[] _Upgrade; + private string[] _UserAgent; + private string[] _Via; + private string[] _Warning; + + internal string[] Accept + { + get + { + if (!((_flag0 & 0x1u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Accept); + if (nativeValue != null) + { + _Accept = new string[] { nativeValue }; + } + _flag0 |= 0x1u; + } + return _Accept; + } + set + { + _flag0 |= 0x1u; + _Accept = value; + } + } + + internal string[] AcceptCharset + { + get + { + if (!((_flag0 & 0x2u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.AcceptCharset); + if (nativeValue != null) + { + _AcceptCharset = new string[] { nativeValue }; + } + _flag0 |= 0x2u; + } + return _AcceptCharset; + } + set + { + _flag0 |= 0x2u; + _AcceptCharset = value; + } + } + + internal string[] AcceptEncoding + { + get + { + if (!((_flag0 & 0x4u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.AcceptEncoding); + if (nativeValue != null) + { + _AcceptEncoding = new string[] { nativeValue }; + } + _flag0 |= 0x4u; + } + return _AcceptEncoding; + } + set + { + _flag0 |= 0x4u; + _AcceptEncoding = value; + } + } + + internal string[] AcceptLanguage + { + get + { + if (!((_flag0 & 0x8u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.AcceptLanguage); + if (nativeValue != null) + { + _AcceptLanguage = new string[] { nativeValue }; + } + _flag0 |= 0x8u; + } + return _AcceptLanguage; + } + set + { + _flag0 |= 0x8u; + _AcceptLanguage = value; + } + } + + internal string[] Allow + { + get + { + if (!((_flag0 & 0x10u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Allow); + if (nativeValue != null) + { + _Allow = new string[] { nativeValue }; + } + _flag0 |= 0x10u; + } + return _Allow; + } + set + { + _flag0 |= 0x10u; + _Allow = value; + } + } + + internal string[] Authorization + { + get + { + if (!((_flag0 & 0x20u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Authorization); + if (nativeValue != null) + { + _Authorization = new string[] { nativeValue }; + } + _flag0 |= 0x20u; + } + return _Authorization; + } + set + { + _flag0 |= 0x20u; + _Authorization = value; + } + } + + internal string[] CacheControl + { + get + { + if (!((_flag0 & 0x40u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.CacheControl); + if (nativeValue != null) + { + _CacheControl = new string[] { nativeValue }; + } + _flag0 |= 0x40u; + } + return _CacheControl; + } + set + { + _flag0 |= 0x40u; + _CacheControl = value; + } + } + + internal string[] Connection + { + get + { + if (!((_flag0 & 0x80u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Connection); + if (nativeValue != null) + { + _Connection = new string[] { nativeValue }; + } + _flag0 |= 0x80u; + } + return _Connection; + } + set + { + _flag0 |= 0x80u; + _Connection = value; + } + } + + internal string[] ContentEncoding + { + get + { + if (!((_flag0 & 0x100u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentEncoding); + if (nativeValue != null) + { + _ContentEncoding = new string[] { nativeValue }; + } + _flag0 |= 0x100u; + } + return _ContentEncoding; + } + set + { + _flag0 |= 0x100u; + _ContentEncoding = value; + } + } + + internal string[] ContentLanguage + { + get + { + if (!((_flag0 & 0x200u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentLanguage); + if (nativeValue != null) + { + _ContentLanguage = new string[] { nativeValue }; + } + _flag0 |= 0x200u; + } + return _ContentLanguage; + } + set + { + _flag0 |= 0x200u; + _ContentLanguage = value; + } + } + + internal string[] ContentLength + { + get + { + if (!((_flag0 & 0x400u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentLength); + if (nativeValue != null) + { + _ContentLength = new string[] { nativeValue }; + } + _flag0 |= 0x400u; + } + return _ContentLength; + } + set + { + _flag0 |= 0x400u; + _ContentLength = value; + } + } + + internal string[] ContentLocation + { + get + { + if (!((_flag0 & 0x800u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentLocation); + if (nativeValue != null) + { + _ContentLocation = new string[] { nativeValue }; + } + _flag0 |= 0x800u; + } + return _ContentLocation; + } + set + { + _flag0 |= 0x800u; + _ContentLocation = value; + } + } + + internal string[] ContentMd5 + { + get + { + if (!((_flag0 & 0x1000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentMd5); + if (nativeValue != null) + { + _ContentMd5 = new string[] { nativeValue }; + } + _flag0 |= 0x1000u; + } + return _ContentMd5; + } + set + { + _flag0 |= 0x1000u; + _ContentMd5 = value; + } + } + + internal string[] ContentRange + { + get + { + if (!((_flag0 & 0x2000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentRange); + if (nativeValue != null) + { + _ContentRange = new string[] { nativeValue }; + } + _flag0 |= 0x2000u; + } + return _ContentRange; + } + set + { + _flag0 |= 0x2000u; + _ContentRange = value; + } + } + + internal string[] ContentType + { + get + { + if (!((_flag0 & 0x4000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ContentType); + if (nativeValue != null) + { + _ContentType = new string[] { nativeValue }; + } + _flag0 |= 0x4000u; + } + return _ContentType; + } + set + { + _flag0 |= 0x4000u; + _ContentType = value; + } + } + + internal string[] Cookie + { + get + { + if (!((_flag0 & 0x8000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Cookie); + if (nativeValue != null) + { + _Cookie = new string[] { nativeValue }; + } + _flag0 |= 0x8000u; + } + return _Cookie; + } + set + { + _flag0 |= 0x8000u; + _Cookie = value; + } + } + + internal string[] Date + { + get + { + if (!((_flag0 & 0x10000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Date); + if (nativeValue != null) + { + _Date = new string[] { nativeValue }; + } + _flag0 |= 0x10000u; + } + return _Date; + } + set + { + _flag0 |= 0x10000u; + _Date = value; + } + } + + internal string[] Expect + { + get + { + if (!((_flag0 & 0x20000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Expect); + if (nativeValue != null) + { + _Expect = new string[] { nativeValue }; + } + _flag0 |= 0x20000u; + } + return _Expect; + } + set + { + _flag0 |= 0x20000u; + _Expect = value; + } + } + + internal string[] Expires + { + get + { + if (!((_flag0 & 0x40000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Expires); + if (nativeValue != null) + { + _Expires = new string[] { nativeValue }; + } + _flag0 |= 0x40000u; + } + return _Expires; + } + set + { + _flag0 |= 0x40000u; + _Expires = value; + } + } + + internal string[] From + { + get + { + if (!((_flag0 & 0x80000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.From); + if (nativeValue != null) + { + _From = new string[] { nativeValue }; + } + _flag0 |= 0x80000u; + } + return _From; + } + set + { + _flag0 |= 0x80000u; + _From = value; + } + } + + internal string[] Host + { + get + { + if (!((_flag0 & 0x100000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Host); + if (nativeValue != null) + { + _Host = new string[] { nativeValue }; + } + _flag0 |= 0x100000u; + } + return _Host; + } + set + { + _flag0 |= 0x100000u; + _Host = value; + } + } + + internal string[] IfMatch + { + get + { + if (!((_flag0 & 0x200000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.IfMatch); + if (nativeValue != null) + { + _IfMatch = new string[] { nativeValue }; + } + _flag0 |= 0x200000u; + } + return _IfMatch; + } + set + { + _flag0 |= 0x200000u; + _IfMatch = value; + } + } + + internal string[] IfModifiedSince + { + get + { + if (!((_flag0 & 0x400000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.IfModifiedSince); + if (nativeValue != null) + { + _IfModifiedSince = new string[] { nativeValue }; + } + _flag0 |= 0x400000u; + } + return _IfModifiedSince; + } + set + { + _flag0 |= 0x400000u; + _IfModifiedSince = value; + } + } + + internal string[] IfNoneMatch + { + get + { + if (!((_flag0 & 0x800000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.IfNoneMatch); + if (nativeValue != null) + { + _IfNoneMatch = new string[] { nativeValue }; + } + _flag0 |= 0x800000u; + } + return _IfNoneMatch; + } + set + { + _flag0 |= 0x800000u; + _IfNoneMatch = value; + } + } + + internal string[] IfRange + { + get + { + if (!((_flag0 & 0x1000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.IfRange); + if (nativeValue != null) + { + _IfRange = new string[] { nativeValue }; + } + _flag0 |= 0x1000000u; + } + return _IfRange; + } + set + { + _flag0 |= 0x1000000u; + _IfRange = value; + } + } + + internal string[] IfUnmodifiedSince + { + get + { + if (!((_flag0 & 0x2000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.IfUnmodifiedSince); + if (nativeValue != null) + { + _IfUnmodifiedSince = new string[] { nativeValue }; + } + _flag0 |= 0x2000000u; + } + return _IfUnmodifiedSince; + } + set + { + _flag0 |= 0x2000000u; + _IfUnmodifiedSince = value; + } + } + + internal string[] KeepAlive + { + get + { + if (!((_flag0 & 0x4000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.KeepAlive); + if (nativeValue != null) + { + _KeepAlive = new string[] { nativeValue }; + } + _flag0 |= 0x4000000u; + } + return _KeepAlive; + } + set + { + _flag0 |= 0x4000000u; + _KeepAlive = value; + } + } + + internal string[] LastModified + { + get + { + if (!((_flag0 & 0x8000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.LastModified); + if (nativeValue != null) + { + _LastModified = new string[] { nativeValue }; + } + _flag0 |= 0x8000000u; + } + return _LastModified; + } + set + { + _flag0 |= 0x8000000u; + _LastModified = value; + } + } + + internal string[] MaxForwards + { + get + { + if (!((_flag0 & 0x10000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.MaxForwards); + if (nativeValue != null) + { + _MaxForwards = new string[] { nativeValue }; + } + _flag0 |= 0x10000000u; + } + return _MaxForwards; + } + set + { + _flag0 |= 0x10000000u; + _MaxForwards = value; + } + } + + internal string[] Pragma + { + get + { + if (!((_flag0 & 0x20000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Pragma); + if (nativeValue != null) + { + _Pragma = new string[] { nativeValue }; + } + _flag0 |= 0x20000000u; + } + return _Pragma; + } + set + { + _flag0 |= 0x20000000u; + _Pragma = value; + } + } + + internal string[] ProxyAuthorization + { + get + { + if (!((_flag0 & 0x40000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.ProxyAuthorization); + if (nativeValue != null) + { + _ProxyAuthorization = new string[] { nativeValue }; + } + _flag0 |= 0x40000000u; + } + return _ProxyAuthorization; + } + set + { + _flag0 |= 0x40000000u; + _ProxyAuthorization = value; + } + } + + internal string[] Range + { + get + { + if (!((_flag0 & 0x80000000u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Range); + if (nativeValue != null) + { + _Range = new string[] { nativeValue }; + } + _flag0 |= 0x80000000u; + } + return _Range; + } + set + { + _flag0 |= 0x80000000u; + _Range = value; + } + } + + internal string[] Referer + { + get + { + if (!((_flag1 & 0x1u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Referer); + if (nativeValue != null) + { + _Referer = new string[] { nativeValue }; + } + _flag1 |= 0x1u; + } + return _Referer; + } + set + { + _flag1 |= 0x1u; + _Referer = value; + } + } + + internal string[] Te + { + get + { + if (!((_flag1 & 0x2u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Te); + if (nativeValue != null) + { + _Te = new string[] { nativeValue }; + } + _flag1 |= 0x2u; + } + return _Te; + } + set + { + _flag1 |= 0x2u; + _Te = value; + } + } + + internal string[] Trailer + { + get + { + if (!((_flag1 & 0x4u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Trailer); + if (nativeValue != null) + { + _Trailer = new string[] { nativeValue }; + } + _flag1 |= 0x4u; + } + return _Trailer; + } + set + { + _flag1 |= 0x4u; + _Trailer = value; + } + } + + internal string[] TransferEncoding + { + get + { + if (!((_flag1 & 0x8u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.TransferEncoding); + if (nativeValue != null) + { + _TransferEncoding = new string[] { nativeValue }; + } + _flag1 |= 0x8u; + } + return _TransferEncoding; + } + set + { + _flag1 |= 0x8u; + _TransferEncoding = value; + } + } + + internal string[] Translate + { + get + { + if (!((_flag1 & 0x10u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Translate); + if (nativeValue != null) + { + _Translate = new string[] { nativeValue }; + } + _flag1 |= 0x10u; + } + return _Translate; + } + set + { + _flag1 |= 0x10u; + _Translate = value; + } + } + + internal string[] Upgrade + { + get + { + if (!((_flag1 & 0x20u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Upgrade); + if (nativeValue != null) + { + _Upgrade = new string[] { nativeValue }; + } + _flag1 |= 0x20u; + } + return _Upgrade; + } + set + { + _flag1 |= 0x20u; + _Upgrade = value; + } + } + + internal string[] UserAgent + { + get + { + if (!((_flag1 & 0x40u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.UserAgent); + if (nativeValue != null) + { + _UserAgent = new string[] { nativeValue }; + } + _flag1 |= 0x40u; + } + return _UserAgent; + } + set + { + _flag1 |= 0x40u; + _UserAgent = value; + } + } + + internal string[] Via + { + get + { + if (!((_flag1 & 0x80u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Via); + if (nativeValue != null) + { + _Via = new string[] { nativeValue }; + } + _flag1 |= 0x80u; + } + return _Via; + } + set + { + _flag1 |= 0x80u; + _Via = value; + } + } + + internal string[] Warning + { + get + { + if (!((_flag1 & 0x100u) != 0)) + { + string nativeValue = GetKnownHeader(HttpSysRequestHeader.Warning); + if (nativeValue != null) + { + _Warning = new string[] { nativeValue }; + } + _flag1 |= 0x100u; + } + return _Warning; + } + set + { + _flag1 |= 0x100u; + _Warning = value; + } + } + + private bool PropertiesContainsKey(string key) + { + switch (key.Length) + { + case 2: + if (string.Equals(key, "Te", StringComparison.OrdinalIgnoreCase)) + { + return Te != null; + } + break; + case 3: + if (string.Equals(key, "Via", StringComparison.OrdinalIgnoreCase)) + { + return Via != null; + } + break; + case 4: + if (string.Equals(key, "Date", StringComparison.OrdinalIgnoreCase)) + { + return Date != null; + } + if (string.Equals(key, "From", StringComparison.OrdinalIgnoreCase)) + { + return From != null; + } + if (string.Equals(key, "Host", StringComparison.OrdinalIgnoreCase)) + { + return Host != null; + } + break; + case 5: + if (string.Equals(key, "Allow", StringComparison.OrdinalIgnoreCase)) + { + return Allow != null; + } + if (string.Equals(key, "Range", StringComparison.OrdinalIgnoreCase)) + { + return Range != null; + } + break; + case 6: + if (string.Equals(key, "Accept", StringComparison.OrdinalIgnoreCase)) + { + return Accept != null; + } + if (string.Equals(key, "Cookie", StringComparison.OrdinalIgnoreCase)) + { + return Cookie != null; + } + if (string.Equals(key, "Expect", StringComparison.OrdinalIgnoreCase)) + { + return Expect != null; + } + if (string.Equals(key, "Pragma", StringComparison.OrdinalIgnoreCase)) + { + return Pragma != null; + } + break; + case 7: + if (string.Equals(key, "Expires", StringComparison.OrdinalIgnoreCase)) + { + return Expires != null; + } + if (string.Equals(key, "Referer", StringComparison.OrdinalIgnoreCase)) + { + return Referer != null; + } + if (string.Equals(key, "Trailer", StringComparison.OrdinalIgnoreCase)) + { + return Trailer != null; + } + if (string.Equals(key, "Upgrade", StringComparison.OrdinalIgnoreCase)) + { + return Upgrade != null; + } + if (string.Equals(key, "Warning", StringComparison.OrdinalIgnoreCase)) + { + return Warning != null; + } + break; + case 8: + if (string.Equals(key, "If-Match", StringComparison.OrdinalIgnoreCase)) + { + return IfMatch != null; + } + if (string.Equals(key, "If-Range", StringComparison.OrdinalIgnoreCase)) + { + return IfRange != null; + } + break; + case 9: + if (string.Equals(key, "Translate", StringComparison.OrdinalIgnoreCase)) + { + return Translate != null; + } + break; + case 10: + if (string.Equals(key, "Connection", StringComparison.OrdinalIgnoreCase)) + { + return Connection != null; + } + if (string.Equals(key, "Keep-Alive", StringComparison.OrdinalIgnoreCase)) + { + return KeepAlive != null; + } + if (string.Equals(key, "User-Agent", StringComparison.OrdinalIgnoreCase)) + { + return UserAgent != null; + } + break; + case 11: + if (string.Equals(key, "Content-Md5", StringComparison.OrdinalIgnoreCase)) + { + return ContentMd5 != null; + } + break; + case 12: + if (string.Equals(key, "Content-Type", StringComparison.OrdinalIgnoreCase)) + { + return ContentType != null; + } + if (string.Equals(key, "Max-Forwards", StringComparison.OrdinalIgnoreCase)) + { + return MaxForwards != null; + } + break; + case 13: + if (string.Equals(key, "Authorization", StringComparison.OrdinalIgnoreCase)) + { + return Authorization != null; + } + if (string.Equals(key, "Cache-Control", StringComparison.OrdinalIgnoreCase)) + { + return CacheControl != null; + } + if (string.Equals(key, "Content-Range", StringComparison.OrdinalIgnoreCase)) + { + return ContentRange != null; + } + if (string.Equals(key, "If-None-Match", StringComparison.OrdinalIgnoreCase)) + { + return IfNoneMatch != null; + } + if (string.Equals(key, "Last-Modified", StringComparison.OrdinalIgnoreCase)) + { + return LastModified != null; + } + break; + case 14: + if (string.Equals(key, "Accept-Charset", StringComparison.OrdinalIgnoreCase)) + { + return AcceptCharset != null; + } + if (string.Equals(key, "Content-Length", StringComparison.OrdinalIgnoreCase)) + { + return ContentLength != null; + } + break; + case 15: + if (string.Equals(key, "Accept-Encoding", StringComparison.OrdinalIgnoreCase)) + { + return AcceptEncoding != null; + } + if (string.Equals(key, "Accept-Language", StringComparison.OrdinalIgnoreCase)) + { + return AcceptLanguage != null; + } + break; + case 16: + if (string.Equals(key, "Content-Encoding", StringComparison.OrdinalIgnoreCase)) + { + return ContentEncoding != null; + } + if (string.Equals(key, "Content-Language", StringComparison.OrdinalIgnoreCase)) + { + return ContentLanguage != null; + } + if (string.Equals(key, "Content-Location", StringComparison.OrdinalIgnoreCase)) + { + return ContentLocation != null; + } + break; + case 17: + if (string.Equals(key, "If-Modified-Since", StringComparison.OrdinalIgnoreCase)) + { + return IfModifiedSince != null; + } + if (string.Equals(key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase)) + { + return TransferEncoding != null; + } + break; + case 19: + if (string.Equals(key, "If-Unmodified-Since", StringComparison.OrdinalIgnoreCase)) + { + return IfUnmodifiedSince != null; + } + if (string.Equals(key, "Proxy-Authorization", StringComparison.OrdinalIgnoreCase)) + { + return ProxyAuthorization != null; + } + break; + } + return false; + } + + private bool PropertiesTryGetValue(string key, out string[] value) + { + switch (key.Length) + { + case 2: + if (string.Equals(key, "Te", StringComparison.OrdinalIgnoreCase)) + { + value = Te; + return value != null; + } + break; + case 3: + if (string.Equals(key, "Via", StringComparison.OrdinalIgnoreCase)) + { + value = Via; + return value != null; + } + break; + case 4: + if (string.Equals(key, "Date", StringComparison.OrdinalIgnoreCase)) + { + value = Date; + return value != null; + } + if (string.Equals(key, "From", StringComparison.OrdinalIgnoreCase)) + { + value = From; + return value != null; + } + if (string.Equals(key, "Host", StringComparison.OrdinalIgnoreCase)) + { + value = Host; + return value != null; + } + break; + case 5: + if (string.Equals(key, "Allow", StringComparison.OrdinalIgnoreCase)) + { + value = Allow; + return value != null; + } + if (string.Equals(key, "Range", StringComparison.OrdinalIgnoreCase)) + { + value = Range; + return value != null; + } + break; + case 6: + if (string.Equals(key, "Accept", StringComparison.OrdinalIgnoreCase)) + { + value = Accept; + return value != null; + } + if (string.Equals(key, "Cookie", StringComparison.OrdinalIgnoreCase)) + { + value = Cookie; + return value != null; + } + if (string.Equals(key, "Expect", StringComparison.OrdinalIgnoreCase)) + { + value = Expect; + return value != null; + } + if (string.Equals(key, "Pragma", StringComparison.OrdinalIgnoreCase)) + { + value = Pragma; + return value != null; + } + break; + case 7: + if (string.Equals(key, "Expires", StringComparison.OrdinalIgnoreCase)) + { + value = Expires; + return value != null; + } + if (string.Equals(key, "Referer", StringComparison.OrdinalIgnoreCase)) + { + value = Referer; + return value != null; + } + if (string.Equals(key, "Trailer", StringComparison.OrdinalIgnoreCase)) + { + value = Trailer; + return value != null; + } + if (string.Equals(key, "Upgrade", StringComparison.OrdinalIgnoreCase)) + { + value = Upgrade; + return value != null; + } + if (string.Equals(key, "Warning", StringComparison.OrdinalIgnoreCase)) + { + value = Warning; + return value != null; + } + break; + case 8: + if (string.Equals(key, "If-Match", StringComparison.OrdinalIgnoreCase)) + { + value = IfMatch; + return value != null; + } + if (string.Equals(key, "If-Range", StringComparison.OrdinalIgnoreCase)) + { + value = IfRange; + return value != null; + } + break; + case 9: + if (string.Equals(key, "Translate", StringComparison.OrdinalIgnoreCase)) + { + value = Translate; + return value != null; + } + break; + case 10: + if (string.Equals(key, "Connection", StringComparison.OrdinalIgnoreCase)) + { + value = Connection; + return value != null; + } + if (string.Equals(key, "Keep-Alive", StringComparison.OrdinalIgnoreCase)) + { + value = KeepAlive; + return value != null; + } + if (string.Equals(key, "User-Agent", StringComparison.OrdinalIgnoreCase)) + { + value = UserAgent; + return value != null; + } + break; + case 11: + if (string.Equals(key, "Content-Md5", StringComparison.OrdinalIgnoreCase)) + { + value = ContentMd5; + return value != null; + } + break; + case 12: + if (string.Equals(key, "Content-Type", StringComparison.OrdinalIgnoreCase)) + { + value = ContentType; + return value != null; + } + if (string.Equals(key, "Max-Forwards", StringComparison.OrdinalIgnoreCase)) + { + value = MaxForwards; + return value != null; + } + break; + case 13: + if (string.Equals(key, "Authorization", StringComparison.OrdinalIgnoreCase)) + { + value = Authorization; + return value != null; + } + if (string.Equals(key, "Cache-Control", StringComparison.OrdinalIgnoreCase)) + { + value = CacheControl; + return value != null; + } + if (string.Equals(key, "Content-Range", StringComparison.OrdinalIgnoreCase)) + { + value = ContentRange; + return value != null; + } + if (string.Equals(key, "If-None-Match", StringComparison.OrdinalIgnoreCase)) + { + value = IfNoneMatch; + return value != null; + } + if (string.Equals(key, "Last-Modified", StringComparison.OrdinalIgnoreCase)) + { + value = LastModified; + return value != null; + } + break; + case 14: + if (string.Equals(key, "Accept-Charset", StringComparison.OrdinalIgnoreCase)) + { + value = AcceptCharset; + return value != null; + } + if (string.Equals(key, "Content-Length", StringComparison.OrdinalIgnoreCase)) + { + value = ContentLength; + return value != null; + } + break; + case 15: + if (string.Equals(key, "Accept-Encoding", StringComparison.OrdinalIgnoreCase)) + { + value = AcceptEncoding; + return value != null; + } + if (string.Equals(key, "Accept-Language", StringComparison.OrdinalIgnoreCase)) + { + value = AcceptLanguage; + return value != null; + } + break; + case 16: + if (string.Equals(key, "Content-Encoding", StringComparison.OrdinalIgnoreCase)) + { + value = ContentEncoding; + return value != null; + } + if (string.Equals(key, "Content-Language", StringComparison.OrdinalIgnoreCase)) + { + value = ContentLanguage; + return value != null; + } + if (string.Equals(key, "Content-Location", StringComparison.OrdinalIgnoreCase)) + { + value = ContentLocation; + return value != null; + } + break; + case 17: + if (string.Equals(key, "If-Modified-Since", StringComparison.OrdinalIgnoreCase)) + { + value = IfModifiedSince; + return value != null; + } + if (string.Equals(key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase)) + { + value = TransferEncoding; + return value != null; + } + break; + case 19: + if (string.Equals(key, "If-Unmodified-Since", StringComparison.OrdinalIgnoreCase)) + { + value = IfUnmodifiedSince; + return value != null; + } + if (string.Equals(key, "Proxy-Authorization", StringComparison.OrdinalIgnoreCase)) + { + value = ProxyAuthorization; + return value != null; + } + break; + } + value = null; + return false; + } + + private bool PropertiesTrySetValue(string key, string[] value) + { + switch (key.Length) + { + case 2: + if (string.Equals(key, "Te", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x2u; + Te = value; + return true; + } + break; + case 3: + if (string.Equals(key, "Via", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x80u; + Via = value; + return true; + } + break; + case 4: + if (string.Equals(key, "Date", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x10000u; + Date = value; + return true; + } + if (string.Equals(key, "From", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x80000u; + From = value; + return true; + } + if (string.Equals(key, "Host", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x100000u; + Host = value; + return true; + } + break; + case 5: + if (string.Equals(key, "Allow", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x10u; + Allow = value; + return true; + } + if (string.Equals(key, "Range", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x80000000u; + Range = value; + return true; + } + break; + case 6: + if (string.Equals(key, "Accept", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x1u; + Accept = value; + return true; + } + if (string.Equals(key, "Cookie", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x8000u; + Cookie = value; + return true; + } + if (string.Equals(key, "Expect", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x20000u; + Expect = value; + return true; + } + if (string.Equals(key, "Pragma", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x20000000u; + Pragma = value; + return true; + } + break; + case 7: + if (string.Equals(key, "Expires", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x40000u; + Expires = value; + return true; + } + if (string.Equals(key, "Referer", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x1u; + Referer = value; + return true; + } + if (string.Equals(key, "Trailer", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x4u; + Trailer = value; + return true; + } + if (string.Equals(key, "Upgrade", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x20u; + Upgrade = value; + return true; + } + if (string.Equals(key, "Warning", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x100u; + Warning = value; + return true; + } + break; + case 8: + if (string.Equals(key, "If-Match", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x200000u; + IfMatch = value; + return true; + } + if (string.Equals(key, "If-Range", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x1000000u; + IfRange = value; + return true; + } + break; + case 9: + if (string.Equals(key, "Translate", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x10u; + Translate = value; + return true; + } + break; + case 10: + if (string.Equals(key, "Connection", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x80u; + Connection = value; + return true; + } + if (string.Equals(key, "Keep-Alive", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x4000000u; + KeepAlive = value; + return true; + } + if (string.Equals(key, "User-Agent", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x40u; + UserAgent = value; + return true; + } + break; + case 11: + if (string.Equals(key, "Content-Md5", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x1000u; + ContentMd5 = value; + return true; + } + break; + case 12: + if (string.Equals(key, "Content-Type", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x4000u; + ContentType = value; + return true; + } + if (string.Equals(key, "Max-Forwards", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x10000000u; + MaxForwards = value; + return true; + } + break; + case 13: + if (string.Equals(key, "Authorization", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x20u; + Authorization = value; + return true; + } + if (string.Equals(key, "Cache-Control", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x40u; + CacheControl = value; + return true; + } + if (string.Equals(key, "Content-Range", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x2000u; + ContentRange = value; + return true; + } + if (string.Equals(key, "If-None-Match", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x800000u; + IfNoneMatch = value; + return true; + } + if (string.Equals(key, "Last-Modified", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x8000000u; + LastModified = value; + return true; + } + break; + case 14: + if (string.Equals(key, "Accept-Charset", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x2u; + AcceptCharset = value; + return true; + } + if (string.Equals(key, "Content-Length", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x400u; + ContentLength = value; + return true; + } + break; + case 15: + if (string.Equals(key, "Accept-Encoding", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x4u; + AcceptEncoding = value; + return true; + } + if (string.Equals(key, "Accept-Language", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x8u; + AcceptLanguage = value; + return true; + } + break; + case 16: + if (string.Equals(key, "Content-Encoding", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x100u; + ContentEncoding = value; + return true; + } + if (string.Equals(key, "Content-Language", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x200u; + ContentLanguage = value; + return true; + } + if (string.Equals(key, "Content-Location", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x800u; + ContentLocation = value; + return true; + } + break; + case 17: + if (string.Equals(key, "If-Modified-Since", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x400000u; + IfModifiedSince = value; + return true; + } + if (string.Equals(key, "Transfer-Encoding", StringComparison.OrdinalIgnoreCase)) + { + _flag1 |= 0x8u; + TransferEncoding = value; + return true; + } + break; + case 19: + if (string.Equals(key, "If-Unmodified-Since", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x2000000u; + IfUnmodifiedSince = value; + return true; + } + if (string.Equals(key, "Proxy-Authorization", StringComparison.OrdinalIgnoreCase)) + { + _flag0 |= 0x40000000u; + ProxyAuthorization = value; + return true; + } + break; + } + return false; + } + + private bool PropertiesTryRemove(string key) + { + switch (key.Length) + { + case 2: + if (_Te != null + && string.Equals(key, "Te", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x2u) != 0); + Te = null; + return wasSet; + } + break; + case 3: + if (_Via != null + && string.Equals(key, "Via", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x80u) != 0); + Via = null; + return wasSet; + } + break; + case 4: + if (_Date != null + && string.Equals(key, "Date", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x10000u) != 0); + Date = null; + return wasSet; + } + if (_From != null + && string.Equals(key, "From", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x80000u) != 0); + From = null; + return wasSet; + } + if (_Host != null + && string.Equals(key, "Host", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x100000u) != 0); + Host = null; + return wasSet; + } + break; + case 5: + if (_Allow != null + && string.Equals(key, "Allow", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x10u) != 0); + Allow = null; + return wasSet; + } + if (_Range != null + && string.Equals(key, "Range", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x80000000u) != 0); + Range = null; + return wasSet; + } + break; + case 6: + if (_Accept != null + && string.Equals(key, "Accept", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x1u) != 0); + Accept = null; + return wasSet; + } + if (_Cookie != null + && string.Equals(key, "Cookie", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x8000u) != 0); + Cookie = null; + return wasSet; + } + if (_Expect != null + && string.Equals(key, "Expect", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x20000u) != 0); + Expect = null; + return wasSet; + } + if (_Pragma != null + && string.Equals(key, "Pragma", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x20000000u) != 0); + Pragma = null; + return wasSet; + } + break; + case 7: + if (_Expires != null + && string.Equals(key, "Expires", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x40000u) != 0); + Expires = null; + return wasSet; + } + if (_Referer != null + && string.Equals(key, "Referer", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x1u) != 0); + Referer = null; + return wasSet; + } + if (_Trailer != null + && string.Equals(key, "Trailer", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x4u) != 0); + Trailer = null; + return wasSet; + } + if (_Upgrade != null + && string.Equals(key, "Upgrade", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x20u) != 0); + Upgrade = null; + return wasSet; + } + if (_Warning != null + && string.Equals(key, "Warning", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x100u) != 0); + Warning = null; + return wasSet; + } + break; + case 8: + if (_IfMatch != null + && string.Equals(key, "If-Match", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x200000u) != 0); + IfMatch = null; + return wasSet; + } + if (_IfRange != null + && string.Equals(key, "If-Range", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x1000000u) != 0); + IfRange = null; + return wasSet; + } + break; + case 9: + if (_Translate != null + && string.Equals(key, "Translate", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x10u) != 0); + Translate = null; + return wasSet; + } + break; + case 10: + if (_Connection != null + && string.Equals(key, "Connection", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x80u) != 0); + Connection = null; + return wasSet; + } + if (_KeepAlive != null + && string.Equals(key, "Keep-Alive", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x4000000u) != 0); + KeepAlive = null; + return wasSet; + } + if (_UserAgent != null + && string.Equals(key, "User-Agent", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x40u) != 0); + UserAgent = null; + return wasSet; + } + break; + case 11: + if (_ContentMd5 != null + && string.Equals(key, "Content-Md5", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x1000u) != 0); + ContentMd5 = null; + return wasSet; + } + break; + case 12: + if (_ContentType != null + && string.Equals(key, "Content-Type", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x4000u) != 0); + ContentType = null; + return wasSet; + } + if (_MaxForwards != null + && string.Equals(key, "Max-Forwards", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x10000000u) != 0); + MaxForwards = null; + return wasSet; + } + break; + case 13: + if (_Authorization != null + && string.Equals(key, "Authorization", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x20u) != 0); + Authorization = null; + return wasSet; + } + if (_CacheControl != null + && string.Equals(key, "Cache-Control", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x40u) != 0); + CacheControl = null; + return wasSet; + } + if (_ContentRange != null + && string.Equals(key, "Content-Range", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x2000u) != 0); + ContentRange = null; + return wasSet; + } + if (_IfNoneMatch != null + && string.Equals(key, "If-None-Match", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x800000u) != 0); + IfNoneMatch = null; + return wasSet; + } + if (_LastModified != null + && string.Equals(key, "Last-Modified", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x8000000u) != 0); + LastModified = null; + return wasSet; + } + break; + case 14: + if (_AcceptCharset != null + && string.Equals(key, "Accept-Charset", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x2u) != 0); + AcceptCharset = null; + return wasSet; + } + if (_ContentLength != null + && string.Equals(key, "Content-Length", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x400u) != 0); + ContentLength = null; + return wasSet; + } + break; + case 15: + if (_AcceptEncoding != null + && string.Equals(key, "Accept-Encoding", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x4u) != 0); + AcceptEncoding = null; + return wasSet; + } + if (_AcceptLanguage != null + && string.Equals(key, "Accept-Language", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x8u) != 0); + AcceptLanguage = null; + return wasSet; + } + break; + case 16: + if (_ContentEncoding != null + && string.Equals(key, "Content-Encoding", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x100u) != 0); + ContentEncoding = null; + return wasSet; + } + if (_ContentLanguage != null + && string.Equals(key, "Content-Language", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x200u) != 0); + ContentLanguage = null; + return wasSet; + } + if (_ContentLocation != null + && string.Equals(key, "Content-Location", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x800u) != 0); + ContentLocation = null; + return wasSet; + } + break; + case 17: + if (_IfModifiedSince != null + && string.Equals(key, "If-Modified-Since", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x400000u) != 0); + IfModifiedSince = null; + return wasSet; + } + if (_TransferEncoding != null + && string.Equals(key, "Transfer-Encoding", StringComparison.Ordinal)) + { + bool wasSet = ((_flag1 & 0x8u) != 0); + TransferEncoding = null; + return wasSet; + } + break; + case 19: + if (_IfUnmodifiedSince != null + && string.Equals(key, "If-Unmodified-Since", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x2000000u) != 0); + IfUnmodifiedSince = null; + return wasSet; + } + if (_ProxyAuthorization != null + && string.Equals(key, "Proxy-Authorization", StringComparison.Ordinal)) + { + bool wasSet = ((_flag0 & 0x40000000u) != 0); + ProxyAuthorization = null; + return wasSet; + } + break; + } + return false; + } + + private IEnumerable PropertiesKeys() + { + if (Accept != null) + { + yield return "Accept"; + } + if (AcceptCharset != null) + { + yield return "Accept-Charset"; + } + if (AcceptEncoding != null) + { + yield return "Accept-Encoding"; + } + if (AcceptLanguage != null) + { + yield return "Accept-Language"; + } + if (Allow != null) + { + yield return "Allow"; + } + if (Authorization != null) + { + yield return "Authorization"; + } + if (CacheControl != null) + { + yield return "Cache-Control"; + } + if (Connection != null) + { + yield return "Connection"; + } + if (ContentEncoding != null) + { + yield return "Content-Encoding"; + } + if (ContentLanguage != null) + { + yield return "Content-Language"; + } + if (ContentLength != null) + { + yield return "Content-Length"; + } + if (ContentLocation != null) + { + yield return "Content-Location"; + } + if (ContentMd5 != null) + { + yield return "Content-Md5"; + } + if (ContentRange != null) + { + yield return "Content-Range"; + } + if (ContentType != null) + { + yield return "Content-Type"; + } + if (Cookie != null) + { + yield return "Cookie"; + } + if (Date != null) + { + yield return "Date"; + } + if (Expect != null) + { + yield return "Expect"; + } + if (Expires != null) + { + yield return "Expires"; + } + if (From != null) + { + yield return "From"; + } + if (Host != null) + { + yield return "Host"; + } + if (IfMatch != null) + { + yield return "If-Match"; + } + if (IfModifiedSince != null) + { + yield return "If-Modified-Since"; + } + if (IfNoneMatch != null) + { + yield return "If-None-Match"; + } + if (IfRange != null) + { + yield return "If-Range"; + } + if (IfUnmodifiedSince != null) + { + yield return "If-Unmodified-Since"; + } + if (KeepAlive != null) + { + yield return "Keep-Alive"; + } + if (LastModified != null) + { + yield return "Last-Modified"; + } + if (MaxForwards != null) + { + yield return "Max-Forwards"; + } + if (Pragma != null) + { + yield return "Pragma"; + } + if (ProxyAuthorization != null) + { + yield return "Proxy-Authorization"; + } + if (Range != null) + { + yield return "Range"; + } + if (Referer != null) + { + yield return "Referer"; + } + if (Te != null) + { + yield return "Te"; + } + if (Trailer != null) + { + yield return "Trailer"; + } + if (TransferEncoding != null) + { + yield return "Transfer-Encoding"; + } + if (Translate != null) + { + yield return "Translate"; + } + if (Upgrade != null) + { + yield return "Upgrade"; + } + if (UserAgent != null) + { + yield return "User-Agent"; + } + if (Via != null) + { + yield return "Via"; + } + if (Warning != null) + { + yield return "Warning"; + } + } + + private IEnumerable PropertiesValues() + { + if (Accept != null) + { + yield return Accept; + } + if (AcceptCharset != null) + { + yield return AcceptCharset; + } + if (AcceptEncoding != null) + { + yield return AcceptEncoding; + } + if (AcceptLanguage != null) + { + yield return AcceptLanguage; + } + if (Allow != null) + { + yield return Allow; + } + if (Authorization != null) + { + yield return Authorization; + } + if (CacheControl != null) + { + yield return CacheControl; + } + if (Connection != null) + { + yield return Connection; + } + if (ContentEncoding != null) + { + yield return ContentEncoding; + } + if (ContentLanguage != null) + { + yield return ContentLanguage; + } + if (ContentLength != null) + { + yield return ContentLength; + } + if (ContentLocation != null) + { + yield return ContentLocation; + } + if (ContentMd5 != null) + { + yield return ContentMd5; + } + if (ContentRange != null) + { + yield return ContentRange; + } + if (ContentType != null) + { + yield return ContentType; + } + if (Cookie != null) + { + yield return Cookie; + } + if (Date != null) + { + yield return Date; + } + if (Expect != null) + { + yield return Expect; + } + if (Expires != null) + { + yield return Expires; + } + if (From != null) + { + yield return From; + } + if (Host != null) + { + yield return Host; + } + if (IfMatch != null) + { + yield return IfMatch; + } + if (IfModifiedSince != null) + { + yield return IfModifiedSince; + } + if (IfNoneMatch != null) + { + yield return IfNoneMatch; + } + if (IfRange != null) + { + yield return IfRange; + } + if (IfUnmodifiedSince != null) + { + yield return IfUnmodifiedSince; + } + if (KeepAlive != null) + { + yield return KeepAlive; + } + if (LastModified != null) + { + yield return LastModified; + } + if (MaxForwards != null) + { + yield return MaxForwards; + } + if (Pragma != null) + { + yield return Pragma; + } + if (ProxyAuthorization != null) + { + yield return ProxyAuthorization; + } + if (Range != null) + { + yield return Range; + } + if (Referer != null) + { + yield return Referer; + } + if (Te != null) + { + yield return Te; + } + if (Trailer != null) + { + yield return Trailer; + } + if (TransferEncoding != null) + { + yield return TransferEncoding; + } + if (Translate != null) + { + yield return Translate; + } + if (Upgrade != null) + { + yield return Upgrade; + } + if (UserAgent != null) + { + yield return UserAgent; + } + if (Via != null) + { + yield return Via; + } + if (Warning != null) + { + yield return Warning; + } + } + + private IEnumerable> PropertiesEnumerable() + { + if (Accept != null) + { + yield return new KeyValuePair("Accept", Accept); + } + if (AcceptCharset != null) + { + yield return new KeyValuePair("Accept-Charset", AcceptCharset); + } + if (AcceptEncoding != null) + { + yield return new KeyValuePair("Accept-Encoding", AcceptEncoding); + } + if (AcceptLanguage != null) + { + yield return new KeyValuePair("Accept-Language", AcceptLanguage); + } + if (Allow != null) + { + yield return new KeyValuePair("Allow", Allow); + } + if (Authorization != null) + { + yield return new KeyValuePair("Authorization", Authorization); + } + if (CacheControl != null) + { + yield return new KeyValuePair("Cache-Control", CacheControl); + } + if (Connection != null) + { + yield return new KeyValuePair("Connection", Connection); + } + if (ContentEncoding != null) + { + yield return new KeyValuePair("Content-Encoding", ContentEncoding); + } + if (ContentLanguage != null) + { + yield return new KeyValuePair("Content-Language", ContentLanguage); + } + if (ContentLength != null) + { + yield return new KeyValuePair("Content-Length", ContentLength); + } + if (ContentLocation != null) + { + yield return new KeyValuePair("Content-Location", ContentLocation); + } + if (ContentMd5 != null) + { + yield return new KeyValuePair("Content-Md5", ContentMd5); + } + if (ContentRange != null) + { + yield return new KeyValuePair("Content-Range", ContentRange); + } + if (ContentType != null) + { + yield return new KeyValuePair("Content-Type", ContentType); + } + if (Cookie != null) + { + yield return new KeyValuePair("Cookie", Cookie); + } + if (Date != null) + { + yield return new KeyValuePair("Date", Date); + } + if (Expect != null) + { + yield return new KeyValuePair("Expect", Expect); + } + if (Expires != null) + { + yield return new KeyValuePair("Expires", Expires); + } + if (From != null) + { + yield return new KeyValuePair("From", From); + } + if (Host != null) + { + yield return new KeyValuePair("Host", Host); + } + if (IfMatch != null) + { + yield return new KeyValuePair("If-Match", IfMatch); + } + if (IfModifiedSince != null) + { + yield return new KeyValuePair("If-Modified-Since", IfModifiedSince); + } + if (IfNoneMatch != null) + { + yield return new KeyValuePair("If-None-Match", IfNoneMatch); + } + if (IfRange != null) + { + yield return new KeyValuePair("If-Range", IfRange); + } + if (IfUnmodifiedSince != null) + { + yield return new KeyValuePair("If-Unmodified-Since", IfUnmodifiedSince); + } + if (KeepAlive != null) + { + yield return new KeyValuePair("Keep-Alive", KeepAlive); + } + if (LastModified != null) + { + yield return new KeyValuePair("Last-Modified", LastModified); + } + if (MaxForwards != null) + { + yield return new KeyValuePair("Max-Forwards", MaxForwards); + } + if (Pragma != null) + { + yield return new KeyValuePair("Pragma", Pragma); + } + if (ProxyAuthorization != null) + { + yield return new KeyValuePair("Proxy-Authorization", ProxyAuthorization); + } + if (Range != null) + { + yield return new KeyValuePair("Range", Range); + } + if (Referer != null) + { + yield return new KeyValuePair("Referer", Referer); + } + if (Te != null) + { + yield return new KeyValuePair("Te", Te); + } + if (Trailer != null) + { + yield return new KeyValuePair("Trailer", Trailer); + } + if (TransferEncoding != null) + { + yield return new KeyValuePair("Transfer-Encoding", TransferEncoding); + } + if (Translate != null) + { + yield return new KeyValuePair("Translate", Translate); + } + if (Upgrade != null) + { + yield return new KeyValuePair("Upgrade", Upgrade); + } + if (UserAgent != null) + { + yield return new KeyValuePair("User-Agent", UserAgent); + } + if (Via != null) + { + yield return new KeyValuePair("Via", Via); + } + if (Warning != null) + { + yield return new KeyValuePair("Warning", Warning); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.tt b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.tt new file mode 100644 index 0000000000..2b339e4640 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.Generated.tt @@ -0,0 +1,218 @@ +<#@ template language="C#" #> +<#@ assembly name="System.Core.dll" #> +<#@ import namespace="System.Linq" #> +<# +var props = new[] +{ + new { Key = "Accept", Name = "Accept", ID = "HttpSysRequestHeader.Accept" }, + new { Key = "Accept-Charset", Name = "AcceptCharset", ID = "HttpSysRequestHeader.AcceptCharset" }, + new { Key = "Accept-Encoding", Name = "AcceptEncoding", ID = "HttpSysRequestHeader.AcceptEncoding" }, + new { Key = "Accept-Language", Name = "AcceptLanguage", ID = "HttpSysRequestHeader.AcceptLanguage" }, + new { Key = "Allow", Name = "Allow", ID = "HttpSysRequestHeader.Allow" }, + new { Key = "Authorization", Name = "Authorization", ID = "HttpSysRequestHeader.Authorization" }, + new { Key = "Cache-Control", Name = "CacheControl", ID = "HttpSysRequestHeader.CacheControl" }, + new { Key = "Connection", Name = "Connection", ID = "HttpSysRequestHeader.Connection" }, + new { Key = "Content-Encoding", Name = "ContentEncoding", ID = "HttpSysRequestHeader.ContentEncoding" }, + new { Key = "Content-Language", Name = "ContentLanguage", ID = "HttpSysRequestHeader.ContentLanguage" }, + new { Key = "Content-Length", Name = "ContentLength", ID = "HttpSysRequestHeader.ContentLength" }, + new { Key = "Content-Location", Name = "ContentLocation", ID = "HttpSysRequestHeader.ContentLocation" }, + new { Key = "Content-Md5", Name = "ContentMd5", ID = "HttpSysRequestHeader.ContentMd5" }, + new { Key = "Content-Range", Name = "ContentRange", ID = "HttpSysRequestHeader.ContentRange" }, + new { Key = "Content-Type", Name = "ContentType", ID = "HttpSysRequestHeader.ContentType" }, + new { Key = "Cookie", Name = "Cookie", ID = "HttpSysRequestHeader.Cookie" }, + new { Key = "Date", Name = "Date", ID = "HttpSysRequestHeader.Date" }, + new { Key = "Expect", Name = "Expect", ID = "HttpSysRequestHeader.Expect" }, + new { Key = "Expires", Name = "Expires", ID = "HttpSysRequestHeader.Expires" }, + new { Key = "From", Name = "From", ID = "HttpSysRequestHeader.From" }, + new { Key = "Host", Name = "Host", ID = "HttpSysRequestHeader.Host" }, + new { Key = "If-Match", Name = "IfMatch", ID = "HttpSysRequestHeader.IfMatch" }, + new { Key = "If-Modified-Since", Name = "IfModifiedSince", ID = "HttpSysRequestHeader.IfModifiedSince" }, + new { Key = "If-None-Match", Name = "IfNoneMatch", ID = "HttpSysRequestHeader.IfNoneMatch" }, + new { Key = "If-Range", Name = "IfRange", ID = "HttpSysRequestHeader.IfRange" }, + new { Key = "If-Unmodified-Since", Name = "IfUnmodifiedSince", ID = "HttpSysRequestHeader.IfUnmodifiedSince" }, + new { Key = "Keep-Alive", Name = "KeepAlive", ID = "HttpSysRequestHeader.KeepAlive" }, + new { Key = "Last-Modified", Name = "LastModified", ID = "HttpSysRequestHeader.LastModified" }, + new { Key = "Max-Forwards", Name = "MaxForwards", ID = "HttpSysRequestHeader.MaxForwards" }, + new { Key = "Pragma", Name = "Pragma", ID = "HttpSysRequestHeader.Pragma" }, + new { Key = "Proxy-Authorization", Name = "ProxyAuthorization", ID = "HttpSysRequestHeader.ProxyAuthorization" }, + new { Key = "Range", Name = "Range", ID = "HttpSysRequestHeader.Range" }, + new { Key = "Referer", Name = "Referer", ID = "HttpSysRequestHeader.Referer" }, + new { Key = "Te", Name = "Te", ID = "HttpSysRequestHeader.Te" }, + new { Key = "Trailer", Name = "Trailer", ID = "HttpSysRequestHeader.Trailer" }, + new { Key = "Transfer-Encoding", Name = "TransferEncoding", ID = "HttpSysRequestHeader.TransferEncoding" }, + new { Key = "Translate", Name = "Translate", ID = "HttpSysRequestHeader.Translate" }, + new { Key = "Upgrade", Name = "Upgrade", ID = "HttpSysRequestHeader.Upgrade" }, + new { Key = "User-Agent", Name = "UserAgent", ID = "HttpSysRequestHeader.UserAgent" }, + new { Key = "Via", Name = "Via", ID = "HttpSysRequestHeader.Via" }, + new { Key = "Warning", Name = "Warning", ID = "HttpSysRequestHeader.Warning" }, +}.Select((prop, Index)=>new {prop.Key, prop.Name, prop.ID, Index}); + +var lengths = props.GroupBy(prop=>prop.Key.Length).OrderBy(prop=>prop.Key); + + +Func IsRead = Index => "((_flag" + (Index / 32) + " & 0x" + (1<<(Index % 32)).ToString("x") + "u) != 0)"; +Func MarkRead = Index => "_flag" + (Index / 32) + " |= 0x" + (1<<(Index % 32)).ToString("x") + "u"; +Func Clear = Index => "_flag" + (Index / 32) + " &= ~0x" + (1<<(Index % 32)).ToString("x") + "u"; +#> +//----------------------------------------------------------------------- +// +// Copyright (c) Katana Contributors. All rights reserved. +// +//----------------------------------------------------------------------- +// + +using System; +using System.CodeDom.Compiler; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Owin.Host.WebListener +{ + [GeneratedCode("TextTemplatingFileGenerator", "")] + internal partial class RequestHeaders + { + // Tracks if individual fields have been read from native or set directly. + // Once read or set, their presence in the collection is marked by if their string[] is null or not. + private UInt32 _flag0, _flag1; + +<# foreach(var prop in props) { #> + private string[] _<#=prop.Name#>; +<# } #> + +<# foreach(var prop in props) { #> + internal string[] <#=prop.Name#> + { + get + { + if (!<#=IsRead(prop.Index)#>) + { + string nativeValue = GetKnownHeader(<#=prop.ID#>); + if (nativeValue != null) + { + _<#=prop.Name#> = new string[] { nativeValue }; + } + <#=MarkRead(prop.Index)#>; + } + return _<#=prop.Name#>; + } + set + { + <#=MarkRead(prop.Index)#>; + _<#=prop.Name#> = value; + } + } + +<# } #> + private bool PropertiesContainsKey(string key) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (string.Equals(key, "<#=prop.Key#>", StringComparison.OrdinalIgnoreCase)) + { + return <#=prop.Name#> != null; + } +<# } #> + break; +<# } #> + } + return false; + } + + private bool PropertiesTryGetValue(string key, out string[] value) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (string.Equals(key, "<#=prop.Key#>", StringComparison.OrdinalIgnoreCase)) + { + value = <#=prop.Name#>; + return value != null; + } +<# } #> + break; +<# } #> + } + value = null; + return false; + } + + private bool PropertiesTrySetValue(string key, string[] value) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (string.Equals(key, "<#=prop.Key#>", StringComparison.OrdinalIgnoreCase)) + { + <#=MarkRead(prop.Index)#>; + <#=prop.Name#> = value; + return true; + } +<# } #> + break; +<# } #> + } + return false; + } + + private bool PropertiesTryRemove(string key) + { + switch (key.Length) + { +<# foreach(var length in lengths) { #> + case <#=length.Key#>: +<# foreach(var prop in length) { #> + if (_<#=prop.Name#> != null + && string.Equals(key, "<#=prop.Key#>", StringComparison.Ordinal)) + { + bool wasSet = <#=IsRead(prop.Index)#>; + <#=prop.Name#> = null; + return wasSet; + } +<# } #> + break; +<# } #> + } + return false; + } + + private IEnumerable PropertiesKeys() + { +<# foreach(var prop in props) { #> + if (<#=prop.Name#> != null) + { + yield return "<#=prop.Key#>"; + } +<# } #> + } + + private IEnumerable PropertiesValues() + { +<# foreach(var prop in props) { #> + if (<#=prop.Name#> != null) + { + yield return <#=prop.Name#>; + } +<# } #> + } + + private IEnumerable> PropertiesEnumerable() + { +<# foreach(var prop in props) { #> + if (<#=prop.Name#> != null) + { + yield return new KeyValuePair("<#=prop.Key#>", <#=prop.Name#>); + } +<# } #> + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.cs new file mode 100644 index 0000000000..9e43ab86e5 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestHeaders.cs @@ -0,0 +1,167 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- +// Copyright 2011-2012 Katana contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Threading; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal partial class RequestHeaders : IDictionary + { + private IDictionary _extra; + private NativeRequestContext _requestMemoryBlob; + + internal RequestHeaders(NativeRequestContext requestMemoryBlob) + { + _requestMemoryBlob = requestMemoryBlob; + } + + private IDictionary Extra + { + get + { + if (_extra == null) + { + var newDict = new Dictionary(StringComparer.OrdinalIgnoreCase); + GetUnknownHeaders(newDict); + Interlocked.CompareExchange(ref _extra, newDict, null); + } + return _extra; + } + } + + string[] IDictionary.this[string key] + { + get + { + string[] value; + return PropertiesTryGetValue(key, out value) ? value : Extra[key]; + } + set + { + if (!PropertiesTrySetValue(key, value)) + { + Extra[key] = value; + } + } + } + + private string GetKnownHeader(HttpSysRequestHeader header) + { + return UnsafeNclNativeMethods.HttpApi.GetKnownHeader(_requestMemoryBlob.RequestBuffer, + _requestMemoryBlob.OriginalBlobAddress, (int)header); + } + + private void GetUnknownHeaders(IDictionary extra) + { + UnsafeNclNativeMethods.HttpApi.GetUnknownHeaders(extra, _requestMemoryBlob.RequestBuffer, + _requestMemoryBlob.OriginalBlobAddress); + } + + void IDictionary.Add(string key, string[] value) + { + if (!PropertiesTrySetValue(key, value)) + { + Extra.Add(key, value); + } + } + + bool IDictionary.ContainsKey(string key) + { + return PropertiesContainsKey(key) || Extra.ContainsKey(key); + } + + ICollection IDictionary.Keys + { + get { return PropertiesKeys().Concat(Extra.Keys).ToArray(); } + } + + bool IDictionary.Remove(string key) + { + // Although this is a mutating operation, Extra is used instead of StrongExtra, + // because if a real dictionary has not been allocated the default behavior of the + // nil dictionary is perfectly fine. + return PropertiesTryRemove(key) || Extra.Remove(key); + } + + bool IDictionary.TryGetValue(string key, out string[] value) + { + return PropertiesTryGetValue(key, out value) || Extra.TryGetValue(key, out value); + } + + ICollection IDictionary.Values + { + get { return PropertiesValues().Concat(Extra.Values).ToArray(); } + } + + void ICollection>.Add(KeyValuePair item) + { + ((IDictionary)this).Add(item.Key, item.Value); + } + + void ICollection>.Clear() + { + foreach (var key in PropertiesKeys()) + { + PropertiesTryRemove(key); + } + Extra.Clear(); + } + + bool ICollection>.Contains(KeyValuePair item) + { + object value; + return ((IDictionary)this).TryGetValue(item.Key, out value) && Object.Equals(value, item.Value); + } + + void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) + { + PropertiesEnumerable().Concat(Extra).ToArray().CopyTo(array, arrayIndex); + } + + int ICollection>.Count + { + get { return PropertiesKeys().Count() + Extra.Count; } + } + + bool ICollection>.IsReadOnly + { + get { return false; } + } + + bool ICollection>.Remove(KeyValuePair item) + { + return ((IDictionary)this).Contains(item) && + ((IDictionary)this).Remove(item.Key); + } + + IEnumerator> IEnumerable>.GetEnumerator() + { + return PropertiesEnumerable().Concat(Extra).GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return ((IDictionary)this).GetEnumerator(); + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestStream.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestStream.cs new file mode 100644 index 0000000000..5c73f0d41c --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestStream.cs @@ -0,0 +1,614 @@ +// ------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class RequestStream : Stream + { + private const int MaxReadSize = 0x20000; // http.sys recommends we limit reads to 128k + + private RequestContext _requestContext; + private uint _dataChunkOffset; + private int _dataChunkIndex; + private bool _closed; + + internal RequestStream(RequestContext httpContext) + { + _requestContext = httpContext; + } + + public override bool CanSeek + { + get + { + return false; + } + } + + public override bool CanWrite + { + get + { + return false; + } + } + + public override bool CanRead + { + get + { + return true; + } + } + + public override long Length + { + get + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + } + + public override long Position + { + get + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + set + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + + public override void Flush() + { + throw new InvalidOperationException(Resources.Exception_ReadOnlyStream); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + throw new InvalidOperationException(Resources.Exception_ReadOnlyStream); + } + + public override unsafe int Read([In, Out] byte[] buffer, int offset, int size) + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException("size"); + } + if (size == 0 || _closed) + { + // TODO: zero sized buffer should be invalid. + return 0; + } + // TODO: Verbose log parameters + + uint dataRead = 0; + + if (_dataChunkIndex != -1) + { + dataRead = UnsafeNclNativeMethods.HttpApi.GetChunks(_requestContext.Request.RequestBuffer, _requestContext.Request.OriginalBlobAddress, ref _dataChunkIndex, ref _dataChunkOffset, buffer, offset, size); + } + + if (_dataChunkIndex == -1 && dataRead < size) + { + uint statusCode = 0; + uint extraDataRead = 0; + offset += (int)dataRead; + size -= (int)dataRead; + + // the http.sys team recommends that we limit the size to 128kb + if (size > MaxReadSize) + { + size = MaxReadSize; + } + + fixed (byte* pBuffer = buffer) + { + // issue unmanaged blocking call + + uint flags = 0; + + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpReceiveRequestEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + flags, + (IntPtr)(pBuffer + offset), + (uint)size, + out extraDataRead, + SafeNativeOverlapped.Zero); + + dataRead += extraDataRead; + } + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "Read", exception); + throw exception; + } + UpdateAfterRead(statusCode, dataRead); + } + + // TODO: Verbose log dump data read + return (int)dataRead; + } + + private void UpdateAfterRead(uint statusCode, uint dataRead) + { + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF || dataRead == 0) + { + Dispose(); + } + } + +#if NET45 + public override unsafe IAsyncResult BeginRead(byte[] buffer, int offset, int size, AsyncCallback callback, object state) +#else + public unsafe IAsyncResult BeginRead(byte[] buffer, int offset, int size, AsyncCallback callback, object state) +#endif + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException("size"); + } + if (size == 0 || _closed) + { + RequestStreamAsyncResult result = new RequestStreamAsyncResult(this, state, callback); + result.Complete(0); + return result; + } + // TODO: Verbose log parameters + + RequestStreamAsyncResult asyncResult = null; + + uint dataRead = 0; + if (_dataChunkIndex != -1) + { + dataRead = UnsafeNclNativeMethods.HttpApi.GetChunks(_requestContext.Request.RequestBuffer, _requestContext.Request.OriginalBlobAddress, ref _dataChunkIndex, ref _dataChunkOffset, buffer, offset, size); + if (_dataChunkIndex != -1 && dataRead == size) + { + asyncResult = new RequestStreamAsyncResult(this, state, callback, buffer, offset, 0); + asyncResult.Complete((int)dataRead); + } + } + + if (_dataChunkIndex == -1 && dataRead < size) + { + uint statusCode = 0; + offset += (int)dataRead; + size -= (int)dataRead; + + // the http.sys team recommends that we limit the size to 128kb + if (size > MaxReadSize) + { + size = MaxReadSize; + } + + asyncResult = new RequestStreamAsyncResult(this, state, callback, buffer, offset, dataRead); + uint bytesReturned; + + try + { + fixed (byte* pBuffer = buffer) + { + uint flags = 0; + + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpReceiveRequestEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + flags, + asyncResult.PinnedBuffer, + (uint)size, + out bytesReturned, + asyncResult.NativeOverlapped); + } + } + catch (Exception e) + { + LogHelper.LogException(_requestContext.Logger, "BeginRead", e); + asyncResult.Dispose(); + throw; + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + asyncResult.Dispose(); + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + asyncResult = new RequestStreamAsyncResult(this, state, callback, dataRead); + asyncResult.Complete((int)bytesReturned); + } + else + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "BeginRead", exception); + throw exception; + } + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + asyncResult.IOCompleted(statusCode, bytesReturned); + } + } + return asyncResult; + } + +#if NET45 + public override int EndRead(IAsyncResult asyncResult) +#else + public int EndRead(IAsyncResult asyncResult) +#endif + { + if (asyncResult == null) + { + throw new ArgumentNullException("asyncResult"); + } + RequestStreamAsyncResult castedAsyncResult = asyncResult as RequestStreamAsyncResult; + if (castedAsyncResult == null || castedAsyncResult.RequestStream != this) + { + throw new ArgumentException(Resources.Exception_WrongIAsyncResult, "asyncResult"); + } + if (castedAsyncResult.EndCalled) + { + throw new InvalidOperationException(Resources.Exception_EndCalledMultipleTimes); + } + castedAsyncResult.EndCalled = true; + // wait & then check for errors + // Throws on failure + int dataRead = castedAsyncResult.Task.Result; + // TODO: Verbose log #dataRead. + return dataRead; + } + + public override unsafe Task ReadAsync(byte[] buffer, int offset, int size, CancellationToken cancellationToken) + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException("size"); + } + if (_closed) + { + throw new ObjectDisposedException(GetType().FullName); + } + if (size == 0) + { + return Task.FromResult(0); + } + // TODO: Needs full cancellation integration + cancellationToken.ThrowIfCancellationRequested(); + // TODO: Verbose log parameters + + RequestStreamAsyncResult asyncResult = null; + + uint dataRead = 0; + if (_dataChunkIndex != -1) + { + dataRead = UnsafeNclNativeMethods.HttpApi.GetChunks(_requestContext.Request.RequestBuffer, _requestContext.Request.OriginalBlobAddress, ref _dataChunkIndex, ref _dataChunkOffset, buffer, offset, size); + if (_dataChunkIndex != -1 && dataRead == size) + { + UpdateAfterRead(UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS, dataRead); + // TODO: Verbose log #dataRead + return Task.FromResult((int)dataRead); + } + } + + if (_dataChunkIndex == -1 && dataRead < size) + { + uint statusCode = 0; + offset += (int)dataRead; + size -= (int)dataRead; + + // the http.sys team recommends that we limit the size to 128kb + if (size > MaxReadSize) + { + size = MaxReadSize; + } + + asyncResult = new RequestStreamAsyncResult(this, null, null, buffer, offset, dataRead); + uint bytesReturned; + + try + { + uint flags = 0; + + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpReceiveRequestEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + flags, + asyncResult.PinnedBuffer, + (uint)size, + out bytesReturned, + asyncResult.NativeOverlapped); + } + catch (Exception e) + { + asyncResult.Dispose(); + LogHelper.LogException(_requestContext.Logger, "ReadAsync", e); + throw; + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + asyncResult.Dispose(); + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + uint totalRead = dataRead + bytesReturned; + UpdateAfterRead(statusCode, totalRead); + // TODO: Verbose log totalRead + return Task.FromResult((int)totalRead); + } + else + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "ReadAsync", exception); + throw exception; + } + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + asyncResult.Dispose(); + uint totalRead = dataRead + bytesReturned; + UpdateAfterRead(statusCode, totalRead); + // TODO: Verbose log + return Task.FromResult((int)totalRead); + } + } + return asyncResult.Task; + } + + public override void Write(byte[] buffer, int offset, int size) + { + throw new InvalidOperationException(Resources.Exception_ReadOnlyStream); + } + +#if NET45 + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback callback, object state) +#else + public IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback callback, object state) +#endif + { + throw new InvalidOperationException(Resources.Exception_ReadOnlyStream); + } + +#if NET45 + public override void EndWrite(IAsyncResult asyncResult) +#else + public void EndWrite(IAsyncResult asyncResult) +#endif + { + throw new InvalidOperationException(Resources.Exception_ReadOnlyStream); + } + + protected override void Dispose(bool disposing) + { + try + { + _closed = true; + } + finally + { + base.Dispose(disposing); + } + } + + private unsafe class RequestStreamAsyncResult : IAsyncResult, IDisposable + { + private static readonly IOCompletionCallback IOCallback = new IOCompletionCallback(Callback); + + private SafeNativeOverlapped _overlapped; + private IntPtr _pinnedBuffer; + private uint _dataAlreadyRead = 0; + private TaskCompletionSource _tcs; + private RequestStream _requestStream; + private AsyncCallback _callback; + + internal RequestStreamAsyncResult(RequestStream requestStream, object userState, AsyncCallback callback) + { + _requestStream = requestStream; + _tcs = new TaskCompletionSource(userState); + _callback = callback; + } + + internal RequestStreamAsyncResult(RequestStream requestStream, object userState, AsyncCallback callback, uint dataAlreadyRead) + : this(requestStream, userState, callback) + { + _dataAlreadyRead = dataAlreadyRead; + } + + internal RequestStreamAsyncResult(RequestStream requestStream, object userState, AsyncCallback callback, byte[] buffer, int offset, uint dataAlreadyRead) + : this(requestStream, userState, callback) + { + _dataAlreadyRead = dataAlreadyRead; + Overlapped overlapped = new Overlapped(); + overlapped.AsyncResult = this; + _overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, buffer)); + _pinnedBuffer = (Marshal.UnsafeAddrOfPinnedArrayElement(buffer, offset)); + } + + internal RequestStream RequestStream + { + get { return _requestStream; } + } + + internal SafeNativeOverlapped NativeOverlapped + { + get { return _overlapped; } + } + + internal IntPtr PinnedBuffer + { + get { return _pinnedBuffer; } + } + + internal uint DataAlreadyRead + { + get { return _dataAlreadyRead; } + } + + internal Task Task + { + get { return _tcs.Task; } + } + + internal bool EndCalled { get; set; } + + internal void IOCompleted(uint errorCode, uint numBytes) + { + IOCompleted(this, errorCode, numBytes); + } + + [SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Redirecting to callback")] + private static void IOCompleted(RequestStreamAsyncResult asyncResult, uint errorCode, uint numBytes) + { + try + { + if (errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + asyncResult.Fail(new WebListenerException((int)errorCode)); + } + else + { + // TODO: Verbose log dump data read + asyncResult.Complete((int)numBytes, errorCode); + } + } + catch (Exception e) + { + asyncResult.Fail(e); + } + } + + private static unsafe void Callback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped) + { + Overlapped callbackOverlapped = Overlapped.Unpack(nativeOverlapped); + RequestStreamAsyncResult asyncResult = callbackOverlapped.AsyncResult as RequestStreamAsyncResult; + + IOCompleted(asyncResult, errorCode, numBytes); + } + + internal void Complete(int read, uint errorCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + if (_tcs.TrySetResult(read + (int)DataAlreadyRead)) + { + RequestStream.UpdateAfterRead((uint)errorCode, (uint)(read + DataAlreadyRead)); + if (_callback != null) + { + try + { + _callback(this); + } + catch (Exception) + { + // TODO: Exception handling? This may be an IO callback thread and throwing here could crash the app. + } + } + } + } + + internal void Fail(Exception ex) + { + if (_tcs.TrySetException(ex) && _callback != null) + { + try + { + _callback(this); + } + catch (Exception) + { + // TODO: Exception handling? This may be an IO callback thread and throwing here could crash the app. + // TODO: Log + } + } + } + + [SuppressMessage("Microsoft.Usage", "CA2216:DisposableTypesShouldDeclareFinalizer", Justification = "The disposable resource referenced does have a finalizer.")] + public void Dispose() + { + Dispose(true); + } + + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + if (_overlapped != null) + { + _overlapped.Dispose(); + } + } + } + + public object AsyncState + { + get { return _tcs.Task.AsyncState; } + } + + public WaitHandle AsyncWaitHandle + { + get { return ((IAsyncResult)_tcs.Task).AsyncWaitHandle; } + } + + public bool CompletedSynchronously + { + get { return ((IAsyncResult)_tcs.Task).CompletedSynchronously; } + } + + public bool IsCompleted + { + get { return _tcs.Task.IsCompleted; } + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestUriBuilder.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestUriBuilder.cs new file mode 100644 index 0000000000..697934fc93 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/RequestUriBuilder.cs @@ -0,0 +1,569 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.Text; + +namespace Microsoft.AspNet.Server.WebListener +{ + // We don't use the cooked URL because http.sys unescapes all percent-encoded values. However, + // we also can't just use the raw Uri, since http.sys supports not only Utf-8, but also ANSI/DBCS and + // Unicode code points. System.Uri only supports Utf-8. + // The purpose of this class is to convert all ANSI, DBCS, and Unicode code points into percent encoded + // Utf-8 characters. + internal sealed class RequestUriBuilder + { + private static readonly bool UseCookedRequestUrl; + private static readonly Encoding Utf8Encoding; + private static readonly Encoding AnsiEncoding; + + private readonly string _rawUri; + private readonly string _cookedUriScheme; + private readonly string _cookedUriHost; + private readonly string _cookedUriPath; + private readonly string _cookedUriQuery; + + // This field is used to build the final request Uri string from the Uri parts passed to the ctor. + private StringBuilder _requestUriString; + + // The raw path is parsed by looping through all characters from left to right. 'rawOctets' + // is used to store consecutive percent encoded octets as actual byte values: e.g. for path /pa%C3%84th%2F/ + // rawOctets will be set to { 0xC3, 0x84 } when we reach character 't' and it will be { 0x2F } when + // we reach the final '/'. I.e. after a sequence of percent encoded octets ends, we use rawOctets as + // input to the encoding and percent encode the resulting string into UTF-8 octets. + // + // When parsing ANSI (Latin 1) encoded path '/pa%C4th/', %C4 will be added to rawOctets and when + // we reach 't', the content of rawOctets { 0xC4 } will be fed into the ANSI encoding. The resulting + // string 'Ä' will be percent encoded into UTF-8 octets and appended to requestUriString. The final + // path will be '/pa%C3%84th/', where '%C3%84' is the UTF-8 percent encoded character 'Ä'. + private List _rawOctets; + private string _rawPath; + + // Holds the final request Uri. + private Uri _requestUri; + + static RequestUriBuilder() + { + // TODO: False triggers more detailed/correct parsing, but it's rather slow. + UseCookedRequestUrl = true; // SettingsSectionInternal.Section.HttpListenerUnescapeRequestUrl; + Utf8Encoding = new UTF8Encoding(false, true); +#if NET45 + AnsiEncoding = Encoding.GetEncoding(0, new EncoderExceptionFallback(), new DecoderExceptionFallback()); +#else + AnsiEncoding = Utf8Encoding; +#endif + } + + private RequestUriBuilder(string rawUri, string cookedUriScheme, string cookedUriHost, + string cookedUriPath, string cookedUriQuery) + { + Debug.Assert(!string.IsNullOrEmpty(rawUri), "Empty raw URL."); + Debug.Assert(!string.IsNullOrEmpty(cookedUriScheme), "Empty cooked URL scheme."); + Debug.Assert(!string.IsNullOrEmpty(cookedUriHost), "Empty cooked URL host."); + Debug.Assert(!string.IsNullOrEmpty(cookedUriPath), "Empty cooked URL path."); + + this._rawUri = rawUri; + this._cookedUriScheme = cookedUriScheme; + this._cookedUriHost = cookedUriHost; + this._cookedUriPath = AddSlashToAsteriskOnlyPath(cookedUriPath); + this._cookedUriQuery = cookedUriQuery ?? string.Empty; + } + + private RequestUriBuilder(string rawUri, string cookedUriPath) + { + Debug.Assert(!string.IsNullOrEmpty(rawUri), "Empty raw URL."); + Debug.Assert(!string.IsNullOrEmpty(cookedUriPath), "Empty cooked URL path."); + + this._rawUri = rawUri; + this._cookedUriScheme = string.Empty; + this._cookedUriHost = string.Empty; + this._cookedUriPath = AddSlashToAsteriskOnlyPath(cookedUriPath); + this._cookedUriQuery = string.Empty; + } + + private enum ParsingResult + { + Success, + InvalidString, + EncodingError + } + + private enum EncodingType + { + Primary, + Secondary + } + + public static Uri GetRequestUri(string rawUri, string cookedUriScheme, string cookedUriHost, + string cookedUriPath, string cookedUriQuery) + { + RequestUriBuilder builder = new RequestUriBuilder(rawUri, + cookedUriScheme, cookedUriHost, cookedUriPath, cookedUriQuery); + + return builder.Build(); + } + + private Uri Build() + { + // if the user enabled the "use raw Uri" setting in section, we'll use the raw + // path rather than the cooked path. + if (UseCookedRequestUrl) + { + // corresponds to pre-4.0 behavior: use the cooked URI. + BuildRequestUriUsingCookedPath(); + + if (_requestUri == null) + { + BuildRequestUriUsingRawPath(); + } + } + else + { + BuildRequestUriUsingRawPath(); + + if (_requestUri == null) + { + BuildRequestUriUsingCookedPath(); + } + } + + return _requestUri; + } + + // Process only the path. + internal static string GetRequestPath(string rawUri, string cookedUriPath) + { + RequestUriBuilder builder = new RequestUriBuilder(rawUri, cookedUriPath); + + return builder.GetPath(); + } + + private string GetPath() + { + if (UseCookedRequestUrl) + { + return _cookedUriPath; + } + + // Initialize 'rawPath' only if really needed; i.e. if we build the request Uri from the raw Uri. + _rawPath = GetPath(_rawUri); + + // If HTTP.sys only parses Utf-8, we can safely use the raw path: it must be a valid Utf-8 string. + if (!HttpSysSettings.EnableNonUtf8 || string.IsNullOrEmpty(_rawPath)) + { + if (string.IsNullOrEmpty(_rawPath)) + { + _rawPath = "/"; + } + return _rawPath; + } + + // Try to check the raw path using first the primary encoding (according to http.sys settings); + // if it fails try the secondary encoding. + _rawOctets = new List(); + _requestUriString = new StringBuilder(); + ParsingResult result = ParseRawPath(GetEncoding(EncodingType.Primary)); + if (result == ParsingResult.EncodingError) + { + _rawOctets = new List(); + _requestUriString = new StringBuilder(); + result = ParseRawPath(GetEncoding(EncodingType.Secondary)); + } + + if (result == ParsingResult.Success) + { + return _requestUriString.ToString(); + } + + // Fallback + return _cookedUriPath; + } + + private void BuildRequestUriUsingCookedPath() + { + bool isValid = Uri.TryCreate(_cookedUriScheme + Constants.SchemeDelimiter + _cookedUriHost + _cookedUriPath + + _cookedUriQuery, UriKind.Absolute, out _requestUri); + + // Creating a Uri from the cooked Uri should really always work: If not, we log at least. + if (!isValid) + { + LogWarning("BuildRequestUriUsingCookedPath", "Unable to create URI: " + _cookedUriScheme + Constants.SchemeDelimiter + + _cookedUriHost + _cookedUriPath + _cookedUriQuery); + } + } + + private void BuildRequestUriUsingRawPath() + { + bool isValid = false; + + // Initialize 'rawPath' only if really needed; i.e. if we build the request Uri from the raw Uri. + _rawPath = GetPath(_rawUri); + + // If HTTP.sys only parses Utf-8, we can safely use the raw path: it must be a valid Utf-8 string. + if (!HttpSysSettings.EnableNonUtf8 || string.IsNullOrEmpty(_rawPath)) + { + string path = _rawPath; + if (string.IsNullOrEmpty(path)) + { + path = "/"; + Debug.Assert(string.IsNullOrEmpty(_cookedUriQuery), + "Query is only allowed if there is a non-empty path. At least '/' path required."); + } + + isValid = Uri.TryCreate(_cookedUriScheme + Constants.SchemeDelimiter + _cookedUriHost + path + _cookedUriQuery, + UriKind.Absolute, out _requestUri); + } + else + { + // Try to check the raw path using first the primary encoding (according to http.sys settings); + // if it fails try the secondary encoding. + ParsingResult result = BuildRequestUriUsingRawPath(GetEncoding(EncodingType.Primary)); + if (result == ParsingResult.EncodingError) + { + Encoding secondaryEncoding = GetEncoding(EncodingType.Secondary); + result = BuildRequestUriUsingRawPath(secondaryEncoding); + } + isValid = (result == ParsingResult.Success) ? true : false; + } + + // Log that we weren't able to create a Uri from the raw string. + if (!isValid) + { + LogWarning("BuildRequestUriUsingRawPath", "Unable to create Uri: " + _cookedUriScheme + Constants.SchemeDelimiter + + _cookedUriHost + _rawPath + _cookedUriQuery); + } + } + + private static Encoding GetEncoding(EncodingType type) + { + Debug.Assert(HttpSysSettings.EnableNonUtf8, + "If 'EnableNonUtf8' is false we shouldn't require an encoding. It's always Utf-8."); + /* This is mucking up the profiler for some reason. + Debug.Assert((type == EncodingType.Primary) || (type == EncodingType.Secondary), + "Unknown 'EncodingType' value: " + type.ToString()); + */ + if (((type == EncodingType.Primary) && (!HttpSysSettings.FavorUtf8)) || + ((type == EncodingType.Secondary) && (HttpSysSettings.FavorUtf8))) + { + return AnsiEncoding; + } + else + { + return Utf8Encoding; + } + } + + private ParsingResult BuildRequestUriUsingRawPath(Encoding encoding) + { + Debug.Assert(encoding != null, "'encoding' must be assigned."); + Debug.Assert(!string.IsNullOrEmpty(_rawPath), "'rawPath' must have at least one character."); + + _rawOctets = new List(); + _requestUriString = new StringBuilder(); + _requestUriString.Append(_cookedUriScheme); + _requestUriString.Append(Constants.SchemeDelimiter); + _requestUriString.Append(_cookedUriHost); + + ParsingResult result = ParseRawPath(encoding); + if (result == ParsingResult.Success) + { + _requestUriString.Append(_cookedUriQuery); + + Debug.Assert(_rawOctets.Count == 0, + "Still raw octets left. They must be added to the result path."); + + if (!Uri.TryCreate(_requestUriString.ToString(), UriKind.Absolute, out _requestUri)) + { + // If we can't create a Uri from the string, this is an invalid string and it doesn't make + // sense to try another encoding. + result = ParsingResult.InvalidString; + } + } + + if (result != ParsingResult.Success) + { + LogWarning("BuildRequestUriUsingRawPath", "Can't convert the raw path: " + _rawPath + " Encoding: " + encoding.WebName); + } + + return result; + } + + private ParsingResult ParseRawPath(Encoding encoding) + { + Debug.Assert(encoding != null, "'encoding' must be assigned."); + + int index = 0; + char current = '\0'; + while (index < _rawPath.Length) + { + current = _rawPath[index]; + if (current == '%') + { + // Assert is enough, since http.sys accepted the request string already. This should never happen. + Debug.Assert(index + 2 < _rawPath.Length, "Expected >=2 characters after '%' (e.g. %2F)"); + + index++; + current = _rawPath[index]; + if (current == 'u' || current == 'U') + { + // We found "%u" which means, we have a Unicode code point of the form "%uXXXX". + Debug.Assert(index + 4 < _rawPath.Length, "Expected >=4 characters after '%u' (e.g. %u0062)"); + + // Decode the content of rawOctets into percent encoded UTF-8 characters and append them + // to requestUriString. + if (!EmptyDecodeAndAppendRawOctetsList(encoding)) + { + return ParsingResult.EncodingError; + } + if (!AppendUnicodeCodePointValuePercentEncoded(_rawPath.Substring(index + 1, 4))) + { + return ParsingResult.InvalidString; + } + index += 5; + } + else + { + // We found '%', but not followed by 'u', i.e. we have a percent encoded octed: %XX + if (!AddPercentEncodedOctetToRawOctetsList(encoding, _rawPath.Substring(index, 2))) + { + return ParsingResult.InvalidString; + } + index += 2; + } + } + else + { + // We found a non-'%' character: decode the content of rawOctets into percent encoded + // UTF-8 characters and append it to the result. + if (!EmptyDecodeAndAppendRawOctetsList(encoding)) + { + return ParsingResult.EncodingError; + } + // Append the current character to the result. + _requestUriString.Append(current); + index++; + } + } + + // if the raw path ends with a sequence of percent encoded octets, make sure those get added to the + // result (requestUriString). + if (!EmptyDecodeAndAppendRawOctetsList(encoding)) + { + return ParsingResult.EncodingError; + } + + return ParsingResult.Success; + } + + private bool AppendUnicodeCodePointValuePercentEncoded(string codePoint) + { + // http.sys only supports %uXXXX (4 hex-digits), even though unicode code points could have up to + // 6 hex digits. Therefore we parse always 4 characters after %u and convert them to an int. + int codePointValue; + if (!int.TryParse(codePoint, NumberStyles.HexNumber, null, out codePointValue)) + { + LogWarning("AppendUnicodeCodePointValuePercentEncoded", "Can't convert code point: " + codePoint); + return false; + } + + string unicodeString = null; + try + { + unicodeString = char.ConvertFromUtf32(codePointValue); + AppendOctetsPercentEncoded(_requestUriString, Utf8Encoding.GetBytes(unicodeString)); + + return true; + } + catch (ArgumentOutOfRangeException) + { + LogWarning("AppendUnicodeCodePointValuePercentEncoded", "Can't convert code point: " + codePoint); + } + catch (EncoderFallbackException e) + { + // If utf8Encoding.GetBytes() fails + LogWarning("AppendUnicodeCodePointValuePercentEncoded", "Can't convert code point: " + unicodeString, e.Message); + } + + return false; + } + + private bool AddPercentEncodedOctetToRawOctetsList(Encoding encoding, string escapedCharacter) + { + byte encodedValue; + if (!byte.TryParse(escapedCharacter, NumberStyles.HexNumber, null, out encodedValue)) + { + LogWarning("AddPercentEncodedOctetToRawOctetsList", "Can't convert code point: " + escapedCharacter); + return false; + } + + _rawOctets.Add(encodedValue); + + return true; + } + + private bool EmptyDecodeAndAppendRawOctetsList(Encoding encoding) + { + if (_rawOctets.Count == 0) + { + return true; + } + + string decodedString = null; + try + { + // If the encoding can get a string out of the byte array, this is a valid string in the + // 'encoding' encoding. + byte[] bytes = _rawOctets.ToArray(); + decodedString = encoding.GetString(bytes, 0, bytes.Length); + + if (encoding == Utf8Encoding) + { + AppendOctetsPercentEncoded(_requestUriString, bytes); + } + else + { + AppendOctetsPercentEncoded(_requestUriString, Utf8Encoding.GetBytes(decodedString)); + } + + _rawOctets.Clear(); + + return true; + } + catch (DecoderFallbackException e) + { + LogWarning("EmptyDecodeAndAppendRawOctetsList", "Can't convert bytes: " + GetOctetsAsString(_rawOctets), e.Message); + } + catch (EncoderFallbackException e) + { + // If utf8Encoding.GetBytes() fails + LogWarning("EmptyDecodeAndAppendRawOctetsList", "Can't convert bytes: " + decodedString, e.Message); + } + + return false; + } + + private static void AppendOctetsPercentEncoded(StringBuilder target, IEnumerable octets) + { + foreach (byte octet in octets) + { + target.Append('%'); + target.Append(octet.ToString("X2", CultureInfo.InvariantCulture)); + } + } + + private static string GetOctetsAsString(IEnumerable octets) + { + StringBuilder octetString = new StringBuilder(); + + bool first = true; + foreach (byte octet in octets) + { + if (first) + { + first = false; + } + else + { + octetString.Append(" "); + } + octetString.Append(octet.ToString("X2", CultureInfo.InvariantCulture)); + } + + return octetString.ToString(); + } + + private static string GetPath(string uriString) + { + Debug.Assert(uriString != null, "uriString must not be null"); + Debug.Assert(uriString.Length > 0, "uriString must not be empty"); + + int pathStartIndex = 0; + + // Perf. improvement: nearly all strings are relative Uris. So just look if the + // string starts with '/'. If so, we have a relative Uri and the path starts at position 0. + // (http.sys already trimmed leading whitespaces) + if (uriString[0] != '/') + { + // We can't check against cookedUriScheme, since http.sys allows for request http://myserver/ to + // use a request line 'GET https://myserver/' (note http vs. https). Therefore check if the + // Uri starts with either http:// or https://. + int authorityStartIndex = 0; + if (uriString.StartsWith("http://", StringComparison.OrdinalIgnoreCase)) + { + authorityStartIndex = 7; + } + else if (uriString.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + { + authorityStartIndex = 8; + } + + if (authorityStartIndex > 0) + { + // we have an absolute Uri. Find out where the authority ends and the path begins. + // Note that Uris like "http://server?query=value/1/2" are invalid according to RFC2616 + // and http.sys behavior: If the Uri contains a query, there must be at least one '/' + // between the authority and the '?' character: It's safe to just look for the first + // '/' after the authority to determine the beginning of the path. + pathStartIndex = uriString.IndexOf('/', authorityStartIndex); + if (pathStartIndex == -1) + { + // e.g. for request lines like: 'GET http://myserver' (no final '/') + pathStartIndex = uriString.Length; + } + } + else + { + // RFC2616: Request-URI = "*" | absoluteURI | abs_path | authority + // 'authority' can only be used with CONNECT which is never received by HttpListener. + // I.e. if we don't have an absolute path (must start with '/') and we don't have + // an absolute Uri (must start with http:// or https://), then 'uriString' must be '*'. + Debug.Assert((uriString.Length == 1) && (uriString[0] == '*'), "Unknown request Uri string format; " + + "Request Uri string is not an absolute Uri, absolute path, or '*': " + uriString); + + // Should we ever get here, be consistent with 2.0/3.5 behavior: just add an initial + // slash to the string and treat it as a path: + uriString = "/" + uriString; + } + } + + // Find end of path: The path is terminated by + // - the first '?' character + // - the first '#' character: This is never the case here, since http.sys won't accept + // Uris containing fragments. Also, RFC2616 doesn't allow fragments in request Uris. + // - end of Uri string + int queryIndex = uriString.IndexOf('?'); + if (queryIndex == -1) + { + queryIndex = uriString.Length; + } + + // will always return a != null string. + return AddSlashToAsteriskOnlyPath(uriString.Substring(pathStartIndex, queryIndex - pathStartIndex)); + } + + private static string AddSlashToAsteriskOnlyPath(string path) + { + Debug.Assert(path != null, "'path' must not be null"); + + // If a request like "OPTIONS * HTTP/1.1" is sent to the listener, then the request Uri + // should be "http[s]://server[:port]/*" to be compatible with pre-4.0 behavior. + if ((path.Length == 1) && (path[0] == '*')) + { + return "/*"; + } + + return path; + } + + private void LogWarning(string methodName, string message, params object[] args) + { + // TODO: Verbose log + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Response.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Response.cs new file mode 100644 index 0000000000..d25af63ed7 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/Response.cs @@ -0,0 +1,756 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal sealed unsafe class Response : IDisposable + { + private ResponseState _responseState; + private IDictionary _headers; + private ResponseStream _responseStream; + private long _contentLength; + private BoundaryType _boundaryType; + private UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE _nativeResponse; + private IList, object>> _onSendingHeadersActions; + + private RequestContext _requestContext; + + internal Response(RequestContext httpContext) + { + // TODO: Verbose log + _requestContext = httpContext; + _nativeResponse = new UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE(); + _headers = new Dictionary(StringComparer.OrdinalIgnoreCase); + _boundaryType = BoundaryType.None; + _nativeResponse.StatusCode = (ushort)HttpStatusCode.OK; + _nativeResponse.Version.MajorVersion = 1; + _nativeResponse.Version.MinorVersion = 1; + _responseState = ResponseState.Created; + _onSendingHeadersActions = new List, object>>(); + } + + private enum ResponseState + { + Created, + ComputedHeaders, + SentHeaders, + Closed, + } + + private RequestContext RequestContext + { + get + { + return _requestContext; + } + } + + private Request Request + { + get + { + return RequestContext.Request; + } + } + + internal Stream OutputStream + { + get + { + CheckDisposed(); + EnsureResponseStream(); + return _responseStream; + } + } + + internal int GetStatusCode() + { + int statusCode = _requestContext.Environment.ResponseStatusCode ?? 200; + if (statusCode <= 100 || statusCode > 999) + { + // TODO: Move this validation to the dictionary facade so it throws when the app sets it, rather than durring send? + throw new InvalidOperationException(string.Format(Resources.Exception_InvalidStatusCode, statusCode)); + } + return statusCode; + } + + internal string GetReasonPhrase(int statusCode) + { + // TODO: Validate user input for illegal chars, length limit, etc.? + string reasonPhrase = _requestContext.Environment.ResponseReasonPhrase; + if (string.IsNullOrWhiteSpace(reasonPhrase)) + { + // if the user hasn't set this, generated on the fly, if possible. + // We know this one is safe, no need to verify it as in the setter. + reasonPhrase = HttpReasonPhrase.Get(statusCode) ?? string.Empty; + } + return reasonPhrase; + } + + // We MUST NOT send message-body when we send responses with these Status codes + private static readonly int[] NoResponseBody = { 100, 101, 204, 205, 304 }; + + private static bool CanSendResponseBody(int responseCode) + { + for (int i = 0; i < NoResponseBody.Length; i++) + { + if (responseCode == NoResponseBody[i]) + { + return false; + } + } + return true; + } + + internal EntitySendFormat EntitySendFormat + { + get + { + return (EntitySendFormat)_boundaryType; + } + } + + internal IDictionary Headers + { + get + { + return _headers; + } + } + + internal long ContentLength64 + { + get + { + return _contentLength; + } + } + + private Version GetProtocolVersion() + { + Version requestVersion = Request.ProtocolVersion; + Version responseVersion = requestVersion; + string protocolVersion = RequestContext.Environment.Get(Constants.HttpResponseProtocolKey); + + // Optional + if (!string.IsNullOrWhiteSpace(protocolVersion)) + { + if (string.Equals("HTTP/1.1", protocolVersion, StringComparison.OrdinalIgnoreCase)) + { + responseVersion = Constants.V1_1; + } + if (string.Equals("HTTP/1.0", protocolVersion, StringComparison.OrdinalIgnoreCase)) + { + responseVersion = Constants.V1_0; + } + else + { + // TODO: Just log? It's too late to get this to user code. + throw new ArgumentException(string.Empty, Constants.HttpResponseProtocolKey); + } + } + + if (requestVersion == responseVersion) + { + return requestVersion; + } + + // Return the lesser of the two versions. There are only two, so it it will always be 1.0. + return Constants.V1_0; + } + + public void Dispose() + { + Dispose(true); + } + + private void Dispose(bool disposing) + { + if (disposing) + { + if (_responseState >= ResponseState.Closed) + { + return; + } + // TODO: Verbose log + EnsureResponseStream(); + _responseStream.Dispose(); + _responseState = ResponseState.Closed; + } + } + + // old API, now private, and helper methods + + internal BoundaryType BoundaryType + { + get + { + return _boundaryType; + } + } + + internal bool SentHeaders + { + get + { + return _responseState >= ResponseState.SentHeaders; + } + } + + internal bool ComputedHeaders + { + get + { + return _responseState >= ResponseState.ComputedHeaders; + } + } + + private void EnsureResponseStream() + { + if (_responseStream == null) + { + _responseStream = new ResponseStream(RequestContext); + } + } + + /* + 12.3 + HttpSendHttpResponse() and HttpSendResponseEntityBody() Flag Values. + The following flags can be used on calls to HttpSendHttpResponse() and HttpSendResponseEntityBody() API calls: + + #define HTTP_SEND_RESPONSE_FLAG_DISCONNECT 0x00000001 + #define HTTP_SEND_RESPONSE_FLAG_MORE_DATA 0x00000002 + #define HTTP_SEND_RESPONSE_FLAG_RAW_HEADER 0x00000004 + #define HTTP_SEND_RESPONSE_FLAG_VALID 0x00000007 + + HTTP_SEND_RESPONSE_FLAG_DISCONNECT: + specifies that the network connection should be disconnected immediately after + sending the response, overriding the HTTP protocol's persistent connection features. + HTTP_SEND_RESPONSE_FLAG_MORE_DATA: + specifies that additional entity body data will be sent by the caller. Thus, + the last call HttpSendResponseEntityBody for a RequestId, will have this flag reset. + HTTP_SEND_RESPONSE_RAW_HEADER: + specifies that a caller of HttpSendResponseEntityBody() is intentionally omitting + a call to HttpSendHttpResponse() in order to bypass normal header processing. The + actual HTTP header will be generated by the application and sent as entity body. + This flag should be passed on the first call to HttpSendResponseEntityBody, and + not after. Thus, flag is not applicable to HttpSendHttpResponse. + */ + + // TODO: Consider using HTTP_SEND_RESPONSE_RAW_HEADER with HttpSendResponseEntityBody instead of calling HttpSendHttpResponse. + // This will give us more control of the bytes that hit the wire, including encodings, HTTP 1.0, etc.. + // It may also be faster to do this work in managed code and then pass down only one buffer. + // What would we loose by bypassing HttpSendHttpResponse? + // + // TODO: Consider using the HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA flag for most/all responses rather than just Opaque. + internal unsafe uint SendHeaders(UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK* pDataChunk, + ResponseStreamAsyncResult asyncResult, + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags, + bool isOpaqueUpgrade) + { + Debug.Assert(!SentHeaders, "HttpListenerResponse::SendHeaders()|SentHeaders is true."); + + // TODO: Verbose log headers + _responseState = ResponseState.SentHeaders; + + _nativeResponse.StatusCode = (ushort)GetStatusCode(); + string reasonPhrase = GetReasonPhrase(_nativeResponse.StatusCode); + + /* + if (m_BoundaryType==BoundaryType.Raw) { + use HTTP_SEND_RESPONSE_FLAG_RAW_HEADER; + } + */ + uint statusCode; + uint bytesSent; + List pinnedHeaders = SerializeHeaders(ref _nativeResponse.Headers, isOpaqueUpgrade); + try + { + if (pDataChunk != null) + { + _nativeResponse.EntityChunkCount = 1; + _nativeResponse.pEntityChunks = pDataChunk; + } + else if (asyncResult != null && asyncResult.DataChunks != null) + { + _nativeResponse.EntityChunkCount = asyncResult.DataChunkCount; + _nativeResponse.pEntityChunks = asyncResult.DataChunks; + } + else + { + _nativeResponse.EntityChunkCount = 0; + _nativeResponse.pEntityChunks = null; + } + + if (reasonPhrase.Length > 0) + { + byte[] reasonPhraseBytes = new byte[HeaderEncoding.GetByteCount(reasonPhrase)]; + fixed (byte* pReasonPhrase = reasonPhraseBytes) + { + _nativeResponse.ReasonLength = (ushort)reasonPhraseBytes.Length; + HeaderEncoding.GetBytes(reasonPhrase, 0, reasonPhraseBytes.Length, reasonPhraseBytes, 0); + _nativeResponse.pReason = (sbyte*)pReasonPhrase; + fixed (UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE* pResponse = &_nativeResponse) + { + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendHttpResponse( + RequestContext.RequestQueueHandle, + Request.RequestId, + (uint)flags, + pResponse, + null, + &bytesSent, + SafeLocalFree.Zero, + 0, + asyncResult == null ? SafeNativeOverlapped.Zero : asyncResult.NativeOverlapped, + IntPtr.Zero); + + if (asyncResult != null && + statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + OwinWebListener.SkipIOCPCallbackOnSuccess) + { + asyncResult.BytesSent = bytesSent; + // The caller will invoke IOCompleted + } + } + } + } + else + { + fixed (UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE* pResponse = &_nativeResponse) + { + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendHttpResponse( + RequestContext.RequestQueueHandle, + Request.RequestId, + (uint)flags, + pResponse, + null, + &bytesSent, + SafeLocalFree.Zero, + 0, + asyncResult == null ? SafeNativeOverlapped.Zero : asyncResult.NativeOverlapped, + IntPtr.Zero); + + if (asyncResult != null && + statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + OwinWebListener.SkipIOCPCallbackOnSuccess) + { + asyncResult.BytesSent = bytesSent; + // The caller will invoke IOCompleted + } + } + } + } + finally + { + FreePinnedHeaders(pinnedHeaders); + } + return statusCode; + } + + internal UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS ComputeHeaders(bool endOfRequest = false) + { + // Notify that this is absolutely the last chance to make changes. + NotifyOnSendingHeaders(); + + // 401 + if (GetStatusCode() == (ushort)HttpStatusCode.Unauthorized) + { + RequestContext.Server.AuthenticationManager.SetAuthenticationChallenge(this); + } + + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + Debug.Assert(!ComputedHeaders, "HttpListenerResponse::ComputeHeaders()|ComputedHeaders is true."); + _responseState = ResponseState.ComputedHeaders; + /* + // here we would check for BoundaryType.Raw, in this case we wouldn't need to do anything + if (m_BoundaryType==BoundaryType.Raw) { + return flags; + } + */ + + // Check the response headers to determine the correct keep alive and boundary type. + Version responseVersion = GetProtocolVersion(); + _nativeResponse.Version.MajorVersion = (ushort)responseVersion.Major; + _nativeResponse.Version.MinorVersion = (ushort)responseVersion.Minor; + bool keepAlive = responseVersion >= Constants.V1_1; + string connectionString = Headers.Get(HttpKnownHeaderNames.Connection); + string keepAliveString = Headers.Get(HttpKnownHeaderNames.KeepAlive); + bool closeSet = false; + bool keepAliveSet = false; + + if (!string.IsNullOrWhiteSpace(connectionString) && string.Equals("close", connectionString.Trim(), StringComparison.OrdinalIgnoreCase)) + { + keepAlive = false; + closeSet = true; + } + else if (!string.IsNullOrWhiteSpace(keepAliveString) && string.Equals("true", keepAliveString.Trim(), StringComparison.OrdinalIgnoreCase)) + { + keepAlive = true; + keepAliveSet = true; + } + + // Content-Length takes priority + string contentLengthString = Headers.Get(HttpKnownHeaderNames.ContentLength); + string transferEncodingString = Headers.Get(HttpKnownHeaderNames.TransferEncoding); + + if (responseVersion == Constants.V1_0 && !string.IsNullOrEmpty(transferEncodingString) + && string.Equals("chunked", transferEncodingString.Trim(), StringComparison.OrdinalIgnoreCase)) + { + // A 1.0 client can't process chunked responses. + Headers.Remove(HttpKnownHeaderNames.TransferEncoding); + transferEncodingString = null; + } + + if (!string.IsNullOrWhiteSpace(contentLengthString)) + { + contentLengthString = contentLengthString.Trim(); + if (string.Equals("0", contentLengthString, StringComparison.Ordinal)) + { + _boundaryType = BoundaryType.ContentLength; + _contentLength = 0; + flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + } + else if (long.TryParse(contentLengthString, NumberStyles.None, CultureInfo.InvariantCulture.NumberFormat, out _contentLength)) + { + _boundaryType = BoundaryType.ContentLength; + if (_contentLength == 0) + { + flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + } + } + else + { + _boundaryType = BoundaryType.Invalid; + } + } + else if (!string.IsNullOrWhiteSpace(transferEncodingString) + && string.Equals("chunked", transferEncodingString.Trim(), StringComparison.OrdinalIgnoreCase)) + { + // Then Transfer-Encoding: chunked + _boundaryType = BoundaryType.Chunked; + } + else if (endOfRequest) + { + // The request is ending without a body, add a Content-Length: 0 header. + Headers[HttpKnownHeaderNames.ContentLength] = new string[] { "0" }; + _boundaryType = BoundaryType.ContentLength; + _contentLength = 0; + flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + } + else + { + // Then fall back to Connection:Close transparent mode. + _boundaryType = BoundaryType.None; + flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; // seems like HTTP_SEND_RESPONSE_FLAG_MORE_DATA but this hangs the app; + if (responseVersion == Constants.V1_0) + { + keepAlive = false; + } + else + { + Headers[HttpKnownHeaderNames.TransferEncoding] = new string[] { "chunked" }; + _boundaryType = BoundaryType.Chunked; + } + + if (CanSendResponseBody(_requestContext.Response.GetStatusCode())) + { + _contentLength = -1; + } + else + { + Headers[HttpKnownHeaderNames.ContentLength] = new string[] { "0" }; + _contentLength = 0; + _boundaryType = BoundaryType.ContentLength; + } + } + + // Also, Keep-Alive vs Connection Close + + if (!keepAlive) + { + if (!closeSet) + { + Headers.Append(HttpKnownHeaderNames.Connection, "close"); + } + if (flags == UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE) + { + flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_DISCONNECT; + } + } + else + { + if (Request.ProtocolVersion.Minor == 0 && !keepAliveSet) + { + Headers[HttpKnownHeaderNames.KeepAlive] = new string[] { "true" }; + } + } + return flags; + } + + private List SerializeHeaders(ref UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE_HEADERS headers, + bool isOpaqueUpgrade) + { + UnsafeNclNativeMethods.HttpApi.HTTP_UNKNOWN_HEADER[] unknownHeaders = null; + List pinnedHeaders; + GCHandle gcHandle; + /* + // here we would check for BoundaryType.Raw, in this case we wouldn't need to do anything + if (m_BoundaryType==BoundaryType.Raw) { + return null; + } + */ + if (Headers.Count == 0) + { + return null; + } + string headerName; + string headerValue; + int lookup; + byte[] bytes = null; + pinnedHeaders = new List(); + + //--------------------------------------------------- + // DTS Issue: 609383: + // The Set-Cookie headers are being merged into one. + // There are two issues here. + // 1. When Set-Cookie headers are set through SetCookie method on the ListenerResponse, + // there is code in the SetCookie method and the methods it calls to flatten the Set-Cookie + // values. This blindly concatenates the cookies with a comma delimiter. There could be + // a cookie value that contains comma, but we don't escape it with %XX value + // + // As an alternative users can add the Set-Cookie header through the AddHeader method + // like ListenerResponse.Headers.Add("name", "value") + // That way they can add multiple headers - AND They can format the value like they want it. + // + // 2. Now that the header collection contains multiple Set-Cookie name, value pairs + // you would think the problem would go away. However here is an interesting thing. + // For NameValueCollection, when you add + // "Set-Cookie", "value1" + // "Set-Cookie", "value2" + // The NameValueCollection.Count == 1. Because there is only one key + // NameValueCollection.Get("Set-Cookie") would conveniently take these two values + // concatenate them with a comma like + // value1,value2. + // In order to get individual values, you need to use + // string[] values = NameValueCollection.GetValues("Set-Cookie"); + // + // ------------------------------------------------------------- + // So here is the proposed fix here. + // We must first to loop through all the NameValueCollection keys + // and if the name is a unknown header, we must compute the number of + // values it has. Then, we should allocate that many unknown header array + // elements. + // + // Note that a part of the fix here is to treat Set-Cookie as an unknown header + // + // + //----------------------------------------------------------- + int numUnknownHeaders = 0; + foreach (KeyValuePair headerPair in Headers) + { + // See if this is an unknown header + lookup = UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE_HEADER_ID.IndexOfKnownHeader(headerPair.Key); + + // TODO: WWW-Authentiate header + // TODO: HTTP_RESPONSE_V2 has a HTTP_MULTIPLE_KNOWN_HEADERS option where you can supply multiple values. + + // TODO: Consider any 'known' header that has multiple values as 'unknown'? + // Treat Set-Cookie as well as Connection header in opaque mode as unknown + if (lookup == (int)HttpSysResponseHeader.SetCookie || + (isOpaqueUpgrade && lookup == (int)HttpSysResponseHeader.Connection)) + { + lookup = -1; + } + + if (lookup == -1) + { + numUnknownHeaders += headerPair.Value.Length; + } + } + + try + { + fixed (UnsafeNclNativeMethods.HttpApi.HTTP_KNOWN_HEADER* pKnownHeaders = &headers.KnownHeaders) + { + foreach (KeyValuePair headerPair in Headers) + { + headerName = headerPair.Key; + lookup = UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE_HEADER_ID.IndexOfKnownHeader(headerName); + if (lookup == (int)HttpSysResponseHeader.SetCookie || + (isOpaqueUpgrade && lookup == (int)HttpSysResponseHeader.Connection)) + { + lookup = -1; + } + + if (lookup == -1) + { + if (unknownHeaders == null) + { + //---------------------------------------- + // *** This following comment is no longer true *** + // we waste some memory here (up to 32*41=1312 bytes) but we gain speed + // unknownHeaders = new UnsafeNclNativeMethods.HttpApi.HTTP_UNKNOWN_HEADER[Headers.Count-index]; + //-------------------------------------------- + unknownHeaders = new UnsafeNclNativeMethods.HttpApi.HTTP_UNKNOWN_HEADER[numUnknownHeaders]; + gcHandle = GCHandle.Alloc(unknownHeaders, GCHandleType.Pinned); + pinnedHeaders.Add(gcHandle); + headers.pUnknownHeaders = (UnsafeNclNativeMethods.HttpApi.HTTP_UNKNOWN_HEADER*)gcHandle.AddrOfPinnedObject(); + } + + //---------------------------------------- + // FOR UNKNOWN HEADERS + // ALLOW MULTIPLE HEADERS to be added + //--------------------------------------- + string[] headerValues = headerPair.Value; + for (int headerValueIndex = 0; headerValueIndex < headerValues.Length; headerValueIndex++) + { + // Add Name + bytes = new byte[HeaderEncoding.GetByteCount(headerName)]; + unknownHeaders[headers.UnknownHeaderCount].NameLength = (ushort)bytes.Length; + HeaderEncoding.GetBytes(headerName, 0, bytes.Length, bytes, 0); + gcHandle = GCHandle.Alloc(bytes, GCHandleType.Pinned); + pinnedHeaders.Add(gcHandle); + unknownHeaders[headers.UnknownHeaderCount].pName = (sbyte*)gcHandle.AddrOfPinnedObject(); + + // Add Value + headerValue = headerValues[headerValueIndex]; + bytes = new byte[HeaderEncoding.GetByteCount(headerValue)]; + unknownHeaders[headers.UnknownHeaderCount].RawValueLength = (ushort)bytes.Length; + HeaderEncoding.GetBytes(headerValue, 0, bytes.Length, bytes, 0); + gcHandle = GCHandle.Alloc(bytes, GCHandleType.Pinned); + pinnedHeaders.Add(gcHandle); + unknownHeaders[headers.UnknownHeaderCount].pRawValue = (sbyte*)gcHandle.AddrOfPinnedObject(); + headers.UnknownHeaderCount++; + } + } + else + { + string[] headerValues = headerPair.Value; + headerValue = headerValues.Length == 1 ? headerValues[0] : string.Join(", ", headerValues); + if (headerValue != null) + { + bytes = new byte[HeaderEncoding.GetByteCount(headerValue)]; + pKnownHeaders[lookup].RawValueLength = (ushort)bytes.Length; + HeaderEncoding.GetBytes(headerValue, 0, bytes.Length, bytes, 0); + gcHandle = GCHandle.Alloc(bytes, GCHandleType.Pinned); + pinnedHeaders.Add(gcHandle); + pKnownHeaders[lookup].pRawValue = (sbyte*)gcHandle.AddrOfPinnedObject(); + } + } + } + } + } + catch + { + FreePinnedHeaders(pinnedHeaders); + throw; + } + return pinnedHeaders; + } + + private static void FreePinnedHeaders(List pinnedHeaders) + { + if (pinnedHeaders != null) + { + foreach (GCHandle gcHandle in pinnedHeaders) + { + if (gcHandle.IsAllocated) + { + gcHandle.Free(); + } + } + } + } + + // Subset of ComputeHeaders + internal void SendOpaqueUpgrade() + { + // TODO: Should we do this notification earlier when you still have a chance to change the status code to avoid an upgrade? + // Notify that this is absolutely the last chance to make changes. + NotifyOnSendingHeaders(); + + // TODO: Send headers async? + ulong errorCode = SendHeaders(null, null, + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_OPAQUE | + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA | + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA, + true); + + if (errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS) + { + throw new WebListenerException((int)errorCode); + } + } + + private void CheckDisposed() + { + if (_responseState >= ResponseState.Closed) + { + throw new ObjectDisposedException(this.GetType().FullName); + } + } + + internal void CancelLastWrite(SafeHandle requestQueueHandle) + { + if (_responseStream != null) + { + _responseStream.CancelLastWrite(requestQueueHandle); + } + } + + internal Task SendFileAsync(string fileName, long offset, long? count, CancellationToken cancel) + { + EnsureResponseStream(); + return _responseStream.SendFileAsync(fileName, offset, count, cancel); + } + + internal void SwitchToOpaqueMode() + { + EnsureResponseStream(); + _responseStream.SwitchToOpaqueMode(); + } + + internal void RegisterForOnSendingHeaders(Action callback, object state) + { + IList, object>> actions = _onSendingHeadersActions; + if (actions == null) + { + throw new InvalidOperationException("Headers already sent"); + } + + actions.Add(new Tuple, object>(callback, state)); + } + + private void NotifyOnSendingHeaders() + { + var actions = Interlocked.Exchange(ref _onSendingHeadersActions, null); + if (actions == null) + { + // Something threw the first time, do not try again. + return; + } + + // Execute last to first. This mimics a stack unwind. + foreach (var actionPair in actions.Reverse()) + { + actionPair.Item1(actionPair.Item2); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStream.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStream.cs new file mode 100644 index 0000000000..0177312901 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStream.cs @@ -0,0 +1,826 @@ +// ------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal class ResponseStream : Stream + { + private static readonly byte[] ChunkTerminator = new byte[] { (byte)'0', (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n' }; + + private RequestContext _requestContext; + private long _leftToWrite = long.MinValue; + private bool _closed; + private bool _inOpaqueMode; + // The last write needs special handling to cancel. + private ResponseStreamAsyncResult _lastWrite; + + internal ResponseStream(RequestContext requestContext) + { + _requestContext = requestContext; + } + + public override bool CanSeek + { + get + { + return false; + } + } + + public override bool CanWrite + { + get + { + return true; + } + } + + public override bool CanRead + { + get + { + return false; + } + } + + public override long Length + { + get + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + } + + public override long Position + { + get + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + set + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + } + + // Send headers + public override void Flush() + { + if (_closed || _requestContext.Response.SentHeaders) + { + return; + } + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); + // TODO: Verbose log + + try + { + uint statusCode; + unsafe + { + flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + statusCode = _requestContext.Response.SendHeaders(null, null, flags, false); + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + throw new WebListenerException((int)statusCode); + } + } + catch (Exception e) + { + LogHelper.LogException(_requestContext.Logger, "Flush", e); + _closed = true; + _requestContext.Abort(); + throw; + } + } + + // Send headers + public override Task FlushAsync(CancellationToken cancellationToken) + { + if (_closed || _requestContext.Response.SentHeaders) + { + return Helpers.CompletedTask(); + } + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); + // TODO: Verbose log + + // TODO: Real cancellation + cancellationToken.ThrowIfCancellationRequested(); + + flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, null, 0, 0, _requestContext.Response.BoundaryType == BoundaryType.Chunked, false); + + try + { + uint statusCode; + unsafe + { + statusCode = _requestContext.Response.SendHeaders(null, asyncResult, flags, false); + } + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + asyncResult.IOCompleted(statusCode); + } + else if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + throw new WebListenerException((int)statusCode); + } + } + catch (Exception e) + { + LogHelper.LogException(_requestContext.Logger, "FlushAsync", e); + asyncResult.Dispose(); + _closed = true; + _requestContext.Abort(); + throw; + } + + return asyncResult.Task; + } + + #region NotSupported Read/Seek + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(Resources.Exception_NoSeek); + } + + public override int Read([In, Out] byte[] buffer, int offset, int size) + { + throw new InvalidOperationException(Resources.Exception_WriteOnlyStream); + } + +#if NET45 + public override IAsyncResult BeginRead(byte[] buffer, int offset, int size, AsyncCallback callback, object state) + { + throw new InvalidOperationException(Resources.Exception_WriteOnlyStream); + } + + public override int EndRead(IAsyncResult asyncResult) + { + throw new InvalidOperationException(Resources.Exception_WriteOnlyStream); + } +#endif + + #endregion + + private UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS ComputeLeftToWrite(bool endOfRequest = false) + { + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + if (!_requestContext.Response.ComputedHeaders) + { + flags = _requestContext.Response.ComputeHeaders(endOfRequest: endOfRequest); + } + if (_leftToWrite == long.MinValue) + { + UnsafeNclNativeMethods.HttpApi.HTTP_VERB method = _requestContext.GetKnownMethod(); + if (method == UnsafeNclNativeMethods.HttpApi.HTTP_VERB.HttpVerbHEAD) + { + _leftToWrite = 0; + } + else if (_requestContext.Response.EntitySendFormat == EntitySendFormat.ContentLength) + { + _leftToWrite = _requestContext.Response.ContentLength64; + } + else + { + _leftToWrite = -1; // unlimited + } + } + return flags; + } + + public override unsafe void Write(byte[] buffer, int offset, int size) + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException("size"); + } + if (_closed) + { + throw new ObjectDisposedException(GetType().FullName); + } + // TODO: Verbose log parameters + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); + if (size == 0 && _leftToWrite != 0) + { + return; + } + if (_leftToWrite >= 0 && size > _leftToWrite) + { + throw new InvalidOperationException(Resources.Exception_TooMuchWritten); + } + // TODO: Verbose log + + uint statusCode; + uint dataToWrite = (uint)size; + SafeLocalFree bufferAsIntPtr = null; + IntPtr pBufferAsIntPtr = IntPtr.Zero; + bool sentHeaders = _requestContext.Response.SentHeaders; + try + { + if (size == 0) + { + // TODO: Is this code path accessible? Is this like a Flush? + statusCode = _requestContext.Response.SendHeaders(null, null, flags, false); + } + else + { + fixed (byte* pDataBuffer = buffer) + { + byte* pBuffer = pDataBuffer; + if (_requestContext.Response.BoundaryType == BoundaryType.Chunked) + { + // TODO: + // here we need some heuristics, some time it is definitely better to split this in 3 write calls + // but for small writes it is probably good enough to just copy the data internally. + string chunkHeader = size.ToString("x", CultureInfo.InvariantCulture); + dataToWrite = dataToWrite + (uint)(chunkHeader.Length + 4); + bufferAsIntPtr = SafeLocalFree.LocalAlloc((int)dataToWrite); + pBufferAsIntPtr = bufferAsIntPtr.DangerousGetHandle(); + for (int i = 0; i < chunkHeader.Length; i++) + { + Marshal.WriteByte(pBufferAsIntPtr, i, (byte)chunkHeader[i]); + } + Marshal.WriteInt16(pBufferAsIntPtr, chunkHeader.Length, 0x0A0D); + Marshal.Copy(buffer, offset, IntPtrHelper.Add(pBufferAsIntPtr, chunkHeader.Length + 2), size); + Marshal.WriteInt16(pBufferAsIntPtr, (int)(dataToWrite - 2), 0x0A0D); + pBuffer = (byte*)pBufferAsIntPtr; + offset = 0; + } + UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK dataChunk = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + dataChunk.DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + dataChunk.fromMemory.pBuffer = (IntPtr)(pBuffer + offset); + dataChunk.fromMemory.BufferLength = dataToWrite; + + flags |= _leftToWrite == size ? UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE : UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + if (!sentHeaders) + { + statusCode = _requestContext.Response.SendHeaders(&dataChunk, null, flags, false); + } + else + { + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendResponseEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + (uint)flags, + 1, + &dataChunk, + null, + SafeLocalFree.Zero, + 0, + SafeNativeOverlapped.Zero, + IntPtr.Zero); + + if (_requestContext.Server.IgnoreWriteExceptions) + { + statusCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS; + } + } + } + } + } + finally + { + if (bufferAsIntPtr != null) + { + // free unmanaged buffer + bufferAsIntPtr.Dispose(); + } + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "Write", exception); + _closed = true; + _requestContext.Abort(); + throw exception; + } + UpdateWritenCount(dataToWrite); + + // TODO: Verbose log data written + } +#if NET45 + public override unsafe IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback callback, object state) +#else + public unsafe IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback callback, object state) +#endif + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException("size"); + } + if (_closed) + { + throw new ObjectDisposedException(GetType().FullName); + } + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); + if (size == 0 && _leftToWrite != 0) + { + ResponseStreamAsyncResult result = new ResponseStreamAsyncResult(this, state, callback); + result.Complete(); + return result; + } + if (_leftToWrite >= 0 && size > _leftToWrite) + { + throw new InvalidOperationException(Resources.Exception_TooMuchWritten); + } + // TODO: Verbose log parameters + + uint statusCode; + uint bytesSent = 0; + flags |= _leftToWrite == size ? UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE : UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + bool sentHeaders = _requestContext.Response.SentHeaders; + ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, state, callback, buffer, offset, size, _requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders); + + // Update m_LeftToWrite now so we can queue up additional BeginWrite's without waiting for EndWrite. + UpdateWritenCount((uint)((_requestContext.Response.BoundaryType == BoundaryType.Chunked) ? 0 : size)); + + try + { + if (!sentHeaders) + { + statusCode = _requestContext.Response.SendHeaders(null, asyncResult, flags, false); + bytesSent = asyncResult.BytesSent; + } + else + { + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendResponseEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + (uint)flags, + asyncResult.DataChunkCount, + asyncResult.DataChunks, + &bytesSent, + SafeLocalFree.Zero, + 0, + asyncResult.NativeOverlapped, + IntPtr.Zero); + } + } + catch (Exception e) + { + LogHelper.LogException(_requestContext.Logger, "BeginWrite", e); + asyncResult.Dispose(); + _closed = true; + _requestContext.Abort(); + throw; + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + asyncResult.Dispose(); + if (_requestContext.Server.IgnoreWriteExceptions && sentHeaders) + { + asyncResult.Complete(); + } + else + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "BeginWrite", exception); + _closed = true; + _requestContext.Abort(); + throw exception; + } + } + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + asyncResult.IOCompleted(statusCode, bytesSent); + } + + // Last write, cache it for special cancelation handling. + if ((flags & UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA) == 0) + { + _lastWrite = asyncResult; + } + + return asyncResult; + } +#if NET45 + public override void EndWrite(IAsyncResult asyncResult) +#else + public void EndWrite(IAsyncResult asyncResult) +#endif + { + if (asyncResult == null) + { + throw new ArgumentNullException("asyncResult"); + } + ResponseStreamAsyncResult castedAsyncResult = asyncResult as ResponseStreamAsyncResult; + if (castedAsyncResult == null || castedAsyncResult.ResponseStream != this) + { + throw new ArgumentException(Resources.Exception_WrongIAsyncResult, "asyncResult"); + } + if (castedAsyncResult.EndCalled) + { + throw new InvalidOperationException(Resources.Exception_EndCalledMultipleTimes); + } + castedAsyncResult.EndCalled = true; + + try + { + // wait & then check for errors + // TODO: Gracefull re-throw + castedAsyncResult.Task.Wait(); + } + catch (Exception exception) + { + LogHelper.LogException(_requestContext.Logger, "EndWrite", exception); + _closed = true; + _requestContext.Abort(); + throw; + } + } + + public override unsafe Task WriteAsync(byte[] buffer, int offset, int size, CancellationToken cancel) + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) + { + throw new ArgumentOutOfRangeException("size"); + } + if (_closed) + { + throw new ObjectDisposedException(GetType().FullName); + } + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); + if (size == 0 && _leftToWrite != 0) + { + return Helpers.CompletedTask(); + } + if (_leftToWrite >= 0 && size > _leftToWrite) + { + throw new InvalidOperationException(Resources.Exception_TooMuchWritten); + } + // TODO: Verbose log + + // TODO: Real cancelation + cancel.ThrowIfCancellationRequested(); + + uint statusCode; + uint bytesSent = 0; + flags |= _leftToWrite == size ? UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE : UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + bool sentHeaders = _requestContext.Response.SentHeaders; + ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, buffer, offset, size, _requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders); + + // Update m_LeftToWrite now so we can queue up additional BeginWrite's without waiting for EndWrite. + UpdateWritenCount((uint)((_requestContext.Response.BoundaryType == BoundaryType.Chunked) ? 0 : size)); + + try + { + if (!sentHeaders) + { + statusCode = _requestContext.Response.SendHeaders(null, asyncResult, flags, false); + bytesSent = asyncResult.BytesSent; + } + else + { + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendResponseEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + (uint)flags, + asyncResult.DataChunkCount, + asyncResult.DataChunks, + &bytesSent, + SafeLocalFree.Zero, + 0, + asyncResult.NativeOverlapped, + IntPtr.Zero); + } + } + catch (Exception e) + { + LogHelper.LogException(_requestContext.Logger, "WriteAsync", e); + asyncResult.Dispose(); + _closed = true; + _requestContext.Abort(); + throw; + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + asyncResult.Dispose(); + if (_requestContext.Server.IgnoreWriteExceptions && sentHeaders) + { + asyncResult.Complete(); + } + else + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "WriteAsync", exception); + _closed = true; + _requestContext.Abort(); + throw exception; + } + } + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + asyncResult.IOCompleted(statusCode, bytesSent); + } + + // Last write, cache it for special cancelation handling. + if ((flags & UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA) == 0) + { + _lastWrite = asyncResult; + } + + return asyncResult.Task; + } + + internal unsafe Task SendFileAsync(string fileName, long offset, long? size, CancellationToken cancel) + { + // It's too expensive to validate the file attributes before opening the file. Open the file and then check the lengths. + // This all happens inside of ResponseStreamAsyncResult. + if (string.IsNullOrWhiteSpace(fileName)) + { + throw new ArgumentNullException("fileName"); + } + if (_closed) + { + throw new ObjectDisposedException(GetType().FullName); + } + + // TODO: Real cancellation + cancel.ThrowIfCancellationRequested(); + + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(); + if (size == 0 && _leftToWrite != 0) + { + return Helpers.CompletedTask(); + } + if (_leftToWrite >= 0 && size > _leftToWrite) + { + throw new InvalidOperationException(Resources.Exception_TooMuchWritten); + } + // TODO: Verbose log + + uint statusCode; + uint bytesSent = 0; + flags |= _leftToWrite == size ? UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE : UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + bool sentHeaders = _requestContext.Response.SentHeaders; + ResponseStreamAsyncResult asyncResult = new ResponseStreamAsyncResult(this, null, null, fileName, offset, size, + _requestContext.Response.BoundaryType == BoundaryType.Chunked, sentHeaders); + + long bytesWritten; + if (_requestContext.Response.BoundaryType == BoundaryType.Chunked) + { + bytesWritten = 0; + } + else if (size.HasValue) + { + bytesWritten = size.Value; + } + else + { + bytesWritten = asyncResult.FileLength - offset; + } + // Update m_LeftToWrite now so we can queue up additional calls to SendFileAsync. + UpdateWritenCount((uint)bytesWritten); + + try + { + if (!sentHeaders) + { + statusCode = _requestContext.Response.SendHeaders(null, asyncResult, flags, false); + bytesSent = asyncResult.BytesSent; + } + else + { + // TODO: If opaque then include the buffer data flag. + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendResponseEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + (uint)flags, + asyncResult.DataChunkCount, + asyncResult.DataChunks, + &bytesSent, + SafeLocalFree.Zero, + 0, + asyncResult.NativeOverlapped, + IntPtr.Zero); + } + } + catch (Exception e) + { + LogHelper.LogException(_requestContext.Logger, "SendFileAsync", e); + asyncResult.Dispose(); + _closed = true; + _requestContext.Abort(); + throw; + } + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + asyncResult.Dispose(); + if (_requestContext.Server.IgnoreWriteExceptions && sentHeaders) + { + asyncResult.Complete(); + } + else + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "SendFileAsync", exception); + _closed = true; + _requestContext.Abort(); + throw exception; + } + } + + if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && OwinWebListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + asyncResult.IOCompleted(statusCode, bytesSent); + } + + // Last write, cache it for special cancellation handling. + if ((flags & UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA) == 0) + { + _lastWrite = asyncResult; + } + + return asyncResult.Task; + } + + private void UpdateWritenCount(uint dataWritten) + { + if (!_inOpaqueMode) + { + if (_leftToWrite > 0) + { + // keep track of the data transferred + _leftToWrite -= dataWritten; + } + if (_leftToWrite == 0) + { + // in this case we already passed 0 as the flag, so we don't need to call HttpSendResponseEntityBody() when we Close() + _closed = true; + } + } + } + + protected override unsafe void Dispose(bool disposing) + { + try + { + if (disposing) + { + if (_closed) + { + return; + } + _closed = true; + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = ComputeLeftToWrite(endOfRequest: true); + if (_leftToWrite > 0 && !_inOpaqueMode) + { + _requestContext.Abort(); + // TODO: Reduce this to a logged warning, it is thrown too late to be visible in user code. + LogHelper.LogError(_requestContext.Logger, "ResponseStream::Dispose", "Fewer bytes were written than were specified in the Content-Length."); + return; + } + bool sentHeaders = _requestContext.Response.SentHeaders; + if (sentHeaders && _leftToWrite == 0) + { + return; + } + + uint statusCode = 0; + if ((_requestContext.Response.BoundaryType == BoundaryType.Chunked || _requestContext.Response.BoundaryType == BoundaryType.None) && (String.Compare(_requestContext.Request.HttpMethod, "HEAD", StringComparison.OrdinalIgnoreCase) != 0)) + { + if (_requestContext.Response.BoundaryType == BoundaryType.None) + { + flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_DISCONNECT; + } + fixed (void* pBuffer = ChunkTerminator) + { + UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK* pDataChunk = null; + if (_requestContext.Response.BoundaryType == BoundaryType.Chunked) + { + UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK dataChunk = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + dataChunk.DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + dataChunk.fromMemory.pBuffer = (IntPtr)pBuffer; + dataChunk.fromMemory.BufferLength = (uint)ChunkTerminator.Length; + pDataChunk = &dataChunk; + } + if (!sentHeaders) + { + statusCode = _requestContext.Response.SendHeaders(pDataChunk, null, flags, false); + } + else + { + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendResponseEntityBody( + _requestContext.RequestQueueHandle, + _requestContext.RequestId, + (uint)flags, + pDataChunk != null ? (ushort)1 : (ushort)0, + pDataChunk, + null, + SafeLocalFree.Zero, + 0, + SafeNativeOverlapped.Zero, + IntPtr.Zero); + + if (_requestContext.Server.IgnoreWriteExceptions) + { + statusCode = UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS; + } + } + } + } + else + { + if (!sentHeaders) + { + statusCode = _requestContext.Response.SendHeaders(null, null, flags, false); + } + } + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF + // Don't throw for disconnects, we were already finished with the response. + && statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_CONNECTION_INVALID) + { + Exception exception = new WebListenerException((int)statusCode); + LogHelper.LogException(_requestContext.Logger, "Dispose", exception); + _requestContext.Abort(); + throw exception; + } + _leftToWrite = 0; + } + } + finally + { + base.Dispose(disposing); + } + } + + internal void SwitchToOpaqueMode() + { + _inOpaqueMode = true; + _leftToWrite = long.MaxValue; + } + + // The final Content-Length async write can only be cancelled by CancelIoEx. + // Sync can only be cancelled by CancelSynchronousIo, but we don't attempt this right now. + [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = + "It is safe to ignore the return value on a cancel operation because the connection is being closed")] + internal unsafe void CancelLastWrite(SafeHandle requestQueueHandle) + { + ResponseStreamAsyncResult asyncState = _lastWrite; + if (asyncState != null && !asyncState.IsCompleted) + { + UnsafeNclNativeMethods.CancelIoEx(requestQueueHandle, asyncState.NativeOverlapped); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStreamAsyncResult.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStreamAsyncResult.cs new file mode 100644 index 0000000000..ca38ba6355 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/ResponseStreamAsyncResult.cs @@ -0,0 +1,441 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal unsafe class ResponseStreamAsyncResult : IAsyncResult, IDisposable + { + private static readonly byte[] CRLF = new byte[] { (byte)'\r', (byte)'\n' }; + private static readonly IOCompletionCallback IOCallback = new IOCompletionCallback(Callback); + + private SafeNativeOverlapped _overlapped; + private UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK[] _dataChunks; + private bool _sentHeaders; + private FileStream _fileStream; + private ResponseStream _responseStream; + private TaskCompletionSource _tcs; + private AsyncCallback _callback; + private uint _bytesSent; + + internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback) + { + _responseStream = responseStream; + _tcs = new TaskCompletionSource(userState); + _callback = callback; + } + + internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback, + byte[] buffer, int offset, int size, bool chunked, bool sentHeaders) + : this(responseStream, userState, callback) + { + _sentHeaders = sentHeaders; + Overlapped overlapped = new Overlapped(); + overlapped.AsyncResult = this; + + if (size == 0) + { + _dataChunks = null; + _overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, null)); + } + else + { + _dataChunks = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK[chunked ? 3 : 1]; + + object[] objectsToPin = new object[1 + _dataChunks.Length]; + objectsToPin[_dataChunks.Length] = _dataChunks; + + int chunkHeaderOffset = 0; + byte[] chunkHeaderBuffer = null; + if (chunked) + { + chunkHeaderBuffer = GetChunkHeader(size, out chunkHeaderOffset); + + _dataChunks[0] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[0].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + _dataChunks[0].fromMemory.BufferLength = (uint)(chunkHeaderBuffer.Length - chunkHeaderOffset); + + objectsToPin[0] = chunkHeaderBuffer; + + _dataChunks[1] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[1].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + _dataChunks[1].fromMemory.BufferLength = (uint)size; + + objectsToPin[1] = buffer; + + _dataChunks[2] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[2].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + _dataChunks[2].fromMemory.BufferLength = (uint)CRLF.Length; + + objectsToPin[2] = CRLF; + } + else + { + _dataChunks[0] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[0].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + _dataChunks[0].fromMemory.BufferLength = (uint)size; + + objectsToPin[0] = buffer; + } + + // This call will pin needed memory + _overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, objectsToPin)); + + if (chunked) + { + _dataChunks[0].fromMemory.pBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(chunkHeaderBuffer, chunkHeaderOffset); + _dataChunks[1].fromMemory.pBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(buffer, offset); + _dataChunks[2].fromMemory.pBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(CRLF, 0); + } + else + { + _dataChunks[0].fromMemory.pBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(buffer, offset); + } + } + } + + internal ResponseStreamAsyncResult(ResponseStream responseStream, object userState, AsyncCallback callback, + string fileName, long offset, long? size, bool chunked, bool sentHeaders) + : this(responseStream, userState, callback) + { + _sentHeaders = sentHeaders; + Overlapped overlapped = new Overlapped(); + overlapped.AsyncResult = this; + + int bufferSize = 1024 * 64; // TODO: Validate buffer size choice. +#if NET45 + // It's too expensive to validate anything before opening the file. Open the file and then check the lengths. + _fileStream = new FileStream(fileName, FileMode.Open, FileAccess.Read, FileShare.ReadWrite, bufferSize, + FileOptions.Asynchronous | FileOptions.SequentialScan); // Extremely expensive. +#else + _fileStream = new FileStream(fileName, FileMode.Open, FileAccess.Read, FileShare.ReadWrite, bufferSize, useAsync: true); // Extremely expensive. +#endif +#if !NET45 + throw new NotImplementedException(); +#else + long length = _fileStream.Length; // Expensive + if (offset < 0 || offset > length) + { + _fileStream.Dispose(); + throw new ArgumentOutOfRangeException("offset", offset, string.Empty); + } + if (size.HasValue && (size < 0 || size > length - offset)) + { + _fileStream.Dispose(); + throw new ArgumentOutOfRangeException("size", size, string.Empty); + } + + if (size == 0 || (!size.HasValue && _fileStream.Length == 0)) + { + _dataChunks = null; + _overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, null)); + } + else + { + _dataChunks = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK[chunked ? 3 : 1]; + + object[] objectsToPin = new object[_dataChunks.Length]; + objectsToPin[_dataChunks.Length - 1] = _dataChunks; + + int chunkHeaderOffset = 0; + byte[] chunkHeaderBuffer = null; + if (chunked) + { + chunkHeaderBuffer = GetChunkHeader((int)(size ?? _fileStream.Length - offset), out chunkHeaderOffset); + + _dataChunks[0] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[0].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + _dataChunks[0].fromMemory.BufferLength = (uint)(chunkHeaderBuffer.Length - chunkHeaderOffset); + + objectsToPin[0] = chunkHeaderBuffer; + + _dataChunks[1] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[1].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromFileHandle; + _dataChunks[1].fromFile.offset = (ulong)offset; + _dataChunks[1].fromFile.count = (ulong)(size ?? -1); + _dataChunks[1].fromFile.fileHandle = _fileStream.SafeFileHandle.DangerousGetHandle(); + // Nothing to pin for the file handle. + + _dataChunks[2] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[2].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + _dataChunks[2].fromMemory.BufferLength = (uint)CRLF.Length; + + objectsToPin[1] = CRLF; + } + else + { + _dataChunks[0] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + _dataChunks[0].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromFileHandle; + _dataChunks[0].fromFile.offset = (ulong)offset; + _dataChunks[0].fromFile.count = (ulong)(size ?? -1); + _dataChunks[0].fromFile.fileHandle = _fileStream.SafeFileHandle.DangerousGetHandle(); + } + + // This call will pin needed memory + _overlapped = new SafeNativeOverlapped(overlapped.Pack(IOCallback, objectsToPin)); + + if (chunked) + { + _dataChunks[0].fromMemory.pBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(chunkHeaderBuffer, chunkHeaderOffset); + _dataChunks[2].fromMemory.pBuffer = Marshal.UnsafeAddrOfPinnedArrayElement(CRLF, 0); + } + } +#endif + } + + internal ResponseStream ResponseStream + { + get { return _responseStream; } + } + + internal SafeNativeOverlapped NativeOverlapped + { + get { return _overlapped; } + } + + internal Task Task + { + get { return _tcs.Task; } + } + + internal uint BytesSent + { + get { return _bytesSent; } + set { _bytesSent = value; } + } + + internal ushort DataChunkCount + { + get + { + if (_dataChunks == null) + { + return 0; + } + else + { + return (ushort)_dataChunks.Length; + } + } + } + + internal UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK* DataChunks + { + get + { + if (_dataChunks == null) + { + return null; + } + else + { + return (UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK*)(Marshal.UnsafeAddrOfPinnedArrayElement(_dataChunks, 0)); + } + } + } + + internal long FileLength + { + get { return _fileStream == null ? 0 : _fileStream.Length; } + } + + internal bool EndCalled { get; set; } + + internal void IOCompleted(uint errorCode) + { + IOCompleted(this, errorCode, BytesSent); + } + + internal void IOCompleted(uint errorCode, uint numBytes) + { + IOCompleted(this, errorCode, numBytes); + } + + [SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Redirecting to callback")] + private static void IOCompleted(ResponseStreamAsyncResult asyncResult, uint errorCode, uint numBytes) + { + try + { + if (errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + asyncResult.Fail(new WebListenerException((int)errorCode)); + } + else + { + if (asyncResult._dataChunks == null) + { + // TODO: Verbose log data written + } + else + { + // TODO: Verbose log + // for (int i = 0; i < asyncResult._dataChunks.Length; i++) + // { + // Logging.Dump(Logging.HttpListener, asyncResult, "Callback", (IntPtr)asyncResult._dataChunks[0].fromMemory.pBuffer, (int)asyncResult._dataChunks[0].fromMemory.BufferLength); + // } + } + asyncResult.Complete(); + } + } + catch (Exception e) + { + asyncResult.Fail(e); + } + } + + private static unsafe void Callback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped) + { + Overlapped callbackOverlapped = Overlapped.Unpack(nativeOverlapped); + ResponseStreamAsyncResult asyncResult = callbackOverlapped.AsyncResult as ResponseStreamAsyncResult; + + IOCompleted(asyncResult, errorCode, numBytes); + } + + internal void Complete() + { + if (_tcs.TrySetResult(null) && _callback != null) + { + try + { + _callback(this); + } + catch (Exception) + { + // TODO: Exception handling? This may be an IO callback thread and throwing here could crash the app. + // TODO: Log + } + } + Dispose(); + } + + internal void Fail(Exception ex) + { + if (_tcs.TrySetException(ex) && _callback != null) + { + try + { + _callback(this); + } + catch (Exception) + { + // TODO: Exception handling? This may be an IO callback thread and throwing here could crash the app. + } + } + Dispose(); + } + + /*++ + + GetChunkHeader + + A private utility routine to convert an integer to a chunk header, + which is an ASCII hex number followed by a CRLF. The header is retuned + as a byte array. + + Input: + + size - Chunk size to be encoded + offset - Out parameter where we store offset into buffer. + + Returns: + + A byte array with the header in int. + + --*/ + + private static byte[] GetChunkHeader(int size, out int offset) + { + uint mask = 0xf0000000; + byte[] header = new byte[10]; + int i; + offset = -1; + + // Loop through the size, looking at each nibble. If it's not 0 + // convert it to hex. Save the index of the first non-zero + // byte. + + for (i = 0; i < 8; i++, size <<= 4) + { + // offset == -1 means that we haven't found a non-zero nibble + // yet. If we haven't found one, and the current one is zero, + // don't do anything. + + if (offset == -1) + { + if ((size & mask) == 0) + { + continue; + } + } + + // Either we have a non-zero nibble or we're no longer skipping + // leading zeros. Convert this nibble to ASCII and save it. + + uint temp = (uint)size >> 28; + + if (temp < 10) + { + header[i] = (byte)(temp + '0'); + } + else + { + header[i] = (byte)((temp - 10) + 'A'); + } + + // If we haven't found a non-zero nibble yet, we've found one + // now, so remember that. + + if (offset == -1) + { + offset = i; + } + } + + header[8] = (byte)'\r'; + header[9] = (byte)'\n'; + + return header; + } + + public object AsyncState + { + get { return _tcs.Task.AsyncState; } + } + + public WaitHandle AsyncWaitHandle + { + get { return ((IAsyncResult)_tcs.Task).AsyncWaitHandle; } + } + + public bool CompletedSynchronously + { + get { return ((IAsyncResult)_tcs.Task).CompletedSynchronously; } + } + + public bool IsCompleted + { + get { return _tcs.Task.IsCompleted; } + } + + public void Dispose() + { + if (_overlapped != null) + { + _overlapped.Dispose(); + } + if (_fileStream != null) + { + _fileStream.Dispose(); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/SslStatus.cs b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/SslStatus.cs new file mode 100644 index 0000000000..72118c2d46 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/RequestProcessing/SslStatus.cs @@ -0,0 +1,15 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace Microsoft.AspNet.Server.WebListener +{ + internal enum SslStatus : byte + { + Insecure, + NoClientCert, + ClientCert + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/Resources.Designer.cs b/src/Microsoft.AspNet.Server.WebListener/Resources.Designer.cs new file mode 100644 index 0000000000..aafe87fb8d --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/Resources.Designer.cs @@ -0,0 +1,153 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.34006 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.Server.WebListener { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class Resources { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal Resources() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Microsoft.AspNet.Server.WebListener.Resources", System.Reflection.IntrospectionExtensions.GetTypeInfo(typeof(Resources)).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to The destination array is too small.. + /// + internal static string Exception_ArrayTooSmall { + get { + return ResourceManager.GetString("Exception_ArrayTooSmall", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to End has already been called.. + /// + internal static string Exception_EndCalledMultipleTimes { + get { + return ResourceManager.GetString("Exception_EndCalledMultipleTimes", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The status code '{0}' is not supported.. + /// + internal static string Exception_InvalidStatusCode { + get { + return ResourceManager.GetString("Exception_InvalidStatusCode", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The stream is not seekable.. + /// + internal static string Exception_NoSeek { + get { + return ResourceManager.GetString("Exception_NoSeek", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The prefix '{0}' is already registered.. + /// + internal static string Exception_PrefixAlreadyRegistered { + get { + return ResourceManager.GetString("Exception_PrefixAlreadyRegistered", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to This stream only supports read operations.. + /// + internal static string Exception_ReadOnlyStream { + get { + return ResourceManager.GetString("Exception_ReadOnlyStream", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to More data written than specified in the Content-Length header.. + /// + internal static string Exception_TooMuchWritten { + get { + return ResourceManager.GetString("Exception_TooMuchWritten", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Only the http and https schemes are supported.. + /// + internal static string Exception_UnsupportedScheme { + get { + return ResourceManager.GetString("Exception_UnsupportedScheme", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to This stream only supports write operations.. + /// + internal static string Exception_WriteOnlyStream { + get { + return ResourceManager.GetString("Exception_WriteOnlyStream", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to The given IAsyncResult does not match this opperation.. + /// + internal static string Exception_WrongIAsyncResult { + get { + return ResourceManager.GetString("Exception_WrongIAsyncResult", resourceCulture); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/Resources.resx b/src/Microsoft.AspNet.Server.WebListener/Resources.resx new file mode 100644 index 0000000000..005bafca2f --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/Resources.resx @@ -0,0 +1,150 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + The destination array is too small. + + + End has already been called. + + + The status code '{0}' is not supported. + + + The stream is not seekable. + + + The prefix '{0}' is already registered. + + + This stream only supports read operations. + + + More data written than specified in the Content-Length header. + + + Only the http and https schemes are supported. + + + This stream only supports write operations. + + + The given IAsyncResult does not match this opperation. + + \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/TimeoutManager.cs b/src/Microsoft.AspNet.Server.WebListener/TimeoutManager.cs new file mode 100644 index 0000000000..9e74f6563a --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/TimeoutManager.cs @@ -0,0 +1,267 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Server.WebListener +{ + // See the native HTTP_TIMEOUT_LIMIT_INFO structure documentation for additional information. + // http://msdn.microsoft.com/en-us/library/aa364661.aspx + + /// + /// Exposes the Http.Sys timeout configurations. These may also be configured in the registry. + /// + public sealed class TimeoutManager + { +#if NET45 + private static readonly int TimeoutLimitSize = + Marshal.SizeOf(typeof(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_LIMIT_INFO)); +#else + private static readonly int TimeoutLimitSize = + Marshal.SizeOf(); +#endif + private OwinWebListener _server; + private int[] _timeouts; + private uint _minSendBytesPerSecond; + + internal TimeoutManager(OwinWebListener context) + { + _server = context; + + // We have to maintain local state since we allow applications to set individual timeouts. Native Http + // API for setting timeouts expects all timeout values in every call so we have remember timeout values + // to fill in the blanks. Except MinSendBytesPerSecond, local state for remaining five timeouts is + // maintained in timeouts array. + // + // No initialization is required because a value of zero indicates that system defaults should be used. + _timeouts = new int[5]; + + LoadConfigurationSettings(); + } + + #region Properties + + /// + /// The time, in seconds, allowed for the request entity body to arrive. The default timer is 2 minutes. + /// + /// The HTTP Server API turns on this timer when the request has an entity body. The timer expiration is + /// initially set to the configured value. When the HTTP Server API receives additional data indications on the + /// request, it resets the timer to give the connection another interval. + /// + /// Use TimeSpan.Zero to indicate that system defaults should be used. + /// + public TimeSpan EntityBody + { + get + { + return GetTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.EntityBody); + } + set + { + SetTimespanTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.EntityBody, value); + } + } + + /// + /// The time, in seconds, allowed for the HTTP Server API to drain the entity body on a Keep-Alive connection. + /// The default timer is 2 minutes. + /// + /// On a Keep-Alive connection, after the application has sent a response for a request and before the request + /// entity body has completely arrived, the HTTP Server API starts draining the remainder of the entity body to + /// reach another potentially pipelined request from the client. If the time to drain the remaining entity body + /// exceeds the allowed period the connection is timed out. + /// + /// Use TimeSpan.Zero to indicate that system defaults should be used. + /// + public TimeSpan DrainEntityBody + { + get + { + return GetTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.DrainEntityBody); + } + set + { + SetTimespanTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.DrainEntityBody, value); + } + } + + /// + /// The time, in seconds, allowed for the request to remain in the request queue before the application picks + /// it up. The default timer is 2 minutes. + /// + /// Use TimeSpan.Zero to indicate that system defaults should be used. + /// + public TimeSpan RequestQueue + { + get + { + return GetTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.RequestQueue); + } + set + { + SetTimespanTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.RequestQueue, value); + } + } + + /// + /// The time, in seconds, allowed for an idle connection. The default timer is 2 minutes. + /// + /// This timeout is only enforced after the first request on the connection is routed to the application. + /// + /// Use TimeSpan.Zero to indicate that system defaults should be used. + /// + public TimeSpan IdleConnection + { + get + { + return GetTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.IdleConnection); + } + set + { + SetTimespanTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.IdleConnection, value); + } + } + + /// + /// The time, in seconds, allowed for the HTTP Server API to parse the request header. The default timer is + /// 2 minutes. + /// + /// This timeout is only enforced after the first request on the connection is routed to the application. + /// + /// Use TimeSpan.Zero to indicate that system defaults should be used. + /// + public TimeSpan HeaderWait + { + get + { + return GetTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.HeaderWait); + } + set + { + SetTimespanTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.HeaderWait, value); + } + } + + /// + /// The minimum send rate, in bytes-per-second, for the response. The default response send rate is 150 + /// bytes-per-second. + /// + /// To disable this timer set it to UInt32.MaxValue + /// + public long MinSendBytesPerSecond + { + get + { + // Since we maintain local state, GET is local. + return _minSendBytesPerSecond; + } + set + { + // MinSendRate value is ULONG in native layer. + if (value < 0 || value > uint.MaxValue) + { + throw new ArgumentOutOfRangeException("value"); + } + + SetServerTimeout(_timeouts, (uint)value); + _minSendBytesPerSecond = (uint)value; + } + } + + #endregion Properties + + // Initial values come from the config. The values can then be overridden using this public API. + private void LoadConfigurationSettings() + { + long[] configTimeouts = new long[_timeouts.Length + 1]; // SettingsSectionInternal.Section.HttpListenerTimeouts; + Debug.Assert(configTimeouts != null); + Debug.Assert(configTimeouts.Length == (_timeouts.Length + 1)); + + bool setNonDefaults = false; + for (int i = 0; i < _timeouts.Length; i++) + { + if (configTimeouts[i] != 0) + { + Debug.Assert(configTimeouts[i] <= ushort.MaxValue, "Timeout out of range: " + configTimeouts[i]); + _timeouts[i] = (int)configTimeouts[i]; + setNonDefaults = true; + } + } + + if (configTimeouts[5] != 0) + { + Debug.Assert(configTimeouts[5] <= uint.MaxValue, "Timeout out of range: " + configTimeouts[5]); + _minSendBytesPerSecond = (uint)configTimeouts[5]; + setNonDefaults = true; + } + + if (setNonDefaults) + { + SetServerTimeout(_timeouts, _minSendBytesPerSecond); + } + } + + #region Helpers + + private TimeSpan GetTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE type) + { + // Since we maintain local state, GET is local. + return new TimeSpan(0, 0, (int)_timeouts[(int)type]); + } + + private void SetTimespanTimeout(UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE type, TimeSpan value) + { + Int64 timeoutValue; + + // All timeouts are defined as USHORT in native layer (except MinSendRate, which is ULONG). Make sure that + // timeout value is within range. + + timeoutValue = Convert.ToInt64(value.TotalSeconds); + + if (timeoutValue < 0 || timeoutValue > ushort.MaxValue) + { + throw new ArgumentOutOfRangeException("value"); + } + + // Use local state to get values for other timeouts. Call into the native layer and if that + // call succeeds, update local state. + + int[] currentTimeouts = _timeouts; + currentTimeouts[(int)type] = (int)timeoutValue; + SetServerTimeout(currentTimeouts, _minSendBytesPerSecond); + _timeouts[(int)type] = (int)timeoutValue; + } + + private unsafe void SetServerTimeout(int[] timeouts, uint minSendBytesPerSecond) + { + UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_LIMIT_INFO timeoutinfo = + new UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_LIMIT_INFO(); + + timeoutinfo.Flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_PROPERTY_FLAG_PRESENT; + timeoutinfo.DrainEntityBody = + (ushort)timeouts[(int)UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.DrainEntityBody]; + timeoutinfo.EntityBody = + (ushort)timeouts[(int)UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.EntityBody]; + timeoutinfo.RequestQueue = + (ushort)timeouts[(int)UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.RequestQueue]; + timeoutinfo.IdleConnection = + (ushort)timeouts[(int)UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.IdleConnection]; + timeoutinfo.HeaderWait = + (ushort)timeouts[(int)UnsafeNclNativeMethods.HttpApi.HTTP_TIMEOUT_TYPE.HeaderWait]; + timeoutinfo.MinSendRate = minSendBytesPerSecond; + + IntPtr infoptr = new IntPtr(&timeoutinfo); + + _server.SetUrlGroupProperty( + UnsafeNclNativeMethods.HttpApi.HTTP_SERVER_PROPERTY.HttpServerTimeoutsProperty, + infoptr, (uint)TimeoutLimitSize); + } + + #endregion Helpers + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/ValidationHelper.cs b/src/Microsoft.AspNet.Server.WebListener/ValidationHelper.cs new file mode 100644 index 0000000000..104e49e320 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/ValidationHelper.cs @@ -0,0 +1,67 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Globalization; + +namespace Microsoft.AspNet.Server.WebListener +{ + internal static class ValidationHelper + { + public static string ExceptionMessage(Exception exception) + { + if (exception == null) + { + return string.Empty; + } + if (exception.InnerException == null) + { + return exception.Message; + } + return exception.Message + " (" + ExceptionMessage(exception.InnerException) + ")"; + } + + public static string ToString(object objectValue) + { + if (objectValue == null) + { + return "(null)"; + } + else if (objectValue is string && ((string)objectValue).Length == 0) + { + return "(string.empty)"; + } + else if (objectValue is Exception) + { + return ExceptionMessage(objectValue as Exception); + } + else if (objectValue is IntPtr) + { + return "0x" + ((IntPtr)objectValue).ToString("x"); + } + else + { + return objectValue.ToString(); + } + } + + public static string HashString(object objectValue) + { + if (objectValue == null) + { + return "(null)"; + } + else if (objectValue is string && ((string)objectValue).Length == 0) + { + return "(string.empty)"; + } + else + { + return objectValue.GetHashCode().ToString(NumberFormatInfo.InvariantInfo); + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/WebListenerException.cs b/src/Microsoft.AspNet.Server.WebListener/WebListenerException.cs new file mode 100644 index 0000000000..19a7b78a6e --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/WebListenerException.cs @@ -0,0 +1,44 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.ComponentModel; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.Server.WebListener +{ +#if NET45 + [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2237:MarkISerializableTypesWithSerializable")] +#endif + internal class WebListenerException : Win32Exception + { + internal WebListenerException() + : base(Marshal.GetLastWin32Error()) + { + } + + internal WebListenerException(int errorCode) + : base(errorCode) + { + } + + internal WebListenerException(int errorCode, string message) + : base(errorCode, message) + { + } + + public override int ErrorCode + { + // the base class returns the HResult with this property + // we need the Win32 Error Code, hence the override. + + get + { + return NativeErrorCode; + } + } + } +} diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/CriticalHandleZeroOrMinusOneIsInvalid.cs b/src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/CriticalHandleZeroOrMinusOneIsInvalid.cs new file mode 100644 index 0000000000..43897343aa --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/CriticalHandleZeroOrMinusOneIsInvalid.cs @@ -0,0 +1,32 @@ +// ==++== +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ==--== + +#if !NET45 + +namespace Microsoft.Win32.SafeHandles +{ + using System; + using System.Runtime.InteropServices; + using System.Runtime.CompilerServices; + + // Class of critical handle which uses 0 or -1 as an invalid handle. + [System.Security.SecurityCritical] // auto-generated_required + internal abstract class CriticalHandleZeroOrMinusOneIsInvalid : CriticalHandle + { + protected CriticalHandleZeroOrMinusOneIsInvalid() + : base(IntPtr.Zero) + { + } + + public override bool IsInvalid + { + [System.Security.SecurityCritical] + get { return handle == new IntPtr(0) || handle == new IntPtr(-1); } + } + } +} + +#endif diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs b/src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs new file mode 100644 index 0000000000..861c9e31ab --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs @@ -0,0 +1,31 @@ +// ==++== +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ==--== + +#if !NET45 + +namespace Microsoft.Win32.SafeHandles +{ + using System; + using System.Runtime.InteropServices; + using System.Runtime.CompilerServices; + + // Class of safe handle which uses 0 or -1 as an invalid handle. + [System.Security.SecurityCritical] // auto-generated_required + internal abstract class SafeHandleZeroOrMinusOneIsInvalid : SafeHandle + { + protected SafeHandleZeroOrMinusOneIsInvalid(bool ownsHandle) + : base(IntPtr.Zero, ownsHandle) + { + } + + public override bool IsInvalid + { + [System.Security.SecurityCritical] + get { return handle == new IntPtr(0) || handle == new IntPtr(-1); } + } + } +} +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/System/ComponentModel/Win32Exception.cs b/src/Microsoft.AspNet.Server.WebListener/fx/System/ComponentModel/Win32Exception.cs new file mode 100644 index 0000000000..108303ee2c --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/System/ComponentModel/Win32Exception.cs @@ -0,0 +1,112 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 + +using System.Runtime.InteropServices; +using System.Text; + +namespace System.ComponentModel +{ + internal class Win32Exception : ExternalException + { + /// + /// Represents the Win32 error code associated with this exception. This + /// field is read-only. + /// + private readonly int nativeErrorCode; + + /// + /// Initializes a new instance of the class with the last Win32 error + /// that occured. + /// + public Win32Exception() + : this(Marshal.GetLastWin32Error()) + { + } + /// + /// Initializes a new instance of the class with the specified error. + /// + public Win32Exception(int error) + : this(error, GetErrorMessage(error)) + { + } + /// + /// Initializes a new instance of the class with the specified error and the + /// specified detailed description. + /// + public Win32Exception(int error, string message) + : base(message) + { + nativeErrorCode = error; + } + + /// + /// Initializes a new instance of the Exception class with a specified error message. + /// FxCop CA1032: Multiple constructors are required to correctly implement a custom exception. + /// + public Win32Exception(string message) + : this(Marshal.GetLastWin32Error(), message) + { + } + + /// + /// Initializes a new instance of the Exception class with a specified error message and a + /// reference to the inner exception that is the cause of this exception. + /// FxCop CA1032: Multiple constructors are required to correctly implement a custom exception. + /// + public Win32Exception(string message, Exception innerException) + : base(message, innerException) + { + nativeErrorCode = Marshal.GetLastWin32Error(); + } + + + /// + /// Represents the Win32 error code associated with this exception. This + /// field is read-only. + /// + public int NativeErrorCode + { + get + { + return nativeErrorCode; + } + } + + private static string GetErrorMessage(int error) + { + //get the system error message... + string errorMsg = ""; + StringBuilder sb = new StringBuilder(256); + int result = SafeNativeMethods.FormatMessage( + SafeNativeMethods.FORMAT_MESSAGE_IGNORE_INSERTS | + SafeNativeMethods.FORMAT_MESSAGE_FROM_SYSTEM | + SafeNativeMethods.FORMAT_MESSAGE_ARGUMENT_ARRAY, + IntPtr.Zero, (uint)error, 0, sb, sb.Capacity + 1, + null); + if (result != 0) + { + int i = sb.Length; + while (i > 0) + { + char ch = sb[i - 1]; + if (ch > 32 && ch != '.') break; + i--; + } + errorMsg = sb.ToString(0, i); + } + else + { + errorMsg = "Unknown error (0x" + Convert.ToString(error, 16) + ")"; + } + + return errorMsg; + } + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/System/Diagnostics/TraceEventType.cs b/src/Microsoft.AspNet.Server.WebListener/fx/System/Diagnostics/TraceEventType.cs new file mode 100644 index 0000000000..36d0b93d25 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/System/Diagnostics/TraceEventType.cs @@ -0,0 +1,35 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 + +using System; +using System.ComponentModel; + +namespace System.Diagnostics +{ + internal enum TraceEventType + { + Critical = 0x01, + Error = 0x02, + Warning = 0x04, + Information = 0x08, + Verbose = 0x10, + + [EditorBrowsable(EditorBrowsableState.Advanced)] + Start = 0x0100, + [EditorBrowsable(EditorBrowsableState.Advanced)] + Stop = 0x0200, + [EditorBrowsable(EditorBrowsableState.Advanced)] + Suspend = 0x0400, + [EditorBrowsable(EditorBrowsableState.Advanced)] + Resume = 0x0800, + [EditorBrowsable(EditorBrowsableState.Advanced)] + Transfer = 0x1000, + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/System/ExternDll.cs b/src/Microsoft.AspNet.Server.WebListener/fx/System/ExternDll.cs new file mode 100644 index 0000000000..98303d48b8 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/System/ExternDll.cs @@ -0,0 +1,16 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 + +namespace System +{ + internal static class ExternDll + { + public const string Kernel32 = "kernel32.dll"; + } +} +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/System/Runtime/InteropServices/ExternalException.cs b/src/Microsoft.AspNet.Server.WebListener/fx/System/Runtime/InteropServices/ExternalException.cs new file mode 100644 index 0000000000..21d82b1de4 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/System/Runtime/InteropServices/ExternalException.cs @@ -0,0 +1,98 @@ +// ==++== +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ==--== +/*============================================================================= +** +** Class: ExternalException +** +** +** Purpose: Exception base class for all errors from Interop or Structured +** Exception Handling code. +** +** +=============================================================================*/ + +#if !NET45 + +namespace System.Runtime.InteropServices +{ + using System; + using System.Globalization; + + // Base exception for COM Interop errors &; Structured Exception Handler + // exceptions. + // + internal class ExternalException : Exception + { + public ExternalException() + { + SetErrorCode(__HResults.E_FAIL); + } + + public ExternalException(String message) + : base(message) + { + SetErrorCode(__HResults.E_FAIL); + } + + public ExternalException(String message, Exception inner) + : base(message, inner) + { + SetErrorCode(__HResults.E_FAIL); + } + + public ExternalException(String message, int errorCode) + : base(message) + { + SetErrorCode(errorCode); + } + + private void SetErrorCode(int errorCode) + { + HResult = ErrorCode; + } + + private static class __HResults + { + internal const int E_FAIL = unchecked((int)0x80004005); + } + + public virtual int ErrorCode + { + get + { + return HResult; + } + } + + public override String ToString() + { + String message = Message; + String s; + String _className = GetType().ToString(); + s = _className + " (0x" + HResult.ToString("X8", CultureInfo.InvariantCulture) + ")"; + + if (!(String.IsNullOrEmpty(message))) + { + s = s + ": " + message; + } + + Exception _innerException = InnerException; + + if (_innerException != null) + { + s = s + " ---> " + _innerException.ToString(); + } + + + if (StackTrace != null) + s += Environment.NewLine + StackTrace; + + return s; + } + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/System/SafeNativeMethods.cs b/src/Microsoft.AspNet.Server.WebListener/fx/System/SafeNativeMethods.cs new file mode 100644 index 0000000000..969d273fa8 --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/System/SafeNativeMethods.cs @@ -0,0 +1,28 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 +using System.Runtime.InteropServices; +using System.Text; + +namespace System +{ + internal static class SafeNativeMethods + { + public const int + FORMAT_MESSAGE_ALLOCATE_BUFFER = 0x00000100, + FORMAT_MESSAGE_IGNORE_INSERTS = 0x00000200, + FORMAT_MESSAGE_FROM_STRING = 0x00000400, + FORMAT_MESSAGE_FROM_SYSTEM = 0x00001000, + FORMAT_MESSAGE_ARGUMENT_ARRAY = 0x00002000; + + [DllImport(ExternDll.Kernel32, CharSet = System.Runtime.InteropServices.CharSet.Unicode, SetLastError = true, BestFitMapping = true)] + public static unsafe extern int FormatMessage(int dwFlags, IntPtr lpSource_mustBeNull, uint dwMessageId, + int dwLanguageId, StringBuilder lpBuffer, int nSize, IntPtr[] arguments); + + } +} +#endif diff --git a/src/Microsoft.AspNet.Server.WebListener/fx/System/Security/Authentication/ExtendedProtection/ChannelBinding.cs b/src/Microsoft.AspNet.Server.WebListener/fx/System/Security/Authentication/ExtendedProtection/ChannelBinding.cs new file mode 100644 index 0000000000..3c31d34a1b --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/fx/System/Security/Authentication/ExtendedProtection/ChannelBinding.cs @@ -0,0 +1,33 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System.Runtime.InteropServices; +using Microsoft.Win32.SafeHandles; + +#if !NET45 + +namespace System.Security.Authentication.ExtendedProtection +{ + internal abstract class ChannelBinding : SafeHandleZeroOrMinusOneIsInvalid + { + protected ChannelBinding() + : base(true) + { + } + + protected ChannelBinding(bool ownsHandle) + : base(ownsHandle) + { + } + + public abstract int Size + { + get; + } + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.Server.WebListener/project.json b/src/Microsoft.AspNet.Server.WebListener/project.json new file mode 100644 index 0000000000..0ba47798dd --- /dev/null +++ b/src/Microsoft.AspNet.Server.WebListener/project.json @@ -0,0 +1,8 @@ +{ + "version": "0.1-alpha-*", + "compilationOptions" : { "allowUnsafe": true }, + "configurations": { + "net45" : { }, + "k10" : { } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/Constants.cs b/src/Microsoft.AspNet.WebSockets/Constants.cs new file mode 100644 index 0000000000..6e9cb51ad3 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/Constants.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. + +namespace Microsoft.AspNet.WebSockets +{ + /// + /// Standard keys and values for use within the OWIN interfaces + /// + internal static class Constants + { + internal const string WebSocketAcceptKey = "websocket.Accept"; + internal const string WebSocketSubProtocolKey = "websocket.SubProtocol"; + internal const string WebSocketSendAsyncKey = "websocket.SendAsync"; + internal const string WebSocketReceiveAyncKey = "websocket.ReceiveAsync"; + internal const string WebSocketCloseAsyncKey = "websocket.CloseAsync"; + internal const string WebSocketCallCancelledKey = "websocket.CallCancelled"; + internal const string WebSocketVersionKey = "websocket.Version"; + internal const string WebSocketVersion = "1.0"; + internal const string WebSocketCloseStatusKey = "websocket.ClientCloseStatus"; + internal const string WebSocketCloseDescriptionKey = "websocket.ClientCloseDescription"; + } +} diff --git a/src/Microsoft.AspNet.WebSockets/HttpKnownHeaderNames.cs b/src/Microsoft.AspNet.WebSockets/HttpKnownHeaderNames.cs new file mode 100644 index 0000000000..f4887cc569 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/HttpKnownHeaderNames.cs @@ -0,0 +1,76 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.WebSockets +{ + // this class contains known header names + internal static class HttpKnownHeaderNames + { + public const string CacheControl = "Cache-Control"; + public const string Connection = "Connection"; + public const string Date = "Date"; + public const string KeepAlive = "Keep-Alive"; + public const string Pragma = "Pragma"; + public const string ProxyConnection = "Proxy-Connection"; + public const string Trailer = "Trailer"; + public const string TransferEncoding = "Transfer-Encoding"; + public const string Upgrade = "Upgrade"; + public const string Via = "Via"; + public const string Warning = "Warning"; + public const string ContentLength = "Content-Length"; + public const string ContentType = "Content-Type"; + public const string ContentDisposition = "Content-Disposition"; + public const string ContentEncoding = "Content-Encoding"; + public const string ContentLanguage = "Content-Language"; + public const string ContentLocation = "Content-Location"; + public const string ContentRange = "Content-Range"; + public const string Expires = "Expires"; + public const string LastModified = "Last-Modified"; + public const string Age = "Age"; + public const string Location = "Location"; + public const string ProxyAuthenticate = "Proxy-Authenticate"; + public const string RetryAfter = "Retry-After"; + public const string Server = "Server"; + public const string SetCookie = "Set-Cookie"; + public const string SetCookie2 = "Set-Cookie2"; + public const string Vary = "Vary"; + public const string WWWAuthenticate = "WWW-Authenticate"; + public const string Accept = "Accept"; + public const string AcceptCharset = "Accept-Charset"; + public const string AcceptEncoding = "Accept-Encoding"; + public const string AcceptLanguage = "Accept-Language"; + public const string Authorization = "Authorization"; + public const string Cookie = "Cookie"; + public const string Cookie2 = "Cookie2"; + public const string Expect = "Expect"; + public const string From = "From"; + public const string Host = "Host"; + public const string IfMatch = "If-Match"; + public const string IfModifiedSince = "If-Modified-Since"; + public const string IfNoneMatch = "If-None-Match"; + public const string IfRange = "If-Range"; + public const string IfUnmodifiedSince = "If-Unmodified-Since"; + public const string MaxForwards = "Max-Forwards"; + public const string ProxyAuthorization = "Proxy-Authorization"; + public const string Referer = "Referer"; + public const string Range = "Range"; + public const string UserAgent = "User-Agent"; + public const string ContentMD5 = "Content-MD5"; + public const string ETag = "ETag"; + public const string TE = "TE"; + public const string Allow = "Allow"; + public const string AcceptRanges = "Accept-Ranges"; + public const string P3P = "P3P"; + public const string XPoweredBy = "X-Powered-By"; + public const string XAspNetVersion = "X-AspNet-Version"; + public const string SecWebSocketKey = "Sec-WebSocket-Key"; + public const string SecWebSocketExtensions = "Sec-WebSocket-Extensions"; + public const string SecWebSocketAccept = "Sec-WebSocket-Accept"; + public const string Origin = "Origin"; + public const string SecWebSocketProtocol = "Sec-WebSocket-Protocol"; + public const string SecWebSocketVersion = "Sec-WebSocket-Version"; + } +} diff --git a/src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerContext.cs b/src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerContext.cs new file mode 100644 index 0000000000..506c830ea2 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerContext.cs @@ -0,0 +1,58 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ +/* +namespace Microsoft.Net +{ + using Microsoft.AspNet.WebSockets; + using System; + using System.ComponentModel; + using System.Threading.Tasks; + + public sealed unsafe class HttpListenerContext { + private HttpListenerRequest m_Request; + + public Task AcceptWebSocketAsync(string subProtocol) + { + return this.AcceptWebSocketAsync(subProtocol, + WebSocketHelpers.DefaultReceiveBufferSize, + WebSocket.DefaultKeepAliveInterval); + } + + public Task AcceptWebSocketAsync(string subProtocol, TimeSpan keepAliveInterval) + { + return this.AcceptWebSocketAsync(subProtocol, + WebSocketHelpers.DefaultReceiveBufferSize, + keepAliveInterval); + } + + public Task AcceptWebSocketAsync(string subProtocol, + int receiveBufferSize, + TimeSpan keepAliveInterval) + { + WebSocketHelpers.ValidateOptions(subProtocol, receiveBufferSize, WebSocketBuffer.MinSendBufferSize, keepAliveInterval); + + ArraySegment internalBuffer = WebSocketBuffer.CreateInternalBufferArraySegment(receiveBufferSize, WebSocketBuffer.MinSendBufferSize, true); + return this.AcceptWebSocketAsync(subProtocol, + receiveBufferSize, + keepAliveInterval, + internalBuffer); + } + + [EditorBrowsable(EditorBrowsableState.Never)] + public Task AcceptWebSocketAsync(string subProtocol, + int receiveBufferSize, + TimeSpan keepAliveInterval, + ArraySegment internalBuffer) + { + return WebSocketHelpers.AcceptWebSocketAsync(this, + subProtocol, + receiveBufferSize, + keepAliveInterval, + internalBuffer); + } + } +} +*/ \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerRequest.cs b/src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerRequest.cs new file mode 100644 index 0000000000..d0c02c1842 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/Legacy/HttpListenerRequest.cs @@ -0,0 +1,66 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +/* +namespace Microsoft.Net +{ + using System; + using System.Collections; + using System.Collections.Specialized; + using System.IO; + using System.Runtime.InteropServices; + using System.Globalization; + using System.Text; + using System.Security.Principal; + using System.Security.Cryptography.X509Certificates; + using System.Net; + using Microsoft.AspNet.WebSockets; + + public sealed unsafe class HttpListenerRequest { + + public bool IsWebSocketRequest + { + get + { + if (!WebSocketProtocolComponent.IsSupported) + { + return false; + } + + bool foundConnectionUpgradeHeader = false; + if (string.IsNullOrEmpty(this.Headers[HttpKnownHeaderNames.Connection]) || string.IsNullOrEmpty(this.Headers[HttpKnownHeaderNames.Upgrade])) + { + return false; + } + + foreach (string connection in this.Headers.GetValues(HttpKnownHeaderNames.Connection)) + { + if (string.Compare(connection, HttpKnownHeaderNames.Upgrade, StringComparison.OrdinalIgnoreCase) == 0) + { + foundConnectionUpgradeHeader = true; + break; + } + } + + if (!foundConnectionUpgradeHeader) + { + return false; + } + + foreach (string upgrade in this.Headers.GetValues(HttpKnownHeaderNames.Upgrade)) + { + if (string.Compare(upgrade, WebSocketHelpers.WebSocketUpgradeToken, StringComparison.OrdinalIgnoreCase) == 0) + { + return true; + } + } + + return false; + } + } + } +} +*/ \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/Legacy/SR.cs b/src/Microsoft.AspNet.WebSockets/Legacy/SR.cs new file mode 100644 index 0000000000..a8aeb2a30e --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/Legacy/SR.cs @@ -0,0 +1,66 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All Rights Reserved. +// Information Contained Herein is Proprietary and Confidential. +// +//------------------------------------------------------------------------------ + +namespace System +{ + internal sealed class SR + { + internal const string net_servicePointAddressNotSupportedInHostMode = "net_servicePointAddressNotSupportedInHostMode"; + internal const string net_Websockets_AlreadyOneOutstandingOperation = "net_Websockets_AlreadyOneOutstandingOperation"; + internal const string net_Websockets_WebSocketBaseFaulted = "net_Websockets_WebSocketBaseFaulted"; + internal const string net_WebSockets_NativeSendResponseHeaders = "net_WebSockets_NativeSendResponseHeaders"; + internal const string net_WebSockets_Generic = "net_WebSockets_Generic"; + internal const string net_WebSockets_NotAWebSocket_Generic = "net_WebSockets_NotAWebSocket_Generic"; + internal const string net_WebSockets_UnsupportedWebSocketVersion_Generic = "net_WebSockets_UnsupportedWebSocketVersion_Generic"; + internal const string net_WebSockets_HeaderError_Generic = "net_WebSockets_HeaderError_Generic"; + internal const string net_WebSockets_UnsupportedProtocol_Generic = "net_WebSockets_UnsupportedProtocol_Generic"; + internal const string net_WebSockets_UnsupportedPlatform = "net_WebSockets_UnsupportedPlatform"; + internal const string net_WebSockets_AcceptNotAWebSocket = "net_WebSockets_AcceptNotAWebSocket"; + internal const string net_WebSockets_AcceptUnsupportedWebSocketVersion = "net_WebSockets_AcceptUnsupportedWebSocketVersion"; + internal const string net_WebSockets_AcceptHeaderNotFound = "net_WebSockets_AcceptHeaderNotFound"; + internal const string net_WebSockets_AcceptUnsupportedProtocol = "net_WebSockets_AcceptUnsupportedProtocol"; + internal const string net_WebSockets_ClientAcceptingNoProtocols = "net_WebSockets_ClientAcceptingNoProtocols"; + internal const string net_WebSockets_ClientSecWebSocketProtocolsBlank = "net_WebSockets_ClientSecWebSocketProtocolsBlank"; + internal const string net_WebSockets_ArgumentOutOfRange_TooSmall = "net_WebSockets_ArgumentOutOfRange_TooSmall"; + internal const string net_WebSockets_ArgumentOutOfRange_InternalBuffer = "net_WebSockets_ArgumentOutOfRange_InternalBuffer"; + internal const string net_WebSockets_ArgumentOutOfRange_TooBig = "net_WebSockets_ArgumentOutOfRange_TooBig"; + internal const string net_WebSockets_InvalidState_Generic = "net_WebSockets_InvalidState_Generic"; + internal const string net_WebSockets_InvalidState_ClosedOrAborted = "net_WebSockets_InvalidState_ClosedOrAborted"; + internal const string net_WebSockets_InvalidState = "net_WebSockets_InvalidState"; + internal const string net_WebSockets_ReceiveAsyncDisallowedAfterCloseAsync = "net_WebSockets_ReceiveAsyncDisallowedAfterCloseAsync"; + internal const string net_WebSockets_InvalidMessageType = "net_WebSockets_InvalidMessageType"; + internal const string net_WebSockets_InvalidBufferType = "net_WebSockets_InvalidBufferType"; + internal const string net_WebSockets_InvalidMessageType_Generic = "net_WebSockets_InvalidMessageType_Generic"; + internal const string net_WebSockets_Argument_InvalidMessageType = "net_WebSockets_Argument_InvalidMessageType"; + internal const string net_WebSockets_ConnectionClosedPrematurely_Generic = "net_WebSockets_ConnectionClosedPrematurely_Generic"; + internal const string net_WebSockets_InvalidCharInProtocolString = "net_WebSockets_InvalidCharInProtocolString"; + internal const string net_WebSockets_InvalidEmptySubProtocol = "net_WebSockets_InvalidEmptySubProtocol"; + internal const string net_WebSockets_ReasonNotNull = "net_WebSockets_ReasonNotNull"; + internal const string net_WebSockets_InvalidCloseStatusCode = "net_WebSockets_InvalidCloseStatusCode"; + internal const string net_WebSockets_InvalidCloseStatusDescription = "net_WebSockets_InvalidCloseStatusDescription"; + internal const string net_WebSockets_Scheme = "net_WebSockets_Scheme"; + internal const string net_WebSockets_AlreadyStarted = "net_WebSockets_AlreadyStarted"; + internal const string net_WebSockets_Connect101Expected = "net_WebSockets_Connect101Expected"; + internal const string net_WebSockets_InvalidResponseHeader = "net_WebSockets_InvalidResponseHeader"; + internal const string net_WebSockets_NotConnected = "net_WebSockets_NotConnected"; + internal const string net_WebSockets_InvalidRegistration = "net_WebSockets_InvalidRegistration"; + internal const string net_WebSockets_NoDuplicateProtocol = "net_WebSockets_NoDuplicateProtocol"; + + internal const string NotReadableStream = "NotReadableStream"; + internal const string NotWriteableStream = "NotWriteableStream"; + + public static string GetString(string name, params object[] args) + { + return name; + } + + public static string GetString(string name) + { + return name; + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/Legacy/WebSocketHttpListenerDuplexStream.cs b/src/Microsoft.AspNet.WebSockets/Legacy/WebSocketHttpListenerDuplexStream.cs new file mode 100644 index 0000000000..54abff91bd --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/Legacy/WebSocketHttpListenerDuplexStream.cs @@ -0,0 +1,1262 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ +/* +namespace Microsoft.AspNet.WebSockets +{ + using Microsoft.Net; + using System; + using System.Collections.Concurrent; + using System.Collections.Generic; + using System.ComponentModel; + using System.Diagnostics; + using System.Diagnostics.Contracts; + using System.Globalization; + using System.IO; + using System.Runtime.InteropServices; + using System.Security; + using System.Threading; + using System.Threading.Tasks; + + internal class WebSocketHttpListenerDuplexStream : Stream, WebSocketBase.IWebSocketStream + { + private static readonly EventHandler s_OnReadCompleted = + new EventHandler(OnReadCompleted); + private static readonly EventHandler s_OnWriteCompleted = + new EventHandler(OnWriteCompleted); + private static readonly Func s_CanHandleException = new Func(CanHandleException); + private static readonly Action s_OnCancel = new Action(OnCancel); + // private readonly HttpRequestStream m_InputStream; + // private readonly HttpResponseStream m_OutputStream; + private HttpListenerContext m_Context; + private bool m_InOpaqueMode; + private WebSocketBase m_WebSocket; + private HttpListenerAsyncEventArgs m_WriteEventArgs; + private HttpListenerAsyncEventArgs m_ReadEventArgs; + private TaskCompletionSource m_WriteTaskCompletionSource; + private TaskCompletionSource m_ReadTaskCompletionSource; + private int m_CleanedUp; + +#if DEBUG + private class OutstandingOperations + { + internal int m_Reads; + internal int m_Writes; + } + + private readonly OutstandingOperations m_OutstandingOperations = new OutstandingOperations(); +#endif //DEBUG + + public WebSocketHttpListenerDuplexStream( + // HttpRequestStream inputStream, + // HttpResponseStream outputStream, + HttpListenerContext context) + { + Contract.Assert(inputStream != null, "'inputStream' MUST NOT be NULL."); + Contract.Assert(outputStream != null, "'outputStream' MUST NOT be NULL."); + Contract.Assert(context != null, "'context' MUST NOT be NULL."); + Contract.Assert(inputStream.CanRead, "'inputStream' MUST support read operations."); + Contract.Assert(outputStream.CanWrite, "'outputStream' MUST support write operations."); + + m_InputStream = inputStream; + m_OutputStream = outputStream; + m_Context = context; + + if (WebSocketBase.LoggingEnabled) + { + Logging.Associate(Logging.WebSockets, inputStream, this); + Logging.Associate(Logging.WebSockets, outputStream, this); + } + } + + public override bool CanRead + { + get + { + return m_InputStream.CanRead; + } + } + + public override bool CanSeek + { + get + { + return false; + } + } + + public override bool CanTimeout + { + get + { + return m_InputStream.CanTimeout && m_OutputStream.CanTimeout; + } + } + + public override bool CanWrite + { + get + { + return m_OutputStream.CanWrite; + } + } + + public override long Length + { + get + { + throw new NotSupportedException(SR.GetString(SR.net_noseek)); + } + } + + public override long Position + { + get + { + throw new NotSupportedException(SR.GetString(SR.net_noseek)); + } + set + { + throw new NotSupportedException(SR.GetString(SR.net_noseek)); + } + } + + public override int Read(byte[] buffer, int offset, int count) + { + return m_InputStream.Read(buffer, offset, count); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + WebSocketHelpers.ValidateBuffer(buffer, offset, count); + + return ReadAsyncCore(buffer, offset, count, cancellationToken); + } + + private async Task ReadAsyncCore(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.ReadAsyncCore, + WebSocketHelpers.GetTraceMsgForParameters(offset, count, cancellationToken)); + } + + CancellationTokenRegistration cancellationTokenRegistration = new CancellationTokenRegistration(); + + int bytesRead = 0; + try + { + if (cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.Register(s_OnCancel, this, false); + } + + if (!m_InOpaqueMode) + { + bytesRead = await m_InputStream.ReadAsync(buffer, offset, count, cancellationToken).SuppressContextFlow(); + } + else + { +#if DEBUG + // When using fast path only one outstanding read is permitted. By switching into opaque mode + // via IWebSocketStream.SwitchToOpaqueMode (see more detailed comments in interface definition) + // caller takes responsibility for enforcing this constraint. + Contract.Assert(Interlocked.Increment(ref m_OutstandingOperations.m_Reads) == 1, + "Only one outstanding read allowed at any given time."); +#endif + m_ReadTaskCompletionSource = new TaskCompletionSource(); + m_ReadEventArgs.SetBuffer(buffer, offset, count); + if (!ReadAsyncFast(m_ReadEventArgs)) + { + if (m_ReadEventArgs.Exception != null) + { + throw m_ReadEventArgs.Exception; + } + + bytesRead = m_ReadEventArgs.BytesTransferred; + } + else + { + bytesRead = await m_ReadTaskCompletionSource.Task.SuppressContextFlow(); + } + } + } + catch (Exception error) + { + if (s_CanHandleException(error)) + { + cancellationToken.ThrowIfCancellationRequested(); + } + + throw; + } + finally + { + cancellationTokenRegistration.Dispose(); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.ReadAsyncCore, bytesRead); + } + } + + return bytesRead; + } + + // return value indicates sync vs async completion + // false: sync completion + // true: async completion + private unsafe bool ReadAsyncFast(HttpListenerAsyncEventArgs eventArgs) + { + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.ReadAsyncFast, string.Empty); + } + + eventArgs.StartOperationCommon(this); + eventArgs.StartOperationReceive(); + + uint statusCode = 0; + bool completedAsynchronously = false; + try + { + Contract.Assert(eventArgs.Buffer != null, "'BufferList' is not supported for read operations."); + if (eventArgs.Count == 0 || m_InputStream.Closed) + { + eventArgs.FinishOperationSuccess(0, true); + return false; + } + + uint dataRead = 0; + int offset = eventArgs.Offset; + int remainingCount = eventArgs.Count; + + if (m_InputStream.BufferedDataChunksAvailable) + { + dataRead = m_InputStream.GetChunks(eventArgs.Buffer, eventArgs.Offset, eventArgs.Count); + if (m_InputStream.BufferedDataChunksAvailable && dataRead == eventArgs.Count) + { + eventArgs.FinishOperationSuccess(eventArgs.Count, true); + return false; + } + } + + Contract.Assert(!m_InputStream.BufferedDataChunksAvailable, "'m_InputStream.BufferedDataChunksAvailable' MUST BE 'FALSE' at this point."); + Contract.Assert(dataRead <= eventArgs.Count, "'dataRead' MUST NOT be bigger than 'eventArgs.Count'."); + + if (dataRead != 0) + { + offset += (int)dataRead; + remainingCount -= (int)dataRead; + //the http.sys team recommends that we limit the size to 128kb + if (remainingCount > HttpRequestStream.MaxReadSize) + { + remainingCount = HttpRequestStream.MaxReadSize; + } + + eventArgs.SetBuffer(eventArgs.Buffer, offset, remainingCount); + } + else if (remainingCount > HttpRequestStream.MaxReadSize) + { + remainingCount = HttpRequestStream.MaxReadSize; + eventArgs.SetBuffer(eventArgs.Buffer, offset, remainingCount); + } + + // m_InputStream.InternalHttpContext.EnsureBoundHandle(); + uint flags = 0; + uint bytesReturned = 0; + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpReceiveRequestEntityBody2( + m_InputStream.InternalHttpContext.RequestQueueHandle, + m_InputStream.InternalHttpContext.RequestId, + flags, + (byte*)m_WebSocket.InternalBuffer.ToIntPtr(eventArgs.Offset), + (uint)eventArgs.Count, + out bytesReturned, + eventArgs.NativeOverlapped); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + throw new HttpListenerException((int)statusCode); + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + HttpListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously. No IO completion port callback is used because + // it was disabled in SwitchToOpaqueMode() + eventArgs.FinishOperationSuccess((int)bytesReturned, true); + completedAsynchronously = false; + } + else + { + completedAsynchronously = true; + } + } + catch (Exception e) + { + m_ReadEventArgs.FinishOperationFailure(e, true); + m_OutputStream.SetClosedFlag(); + m_OutputStream.InternalHttpContext.Abort(); + + throw; + } + finally + { + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.ReadAsyncFast, completedAsynchronously); + } + } + + return completedAsynchronously; + } + + public override int ReadByte() + { + return m_InputStream.ReadByte(); + } + + public bool SupportsMultipleWrite + { + get + { + return true; + } + } + + public override IAsyncResult BeginRead(byte[] buffer, + int offset, + int count, + AsyncCallback callback, + object state) + { + return m_InputStream.BeginRead(buffer, offset, count, callback, state); + } + + public override int EndRead(IAsyncResult asyncResult) + { + return m_InputStream.EndRead(asyncResult); + } + + public Task MultipleWriteAsync(IList> sendBuffers, CancellationToken cancellationToken) + { + Contract.Assert(m_InOpaqueMode, "The stream MUST be in opaque mode at this point."); + Contract.Assert(sendBuffers != null, "'sendBuffers' MUST NOT be NULL."); + Contract.Assert(sendBuffers.Count == 1 || sendBuffers.Count == 2, + "'sendBuffers.Count' MUST be either '1' or '2'."); + + if (sendBuffers.Count == 1) + { + ArraySegment buffer = sendBuffers[0]; + return WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken); + } + + return MultipleWriteAsyncCore(sendBuffers, cancellationToken); + } + + private async Task MultipleWriteAsyncCore(IList> sendBuffers, CancellationToken cancellationToken) + { + Contract.Assert(sendBuffers != null, "'sendBuffers' MUST NOT be NULL."); + Contract.Assert(sendBuffers.Count == 2, "'sendBuffers.Count' MUST be '2' at this point."); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.MultipleWriteAsyncCore, string.Empty); + } + + CancellationTokenRegistration cancellationTokenRegistration = new CancellationTokenRegistration(); + + try + { + if (cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.Register(s_OnCancel, this, false); + } +#if DEBUG + // When using fast path only one outstanding read is permitted. By switching into opaque mode + // via IWebSocketStream.SwitchToOpaqueMode (see more detailed comments in interface definition) + // caller takes responsibility for enforcing this constraint. + Contract.Assert(Interlocked.Increment(ref m_OutstandingOperations.m_Writes) == 1, + "Only one outstanding write allowed at any given time."); +#endif + m_WriteTaskCompletionSource = new TaskCompletionSource(); + m_WriteEventArgs.SetBuffer(null, 0, 0); + m_WriteEventArgs.BufferList = sendBuffers; + if (WriteAsyncFast(m_WriteEventArgs)) + { + await m_WriteTaskCompletionSource.Task.SuppressContextFlow(); + } + } + catch (Exception error) + { + if (s_CanHandleException(error)) + { + cancellationToken.ThrowIfCancellationRequested(); + } + + throw; + } + finally + { + cancellationTokenRegistration.Dispose(); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.MultipleWriteAsyncCore, string.Empty); + } + } + } + + public override void Write(byte[] buffer, int offset, int count) + { + m_OutputStream.Write(buffer, offset, count); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + WebSocketHelpers.ValidateBuffer(buffer, offset, count); + + return WriteAsyncCore(buffer, offset, count, cancellationToken); + } + + private async Task WriteAsyncCore(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.WriteAsyncCore, + WebSocketHelpers.GetTraceMsgForParameters(offset, count, cancellationToken)); + } + + CancellationTokenRegistration cancellationTokenRegistration = new CancellationTokenRegistration(); + + try + { + if (cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.Register(s_OnCancel, this, false); + } + + if (!m_InOpaqueMode) + { + await m_OutputStream.WriteAsync(buffer, offset, count, cancellationToken).SuppressContextFlow(); + } + else + { +#if DEBUG + // When using fast path only one outstanding read is permitted. By switching into opaque mode + // via IWebSocketStream.SwitchToOpaqueMode (see more detailed comments in interface definition) + // caller takes responsibility for enforcing this constraint. + Contract.Assert(Interlocked.Increment(ref m_OutstandingOperations.m_Writes) == 1, + "Only one outstanding write allowed at any given time."); +#endif + m_WriteTaskCompletionSource = new TaskCompletionSource(); + m_WriteEventArgs.BufferList = null; + m_WriteEventArgs.SetBuffer(buffer, offset, count); + if (WriteAsyncFast(m_WriteEventArgs)) + { + await m_WriteTaskCompletionSource.Task.SuppressContextFlow(); + } + } + } + catch (Exception error) + { + if (s_CanHandleException(error)) + { + cancellationToken.ThrowIfCancellationRequested(); + } + + throw; + } + finally + { + cancellationTokenRegistration.Dispose(); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.WriteAsyncCore, string.Empty); + } + } + } + + // return value indicates sync vs async completion + // false: sync completion + // true: async completion + private bool WriteAsyncFast(HttpListenerAsyncEventArgs eventArgs) + { + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.WriteAsyncFast, string.Empty); + } + + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; + + eventArgs.StartOperationCommon(this); + eventArgs.StartOperationSend(); + + uint statusCode; + bool completedAsynchronously = false; + try + { + if (m_OutputStream.Closed || + (eventArgs.Buffer != null && eventArgs.Count == 0)) + { + eventArgs.FinishOperationSuccess(eventArgs.Count, true); + return false; + } + + if (eventArgs.ShouldCloseOutput) + { + flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_DISCONNECT; + } + else + { + flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA; + // When using HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA HTTP.SYS will copy the payload to + // kernel memory (Non-Paged Pool). Http.Sys will buffer up to + // Math.Min(16 MB, current TCP window size) + flags |= UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA; + } + + m_OutputStream.InternalHttpContext.EnsureBoundHandle(); + uint bytesSent; + statusCode = + UnsafeNclNativeMethods.HttpApi.HttpSendResponseEntityBody2( + m_OutputStream.InternalHttpContext.RequestQueueHandle, + m_OutputStream.InternalHttpContext.RequestId, + (uint)flags, + eventArgs.EntityChunkCount, + eventArgs.EntityChunks, + out bytesSent, + SafeLocalFree.Zero, + 0, + eventArgs.NativeOverlapped, + IntPtr.Zero); + + if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING) + { + throw new HttpListenerException((int)statusCode); + } + else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && + HttpListener.SkipIOCPCallbackOnSuccess) + { + // IO operation completed synchronously - callback won't be called to signal completion. + eventArgs.FinishOperationSuccess((int)bytesSent, true); + completedAsynchronously = false; + } + else + { + completedAsynchronously = true; + } + } + catch (Exception e) + { + m_WriteEventArgs.FinishOperationFailure(e, true); + m_OutputStream.SetClosedFlag(); + m_OutputStream.InternalHttpContext.Abort(); + + throw; + } + finally + { + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.WriteAsyncFast, completedAsynchronously); + } + } + + return completedAsynchronously; + } + + public override void WriteByte(byte value) + { + m_OutputStream.WriteByte(value); + } + + public override IAsyncResult BeginWrite(byte[] buffer, + int offset, + int count, + AsyncCallback callback, + object state) + { + return m_OutputStream.BeginWrite(buffer, offset, count, callback, state); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + m_OutputStream.EndWrite(asyncResult); + } + + public override void Flush() + { + m_OutputStream.Flush(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return m_OutputStream.FlushAsync(cancellationToken); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(SR.GetString(SR.net_noseek)); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(SR.GetString(SR.net_noseek)); + } + + public async Task CloseNetworkConnectionAsync(CancellationToken cancellationToken) + { + // need to yield here to make sure that we don't get any exception synchronously + await Task.Yield(); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.CloseNetworkConnectionAsync, string.Empty); + } + + CancellationTokenRegistration cancellationTokenRegistration = new CancellationTokenRegistration(); + + try + { + if (cancellationToken.CanBeCanceled) + { + cancellationTokenRegistration = cancellationToken.Register(s_OnCancel, this, false); + } +#if DEBUG + // When using fast path only one outstanding read is permitted. By switching into opaque mode + // via IWebSocketStream.SwitchToOpaqueMode (see more detailed comments in interface definition) + // caller takes responsibility for enforcing this constraint. + Contract.Assert(Interlocked.Increment(ref m_OutstandingOperations.m_Writes) == 1, + "Only one outstanding write allowed at any given time."); +#endif + m_WriteTaskCompletionSource = new TaskCompletionSource(); + m_WriteEventArgs.SetShouldCloseOutput(); + if (WriteAsyncFast(m_WriteEventArgs)) + { + await m_WriteTaskCompletionSource.Task.SuppressContextFlow(); + } + } + catch (Exception error) + { + if (!s_CanHandleException(error)) + { + throw; + } + + // throw OperationCancelledException when canceled by the caller + // otherwise swallow the exception + cancellationToken.ThrowIfCancellationRequested(); + } + finally + { + cancellationTokenRegistration.Dispose(); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.CloseNetworkConnectionAsync, string.Empty); + } + } + } + + protected override void Dispose(bool disposing) + { + if (disposing && Interlocked.Exchange(ref m_CleanedUp, 1) == 0) + { + if (m_ReadTaskCompletionSource != null) + { + m_ReadTaskCompletionSource.TrySetCanceled(); + } + + if (m_WriteTaskCompletionSource != null) + { + m_WriteTaskCompletionSource.TrySetCanceled(); + } + + if (m_ReadEventArgs != null) + { + m_ReadEventArgs.Dispose(); + } + + if (m_WriteEventArgs != null) + { + m_WriteEventArgs.Dispose(); + } + + try + { + m_InputStream.Close(); + } + finally + { + m_OutputStream.Close(); + } + } + } + + public void Abort() + { + OnCancel(this); + } + + private static bool CanHandleException(Exception error) + { + return error is HttpListenerException || + error is ObjectDisposedException || + error is IOException; + } + + private static void OnCancel(object state) + { + Contract.Assert(state != null, "'state' MUST NOT be NULL."); + WebSocketHttpListenerDuplexStream thisPtr = state as WebSocketHttpListenerDuplexStream; + Contract.Assert(thisPtr != null, "'thisPtr' MUST NOT be NULL."); + + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, state, Methods.OnCancel, string.Empty); + } + + try + { + thisPtr.m_OutputStream.SetClosedFlag(); + // thisPtr.m_Context.Abort(); + } + catch { } + + TaskCompletionSource readTaskCompletionSourceSnapshot = thisPtr.m_ReadTaskCompletionSource; + + if (readTaskCompletionSourceSnapshot != null) + { + readTaskCompletionSourceSnapshot.TrySetCanceled(); + } + + TaskCompletionSource writeTaskCompletionSourceSnapshot = thisPtr.m_WriteTaskCompletionSource; + + if (writeTaskCompletionSourceSnapshot != null) + { + writeTaskCompletionSourceSnapshot.TrySetCanceled(); + } + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, state, Methods.OnCancel, string.Empty); + } + } + + public void SwitchToOpaqueMode(WebSocketBase webSocket) + { + Contract.Assert(webSocket != null, "'webSocket' MUST NOT be NULL."); + Contract.Assert(m_OutputStream != null, "'m_OutputStream' MUST NOT be NULL."); + Contract.Assert(m_OutputStream.InternalHttpContext != null, + "'m_OutputStream.InternalHttpContext' MUST NOT be NULL."); + Contract.Assert(m_OutputStream.InternalHttpContext.Response != null, + "'m_OutputStream.InternalHttpContext.Response' MUST NOT be NULL."); + Contract.Assert(m_OutputStream.InternalHttpContext.Response.SentHeaders, + "Headers MUST have been sent at this point."); + Contract.Assert(!m_InOpaqueMode, "SwitchToOpaqueMode MUST NOT be called multiple times."); + + if (m_InOpaqueMode) + { + throw new InvalidOperationException(); + } + + m_WebSocket = webSocket; + m_InOpaqueMode = true; + m_ReadEventArgs = new HttpListenerAsyncEventArgs(webSocket, this); + m_ReadEventArgs.Completed += s_OnReadCompleted; + m_WriteEventArgs = new HttpListenerAsyncEventArgs(webSocket, this); + m_WriteEventArgs.Completed += s_OnWriteCompleted; + + if (WebSocketBase.LoggingEnabled) + { + Logging.Associate(Logging.WebSockets, this, webSocket); + } + } + + private static void OnWriteCompleted(object sender, HttpListenerAsyncEventArgs eventArgs) + { + Contract.Assert(eventArgs != null, "'eventArgs' MUST NOT be NULL."); + WebSocketHttpListenerDuplexStream thisPtr = eventArgs.CurrentStream; + Contract.Assert(thisPtr != null, "'thisPtr' MUST NOT be NULL."); +#if DEBUG + Contract.Assert(Interlocked.Decrement(ref thisPtr.m_OutstandingOperations.m_Writes) >= 0, + "'thisPtr.m_OutstandingOperations.m_Writes' MUST NOT be negative."); +#endif + + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, thisPtr, Methods.OnWriteCompleted, string.Empty); + } + + if (eventArgs.Exception != null) + { + thisPtr.m_WriteTaskCompletionSource.TrySetException(eventArgs.Exception); + } + else + { + thisPtr.m_WriteTaskCompletionSource.TrySetResult(null); + } + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, thisPtr, Methods.OnWriteCompleted, string.Empty); + } + } + + private static void OnReadCompleted(object sender, HttpListenerAsyncEventArgs eventArgs) + { + Contract.Assert(eventArgs != null, "'eventArgs' MUST NOT be NULL."); + WebSocketHttpListenerDuplexStream thisPtr = eventArgs.CurrentStream; + Contract.Assert(thisPtr != null, "'thisPtr' MUST NOT be NULL."); +#if DEBUG + Contract.Assert(Interlocked.Decrement(ref thisPtr.m_OutstandingOperations.m_Reads) >= 0, + "'thisPtr.m_OutstandingOperations.m_Reads' MUST NOT be negative."); +#endif + + if (WebSocketBase.LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, thisPtr, Methods.OnReadCompleted, string.Empty); + } + + if (eventArgs.Exception != null) + { + thisPtr.m_ReadTaskCompletionSource.TrySetException(eventArgs.Exception); + } + else + { + thisPtr.m_ReadTaskCompletionSource.TrySetResult(eventArgs.BytesTransferred); + } + + if (WebSocketBase.LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, thisPtr, Methods.OnReadCompleted, string.Empty); + } + } + + internal class HttpListenerAsyncEventArgs : EventArgs, IDisposable + { + private const int Free = 0; + private const int InProgress = 1; + private const int Disposed = 2; + private int m_Operating; + + private bool m_DisposeCalled; + private SafeNativeOverlapped m_PtrNativeOverlapped; + private Overlapped m_Overlapped; + private event EventHandler m_Completed; + private byte[] m_Buffer; + private IList> m_BufferList; + private int m_Count; + private int m_Offset; + private int m_BytesTransferred; + private HttpListenerAsyncOperation m_CompletedOperation; + private UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK[] m_DataChunks; + private GCHandle m_DataChunksGCHandle; + private ushort m_DataChunkCount; + private Exception m_Exception; + private bool m_ShouldCloseOutput; + private readonly WebSocketBase m_WebSocket; + private readonly WebSocketHttpListenerDuplexStream m_CurrentStream; + + public HttpListenerAsyncEventArgs(WebSocketBase webSocket, WebSocketHttpListenerDuplexStream stream) + : base() + { + m_WebSocket = webSocket; + m_CurrentStream = stream; + InitializeOverlapped(); + } + + public int BytesTransferred + { + get { return m_BytesTransferred; } + } + + public byte[] Buffer + { + get { return m_Buffer; } + } + + // BufferList property. + // Mutually exclusive with Buffer. + // Setting this property with an existing non-null Buffer will cause an assert. + public IList> BufferList + { + get { return m_BufferList; } + set + { + Contract.Assert(!m_ShouldCloseOutput, "'m_ShouldCloseOutput' MUST be 'false' at this point."); + Contract.Assert(value == null || m_Buffer == null, + "Either 'm_Buffer' or 'm_BufferList' MUST be NULL."); + Contract.Assert(m_Operating == Free, + "This property can only be modified if no IO operation is outstanding."); + Contract.Assert(value == null || value.Count == 2, + "This list can only be 'NULL' or MUST have exactly '2' items."); + m_BufferList = value; + } + } + + public bool ShouldCloseOutput + { + get { return m_ShouldCloseOutput; } + } + + public int Offset + { + get { return m_Offset; } + } + + public int Count + { + get { return m_Count; } + } + + public Exception Exception + { + get { return m_Exception; } + } + + public ushort EntityChunkCount + { + get + { + if (m_DataChunks == null) + { + return 0; + } + + return m_DataChunkCount; + } + } + + public SafeNativeOverlapped NativeOverlapped + { + get { return m_PtrNativeOverlapped; } + } + + public IntPtr EntityChunks + { + get + { + if (m_DataChunks == null) + { + return IntPtr.Zero; + } + + return Marshal.UnsafeAddrOfPinnedArrayElement(m_DataChunks, 0); + } + } + + public WebSocketHttpListenerDuplexStream CurrentStream + { + get { return m_CurrentStream; } + } + + public event EventHandler Completed + { + add + { + m_Completed += value; + } + remove + { + m_Completed -= value; + } + } + + protected virtual void OnCompleted(HttpListenerAsyncEventArgs e) + { + EventHandler handler = m_Completed; + if (handler != null) + { + handler(e.m_CurrentStream, e); + } + } + + public void SetShouldCloseOutput() + { + m_BufferList = null; + m_Buffer = null; + m_ShouldCloseOutput = true; + } + + public void Dispose() + { + // Remember that Dispose was called. + m_DisposeCalled = true; + + // Check if this object is in-use for an async socket operation. + if (Interlocked.CompareExchange(ref m_Operating, Disposed, Free) != Free) + { + // Either already disposed or will be disposed when current operation completes. + return; + } + + // OK to dispose now. + // Free native overlapped data. + FreeOverlapped(false); + + // Don't bother finalizing later. + GC.SuppressFinalize(this); + } + + // Finalizer + ~HttpListenerAsyncEventArgs() + { + FreeOverlapped(true); + } + + private unsafe void InitializeOverlapped() + { + m_Overlapped = new Overlapped(); + m_PtrNativeOverlapped = new SafeNativeOverlapped(m_Overlapped.UnsafePack(CompletionPortCallback, null)); + } + + // Method to clean up any existing Overlapped object and related state variables. + private void FreeOverlapped(bool checkForShutdown) + { + if (!checkForShutdown || !NclUtilities.HasShutdownStarted) + { + // Free the overlapped object + if (m_PtrNativeOverlapped != null && !m_PtrNativeOverlapped.IsInvalid) + { + m_PtrNativeOverlapped.Dispose(); + } + + if (m_DataChunksGCHandle.IsAllocated) + { + m_DataChunksGCHandle.Free(); + } + } + } + + // Method called to prepare for a native async http.sys call. + // This method performs the tasks common to all http.sys operations. + internal void StartOperationCommon(WebSocketHttpListenerDuplexStream currentStream) + { + // Change status to "in-use". + if(Interlocked.CompareExchange(ref m_Operating, InProgress, Free) != Free) + { + // If it was already "in-use" check if Dispose was called. + if (m_DisposeCalled) + { + // Dispose was called - throw ObjectDisposed. + throw new ObjectDisposedException(GetType().FullName); + } + + Contract.Assert(false, "Only one outstanding async operation is allowed per HttpListenerAsyncEventArgs instance."); + // Only one at a time. + throw new InvalidOperationException(); + } + + // HttpSendResponseEntityBody can return ERROR_INVALID_PARAMETER if the InternalHigh field of the overlapped + // is not IntPtr.Zero, so we have to reset this field because we are reusing the Overlapped. + // When using the IAsyncResult based approach of HttpListenerResponseStream the Overlapped is reinitialized + // for each operation by the CLR when returned from the OverlappedDataCache. + NativeOverlapped.ReinitializeNativeOverlapped(); + m_Exception = null; + m_BytesTransferred = 0; + } + + internal void StartOperationReceive() + { + // Remember the operation type. + m_CompletedOperation = HttpListenerAsyncOperation.Receive; + } + + internal void StartOperationSend() + { + UpdateDataChunk(); + + // Remember the operation type. + m_CompletedOperation = HttpListenerAsyncOperation.Send; + } + + public void SetBuffer(byte[] buffer, int offset, int count) + { + Contract.Assert(!m_ShouldCloseOutput, "'m_ShouldCloseOutput' MUST be 'false' at this point."); + Contract.Assert(buffer == null || m_BufferList == null, "Either 'm_Buffer' or 'm_BufferList' MUST be NULL."); + m_Buffer = buffer; + m_Offset = offset; + m_Count = count; + } + + private unsafe void UpdateDataChunk() + { + if (m_DataChunks == null) + { + m_DataChunks = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK[2]; + m_DataChunksGCHandle = GCHandle.Alloc(m_DataChunks); + m_DataChunks[0] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + m_DataChunks[0].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + m_DataChunks[1] = new UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK(); + m_DataChunks[1].DataChunkType = UnsafeNclNativeMethods.HttpApi.HTTP_DATA_CHUNK_TYPE.HttpDataChunkFromMemory; + } + + Contract.Assert(m_Buffer == null || m_BufferList == null, "Either 'm_Buffer' or 'm_BufferList' MUST be NULL."); + Contract.Assert(m_ShouldCloseOutput || m_Buffer != null || m_BufferList != null, "Either 'm_Buffer' or 'm_BufferList' MUST NOT be NULL."); + + // The underlying byte[] m_Buffer or each m_BufferList[].Array are pinned already + if (m_Buffer != null) + { + UpdateDataChunk(0, m_Buffer, m_Offset, m_Count); + UpdateDataChunk(1, null, 0, 0); + m_DataChunkCount = 1; + } + else if (m_BufferList != null) + { + Contract.Assert(m_BufferList != null && m_BufferList.Count == 2, + "'m_BufferList' MUST NOT be NULL and have exactly '2' items at this point."); + UpdateDataChunk(0, m_BufferList[0].Array, m_BufferList[0].Offset, m_BufferList[0].Count); + UpdateDataChunk(1, m_BufferList[1].Array, m_BufferList[1].Offset, m_BufferList[1].Count); + m_DataChunkCount = 2; + } + else + { + Contract.Assert(m_ShouldCloseOutput, "'m_ShouldCloseOutput' MUST be 'true' at this point."); + m_DataChunks = null; + } + } + + private unsafe void UpdateDataChunk(int index, byte[] buffer, int offset, int count) + { + if (buffer == null) + { + m_DataChunks[index].pBuffer = null; + m_DataChunks[index].BufferLength = 0; + return; + } + + if (m_WebSocket.InternalBuffer.IsInternalBuffer(buffer, offset, count)) + { + m_DataChunks[index].pBuffer = (byte*)(m_WebSocket.InternalBuffer.ToIntPtr(offset)); + } + else + { + m_DataChunks[index].pBuffer = + (byte*)m_WebSocket.InternalBuffer.ConvertPinnedSendPayloadToNative(buffer, offset, count); + } + + m_DataChunks[index].BufferLength = (uint)count; + } + + // Method to mark this object as no longer "in-use". + // Will also execute a Dispose deferred because I/O was in progress. + internal void Complete() + { + // Mark as not in-use + m_Operating = Free; + + // Check for deferred Dispose(). + // The deferred Dispose is not guaranteed if Dispose is called while an operation is in progress. + // The m_DisposeCalled variable is not managed in a thread-safe manner on purpose for performance. + if (m_DisposeCalled) + { + Dispose(); + } + } + + // Method to update internal state after sync or async completion. + private void SetResults(Exception exception, int bytesTransferred) + { + m_Exception = exception; + m_BytesTransferred = bytesTransferred; + } + + internal void FinishOperationFailure(Exception exception, bool syncCompletion) + { + SetResults(exception, 0); + + if (WebSocketBase.LoggingEnabled) + { + Logging.PrintError(Logging.WebSockets, m_CurrentStream, + m_CompletedOperation == HttpListenerAsyncOperation.Receive ? Methods.ReadAsyncFast : Methods.WriteAsyncFast, + exception.ToString()); + } + + Complete(); + OnCompleted(this); + } + + internal void FinishOperationSuccess(int bytesTransferred, bool syncCompletion) + { + SetResults(null, bytesTransferred); + + if (WebSocketBase.LoggingEnabled) + { + if (m_Buffer != null) + { + Logging.Dump(Logging.WebSockets, m_CurrentStream, + m_CompletedOperation == HttpListenerAsyncOperation.Receive ? Methods.ReadAsyncFast : Methods.WriteAsyncFast, + m_Buffer, m_Offset, bytesTransferred); + } + else if (m_BufferList != null) + { + Contract.Assert(m_CompletedOperation == HttpListenerAsyncOperation.Send, + "'BufferList' is only supported for send operations."); + + foreach (ArraySegment buffer in BufferList) + { + Logging.Dump(Logging.WebSockets, this, Methods.WriteAsyncFast, buffer.Array, buffer.Offset, buffer.Count); + } + } + else + { + Logging.PrintLine(Logging.WebSockets, TraceEventType.Verbose, 0, + string.Format(CultureInfo.InvariantCulture, "Output channel closed for {0}#{1}", + m_CurrentStream.GetType().Name, ValidationHelper.HashString(m_CurrentStream))); + } + } + + if (m_ShouldCloseOutput) + { + m_CurrentStream.m_OutputStream.SetClosedFlag(); + } + + // Complete the operation and raise completion event. + Complete(); + OnCompleted(this); + } + + private unsafe void CompletionPortCallback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped) + { + if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS || + errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_HANDLE_EOF) + { + FinishOperationSuccess((int)numBytes, false); + } + else + { + FinishOperationFailure(new HttpListenerException((int)errorCode), false); + } + } + + public enum HttpListenerAsyncOperation + { + None, + Receive, + Send + } + } + + private static class Methods + { + public const string CloseNetworkConnectionAsync = "CloseNetworkConnectionAsync"; + public const string OnCancel = "OnCancel"; + public const string OnReadCompleted = "OnReadCompleted"; + public const string OnWriteCompleted = "OnWriteCompleted"; + public const string ReadAsyncFast = "ReadAsyncFast"; + public const string ReadAsyncCore = "ReadAsyncCore"; + public const string WriteAsyncFast = "WriteAsyncFast"; + public const string WriteAsyncCore = "WriteAsyncCore"; + public const string MultipleWriteAsyncCore = "MultipleWriteAsyncCore"; + } + } +} +*/ \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeLoadLibrary.cs b/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeLoadLibrary.cs new file mode 100644 index 0000000000..43276eefe5 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeLoadLibrary.cs @@ -0,0 +1,40 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.WebSockets +{ + internal sealed class SafeLoadLibrary : SafeHandleZeroOrMinusOneIsInvalid + { + private const string KERNEL32 = "kernel32.dll"; + + public static readonly SafeLoadLibrary Zero = new SafeLoadLibrary(false); + + private SafeLoadLibrary() : base(true) + { + } + + private SafeLoadLibrary(bool ownsHandle) : base(ownsHandle) + { + } + + public static unsafe SafeLoadLibrary LoadLibraryEx(string library) + { + SafeLoadLibrary result = UnsafeNativeMethods.SafeNetHandles.LoadLibraryExW(library, null, 0); + if (result.IsInvalid) + { + result.SetHandleAsInvalid(); + } + return result; + } + + protected override bool ReleaseHandle() + { + return UnsafeNativeMethods.SafeNetHandles.FreeLibrary(handle); + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeNativeOverlapped.cs b/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeNativeOverlapped.cs new file mode 100644 index 0000000000..8682c89918 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeNativeOverlapped.cs @@ -0,0 +1,80 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Microsoft.AspNet.WebSockets +{ + internal class SafeNativeOverlapped : SafeHandle + { + internal static readonly SafeNativeOverlapped Zero = new SafeNativeOverlapped(); + + internal SafeNativeOverlapped() + : this(IntPtr.Zero) + { + } + + internal unsafe SafeNativeOverlapped(NativeOverlapped* handle) + : this((IntPtr)handle) + { + } + + internal SafeNativeOverlapped(IntPtr handle) + : base(IntPtr.Zero, true) + { + SetHandle(handle); + } + + public override bool IsInvalid + { + get { return handle == IntPtr.Zero; } + } + + public void ReinitializeNativeOverlapped() + { + IntPtr handleSnapshot = handle; + + if (handleSnapshot != IntPtr.Zero) + { + unsafe + { + ((NativeOverlapped*)handleSnapshot)->InternalHigh = IntPtr.Zero; + ((NativeOverlapped*)handleSnapshot)->InternalLow = IntPtr.Zero; + ((NativeOverlapped*)handleSnapshot)->EventHandle = IntPtr.Zero; + } + } + } + + protected override bool ReleaseHandle() + { + IntPtr oldHandle = Interlocked.Exchange(ref handle, IntPtr.Zero); + // Do not call free durring AppDomain shutdown, there may be an outstanding operation. + // Overlapped will take care calling free when the native callback completes. + if (oldHandle != IntPtr.Zero && !HasShutdownStarted) + { + unsafe + { + Overlapped.Free((NativeOverlapped*)oldHandle); + } + } + return true; + } + + internal static bool HasShutdownStarted + { + get + { + return Environment.HasShutdownStarted +#if NET45 + || AppDomain.CurrentDomain.IsFinalizingForUnload() +#endif + ; + } + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeWebSocketHandle.cs b/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeWebSocketHandle.cs new file mode 100644 index 0000000000..60c8341561 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/NativeInterop/SafeWebSocketHandle.cs @@ -0,0 +1,32 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using Microsoft.AspNet.WebSockets; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.AspNet.WebSockets +{ + // This class is a wrapper for a WSPC (WebSocket protocol component) session. WebSocketCreateClientHandle and WebSocketCreateServerHandle return a PVOID and not a real handle + // but we use a SafeHandle because it provides us the guarantee that WebSocketDeleteHandle will always get called. + internal sealed class SafeWebSocketHandle : SafeHandleZeroOrMinusOneIsInvalid + { + internal SafeWebSocketHandle() + : base(true) + { + } + + protected override bool ReleaseHandle() + { + if (this.IsInvalid) + { + return true; + } + + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketDeleteHandle(this.handle); + return true; + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/NativeInterop/UnsafeNativeMethods.cs b/src/Microsoft.AspNet.WebSockets/NativeInterop/UnsafeNativeMethods.cs new file mode 100644 index 0000000000..86355cdee4 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/NativeInterop/UnsafeNativeMethods.cs @@ -0,0 +1,842 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.Contracts; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Microsoft.AspNet.WebSockets +{ + internal static class UnsafeNativeMethods + { + private const string KERNEL32 = "kernel32.dll"; + private const string WEBSOCKET = "websocket.dll"; + + internal static class SafeNetHandles + { + [DllImport(KERNEL32, ExactSpelling = true, CharSet=CharSet.Unicode, SetLastError = true)] + internal static extern unsafe SafeLoadLibrary LoadLibraryExW([In] string lpwLibFileName, [In] void* hFile, [In] uint dwFlags); + + [DllImport(KERNEL32, ExactSpelling = true, SetLastError = true)] + internal static extern unsafe bool FreeLibrary([In] IntPtr hModule); + } + + internal static class WebSocketProtocolComponent + { + private static readonly string DllFileName; + private static readonly string DummyWebsocketKeyBase64 = Convert.ToBase64String(new byte[16]); + private static readonly SafeLoadLibrary WebSocketDllHandle; + private static readonly string PrivateSupportedVersion; + + private static readonly HttpHeader[] InitialClientRequestHeaders = new HttpHeader[] + { + new HttpHeader() + { + Name = HttpKnownHeaderNames.Connection, + NameLength = (uint)HttpKnownHeaderNames.Connection.Length, + Value = HttpKnownHeaderNames.Upgrade, + ValueLength = (uint)HttpKnownHeaderNames.Upgrade.Length + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.Upgrade, + NameLength = (uint)HttpKnownHeaderNames.Upgrade.Length, + Value = WebSocketHelpers.WebSocketUpgradeToken, + ValueLength = (uint)WebSocketHelpers.WebSocketUpgradeToken.Length + } + }; + + private static readonly HttpHeader[] ServerFakeRequestHeaders; + + internal static class Errors + { + internal const int E_INVALID_OPERATION = unchecked((int)0x80000050); + internal const int E_INVALID_PROTOCOL_OPERATION = unchecked((int)0x80000051); + internal const int E_INVALID_PROTOCOL_FORMAT = unchecked((int)0x80000052); + internal const int E_NUMERIC_OVERFLOW = unchecked((int)0x80000053); + internal const int E_FAIL = unchecked((int)0x80004005); + } + + internal enum Action + { + NoAction = 0, + SendToNetwork = 1, + IndicateSendComplete = 2, + ReceiveFromNetwork = 3, + IndicateReceiveComplete = 4, + } + + internal enum BufferType : uint + { + None = 0x00000000, + UTF8Message = 0x80000000, + UTF8Fragment = 0x80000001, + BinaryMessage = 0x80000002, + BinaryFragment = 0x80000003, + Close = 0x80000004, + PingPong = 0x80000005, + UnsolicitedPong = 0x80000006 + } + + internal enum PropertyType + { + ReceiveBufferSize = 0, + SendBufferSize = 1, + DisableMasking = 2, + AllocatedBuffer = 3, + DisableUtf8Verification = 4, + KeepAliveInterval = 5, + } + + internal enum ActionQueue + { + Send = 1, + Receive = 2, + } + + [StructLayout(LayoutKind.Sequential)] + internal struct Property + { + internal PropertyType Type; + internal IntPtr PropertyData; + internal uint PropertySize; + } + + [StructLayout(LayoutKind.Explicit)] + internal struct Buffer + { + [FieldOffset(0)] + internal DataBuffer Data; + [FieldOffset(0)] + internal CloseBuffer CloseStatus; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct DataBuffer + { + internal IntPtr BufferData; + internal uint BufferLength; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct CloseBuffer + { + internal IntPtr ReasonData; + internal uint ReasonLength; + internal ushort CloseStatus; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct HttpHeader + { + [MarshalAs(UnmanagedType.LPStr)] + internal string Name; + internal uint NameLength; + [MarshalAs(UnmanagedType.LPStr)] + internal string Value; + internal uint ValueLength; + } + + static WebSocketProtocolComponent() + { + DllFileName = Path.Combine(Environment.SystemDirectory, WEBSOCKET); + WebSocketDllHandle = SafeLoadLibrary.LoadLibraryEx(DllFileName); + + if (!WebSocketDllHandle.IsInvalid) + { + PrivateSupportedVersion = GetSupportedVersion(); + + ServerFakeRequestHeaders = new HttpHeader[] + { + new HttpHeader() + { + Name = HttpKnownHeaderNames.Connection, + NameLength = (uint)HttpKnownHeaderNames.Connection.Length, + Value = HttpKnownHeaderNames.Upgrade, + ValueLength = (uint)HttpKnownHeaderNames.Upgrade.Length + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.Upgrade, + NameLength = (uint)HttpKnownHeaderNames.Upgrade.Length, + Value = WebSocketHelpers.WebSocketUpgradeToken, + ValueLength = (uint)WebSocketHelpers.WebSocketUpgradeToken.Length + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.Host, + NameLength = (uint)HttpKnownHeaderNames.Host.Length, + Value = string.Empty, + ValueLength = 0 + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.SecWebSocketVersion, + NameLength = (uint)HttpKnownHeaderNames.SecWebSocketVersion.Length, + Value = SupportedVersion, + ValueLength = (uint)SupportedVersion.Length + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.SecWebSocketKey, + NameLength = (uint)HttpKnownHeaderNames.SecWebSocketKey.Length, + Value = DummyWebsocketKeyBase64, + ValueLength = (uint)DummyWebsocketKeyBase64.Length + } + }; + } + } + + internal static string SupportedVersion + { + get + { + if (WebSocketDllHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + return PrivateSupportedVersion; + } + } + + internal static bool IsSupported + { + get + { + return !WebSocketDllHandle.IsInvalid; + } + } + + internal static string GetSupportedVersion() + { + if (WebSocketDllHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + SafeWebSocketHandle webSocketHandle = null; + try + { + int errorCode = WebSocketCreateClientHandle_Raw(null, 0, out webSocketHandle); + ThrowOnError(errorCode); + + if (webSocketHandle == null || + webSocketHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + IntPtr additionalHeadersPtr; + uint additionalHeaderCount; + + errorCode = WebSocketBeginClientHandshake_Raw(webSocketHandle, + IntPtr.Zero, + 0, + IntPtr.Zero, + 0, + InitialClientRequestHeaders, + (uint)InitialClientRequestHeaders.Length, + out additionalHeadersPtr, + out additionalHeaderCount); + ThrowOnError(errorCode); + + HttpHeader[] additionalHeaders = MarshalHttpHeaders(additionalHeadersPtr, (int)additionalHeaderCount); + + string version = null; + foreach (HttpHeader header in additionalHeaders) + { + if (string.Compare(header.Name, + HttpKnownHeaderNames.SecWebSocketVersion, + StringComparison.OrdinalIgnoreCase) == 0) + { + version = header.Value; + break; + } + } + Contract.Assert(version != null, "'version' MUST NOT be NULL."); + + return version; + } + finally + { + if (webSocketHandle != null) + { + webSocketHandle.Dispose(); + } + } + } + + internal static void WebSocketCreateClientHandle(Property[] properties, + out SafeWebSocketHandle webSocketHandle) + { + uint propertyCount = properties == null ? 0 : (uint)properties.Length; + + if (WebSocketDllHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + int errorCode = WebSocketCreateClientHandle_Raw(properties, propertyCount, out webSocketHandle); + ThrowOnError(errorCode); + + if (webSocketHandle == null || + webSocketHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + IntPtr additionalHeadersPtr; + uint additionalHeaderCount; + + // Currently the WSPC doesn't allow to initiate a data session + // without also being involved in the http handshake + // There is no information whatsoever, which is needed by the + // WSPC for parsing WebSocket frames from the HTTP handshake + // In the managed implementation the HTTP header handling + // will be done using the managed HTTP stack and we will + // just fake an HTTP handshake for the WSPC calling + // WebSocketBeginClientHandshake and WebSocketEndClientHandshake + // with statically defined dummy headers. + errorCode = WebSocketBeginClientHandshake_Raw(webSocketHandle, + IntPtr.Zero, + 0, + IntPtr.Zero, + 0, + InitialClientRequestHeaders, + (uint)InitialClientRequestHeaders.Length, + out additionalHeadersPtr, + out additionalHeaderCount); + + ThrowOnError(errorCode); + + HttpHeader[] additionalHeaders = MarshalHttpHeaders(additionalHeadersPtr, (int)additionalHeaderCount); + + string key = null; + foreach (HttpHeader header in additionalHeaders) + { + if (string.Compare(header.Name, + HttpKnownHeaderNames.SecWebSocketKey, + StringComparison.OrdinalIgnoreCase) == 0) + { + key = header.Value; + break; + } + } + Contract.Assert(key != null, "'key' MUST NOT be NULL."); + + string acceptValue = WebSocketHelpers.GetSecWebSocketAcceptString(key); + HttpHeader[] responseHeaders = new HttpHeader[] + { + new HttpHeader() + { + Name = HttpKnownHeaderNames.Connection, + NameLength = (uint)HttpKnownHeaderNames.Connection.Length, + Value = HttpKnownHeaderNames.Upgrade, + ValueLength = (uint)HttpKnownHeaderNames.Upgrade.Length + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.Upgrade, + NameLength = (uint)HttpKnownHeaderNames.Upgrade.Length, + Value = WebSocketHelpers.WebSocketUpgradeToken, + ValueLength = (uint)WebSocketHelpers.WebSocketUpgradeToken.Length + }, + new HttpHeader() + { + Name = HttpKnownHeaderNames.SecWebSocketAccept, + NameLength = (uint)HttpKnownHeaderNames.SecWebSocketAccept.Length, + Value = acceptValue, + ValueLength = (uint)acceptValue.Length + } + }; + + errorCode = WebSocketEndClientHandshake_Raw(webSocketHandle, + responseHeaders, + (uint)responseHeaders.Length, + IntPtr.Zero, + IntPtr.Zero, + IntPtr.Zero); + + ThrowOnError(errorCode); + + Contract.Assert(webSocketHandle != null, "'webSocketHandle' MUST NOT be NULL at this point."); + } + + internal static void WebSocketCreateServerHandle(Property[] properties, + int propertyCount, + out SafeWebSocketHandle webSocketHandle) + { + Contract.Assert(propertyCount >= 0, "'propertyCount' MUST NOT be negative."); + Contract.Assert((properties == null && propertyCount == 0) || + (properties != null && propertyCount == properties.Length), + "'propertyCount' MUST MATCH 'properties.Length'."); + + if (WebSocketDllHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + int errorCode = WebSocketCreateServerHandle_Raw(properties, (uint)propertyCount, out webSocketHandle); + ThrowOnError(errorCode); + + if (webSocketHandle == null || + webSocketHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + IntPtr responseHeadersPtr; + uint responseHeaderCount; + + // Currently the WSPC doesn't allow to initiate a data session + // without also being involved in the http handshake + // There is no information whatsoever, which is needed by the + // WSPC for parsing WebSocket frames from the HTTP handshake + // In the managed implementation the HTTP header handling + // will be done using the managed HTTP stack and we will + // just fake an HTTP handshake for the WSPC calling + // WebSocketBeginServerHandshake and WebSocketEndServerHandshake + // with statically defined dummy headers. + errorCode = WebSocketBeginServerHandshake_Raw(webSocketHandle, + IntPtr.Zero, + IntPtr.Zero, + 0, + ServerFakeRequestHeaders, + (uint)ServerFakeRequestHeaders.Length, + out responseHeadersPtr, + out responseHeaderCount); + + ThrowOnError(errorCode); + + HttpHeader[] responseHeaders = MarshalHttpHeaders(responseHeadersPtr, (int)responseHeaderCount); + errorCode = WebSocketEndServerHandshake_Raw(webSocketHandle); + + ThrowOnError(errorCode); + + Contract.Assert(webSocketHandle != null, "'webSocketHandle' MUST NOT be NULL at this point."); + } + + internal static void WebSocketAbortHandle(SafeHandle webSocketHandle) + { + Contract.Assert(webSocketHandle != null && !webSocketHandle.IsInvalid, + "'webSocketHandle' MUST NOT be NULL or INVALID."); + + WebSocketAbortHandle_Raw(webSocketHandle); + + DrainActionQueue(webSocketHandle, ActionQueue.Send); + DrainActionQueue(webSocketHandle, ActionQueue.Receive); + } + + internal static void WebSocketDeleteHandle(IntPtr webSocketPtr) + { + Contract.Assert(webSocketPtr != IntPtr.Zero, "'webSocketPtr' MUST NOT be IntPtr.Zero."); + WebSocketDeleteHandle_Raw(webSocketPtr); + } + + internal static void WebSocketSend(WebSocketBase webSocket, + BufferType bufferType, + Buffer buffer) + { + Contract.Assert(webSocket != null, + "'webSocket' MUST NOT be NULL or INVALID."); + Contract.Assert(webSocket.SessionHandle != null && !webSocket.SessionHandle.IsInvalid, + "'webSocket.SessionHandle' MUST NOT be NULL or INVALID."); + + ThrowIfSessionHandleClosed(webSocket); + + int errorCode; + try + { + errorCode = WebSocketSend_Raw(webSocket.SessionHandle, bufferType, ref buffer, IntPtr.Zero); + } + catch (ObjectDisposedException innerException) + { + throw ConvertObjectDisposedException(webSocket, innerException); + } + + ThrowOnError(errorCode); + } + + internal static void WebSocketSendWithoutBody(WebSocketBase webSocket, + BufferType bufferType) + { + Contract.Assert(webSocket != null, + "'webSocket' MUST NOT be NULL or INVALID."); + Contract.Assert(webSocket.SessionHandle != null && !webSocket.SessionHandle.IsInvalid, + "'webSocket.SessionHandle' MUST NOT be NULL or INVALID."); + + ThrowIfSessionHandleClosed(webSocket); + + int errorCode; + try + { + errorCode = WebSocketSendWithoutBody_Raw(webSocket.SessionHandle, bufferType, IntPtr.Zero, IntPtr.Zero); + } + catch (ObjectDisposedException innerException) + { + throw ConvertObjectDisposedException(webSocket, innerException); + } + + ThrowOnError(errorCode); + } + + internal static void WebSocketReceive(WebSocketBase webSocket) + { + Contract.Assert(webSocket != null, + "'webSocket' MUST NOT be NULL or INVALID."); + Contract.Assert(webSocket.SessionHandle != null && !webSocket.SessionHandle.IsInvalid, + "'webSocket.SessionHandle' MUST NOT be NULL or INVALID."); + + ThrowIfSessionHandleClosed(webSocket); + + int errorCode; + try + { + errorCode = WebSocketReceive_Raw(webSocket.SessionHandle, IntPtr.Zero, IntPtr.Zero); + } + catch (ObjectDisposedException innerException) + { + throw ConvertObjectDisposedException(webSocket, innerException); + } + + ThrowOnError(errorCode); + } + + internal static void WebSocketGetAction(WebSocketBase webSocket, + ActionQueue actionQueue, + Buffer[] dataBuffers, + ref uint dataBufferCount, + out Action action, + out BufferType bufferType, + out IntPtr actionContext) + { + Contract.Assert(webSocket != null, + "'webSocket' MUST NOT be NULL or INVALID."); + Contract.Assert(webSocket.SessionHandle != null && !webSocket.SessionHandle.IsInvalid, + "'webSocket.SessionHandle' MUST NOT be NULL or INVALID."); + Contract.Assert(dataBufferCount >= 0, "'dataBufferCount' MUST NOT be negative."); + Contract.Assert((dataBuffers == null && dataBufferCount == 0) || + (dataBuffers != null && dataBufferCount == dataBuffers.Length), + "'dataBufferCount' MUST MATCH 'dataBuffers.Length'."); + + action = Action.NoAction; + bufferType = BufferType.None; + actionContext = IntPtr.Zero; + + IntPtr dummy; + ThrowIfSessionHandleClosed(webSocket); + + int errorCode; + try + { + errorCode = WebSocketGetAction_Raw(webSocket.SessionHandle, + actionQueue, + dataBuffers, + ref dataBufferCount, + out action, + out bufferType, + out dummy, + out actionContext); + } + catch (ObjectDisposedException innerException) + { + throw ConvertObjectDisposedException(webSocket, innerException); + } + ThrowOnError(errorCode); + + webSocket.ValidateNativeBuffers(action, bufferType, dataBuffers, dataBufferCount); + + Contract.Assert(dataBufferCount >= 0); + Contract.Assert((dataBufferCount == 0 && dataBuffers == null) || + (dataBufferCount <= dataBuffers.Length)); + } + + internal static void WebSocketCompleteAction(WebSocketBase webSocket, + IntPtr actionContext, + int bytesTransferred) + { + Contract.Assert(webSocket != null, + "'webSocket' MUST NOT be NULL or INVALID."); + Contract.Assert(webSocket.SessionHandle != null && !webSocket.SessionHandle.IsInvalid, + "'webSocket.SessionHandle' MUST NOT be NULL or INVALID."); + Contract.Assert(actionContext != IntPtr.Zero, "'actionContext' MUST NOT be IntPtr.Zero."); + Contract.Assert(bytesTransferred >= 0, "'bytesTransferred' MUST NOT be negative."); + + if (webSocket.SessionHandle.IsClosed) + { + return; + } + + try + { + WebSocketCompleteAction_Raw(webSocket.SessionHandle, actionContext, (uint)bytesTransferred); + } + catch (ObjectDisposedException) + { + } + } + + internal static TimeSpan WebSocketGetDefaultKeepAliveInterval() + { + uint result = 0; + uint size = sizeof(uint); + int errorCode = WebSocketGetGlobalProperty_Raw(PropertyType.KeepAliveInterval, ref result, ref size); + if (!Succeeded(errorCode)) + { + Contract.Assert(errorCode == 0, "errorCode: " + errorCode); + return Timeout.InfiniteTimeSpan; + } + return TimeSpan.FromMilliseconds(result); + } + + private static void DrainActionQueue(SafeHandle webSocketHandle, ActionQueue actionQueue) + { + Contract.Assert(webSocketHandle != null && !webSocketHandle.IsInvalid, + "'webSocketHandle' MUST NOT be NULL or INVALID."); + + IntPtr actionContext; + IntPtr dummy; + Action action; + BufferType bufferType; + + while (true) + { + Buffer[] dataBuffers = new Buffer[1]; + uint dataBufferCount = 1; + int errorCode = WebSocketGetAction_Raw(webSocketHandle, + actionQueue, + dataBuffers, + ref dataBufferCount, + out action, + out bufferType, + out dummy, + out actionContext); + + if (!Succeeded(errorCode)) + { + Contract.Assert(errorCode == 0, "'errorCode' MUST be 0."); + return; + } + + if (action == Action.NoAction) + { + return; + } + + WebSocketCompleteAction_Raw(webSocketHandle, actionContext, 0); + } + } + + private static void MarshalAndVerifyHttpHeader(IntPtr httpHeaderPtr, + ref HttpHeader httpHeader) + { + Contract.Assert(httpHeaderPtr != IntPtr.Zero, "'currentHttpHeaderPtr' MUST NOT be IntPtr.Zero."); + + IntPtr httpHeaderNamePtr = Marshal.ReadIntPtr(httpHeaderPtr); + IntPtr lengthPtr = IntPtr.Add(httpHeaderPtr, IntPtr.Size); + int length = Marshal.ReadInt32(lengthPtr); + Contract.Assert(length >= 0, "'length' MUST NOT be negative."); + + if (httpHeaderNamePtr != IntPtr.Zero) + { + httpHeader.Name = Marshal.PtrToStringAnsi(httpHeaderNamePtr, length); + } + + if ((httpHeader.Name == null && length != 0) || + (httpHeader.Name != null && length != httpHeader.Name.Length)) + { + Contract.Assert(false, "The length of 'httpHeader.Name' MUST MATCH 'length'."); + throw new AccessViolationException(); + } + + // structure of HttpHeader: + // Name = string* + // NameLength = uint* + // Value = string* + // ValueLength = uint* + // NOTE - All fields in the object are pointers to the actual value, hence the use of + // n * IntPtr.Size to get to the correct place in the object. + int valueOffset = 2 * IntPtr.Size; + int lengthOffset = 3 * IntPtr.Size; + + IntPtr httpHeaderValuePtr = + Marshal.ReadIntPtr(IntPtr.Add(httpHeaderPtr, valueOffset)); + lengthPtr = IntPtr.Add(httpHeaderPtr, lengthOffset); + length = Marshal.ReadInt32(lengthPtr); + httpHeader.Value = Marshal.PtrToStringAnsi(httpHeaderValuePtr, (int)length); + + if ((httpHeader.Value == null && length != 0) || + (httpHeader.Value != null && length != httpHeader.Value.Length)) + { + Contract.Assert(false, "The length of 'httpHeader.Value' MUST MATCH 'length'."); + throw new AccessViolationException(); + } + } + + private static HttpHeader[] MarshalHttpHeaders(IntPtr nativeHeadersPtr, + int nativeHeaderCount) + { + Contract.Assert(nativeHeaderCount >= 0, "'nativeHeaderCount' MUST NOT be negative."); + Contract.Assert(nativeHeadersPtr != IntPtr.Zero || nativeHeaderCount == 0, + "'nativeHeaderCount' MUST be 0."); + + HttpHeader[] httpHeaders = new HttpHeader[nativeHeaderCount]; + + // structure of HttpHeader: + // Name = string* + // NameLength = uint* + // Value = string* + // ValueLength = uint* + // NOTE - All fields in the object are pointers to the actual value, hence the use of + // 4 * IntPtr.Size to get to the next header. + int httpHeaderStructSize = 4 * IntPtr.Size; + + for (int i = 0; i < nativeHeaderCount; i++) + { + int offset = httpHeaderStructSize * i; + IntPtr currentHttpHeaderPtr = IntPtr.Add(nativeHeadersPtr, offset); + MarshalAndVerifyHttpHeader(currentHttpHeaderPtr, ref httpHeaders[i]); + } + + Contract.Assert(httpHeaders != null); + Contract.Assert(httpHeaders.Length == nativeHeaderCount); + + return httpHeaders; + } + + public static bool Succeeded(int hr) + { + return (hr >= 0); + } + + private static void ThrowOnError(int errorCode) + { + if (Succeeded(errorCode)) + { + return; + } + + throw new WebSocketException(errorCode); + } + + private static void ThrowIfSessionHandleClosed(WebSocketBase webSocket) + { + if (webSocket.SessionHandle.IsClosed) + { + throw new WebSocketException(WebSocketError.InvalidState, + SR.GetString(SR.net_WebSockets_InvalidState_ClosedOrAborted, webSocket.GetType().FullName, webSocket.State)); + } + } + + private static WebSocketException ConvertObjectDisposedException(WebSocketBase webSocket, ObjectDisposedException innerException) + { + return new WebSocketException(WebSocketError.InvalidState, + SR.GetString(SR.net_WebSockets_InvalidState_ClosedOrAborted, webSocket.GetType().FullName, webSocket.State), + innerException); + } + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketCreateClientHandle", ExactSpelling = true)] + private static extern int WebSocketCreateClientHandle_Raw( + [In]Property[] properties, + [In] uint propertyCount, + [Out] out SafeWebSocketHandle webSocketHandle); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketBeginClientHandshake", ExactSpelling = true)] + private static extern int WebSocketBeginClientHandshake_Raw( + [In] SafeHandle webSocketHandle, + [In] IntPtr subProtocols, + [In] uint subProtocolCount, + [In] IntPtr extensions, + [In] uint extensionCount, + [In] HttpHeader[] initialHeaders, + [In] uint initialHeaderCount, + [Out] out IntPtr additionalHeadersPtr, + [Out] out uint additionalHeaderCount); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketEndClientHandshake", ExactSpelling = true)] + private static extern int WebSocketEndClientHandshake_Raw([In] SafeHandle webSocketHandle, + [In] HttpHeader[] responseHeaders, + [In] uint responseHeaderCount, + [In, Out] IntPtr selectedExtensions, + [In] IntPtr selectedExtensionCount, + [In] IntPtr selectedSubProtocol); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketBeginServerHandshake", ExactSpelling = true)] + private static extern int WebSocketBeginServerHandshake_Raw( + [In] SafeHandle webSocketHandle, + [In] IntPtr subProtocol, + [In] IntPtr extensions, + [In] uint extensionCount, + [In] HttpHeader[] requestHeaders, + [In] uint requestHeaderCount, + [Out] out IntPtr responseHeadersPtr, + [Out] out uint responseHeaderCount); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketEndServerHandshake", ExactSpelling = true)] + private static extern int WebSocketEndServerHandshake_Raw([In] SafeHandle webSocketHandle); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketCreateServerHandle", ExactSpelling = true)] + private static extern int WebSocketCreateServerHandle_Raw( + [In]Property[] properties, + [In] uint propertyCount, + [Out] out SafeWebSocketHandle webSocketHandle); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketAbortHandle", ExactSpelling = true)] + private static extern void WebSocketAbortHandle_Raw( + [In] SafeHandle webSocketHandle); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketDeleteHandle", ExactSpelling = true)] + private static extern void WebSocketDeleteHandle_Raw( + [In] IntPtr webSocketHandle); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketSend", ExactSpelling = true)] + private static extern int WebSocketSend_Raw( + [In] SafeHandle webSocketHandle, + [In] BufferType bufferType, + [In] ref Buffer buffer, + [In] IntPtr applicationContext); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketSend", ExactSpelling = true)] + private static extern int WebSocketSendWithoutBody_Raw( + [In] SafeHandle webSocketHandle, + [In] BufferType bufferType, + [In] IntPtr buffer, + [In] IntPtr applicationContext); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketReceive", ExactSpelling = true)] + private static extern int WebSocketReceive_Raw( + [In] SafeHandle webSocketHandle, + [In] IntPtr buffers, + [In] IntPtr applicationContext); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketGetAction", ExactSpelling = true)] + private static extern int WebSocketGetAction_Raw( + [In] SafeHandle webSocketHandle, + [In] ActionQueue actionQueue, + [In, Out] Buffer[] dataBuffers, + [In, Out] ref uint dataBufferCount, + [Out] out Action action, + [Out] out BufferType bufferType, + [Out] out IntPtr applicationContext, + [Out] out IntPtr actionContext); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketCompleteAction", ExactSpelling = true)] + private static extern void WebSocketCompleteAction_Raw( + [In] SafeHandle webSocketHandle, + [In] IntPtr actionContext, + [In] uint bytesTransferred); + + [DllImport(WEBSOCKET, EntryPoint = "WebSocketGetGlobalProperty", ExactSpelling = true)] + private static extern int WebSocketGetGlobalProperty_Raw( + [In] PropertyType property, + [In, Out] ref uint value, + [In, Out] ref uint size); + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/OwinWebSocketWrapper.cs b/src/Microsoft.AspNet.WebSockets/OwinWebSocketWrapper.cs new file mode 100644 index 0000000000..a11c34488d --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/OwinWebSocketWrapper.cs @@ -0,0 +1,137 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebSockets +{ + using WebSocketCloseAsync = + Func; + using WebSocketReceiveAsync = + Func /* data */, + CancellationToken /* cancel */, + Task>>; + using WebSocketReceiveTuple = + Tuple; + using WebSocketSendAsync = + Func /* data */, + int /* messageType */, + bool /* endOfMessage */, + CancellationToken /* cancel */, + Task>; + + internal class OwinWebSocketWrapper + { + private readonly WebSocket _webSocket; + private readonly IDictionary _environment; + private readonly CancellationToken _cancellationToken; + + internal OwinWebSocketWrapper(WebSocket webSocket, CancellationToken ct) + { + _webSocket = webSocket; + _cancellationToken = ct; + + _environment = new Dictionary(); + _environment[Constants.WebSocketSendAsyncKey] = new WebSocketSendAsync(SendAsync); + _environment[Constants.WebSocketReceiveAyncKey] = new WebSocketReceiveAsync(ReceiveAsync); + _environment[Constants.WebSocketCloseAsyncKey] = new WebSocketCloseAsync(CloseAsync); + _environment[Constants.WebSocketCallCancelledKey] = ct; + _environment[Constants.WebSocketVersionKey] = Constants.WebSocketVersion; + } + + internal IDictionary Environment + { + get { return _environment; } + } + + internal Task SendAsync(ArraySegment buffer, int messageType, bool endOfMessage, CancellationToken cancel) + { + // Remap close messages to CloseAsync. System.Net.WebSockets.WebSocket.SendAsync does not allow close messages. + if (messageType == 0x8) + { + return RedirectSendToCloseAsync(buffer, cancel); + } + else if (messageType == 0x9 || messageType == 0xA) + { + // Ping & Pong, not allowed by the underlying APIs, silently discard. + return Task.FromResult(0); + } + + return _webSocket.SendAsync(buffer, (WebSocketMessageType)messageType, endOfMessage, cancel); + } + + internal async Task ReceiveAsync(ArraySegment buffer, CancellationToken cancel) + { + WebSocketReceiveResult nativeResult = await _webSocket.ReceiveAsync(buffer, cancel); + + if (nativeResult.MessageType == WebSocketMessageType.Close) + { + _environment[Constants.WebSocketCloseStatusKey] = (int)(nativeResult.CloseStatus ?? WebSocketCloseStatus.NormalClosure); + _environment[Constants.WebSocketCloseDescriptionKey] = nativeResult.CloseStatusDescription ?? string.Empty; + } + + return new WebSocketReceiveTuple( + (int)nativeResult.MessageType, + nativeResult.EndOfMessage, + nativeResult.Count); + } + + internal Task CloseAsync(int status, string description, CancellationToken cancel) + { + return _webSocket.CloseOutputAsync((WebSocketCloseStatus)status, description, cancel); + } + + private Task RedirectSendToCloseAsync(ArraySegment buffer, CancellationToken cancel) + { + if (buffer.Array == null || buffer.Count == 0) + { + return CloseAsync(1000, string.Empty, cancel); + } + else if (buffer.Count >= 2) + { + // Unpack the close message. + int statusCode = + (buffer.Array[buffer.Offset] << 8) + | buffer.Array[buffer.Offset + 1]; + string description = Encoding.UTF8.GetString(buffer.Array, buffer.Offset + 2, buffer.Count - 2); + + return CloseAsync(statusCode, description, cancel); + } + else + { + throw new ArgumentOutOfRangeException("buffer"); + } + } + + internal async Task CleanupAsync() + { + switch (_webSocket.State) + { + case WebSocketState.Closed: // Closed gracefully, no action needed. + case WebSocketState.Aborted: // Closed abortively, no action needed. + break; + case WebSocketState.CloseReceived: + // Echo what the client said, if anything. + await _webSocket.CloseAsync(_webSocket.CloseStatus ?? WebSocketCloseStatus.NormalClosure, + _webSocket.CloseStatusDescription ?? string.Empty, _cancellationToken); + break; + case WebSocketState.Open: + case WebSocketState.CloseSent: // No close received, abort so we don't have to drain the pipe. + _webSocket.Abort(); + break; + default: + throw new ArgumentOutOfRangeException("state", _webSocket.State, string.Empty); + } + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/Properties/AssemblyInfo.cs b/src/Microsoft.AspNet.WebSockets/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..1abade9a33 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/Properties/AssemblyInfo.cs @@ -0,0 +1,36 @@ +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("Microsoft.AspNet.WebSockets")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.AspNet.WebSockets")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("1f471909-581f-4060-a147-430891e9c3c1")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] diff --git a/src/Microsoft.AspNet.WebSockets/ServerWebSocket.cs b/src/Microsoft.AspNet.WebSockets/ServerWebSocket.cs new file mode 100644 index 0000000000..f23d13af54 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/ServerWebSocket.cs @@ -0,0 +1,62 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.IO; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.WebSockets +{ + internal sealed class ServerWebSocket : WebSocketBase + { + private readonly SafeHandle _sessionHandle; + private readonly UnsafeNativeMethods.WebSocketProtocolComponent.Property[] _properties; + + public ServerWebSocket(Stream innerStream, + string subProtocol, + int receiveBufferSize, + TimeSpan keepAliveInterval, + ArraySegment internalBuffer) + : base(innerStream, subProtocol, keepAliveInterval, + WebSocketBuffer.CreateServerBuffer(internalBuffer, receiveBufferSize)) + { + _properties = this.InternalBuffer.CreateProperties(false); + _sessionHandle = this.CreateWebSocketHandle(); + + if (_sessionHandle == null || _sessionHandle.IsInvalid) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + StartKeepAliveTimer(); + } + + internal override SafeHandle SessionHandle + { + get + { + Contract.Assert(_sessionHandle != null, "'m_SessionHandle MUST NOT be NULL."); + return _sessionHandle; + } + } + + [SuppressMessage("Microsoft.Security", "CA2122:DoNotIndirectlyExposeMethodsWithLinkDemands", + Justification = "No arbitrary data controlled by PT code is leaking into native code.")] + private SafeHandle CreateWebSocketHandle() + { + Contract.Assert(_properties != null, "'m_Properties' MUST NOT be NULL."); + SafeWebSocketHandle sessionHandle; + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketCreateServerHandle(_properties, + _properties.Length, + out sessionHandle); + Contract.Assert(sessionHandle != null, "'sessionHandle MUST NOT be NULL."); + + return sessionHandle; + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/WebSocket.cs b/src/Microsoft.AspNet.WebSockets/WebSocket.cs new file mode 100644 index 0000000000..f75e50a3c6 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocket.cs @@ -0,0 +1,124 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebSockets +{ + public abstract class WebSocket : IDisposable + { + private static TimeSpan? defaultKeepAliveInterval; + + public abstract WebSocketCloseStatus? CloseStatus { get; } + public abstract string CloseStatusDescription { get; } + public abstract string SubProtocol { get; } + public abstract WebSocketState State { get; } + + public static TimeSpan DefaultKeepAliveInterval + { + [SuppressMessage("Microsoft.Security", "CA2122:DoNotIndirectlyExposeMethodsWithLinkDemands", + Justification = "This is a harmless read-only operation")] + get + { + if (defaultKeepAliveInterval == null) + { + if (UnsafeNativeMethods.WebSocketProtocolComponent.IsSupported) + { + defaultKeepAliveInterval = UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketGetDefaultKeepAliveInterval(); + } + else + { + defaultKeepAliveInterval = Timeout.InfiniteTimeSpan; + } + } + return defaultKeepAliveInterval.Value; + } + } + + public static ArraySegment CreateClientBuffer(int receiveBufferSize, int sendBufferSize) + { + WebSocketHelpers.ValidateBufferSizes(receiveBufferSize, sendBufferSize); + + return WebSocketBuffer.CreateInternalBufferArraySegment(receiveBufferSize, sendBufferSize, false); + } + + public static ArraySegment CreateServerBuffer(int receiveBufferSize) + { + WebSocketHelpers.ValidateBufferSizes(receiveBufferSize, WebSocketBuffer.MinSendBufferSize); + + return WebSocketBuffer.CreateInternalBufferArraySegment(receiveBufferSize, WebSocketBuffer.MinSendBufferSize, true); + } + + internal static WebSocket CreateServerWebSocket(Stream innerStream, + string subProtocol, + int receiveBufferSize, + TimeSpan keepAliveInterval, + ArraySegment internalBuffer) + { + if (!UnsafeNativeMethods.WebSocketProtocolComponent.IsSupported) + { + WebSocketHelpers.ThrowPlatformNotSupportedException_WSPC(); + } + + WebSocketHelpers.ValidateInnerStream(innerStream); + WebSocketHelpers.ValidateOptions(subProtocol, receiveBufferSize, WebSocketBuffer.MinSendBufferSize, keepAliveInterval); + WebSocketHelpers.ValidateArraySegment(internalBuffer, "internalBuffer"); + WebSocketBuffer.Validate(internalBuffer.Count, receiveBufferSize, WebSocketBuffer.MinSendBufferSize, true); + + return new ServerWebSocket(innerStream, + subProtocol, + receiveBufferSize, + keepAliveInterval, + internalBuffer); + } + + public abstract void Abort(); + public abstract Task CloseAsync(WebSocketCloseStatus closeStatus, + string statusDescription, + CancellationToken cancellationToken); + public abstract Task CloseOutputAsync(WebSocketCloseStatus closeStatus, + string statusDescription, + CancellationToken cancellationToken); + [SuppressMessage("Microsoft.Design", "CA1063:ImplementIDisposableCorrectly", Justification = "This rule is outdated")] + public abstract void Dispose(); + public abstract Task ReceiveAsync(ArraySegment buffer, + CancellationToken cancellationToken); + public abstract Task SendAsync(ArraySegment buffer, + WebSocketMessageType messageType, + bool endOfMessage, + CancellationToken cancellationToken); + + protected static void ThrowOnInvalidState(WebSocketState state, params WebSocketState[] validStates) + { + string validStatesText = string.Empty; + + if (validStates != null && validStates.Length > 0) + { + foreach (WebSocketState currentState in validStates) + { + if (state == currentState) + { + return; + } + } + + validStatesText = string.Join(", ", validStates); + } + + throw new WebSocketException(SR.GetString(SR.net_WebSockets_InvalidState, state, validStatesText)); + } + + protected static bool IsStateTerminal(WebSocketState state) + { + return state == WebSocketState.Closed || + state == WebSocketState.Aborted; + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketBase.cs b/src/Microsoft.AspNet.WebSockets/WebSocketBase.cs new file mode 100644 index 0000000000..3e3b798396 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketBase.cs @@ -0,0 +1,2481 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebSockets +{ + internal abstract class WebSocketBase : WebSocket, IDisposable + { + // private static volatile bool s_LoggingEnabled; + + private readonly OutstandingOperationHelper _closeOutstandingOperationHelper; + private readonly OutstandingOperationHelper _closeOutputOutstandingOperationHelper; + private readonly OutstandingOperationHelper _receiveOutstandingOperationHelper; + private readonly OutstandingOperationHelper _sendOutstandingOperationHelper; + private readonly Stream _innerStream; + private readonly IWebSocketStream _innerStreamAsWebSocketStream; + private readonly string _subProtocol; + + // We are not calling Dispose method on this object in Cleanup method to avoid a race condition while one thread is calling disposing on + // this object and another one is still using WaitAsync. According to Dev11 358715, this should be fine as long as we are not accessing the + // AvailableWaitHandle on this SemaphoreSlim object. + private readonly SemaphoreSlim _sendFrameThrottle; + // locking m_ThisLock protects access to + // - State + // - m_CloseStack + // - m_CloseAsyncStartedReceive + // - m_CloseReceivedTaskCompletionSource + // - m_CloseNetworkConnectionTask + private readonly object _thisLock; + private readonly WebSocketBuffer _internalBuffer; + private readonly KeepAliveTracker _keepAliveTracker; + +#if DEBUG + private volatile string _closeStack; +#endif + + private volatile bool _cleanedUp; + private volatile TaskCompletionSource _closeReceivedTaskCompletionSource; + private volatile Task _closeOutputTask; + private volatile bool _isDisposed; + private volatile Task _closeNetworkConnectionTask; + private volatile bool _closeAsyncStartedReceive; + private volatile WebSocketState _state; + private volatile Task _keepAliveTask; + private volatile WebSocketOperation.ReceiveOperation _receiveOperation; + private volatile WebSocketOperation.SendOperation _sendOperation; + private volatile WebSocketOperation.SendOperation _keepAliveOperation; + private volatile WebSocketOperation.CloseOutputOperation _closeOutputOperation; + private WebSocketCloseStatus? _closeStatus; + private string _closeStatusDescription; + private int _receiveState; + private Exception _pendingException; + + protected WebSocketBase(Stream innerStream, + string subProtocol, + TimeSpan keepAliveInterval, + WebSocketBuffer internalBuffer) + { + Contract.Assert(internalBuffer != null, "'internalBuffer' MUST NOT be NULL."); + WebSocketHelpers.ValidateInnerStream(innerStream); + WebSocketHelpers.ValidateOptions(subProtocol, internalBuffer.ReceiveBufferSize, + internalBuffer.SendBufferSize, keepAliveInterval); + + // s_LoggingEnabled = Logging.On && Logging.WebSockets.Switch.ShouldTrace(TraceEventType.Critical); + string parameters = string.Empty; + /* + if (s_LoggingEnabled) + { + parameters = string.Format(CultureInfo.InvariantCulture, + "ReceiveBufferSize: {0}, SendBufferSize: {1}, Protocols: {2}, KeepAliveInterval: {3}, innerStream: {4}, internalBuffer: {5}", + internalBuffer.ReceiveBufferSize, + internalBuffer.SendBufferSize, + subProtocol, + keepAliveInterval, + Logging.GetObjectLogHash(innerStream), + Logging.GetObjectLogHash(internalBuffer)); + + Logging.Enter(Logging.WebSockets, this, Methods.Initialize, parameters); + } + */ + _thisLock = new object(); + + try + { + _innerStream = innerStream; + _internalBuffer = internalBuffer; + /*if (s_LoggingEnabled) + { + Logging.Associate(Logging.WebSockets, this, m_InnerStream); + Logging.Associate(Logging.WebSockets, this, m_InternalBuffer); + }*/ + + _closeOutstandingOperationHelper = new OutstandingOperationHelper(); + _closeOutputOutstandingOperationHelper = new OutstandingOperationHelper(); + _receiveOutstandingOperationHelper = new OutstandingOperationHelper(); + _sendOutstandingOperationHelper = new OutstandingOperationHelper(); + _state = WebSocketState.Open; + _subProtocol = subProtocol; + _sendFrameThrottle = new SemaphoreSlim(1, 1); + _closeStatus = null; + _closeStatusDescription = null; + _innerStreamAsWebSocketStream = innerStream as IWebSocketStream; + if (_innerStreamAsWebSocketStream != null) + { + _innerStreamAsWebSocketStream.SwitchToOpaqueMode(this); + } + _keepAliveTracker = KeepAliveTracker.Create(keepAliveInterval); + } + finally + { + /*if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.Initialize, parameters); + }*/ + } + } + /* + internal static bool LoggingEnabled + { + get + { + return s_LoggingEnabled; + } + } + */ + public override WebSocketState State + { + get + { + Contract.Assert(_state != WebSocketState.None, "'m_State' MUST NOT be 'WebSocketState.None'."); + return _state; + } + } + + public override string SubProtocol + { + get + { + return _subProtocol; + } + } + + public override WebSocketCloseStatus? CloseStatus + { + get + { + return _closeStatus; + } + } + + public override string CloseStatusDescription + { + get + { + return _closeStatusDescription; + } + } + + internal WebSocketBuffer InternalBuffer + { + get + { + Contract.Assert(_internalBuffer != null, "'m_InternalBuffer' MUST NOT be NULL."); + return _internalBuffer; + } + } + + protected void StartKeepAliveTimer() + { + _keepAliveTracker.StartTimer(this); + } + + // locking SessionHandle protects access to + // - WSPC (WebSocketProtocolComponent) + // - m_KeepAliveTask + // - m_CloseOutputTask + // - m_LastSendActivity + internal abstract SafeHandle SessionHandle { get; } + + // MultiThreading: ThreadSafe; At most one outstanding call to ReceiveAsync is allowed + public override Task ReceiveAsync(ArraySegment buffer, + CancellationToken cancellationToken) + { + WebSocketHelpers.ValidateArraySegment(buffer, "buffer"); + return ReceiveAsyncCore(buffer, cancellationToken); + } + + private async Task ReceiveAsyncCore(ArraySegment buffer, + CancellationToken cancellationToken) + { + Contract.Assert(buffer.Array != null); + /* + if (s_LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.ReceiveAsync, string.Empty); + } + */ + WebSocketReceiveResult receiveResult; + try + { + ThrowIfPendingException(); + ThrowIfDisposed(); + ThrowOnInvalidState(State, WebSocketState.Open, WebSocketState.CloseSent); + + bool ownsCancellationTokenSource = false; + CancellationToken linkedCancellationToken = CancellationToken.None; + try + { + ownsCancellationTokenSource = _receiveOutstandingOperationHelper.TryStartOperation(cancellationToken, + out linkedCancellationToken); + if (!ownsCancellationTokenSource) + { + lock (_thisLock) + { + if (_closeAsyncStartedReceive) + { + throw new InvalidOperationException( + SR.GetString(SR.net_WebSockets_ReceiveAsyncDisallowedAfterCloseAsync, Methods.CloseAsync, Methods.CloseOutputAsync)); + } + + throw new InvalidOperationException( + SR.GetString(SR.net_Websockets_AlreadyOneOutstandingOperation, Methods.ReceiveAsync)); + } + } + + EnsureReceiveOperation(); + receiveResult = await _receiveOperation.Process(buffer, linkedCancellationToken).SuppressContextFlow(); + /* + if (s_LoggingEnabled && receiveResult.Count > 0) + { + Logging.Dump(Logging.WebSockets, + this, + Methods.ReceiveAsync, + buffer.Array, + buffer.Offset, + receiveResult.Count); + }*/ + } + catch (Exception exception) + { + bool aborted = linkedCancellationToken.IsCancellationRequested; + Abort(); + ThrowIfConvertibleException(Methods.ReceiveAsync, exception, cancellationToken, aborted); + throw; + } + finally + { + _receiveOutstandingOperationHelper.CompleteOperation(ownsCancellationTokenSource); + } + } + finally + {/* + if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.ReceiveAsync, string.Empty); + }*/ + } + + return receiveResult; + } + + // MultiThreading: ThreadSafe; At most one outstanding call to SendAsync is allowed + public override Task SendAsync(ArraySegment buffer, + WebSocketMessageType messageType, + bool endOfMessage, + CancellationToken cancellationToken) + { + if (messageType != WebSocketMessageType.Binary && + messageType != WebSocketMessageType.Text) + { + throw new ArgumentException(SR.GetString(SR.net_WebSockets_Argument_InvalidMessageType, + messageType, + Methods.SendAsync, + WebSocketMessageType.Binary, + WebSocketMessageType.Text, + Methods.CloseOutputAsync), + "messageType"); + } + + WebSocketHelpers.ValidateArraySegment(buffer, "buffer"); + + return SendAsyncCore(buffer, messageType, endOfMessage, cancellationToken); + } + + private async Task SendAsyncCore(ArraySegment buffer, + WebSocketMessageType messageType, + bool endOfMessage, + CancellationToken cancellationToken) + { + Contract.Assert(messageType == WebSocketMessageType.Binary || messageType == WebSocketMessageType.Text, + "'messageType' MUST be either 'WebSocketMessageType.Binary' or 'WebSocketMessageType.Text'."); + Contract.Assert(buffer.Array != null); + + string inputParameter = string.Empty; + /*if (s_LoggingEnabled) + { + inputParameter = string.Format(CultureInfo.InvariantCulture, + "messageType: {0}, endOfMessage: {1}", + messageType, + endOfMessage); + Logging.Enter(Logging.WebSockets, this, Methods.SendAsync, inputParameter); + }*/ + + try + { + ThrowIfPendingException(); + ThrowIfDisposed(); + ThrowOnInvalidState(State, WebSocketState.Open, WebSocketState.CloseReceived); + bool ownsCancellationTokenSource = false; + CancellationToken linkedCancellationToken = CancellationToken.None; + + try + { + while (!(ownsCancellationTokenSource = _sendOutstandingOperationHelper.TryStartOperation(cancellationToken, out linkedCancellationToken))) + { + Task keepAliveTask; + + lock (SessionHandle) + { + keepAliveTask = _keepAliveTask; + + if (keepAliveTask == null) + { + // Check whether there is still another outstanding send operation + // Potentially the keepAlive operation has completed before this thread + // was able to enter the SessionHandle-lock. + _sendOutstandingOperationHelper.CompleteOperation(ownsCancellationTokenSource); + if (ownsCancellationTokenSource = _sendOutstandingOperationHelper.TryStartOperation(cancellationToken, out linkedCancellationToken)) + { + break; + } + else + { + throw new InvalidOperationException( + SR.GetString(SR.net_Websockets_AlreadyOneOutstandingOperation, Methods.SendAsync)); + } + } + } + + await keepAliveTask.SuppressContextFlow(); + ThrowIfPendingException(); + + _sendOutstandingOperationHelper.CompleteOperation(ownsCancellationTokenSource); + } + /* + if (s_LoggingEnabled && buffer.Count > 0) + { + Logging.Dump(Logging.WebSockets, + this, + Methods.SendAsync, + buffer.Array, + buffer.Offset, + buffer.Count); + }*/ + + int position = buffer.Offset; + + EnsureSendOperation(); + _sendOperation.BufferType = GetBufferType(messageType, endOfMessage); + await _sendOperation.Process(buffer, linkedCancellationToken).SuppressContextFlow(); + } + catch (Exception exception) + { + bool aborted = linkedCancellationToken.IsCancellationRequested; + Abort(); + ThrowIfConvertibleException(Methods.SendAsync, exception, cancellationToken, aborted); + throw; + } + finally + { + _sendOutstandingOperationHelper.CompleteOperation(ownsCancellationTokenSource); + } + } + finally + { + /*if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.SendAsync, inputParameter); + }*/ + } + } + + private async Task SendFrameAsync(IList> sendBuffers, CancellationToken cancellationToken) + { + bool sendFrameLockTaken = false; + try + { + await _sendFrameThrottle.WaitAsync(cancellationToken).SuppressContextFlow(); + sendFrameLockTaken = true; + + if (sendBuffers.Count > 1 && + _innerStreamAsWebSocketStream != null && + _innerStreamAsWebSocketStream.SupportsMultipleWrite) + { + await _innerStreamAsWebSocketStream.MultipleWriteAsync(sendBuffers, + cancellationToken).SuppressContextFlow(); + } + else + { + foreach (ArraySegment buffer in sendBuffers) + { + await _innerStream.WriteAsync(buffer.Array, + buffer.Offset, + buffer.Count, + cancellationToken).SuppressContextFlow(); + } + } + } + catch (ObjectDisposedException objectDisposedException) + { + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely, objectDisposedException); + } + catch (NotSupportedException notSupportedException) + { + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely, notSupportedException); + } + finally + { + if (sendFrameLockTaken) + { + _sendFrameThrottle.Release(); + } + } + } + + // MultiThreading: ThreadSafe; No-op if already in a terminal state + public override void Abort() + { + /*if (s_LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, this, Methods.Abort, string.Empty); + }*/ + + bool thisLockTaken = false; + bool sessionHandleLockTaken = false; + try + { + if (IsStateTerminal(State)) + { + return; + } + + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + if (IsStateTerminal(State)) + { + return; + } + + _state = WebSocketState.Aborted; + +#if DEBUG + string stackTrace = new StackTrace().ToString(); + if (_closeStack == null) + { + _closeStack = stackTrace; + } + /* + if (s_LoggingEnabled) + { + string message = string.Format(CultureInfo.InvariantCulture, "Stack: {0}", stackTrace); + Logging.PrintWarning(Logging.WebSockets, this, Methods.Abort, message); + }*/ +#endif + + // Abort any outstanding IO operations. + if (SessionHandle != null && !SessionHandle.IsClosed && !SessionHandle.IsInvalid) + { + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketAbortHandle(SessionHandle); + } + + _receiveOutstandingOperationHelper.CancelIO(); + _sendOutstandingOperationHelper.CancelIO(); + _closeOutputOutstandingOperationHelper.CancelIO(); + _closeOutstandingOperationHelper.CancelIO(); + if (_innerStreamAsWebSocketStream != null) + { + _innerStreamAsWebSocketStream.Abort(); + } + CleanUp(); + } + finally + { + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + /*if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.Abort, string.Empty); + }*/ + } + } + + // MultiThreading: ThreadSafe; No-op if already in a terminal state + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, + string statusDescription, + CancellationToken cancellationToken) + { + WebSocketHelpers.ValidateCloseStatus(closeStatus, statusDescription); + + return CloseOutputAsyncCore(closeStatus, statusDescription, cancellationToken); + } + + private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, + string statusDescription, + CancellationToken cancellationToken) + { + string inputParameter = string.Empty; + /*if (s_LoggingEnabled) + { + inputParameter = string.Format(CultureInfo.InvariantCulture, + "closeStatus: {0}, statusDescription: {1}", + closeStatus, + statusDescription); + Logging.Enter(Logging.WebSockets, this, Methods.CloseOutputAsync, inputParameter); + }*/ + + try + { + ThrowIfPendingException(); + if (IsStateTerminal(State)) + { + return; + } + ThrowIfDisposed(); + + bool thisLockTaken = false; + bool sessionHandleLockTaken = false; + bool needToCompleteSendOperation = false; + bool ownsCloseOutputCancellationTokenSource = false; + bool ownsSendCancellationTokenSource = false; + CancellationToken linkedCancellationToken = CancellationToken.None; + try + { + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + ThrowIfPendingException(); + ThrowIfDisposed(); + + if (IsStateTerminal(State)) + { + return; + } + + ThrowOnInvalidState(State, WebSocketState.Open, WebSocketState.CloseReceived); + ownsCloseOutputCancellationTokenSource = _closeOutputOutstandingOperationHelper.TryStartOperation(cancellationToken, out linkedCancellationToken); + if (!ownsCloseOutputCancellationTokenSource) + { + Task closeOutputTask = _closeOutputTask; + + if (closeOutputTask != null) + { + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + await closeOutputTask.SuppressContextFlow(); + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + } + } + else + { + needToCompleteSendOperation = true; + while (!(ownsSendCancellationTokenSource = + _sendOutstandingOperationHelper.TryStartOperation(cancellationToken, + out linkedCancellationToken))) + { + if (_keepAliveTask != null) + { + Task keepAliveTask = _keepAliveTask; + + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + await keepAliveTask.SuppressContextFlow(); + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + + ThrowIfPendingException(); + } + else + { + throw new InvalidOperationException( + SR.GetString(SR.net_Websockets_AlreadyOneOutstandingOperation, Methods.SendAsync)); + } + + _sendOutstandingOperationHelper.CompleteOperation(ownsSendCancellationTokenSource); + } + + EnsureCloseOutputOperation(); + _closeOutputOperation.CloseStatus = closeStatus; + _closeOutputOperation.CloseReason = statusDescription; + _closeOutputTask = _closeOutputOperation.Process(null, linkedCancellationToken); + + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + await _closeOutputTask.SuppressContextFlow(); + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + + if (OnCloseOutputCompleted()) + { + bool callCompleteOnCloseCompleted = false; + + try + { + callCompleteOnCloseCompleted = await StartOnCloseCompleted( + thisLockTaken, sessionHandleLockTaken, linkedCancellationToken).SuppressContextFlow(); + } + catch (Exception) + { + // If an exception is thrown we know that the locks have been released, + // because we enforce IWebSocketStream.CloseNetworkConnectionAsync to yield + ResetFlagsAndTakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + throw; + } + + if (callCompleteOnCloseCompleted) + { + ResetFlagsAndTakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + FinishOnCloseCompleted(); + } + } + } + } + catch (Exception exception) + { + bool aborted = linkedCancellationToken.IsCancellationRequested; + Abort(); + ThrowIfConvertibleException(Methods.CloseOutputAsync, exception, cancellationToken, aborted); + throw; + } + finally + { + _closeOutputOutstandingOperationHelper.CompleteOperation(ownsCloseOutputCancellationTokenSource); + + if (needToCompleteSendOperation) + { + _sendOutstandingOperationHelper.CompleteOperation(ownsSendCancellationTokenSource); + } + + _closeOutputTask = null; + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + } + } + finally + { + /*if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.CloseOutputAsync, inputParameter); + }*/ + } + } + + // returns TRUE if the caller should also call StartOnCloseCompleted + private bool OnCloseOutputCompleted() + { + if (IsStateTerminal(State)) + { + return false; + } + + switch (State) + { + case WebSocketState.Open: + _state = WebSocketState.CloseSent; + return false; + case WebSocketState.CloseReceived: + return true; + default: + return false; + } + } + + // MultiThreading: This method has to be called under a m_ThisLock-lock + // ReturnValue: This method returns true only if CompleteOnCloseCompleted needs to be called + // If this method returns true all locks were released before starting the IO operation + // and they have to be retaken by the caller before calling CompleteOnCloseCompleted + // Exception handling: If an exception is thrown from await StartOnCloseCompleted + // it always means the locks have been released already - so the caller has to retake the + // locks in the catch-block. + // This is ensured by enforcing a Task.Yield for IWebSocketStream.CloseNetowrkConnectionAsync + private async Task StartOnCloseCompleted(bool thisLockTakenSnapshot, + bool sessionHandleLockTakenSnapshot, + CancellationToken cancellationToken) + { + Contract.Assert(thisLockTakenSnapshot, "'thisLockTakenSnapshot' MUST be 'true' at this point."); + + if (IsStateTerminal(_state)) + { + return false; + } + + _state = WebSocketState.Closed; + +#if DEBUG + if (_closeStack == null) + { + _closeStack = new StackTrace().ToString(); + } +#endif + + if (_innerStreamAsWebSocketStream != null) + { + bool thisLockTaken = thisLockTakenSnapshot; + bool sessionHandleLockTaken = sessionHandleLockTakenSnapshot; + + try + { + if (_closeNetworkConnectionTask == null) + { + _closeNetworkConnectionTask = + _innerStreamAsWebSocketStream.CloseNetworkConnectionAsync(cancellationToken); + } + + if (thisLockTaken && sessionHandleLockTaken) + { + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + } + else if (thisLockTaken) + { + ReleaseLock(_thisLock, ref thisLockTaken); + } + + await _closeNetworkConnectionTask.SuppressContextFlow(); + } + catch (Exception closeNetworkConnectionTaskException) + { + if (!CanHandleExceptionDuringClose(closeNetworkConnectionTaskException)) + { + ThrowIfConvertibleException(Methods.StartOnCloseCompleted, + closeNetworkConnectionTaskException, + cancellationToken, + cancellationToken.IsCancellationRequested); + throw; + } + } + } + + return true; + } + + // MultiThreading: This method has to be called under a thisLock-lock + private void FinishOnCloseCompleted() + { + CleanUp(); + } + + // MultiThreading: ThreadSafe; No-op if already in a terminal state + public override Task CloseAsync(WebSocketCloseStatus closeStatus, + string statusDescription, + CancellationToken cancellationToken) + { + WebSocketHelpers.ValidateCloseStatus(closeStatus, statusDescription); + return CloseAsyncCore(closeStatus, statusDescription, cancellationToken); + } + + private async Task CloseAsyncCore(WebSocketCloseStatus closeStatus, + string statusDescription, + CancellationToken cancellationToken) + { + string inputParameter = string.Empty; + /*if (s_LoggingEnabled) + { + inputParameter = string.Format(CultureInfo.InvariantCulture, + "closeStatus: {0}, statusDescription: {1}", + closeStatus, + statusDescription); + Logging.Enter(Logging.WebSockets, this, Methods.CloseAsync, inputParameter); + }*/ + + try + { + ThrowIfPendingException(); + if (IsStateTerminal(State)) + { + return; + } + ThrowIfDisposed(); + + bool lockTaken = false; + Monitor.Enter(_thisLock, ref lockTaken); + bool ownsCloseCancellationTokenSource = false; + CancellationToken linkedCancellationToken = CancellationToken.None; + try + { + ThrowIfPendingException(); + if (IsStateTerminal(State)) + { + return; + } + ThrowIfDisposed(); + ThrowOnInvalidState(State, + WebSocketState.Open, WebSocketState.CloseReceived, WebSocketState.CloseSent); + + Task closeOutputTask; + ownsCloseCancellationTokenSource = _closeOutstandingOperationHelper.TryStartOperation(cancellationToken, out linkedCancellationToken); + if (ownsCloseCancellationTokenSource) + { + closeOutputTask = _closeOutputTask; + if (closeOutputTask == null && State != WebSocketState.CloseSent) + { + if (_closeReceivedTaskCompletionSource == null) + { + _closeReceivedTaskCompletionSource = new TaskCompletionSource(); + } + + closeOutputTask = CloseOutputAsync(closeStatus, + statusDescription, + linkedCancellationToken); + } + } + else + { + Contract.Assert(_closeReceivedTaskCompletionSource != null, + "'m_CloseReceivedTaskCompletionSource' MUST NOT be NULL."); + closeOutputTask = _closeReceivedTaskCompletionSource.Task; + } + + if (closeOutputTask != null) + { + ReleaseLock(_thisLock, ref lockTaken); + try + { + await closeOutputTask.SuppressContextFlow(); + } + catch (Exception closeOutputError) + { + Monitor.Enter(_thisLock, ref lockTaken); + + if (!CanHandleExceptionDuringClose(closeOutputError)) + { + ThrowIfConvertibleException(Methods.CloseOutputAsync, + closeOutputError, + cancellationToken, + linkedCancellationToken.IsCancellationRequested); + throw; + } + } + + // When closeOutputTask != null and an exception thrown from await closeOutputTask is handled, + // the lock will be taken in the catch-block. So the logic here avoids taking the lock twice. + if (!lockTaken) + { + Monitor.Enter(_thisLock, ref lockTaken); + } + } + + if (OnCloseOutputCompleted()) + { + bool callCompleteOnCloseCompleted = false; + + try + { + // linkedCancellationToken can be CancellationToken.None if ownsCloseCancellationTokenSource==false + // This is still ok because OnCloseOutputCompleted won't start any IO operation in this case + callCompleteOnCloseCompleted = await StartOnCloseCompleted( + lockTaken, false, linkedCancellationToken).SuppressContextFlow(); + } + catch (Exception) + { + // If an exception is thrown we know that the locks have been released, + // because we enforce IWebSocketStream.CloseNetworkConnectionAsync to yield + ResetFlagAndTakeLock(_thisLock, ref lockTaken); + throw; + } + + if (callCompleteOnCloseCompleted) + { + ResetFlagAndTakeLock(_thisLock, ref lockTaken); + FinishOnCloseCompleted(); + } + } + + if (IsStateTerminal(State)) + { + return; + } + + linkedCancellationToken = CancellationToken.None; + + bool ownsReceiveCancellationTokenSource = _receiveOutstandingOperationHelper.TryStartOperation(cancellationToken, out linkedCancellationToken); + if (ownsReceiveCancellationTokenSource) + { + _closeAsyncStartedReceive = true; + ArraySegment closeMessageBuffer = + new ArraySegment(new byte[WebSocketBuffer.MinReceiveBufferSize]); + EnsureReceiveOperation(); + Task receiveAsyncTask = _receiveOperation.Process(closeMessageBuffer, + linkedCancellationToken); + ReleaseLock(_thisLock, ref lockTaken); + + WebSocketReceiveResult receiveResult = null; + try + { + receiveResult = await receiveAsyncTask.SuppressContextFlow(); + } + catch (Exception receiveException) + { + Monitor.Enter(_thisLock, ref lockTaken); + + if (!CanHandleExceptionDuringClose(receiveException)) + { + ThrowIfConvertibleException(Methods.CloseAsync, + receiveException, + cancellationToken, + linkedCancellationToken.IsCancellationRequested); + throw; + } + } + + // receiveResult is NEVER NULL if WebSocketBase.ReceiveOperation.Process completes successfully + // - but in the close code path we handle some exception if another thread was able to tranistion + // the state into Closed successfully. In this case receiveResult can be NULL and it is safe to + // skip the statements in the if-block. + if (receiveResult != null) + { + /*if (s_LoggingEnabled && receiveResult.Count > 0) + { + Logging.Dump(Logging.WebSockets, + this, + Methods.ReceiveAsync, + closeMessageBuffer.Array, + closeMessageBuffer.Offset, + receiveResult.Count); + }*/ + + if (receiveResult.MessageType != WebSocketMessageType.Close) + { + throw new WebSocketException(WebSocketError.InvalidMessageType, + SR.GetString(SR.net_WebSockets_InvalidMessageType, + typeof(WebSocket).Name + "." + Methods.CloseAsync, + typeof(WebSocket).Name + "." + Methods.CloseOutputAsync, + receiveResult.MessageType)); + } + } + } + else + { + _receiveOutstandingOperationHelper.CompleteOperation(ownsReceiveCancellationTokenSource); + ReleaseLock(_thisLock, ref lockTaken); + await _closeReceivedTaskCompletionSource.Task.SuppressContextFlow(); + } + + // When ownsReceiveCancellationTokenSource is true and an exception is thrown, the lock will be taken. + // So this logic here is to avoid taking the lock twice. + if (!lockTaken) + { + Monitor.Enter(_thisLock, ref lockTaken); + } + + if (!IsStateTerminal(State)) + { + bool ownsSendCancellationSource = false; + try + { + // We know that the CloseFrame has been sent at this point. So no Send-operation is allowed anymore and we + // can hijack the m_SendOutstandingOperationHelper to create a linkedCancellationToken + ownsSendCancellationSource = _sendOutstandingOperationHelper.TryStartOperation(cancellationToken, out linkedCancellationToken); + Contract.Assert(ownsSendCancellationSource, "'ownsSendCancellationSource' MUST be 'true' at this point."); + + bool callCompleteOnCloseCompleted = false; + + try + { + // linkedCancellationToken can be CancellationToken.None if ownsCloseCancellationTokenSource==false + // This is still ok because OnCloseOutputCompleted won't start any IO operation in this case + callCompleteOnCloseCompleted = await StartOnCloseCompleted( + lockTaken, false, linkedCancellationToken).SuppressContextFlow(); + } + catch (Exception) + { + // If an exception is thrown we know that the locks have been released, + // because we enforce IWebSocketStream.CloseNetworkConnectionAsync to yield + ResetFlagAndTakeLock(_thisLock, ref lockTaken); + throw; + } + + if (callCompleteOnCloseCompleted) + { + ResetFlagAndTakeLock(_thisLock, ref lockTaken); + FinishOnCloseCompleted(); + } + } + finally + { + _sendOutstandingOperationHelper.CompleteOperation(ownsSendCancellationSource); + } + } + } + catch (Exception exception) + { + bool aborted = linkedCancellationToken.IsCancellationRequested; + Abort(); + ThrowIfConvertibleException(Methods.CloseAsync, exception, cancellationToken, aborted); + throw; + } + finally + { + _closeOutstandingOperationHelper.CompleteOperation(ownsCloseCancellationTokenSource); + ReleaseLock(_thisLock, ref lockTaken); + } + } + finally + { + /*if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, this, Methods.CloseAsync, inputParameter); + }*/ + } + } + + // MultiThreading: ThreadSafe; No-op if already in a terminal state + [SuppressMessage("Microsoft.Usage", "CA2213:DisposableFieldsShouldBeDisposed", MessageId = "_sendFrameThrottle", + Justification = "SemaphoreSlim.Dispose is not threadsafe and can cause NullRef exceptions on other threads." + + "Also according to the CLR Dev11#358715) there is no need to dispose SemaphoreSlim if the ManualResetEvent " + + "is not used.")] + public override void Dispose() + { + if (_isDisposed) + { + return; + } + + bool thisLockTaken = false; + bool sessionHandleLockTaken = false; + + try + { + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + + if (_isDisposed) + { + return; + } + + if (!IsStateTerminal(State)) + { + Abort(); + } + else + { + CleanUp(); + } + + _isDisposed = true; + } + finally + { + ReleaseLocks(ref thisLockTaken, ref sessionHandleLockTaken); + } + } + + private void ResetFlagAndTakeLock(object lockObject, ref bool thisLockTaken) + { + Contract.Assert(lockObject != null, "'lockObject' MUST NOT be NULL."); + thisLockTaken = false; + Monitor.Enter(lockObject, ref thisLockTaken); + } + + private void ResetFlagsAndTakeLocks(ref bool thisLockTaken, ref bool sessionHandleLockTaken) + { + thisLockTaken = false; + sessionHandleLockTaken = false; + TakeLocks(ref thisLockTaken, ref sessionHandleLockTaken); + } + + private void TakeLocks(ref bool thisLockTaken, ref bool sessionHandleLockTaken) + { + Contract.Assert(_thisLock != null, "'m_ThisLock' MUST NOT be NULL."); + Contract.Assert(SessionHandle != null, "'SessionHandle' MUST NOT be NULL."); + + Monitor.Enter(SessionHandle, ref sessionHandleLockTaken); + Monitor.Enter(_thisLock, ref thisLockTaken); + } + + private void ReleaseLocks(ref bool thisLockTaken, ref bool sessionHandleLockTaken) + { + Contract.Assert(_thisLock != null, "'m_ThisLock' MUST NOT be NULL."); + Contract.Assert(SessionHandle != null, "'SessionHandle' MUST NOT be NULL."); + + if (thisLockTaken || sessionHandleLockTaken) + { +#if NET45 + RuntimeHelpers.PrepareConstrainedRegions(); +#endif + try + { + } + finally + { + if (thisLockTaken) + { + Monitor.Exit(_thisLock); + thisLockTaken = false; + } + + if (sessionHandleLockTaken) + { + Monitor.Exit(SessionHandle); + sessionHandleLockTaken = false; + } + } + } + } + + private void EnsureReceiveOperation() + { + if (_receiveOperation == null) + { + lock (_thisLock) + { + if (_receiveOperation == null) + { + _receiveOperation = new WebSocketOperation.ReceiveOperation(this); + } + } + } + } + + private void EnsureSendOperation() + { + if (_sendOperation == null) + { + lock (_thisLock) + { + if (_sendOperation == null) + { + _sendOperation = new WebSocketOperation.SendOperation(this); + } + } + } + } + + private void EnsureKeepAliveOperation() + { + if (_keepAliveOperation == null) + { + lock (_thisLock) + { + if (_keepAliveOperation == null) + { + WebSocketOperation.SendOperation keepAliveOperation = new WebSocketOperation.SendOperation(this); + keepAliveOperation.BufferType = UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UnsolicitedPong; + _keepAliveOperation = keepAliveOperation; + } + } + } + } + + private void EnsureCloseOutputOperation() + { + if (_closeOutputOperation == null) + { + lock (_thisLock) + { + if (_closeOutputOperation == null) + { + _closeOutputOperation = new WebSocketOperation.CloseOutputOperation(this); + } + } + } + } + + private static void ReleaseLock(object lockObject, ref bool lockTaken) + { + Contract.Assert(lockObject != null, "'lockObject' MUST NOT be NULL."); + if (lockTaken) + { +#if NET45 + RuntimeHelpers.PrepareConstrainedRegions(); +#endif + try + { + } + finally + { + Monitor.Exit(lockObject); + lockTaken = false; + } + } + } + + private static UnsafeNativeMethods.WebSocketProtocolComponent.BufferType GetBufferType(WebSocketMessageType messageType, + bool endOfMessage) + { + Contract.Assert(messageType == WebSocketMessageType.Binary || messageType == WebSocketMessageType.Text, + string.Format(CultureInfo.InvariantCulture, + "The value of 'messageType' ({0}) is invalid. Valid message types: '{1}, {2}'", + messageType, + WebSocketMessageType.Binary, + WebSocketMessageType.Text)); + + if (messageType == WebSocketMessageType.Text) + { + if (endOfMessage) + { + return UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Message; + } + + return UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Fragment; + } + else + { + if (endOfMessage) + { + return UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryMessage; + } + + return UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryFragment; + } + } + + private static WebSocketMessageType GetMessageType(UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType) + { + switch (bufferType) + { + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close: + return WebSocketMessageType.Close; + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryFragment: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryMessage: + return WebSocketMessageType.Binary; + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Fragment: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Message: + return WebSocketMessageType.Text; + default: + // This indicates a contract violation of the websocket protocol component, + // because we currently don't support any WebSocket extensions and would + // not accept a Websocket handshake requesting extensions + Contract.Assert(false, + string.Format(CultureInfo.InvariantCulture, + "The value of 'bufferType' ({0}) is invalid. Valid buffer types: {1}, {2}, {3}, {4}, {5}.", + bufferType, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryFragment, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryMessage, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Fragment, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Message)); + + throw new WebSocketException(WebSocketError.NativeError, + SR.GetString(SR.net_WebSockets_InvalidBufferType, + bufferType, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryFragment, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryMessage, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Fragment, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Message)); + } + } + + internal void ValidateNativeBuffers(UnsafeNativeMethods.WebSocketProtocolComponent.Action action, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType, + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer[] dataBuffers, + uint dataBufferCount) + { + _internalBuffer.ValidateNativeBuffers(action, bufferType, dataBuffers, dataBufferCount); + } + + internal void ThrowIfClosedOrAborted() + { + if (State == WebSocketState.Closed || State == WebSocketState.Aborted) + { + throw new WebSocketException(WebSocketError.InvalidState, + SR.GetString(SR.net_WebSockets_InvalidState_ClosedOrAborted, GetType().FullName, State)); + } + } + + private void ThrowIfAborted(bool aborted, Exception innerException) + { + if (aborted) + { + throw new WebSocketException(WebSocketError.InvalidState, + SR.GetString(SR.net_WebSockets_InvalidState_ClosedOrAborted, GetType().FullName, WebSocketState.Aborted), + innerException); + } + } + + private bool CanHandleExceptionDuringClose(Exception error) + { + Contract.Assert(error != null, "'error' MUST NOT be NULL."); + + if (State != WebSocketState.Closed) + { + return false; + } + + return error is OperationCanceledException || + error is WebSocketException || + // error is SocketException || + // error is HttpListenerException || + error is IOException; + } + + // We only want to throw an OperationCanceledException if the CancellationToken passed + // down from the caller is canceled - not when Abort is called on another thread and + // the linkedCancellationToken is canceled. + private void ThrowIfConvertibleException(string methodName, + Exception exception, + CancellationToken cancellationToken, + bool aborted) + { + Contract.Assert(exception != null, "'exception' MUST NOT be NULL."); + /* + if (s_LoggingEnabled && !string.IsNullOrEmpty(methodName)) + { + Logging.Exception(Logging.WebSockets, this, methodName, exception); + }*/ + + OperationCanceledException operationCanceledException = exception as OperationCanceledException; + if (operationCanceledException != null) + { + if (cancellationToken.IsCancellationRequested || + !aborted) + { + return; + } + ThrowIfAborted(aborted, exception); + } + + WebSocketException convertedException = exception as WebSocketException; + if (convertedException != null) + { + cancellationToken.ThrowIfCancellationRequested(); + ThrowIfAborted(aborted, convertedException); + return; + } + /* + SocketException socketException = exception as SocketException; + if (socketException != null) + { + convertedException = new WebSocketException(socketException.NativeErrorCode, socketException); + } + HttpListenerException httpListenerException = exception as HttpListenerException; + if (httpListenerException != null) + { + convertedException = new WebSocketException(httpListenerException.ErrorCode, httpListenerException); + } + + IOException ioException = exception as IOException; + if (ioException != null) + { + socketException = exception.InnerException as SocketException; + if (socketException != null) + { + convertedException = new WebSocketException(socketException.NativeErrorCode, ioException); + } + } +*/ + if (convertedException != null) + { + cancellationToken.ThrowIfCancellationRequested(); + ThrowIfAborted(aborted, convertedException); + throw convertedException; + } + + AggregateException aggregateException = exception as AggregateException; + if (aggregateException != null) + { + // Collapse possibly nested graph into a flat list. + // Empty inner exception list is unlikely but possible via public api. + ReadOnlyCollection unwrappedExceptions = aggregateException.Flatten().InnerExceptions; + if (unwrappedExceptions.Count == 0) + { + return; + } + + foreach (Exception unwrappedException in unwrappedExceptions) + { + ThrowIfConvertibleException(null, unwrappedException, cancellationToken, aborted); + } + } + } + + private void CleanUp() + { + // Multithreading: This method is always called under the m_ThisLock lock + if (_cleanedUp) + { + return; + } + + _cleanedUp = true; + + if (SessionHandle != null) + { + SessionHandle.Dispose(); + } + + if (_internalBuffer != null) + { + _internalBuffer.Dispose(this.State); + } + + if (_receiveOutstandingOperationHelper != null) + { + _receiveOutstandingOperationHelper.Dispose(); + } + + if (_sendOutstandingOperationHelper != null) + { + _sendOutstandingOperationHelper.Dispose(); + } + + if (_closeOutputOutstandingOperationHelper != null) + { + _closeOutputOutstandingOperationHelper.Dispose(); + } + + if (_closeOutstandingOperationHelper != null) + { + _closeOutstandingOperationHelper.Dispose(); + } + + if (_innerStream != null) + { + try + { + _innerStream.Dispose(); + } + catch (ObjectDisposedException) + { + } + catch (IOException) + { + } + /*catch (SocketException) + { + }*/ + catch (Exception) + { + } + } + + _keepAliveTracker.Dispose(); + } + + private void OnBackgroundTaskException(Exception exception) + { + if (Interlocked.CompareExchange(ref _pendingException, exception, null) == null) + { + /*if (s_LoggingEnabled) + { + Logging.Exception(Logging.WebSockets, this, Methods.Fault, exception); + }*/ + Abort(); + } + } + + private void ThrowIfPendingException() + { + Exception pendingException = Interlocked.Exchange(ref _pendingException, null); + if (pendingException != null) + { + throw new WebSocketException(WebSocketError.Faulted, pendingException); + } + } + + private void ThrowIfDisposed() + { + if (_isDisposed) + { + throw new ObjectDisposedException(GetType().FullName); + } + } + + private void UpdateReceiveState(int newReceiveState, int expectedReceiveState) + { + int receiveState; + if ((receiveState = Interlocked.Exchange(ref _receiveState, newReceiveState)) != expectedReceiveState) + { + Contract.Assert(false, + string.Format(CultureInfo.InvariantCulture, + "'m_ReceiveState' had an invalid value '{0}'. The expected value was '{1}'.", + receiveState, + expectedReceiveState)); + } + } + + private bool StartOnCloseReceived(ref bool thisLockTaken) + { + ThrowIfDisposed(); + + if (IsStateTerminal(State) || State == WebSocketState.CloseReceived) + { + return false; + } + + Monitor.Enter(_thisLock, ref thisLockTaken); + if (IsStateTerminal(State) || State == WebSocketState.CloseReceived) + { + return false; + } + + if (State == WebSocketState.Open) + { + _state = WebSocketState.CloseReceived; + + if (_closeReceivedTaskCompletionSource == null) + { + _closeReceivedTaskCompletionSource = new TaskCompletionSource(); + } + + return false; + } + + return true; + } + + private void FinishOnCloseReceived(WebSocketCloseStatus closeStatus, + string closeStatusDescription) + { + if (_closeReceivedTaskCompletionSource != null) + { + _closeReceivedTaskCompletionSource.TrySetResult(null); + } + + _closeStatus = closeStatus; + _closeStatusDescription = closeStatusDescription; + /* + if (s_LoggingEnabled) + { + string parameters = string.Format(CultureInfo.InvariantCulture, + "closeStatus: {0}, closeStatusDescription: {1}, m_State: {2}", + closeStatus, closeStatusDescription, m_State); + + Logging.PrintInfo(Logging.WebSockets, this, Methods.FinishOnCloseReceived, parameters); + }*/ + } + + private async static void OnKeepAlive(object sender) + { + Contract.Assert(sender != null, "'sender' MUST NOT be NULL."); + Contract.Assert((sender as WebSocketBase) != null, "'sender as WebSocketBase' MUST NOT be NULL."); + + WebSocketBase thisPtr = sender as WebSocketBase; + bool lockTaken = false; + /* + if (s_LoggingEnabled) + { + Logging.Enter(Logging.WebSockets, thisPtr, Methods.OnKeepAlive, string.Empty); + }*/ + + CancellationToken linkedCancellationToken = CancellationToken.None; + try + { + Monitor.Enter(thisPtr.SessionHandle, ref lockTaken); + + if (thisPtr._isDisposed || + thisPtr._state != WebSocketState.Open || + thisPtr._closeOutputTask != null) + { + return; + } + + if (thisPtr._keepAliveTracker.ShouldSendKeepAlive()) + { + bool ownsCancellationTokenSource = false; + try + { + ownsCancellationTokenSource = thisPtr._sendOutstandingOperationHelper.TryStartOperation(CancellationToken.None, out linkedCancellationToken); + if (ownsCancellationTokenSource) + { + thisPtr.EnsureKeepAliveOperation(); + thisPtr._keepAliveTask = thisPtr._keepAliveOperation.Process(null, linkedCancellationToken); + ReleaseLock(thisPtr.SessionHandle, ref lockTaken); + await thisPtr._keepAliveTask.SuppressContextFlow(); + } + } + finally + { + if (!lockTaken) + { + Monitor.Enter(thisPtr.SessionHandle, ref lockTaken); + } + thisPtr._sendOutstandingOperationHelper.CompleteOperation(ownsCancellationTokenSource); + thisPtr._keepAliveTask = null; + } + + thisPtr._keepAliveTracker.ResetTimer(); + } + } + catch (Exception exception) + { + try + { + thisPtr.ThrowIfConvertibleException(Methods.OnKeepAlive, + exception, + CancellationToken.None, + linkedCancellationToken.IsCancellationRequested); + throw; + } + catch (Exception backgroundException) + { + thisPtr.OnBackgroundTaskException(backgroundException); + } + } + finally + { + ReleaseLock(thisPtr.SessionHandle, ref lockTaken); + /* + if (s_LoggingEnabled) + { + Logging.Exit(Logging.WebSockets, thisPtr, Methods.OnKeepAlive, string.Empty); + }*/ + } + } + + private abstract class WebSocketOperation + { + private readonly WebSocketBase _webSocket; + + internal WebSocketOperation(WebSocketBase webSocket) + { + Contract.Assert(webSocket != null, "'webSocket' MUST NOT be NULL."); + _webSocket = webSocket; + } + + public WebSocketReceiveResult ReceiveResult { get; protected set; } + protected abstract int BufferCount { get; } + protected abstract UnsafeNativeMethods.WebSocketProtocolComponent.ActionQueue ActionQueue { get; } + protected abstract void Initialize(ArraySegment? buffer, CancellationToken cancellationToken); + protected abstract bool ShouldContinue(CancellationToken cancellationToken); + + // Multi-Threading: This method has to be called under a SessionHandle-lock. It returns true if a + // close frame was received. Handling the received close frame might involve IO - to make the locking + // strategy easier and reduce one level in the await-hierarchy the IO is kicked off by the caller. + protected abstract bool ProcessAction_NoAction(); + + protected virtual void ProcessAction_IndicateReceiveComplete( + ArraySegment? buffer, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType, + UnsafeNativeMethods.WebSocketProtocolComponent.Action action, + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer[] dataBuffers, + uint dataBufferCount, + IntPtr actionContext) + { + throw new NotImplementedException(); + } + + protected abstract void Cleanup(); + + internal async Task Process(ArraySegment? buffer, + CancellationToken cancellationToken) + { + Contract.Assert(BufferCount >= 1 && BufferCount <= 2, "'bufferCount' MUST ONLY BE '1' or '2'."); + + bool sessionHandleLockTaken = false; + ReceiveResult = null; + try + { + Monitor.Enter(_webSocket.SessionHandle, ref sessionHandleLockTaken); + _webSocket.ThrowIfPendingException(); + Initialize(buffer, cancellationToken); + + while (ShouldContinue(cancellationToken)) + { + UnsafeNativeMethods.WebSocketProtocolComponent.Action action; + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType; + + bool completed = false; + while (!completed) + { + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer[] dataBuffers = + new UnsafeNativeMethods.WebSocketProtocolComponent.Buffer[BufferCount]; + uint dataBufferCount = (uint)BufferCount; + IntPtr actionContext; + + _webSocket.ThrowIfDisposed(); + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketGetAction(_webSocket, + ActionQueue, + dataBuffers, + ref dataBufferCount, + out action, + out bufferType, + out actionContext); + + switch (action) + { + case UnsafeNativeMethods.WebSocketProtocolComponent.Action.NoAction: + if (ProcessAction_NoAction()) + { + // A close frame was received + + Contract.Assert(ReceiveResult.Count == 0, "'receiveResult.Count' MUST be 0."); + Contract.Assert(ReceiveResult.CloseStatus != null, "'receiveResult.CloseStatus' MUST NOT be NULL for message type 'Close'."); + bool thisLockTaken = false; + try + { + if (_webSocket.StartOnCloseReceived(ref thisLockTaken)) + { + // If StartOnCloseReceived returns true the WebSocket close handshake has been completed + // so there is no need to retake the SessionHandle-lock. + // m_ThisLock lock is guaranteed to be taken by StartOnCloseReceived when returning true + ReleaseLock(_webSocket.SessionHandle, ref sessionHandleLockTaken); + bool callCompleteOnCloseCompleted = false; + + try + { + callCompleteOnCloseCompleted = await _webSocket.StartOnCloseCompleted( + thisLockTaken, sessionHandleLockTaken, cancellationToken).SuppressContextFlow(); + } + catch (Exception) + { + // If an exception is thrown we know that the locks have been released, + // because we enforce IWebSocketStream.CloseNetworkConnectionAsync to yield + _webSocket.ResetFlagAndTakeLock(_webSocket._thisLock, ref thisLockTaken); + throw; + } + + if (callCompleteOnCloseCompleted) + { + _webSocket.ResetFlagAndTakeLock(_webSocket._thisLock, ref thisLockTaken); + _webSocket.FinishOnCloseCompleted(); + } + } + _webSocket.FinishOnCloseReceived(ReceiveResult.CloseStatus.Value, ReceiveResult.CloseStatusDescription); + } + finally + { + if (thisLockTaken) + { + ReleaseLock(_webSocket._thisLock, ref thisLockTaken); + } + } + } + completed = true; + break; + case UnsafeNativeMethods.WebSocketProtocolComponent.Action.IndicateReceiveComplete: + ProcessAction_IndicateReceiveComplete(buffer, + bufferType, + action, + dataBuffers, + dataBufferCount, + actionContext); + break; + case UnsafeNativeMethods.WebSocketProtocolComponent.Action.ReceiveFromNetwork: + int count = 0; + try + { + ArraySegment payload = _webSocket._internalBuffer.ConvertNativeBuffer(action, dataBuffers[0], bufferType); + + ReleaseLock(_webSocket.SessionHandle, ref sessionHandleLockTaken); + WebSocketHelpers.ThrowIfConnectionAborted(_webSocket._innerStream, true); + try + { + Task readTask = _webSocket._innerStream.ReadAsync(payload.Array, + payload.Offset, + payload.Count, + cancellationToken); + count = await readTask.SuppressContextFlow(); + _webSocket._keepAliveTracker.OnDataReceived(); + } + catch (ObjectDisposedException objectDisposedException) + { + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely, objectDisposedException); + } + catch (NotSupportedException notSupportedException) + { + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely, notSupportedException); + } + Monitor.Enter(_webSocket.SessionHandle, ref sessionHandleLockTaken); + _webSocket.ThrowIfPendingException(); + // If the client unexpectedly closed the socket we throw an exception as we didn't get any close message + if (count == 0) + { + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); + } + } + finally + { + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketCompleteAction(_webSocket, + actionContext, + count); + } + break; + case UnsafeNativeMethods.WebSocketProtocolComponent.Action.IndicateSendComplete: + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketCompleteAction(_webSocket, actionContext, 0); + ReleaseLock(_webSocket.SessionHandle, ref sessionHandleLockTaken); + await _webSocket._innerStream.FlushAsync().SuppressContextFlow(); + Monitor.Enter(_webSocket.SessionHandle, ref sessionHandleLockTaken); + break; + case UnsafeNativeMethods.WebSocketProtocolComponent.Action.SendToNetwork: + int bytesSent = 0; + try + { + if (_webSocket.State != WebSocketState.CloseSent || + (bufferType != UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.PingPong && + bufferType != UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UnsolicitedPong)) + { + if (dataBufferCount == 0) + { + break; + } + + List> sendBuffers = new List>((int)dataBufferCount); + int sendBufferSize = 0; + ArraySegment framingBuffer = _webSocket._internalBuffer.ConvertNativeBuffer(action, dataBuffers[0], bufferType); + sendBuffers.Add(framingBuffer); + sendBufferSize += framingBuffer.Count; + + // There can be at most 2 dataBuffers + // - one for the framing header and one for the payload + if (dataBufferCount == 2) + { + ArraySegment payload = _webSocket._internalBuffer.ConvertPinnedSendPayloadFromNative(dataBuffers[1], bufferType); + sendBuffers.Add(payload); + sendBufferSize += payload.Count; + } + + ReleaseLock(_webSocket.SessionHandle, ref sessionHandleLockTaken); + WebSocketHelpers.ThrowIfConnectionAborted(_webSocket._innerStream, false); + await _webSocket.SendFrameAsync(sendBuffers, cancellationToken).SuppressContextFlow(); + Monitor.Enter(_webSocket.SessionHandle, ref sessionHandleLockTaken); + _webSocket.ThrowIfPendingException(); + bytesSent += sendBufferSize; + _webSocket._keepAliveTracker.OnDataSent(); + } + } + finally + { + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketCompleteAction(_webSocket, + actionContext, + bytesSent); + } + + break; + default: + string assertMessage = string.Format(CultureInfo.InvariantCulture, + "Invalid action '{0}' returned from WebSocketGetAction.", + action); + Contract.Assert(false, assertMessage); + throw new InvalidOperationException(); + } + } + } + } + finally + { + Cleanup(); + ReleaseLock(_webSocket.SessionHandle, ref sessionHandleLockTaken); + } + + return ReceiveResult; + } + + public class ReceiveOperation : WebSocketOperation + { + private int _receiveState; + private bool _pongReceived; + private bool _receiveCompleted; + + public ReceiveOperation(WebSocketBase webSocket) + : base(webSocket) + { + } + + protected override UnsafeNativeMethods.WebSocketProtocolComponent.ActionQueue ActionQueue + { + get { return UnsafeNativeMethods.WebSocketProtocolComponent.ActionQueue.Receive; } + } + + protected override int BufferCount + { + get { return 1; } + } + + protected override void Initialize(ArraySegment? buffer, CancellationToken cancellationToken) + { + Contract.Assert(buffer != null, "'buffer' MUST NOT be NULL."); + _pongReceived = false; + _receiveCompleted = false; + _webSocket.ThrowIfDisposed(); + + int originalReceiveState = Interlocked.CompareExchange(ref _webSocket._receiveState, + ReceiveState.Application, ReceiveState.Idle); + + switch (originalReceiveState) + { + case ReceiveState.Idle: + _receiveState = ReceiveState.Application; + break; + case ReceiveState.Application: + Contract.Assert(false, "'originalReceiveState' MUST NEVER be ReceiveState.Application at this point."); + break; + case ReceiveState.PayloadAvailable: + WebSocketReceiveResult receiveResult; + if (!_webSocket._internalBuffer.ReceiveFromBufferedPayload(buffer.Value, out receiveResult)) + { + _webSocket.UpdateReceiveState(ReceiveState.Idle, ReceiveState.PayloadAvailable); + } + ReceiveResult = receiveResult; + _receiveCompleted = true; + break; + default: + Contract.Assert(false, + string.Format(CultureInfo.InvariantCulture, "Invalid ReceiveState '{0}'.", originalReceiveState)); + break; + } + } + + protected override void Cleanup() + { + } + + protected override bool ShouldContinue(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (_receiveCompleted) + { + return false; + } + + _webSocket.ThrowIfDisposed(); + _webSocket.ThrowIfPendingException(); + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketReceive(_webSocket); + + return true; + } + + protected override bool ProcessAction_NoAction() + { + if (_pongReceived) + { + _receiveCompleted = false; + _pongReceived = false; + return false; + } + + Contract.Assert(ReceiveResult != null, + "'ReceiveResult' MUST NOT be NULL."); + _receiveCompleted = true; + + if (ReceiveResult.MessageType == WebSocketMessageType.Close) + { + return true; + } + + return false; + } + + protected override void ProcessAction_IndicateReceiveComplete( + ArraySegment? buffer, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType, + UnsafeNativeMethods.WebSocketProtocolComponent.Action action, + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer[] dataBuffers, + uint dataBufferCount, + IntPtr actionContext) + { + Contract.Assert(buffer != null, "'buffer MUST NOT be NULL."); + + int bytesTransferred = 0; + _pongReceived = false; + + if (bufferType == UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.PingPong) + { + // ignoring received pong frame + _pongReceived = true; + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketCompleteAction(_webSocket, + actionContext, + bytesTransferred); + return; + } + + WebSocketReceiveResult receiveResult; + try + { + ArraySegment payload; + WebSocketMessageType messageType = GetMessageType(bufferType); + int newReceiveState = ReceiveState.Idle; + + if (bufferType == UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close) + { + payload = WebSocketHelpers.EmptyPayload; + string reason; + WebSocketCloseStatus closeStatus; + _webSocket._internalBuffer.ConvertCloseBuffer(action, dataBuffers[0], out closeStatus, out reason); + + receiveResult = new WebSocketReceiveResult(bytesTransferred, + messageType, true, closeStatus, reason); + } + else + { + payload = _webSocket._internalBuffer.ConvertNativeBuffer(action, dataBuffers[0], bufferType); + + bool endOfMessage = bufferType == + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryMessage || + bufferType == UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Message || + bufferType == UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close; + + if (payload.Count > buffer.Value.Count) + { + _webSocket._internalBuffer.BufferPayload(payload, buffer.Value.Count, messageType, endOfMessage); + newReceiveState = ReceiveState.PayloadAvailable; + endOfMessage = false; + } + + bytesTransferred = Math.Min(payload.Count, (int)buffer.Value.Count); + if (bytesTransferred > 0) + { + Buffer.BlockCopy(payload.Array, + payload.Offset, + buffer.Value.Array, + buffer.Value.Offset, + bytesTransferred); + } + + receiveResult = new WebSocketReceiveResult(bytesTransferred, messageType, endOfMessage); + } + + _webSocket.UpdateReceiveState(newReceiveState, _receiveState); + } + finally + { + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketCompleteAction(_webSocket, + actionContext, + bytesTransferred); + } + + ReceiveResult = receiveResult; + } + } + + public class SendOperation : WebSocketOperation + { + private bool _completed; + protected bool _bufferHasBeenPinned; + + public SendOperation(WebSocketBase webSocket) + : base(webSocket) + { + } + + protected override UnsafeNativeMethods.WebSocketProtocolComponent.ActionQueue ActionQueue + { + get { return UnsafeNativeMethods.WebSocketProtocolComponent.ActionQueue.Send; } + } + + protected override int BufferCount + { + get { return 2; } + } + + protected virtual UnsafeNativeMethods.WebSocketProtocolComponent.Buffer? CreateBuffer(ArraySegment? buffer) + { + if (buffer == null) + { + return null; + } + + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer payloadBuffer; + payloadBuffer = new UnsafeNativeMethods.WebSocketProtocolComponent.Buffer(); + _webSocket._internalBuffer.PinSendBuffer(buffer.Value, out _bufferHasBeenPinned); + payloadBuffer.Data.BufferData = _webSocket._internalBuffer.ConvertPinnedSendPayloadToNative(buffer.Value); + payloadBuffer.Data.BufferLength = (uint)buffer.Value.Count; + return payloadBuffer; + } + + protected override bool ProcessAction_NoAction() + { + _completed = true; + return false; + } + + protected override void Cleanup() + { + if (_bufferHasBeenPinned) + { + _bufferHasBeenPinned = false; + _webSocket._internalBuffer.ReleasePinnedSendBuffer(); + } + } + + internal UnsafeNativeMethods.WebSocketProtocolComponent.BufferType BufferType { get; set; } + + protected override void Initialize(ArraySegment? buffer, + CancellationToken cancellationToken) + { + Contract.Assert(!_bufferHasBeenPinned, "'m_BufferHasBeenPinned' MUST NOT be pinned at this point."); + _webSocket.ThrowIfDisposed(); + _webSocket.ThrowIfPendingException(); + _completed = false; + + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer? payloadBuffer = CreateBuffer(buffer); + if (payloadBuffer != null) + { + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketSend(_webSocket, BufferType, payloadBuffer.Value); + } + else + { + UnsafeNativeMethods.WebSocketProtocolComponent.WebSocketSendWithoutBody(_webSocket, BufferType); + } + } + + protected override bool ShouldContinue(CancellationToken cancellationToken) + { + Contract.Assert(ReceiveResult == null, "'ReceiveResult' MUST be NULL."); + if (_completed) + { + return false; + } + + cancellationToken.ThrowIfCancellationRequested(); + return true; + } + } + + public class CloseOutputOperation : SendOperation + { + public CloseOutputOperation(WebSocketBase webSocket) + : base(webSocket) + { + BufferType = UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close; + } + + internal WebSocketCloseStatus CloseStatus { get; set; } + internal string CloseReason { get; set; } + + protected override UnsafeNativeMethods.WebSocketProtocolComponent.Buffer? CreateBuffer(ArraySegment? buffer) + { + Contract.Assert(buffer == null, "'buffer' MUST BE NULL."); + _webSocket.ThrowIfDisposed(); + _webSocket.ThrowIfPendingException(); + + if (CloseStatus == WebSocketCloseStatus.Empty) + { + return null; + } + + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer payloadBuffer = new UnsafeNativeMethods.WebSocketProtocolComponent.Buffer(); + if (CloseReason != null) + { + byte[] blob = UTF8Encoding.UTF8.GetBytes(CloseReason); + Contract.Assert(blob.Length <= WebSocketHelpers.MaxControlFramePayloadLength, + "The close reason is too long."); + ArraySegment closeBuffer = new ArraySegment(blob, 0, Math.Min(WebSocketHelpers.MaxControlFramePayloadLength, blob.Length)); + _webSocket._internalBuffer.PinSendBuffer(closeBuffer, out _bufferHasBeenPinned); + payloadBuffer.CloseStatus.ReasonData = _webSocket._internalBuffer.ConvertPinnedSendPayloadToNative(closeBuffer); + payloadBuffer.CloseStatus.ReasonLength = (uint)closeBuffer.Count; + } + + payloadBuffer.CloseStatus.CloseStatus = (ushort)CloseStatus; + return payloadBuffer; + } + } + } + + private abstract class KeepAliveTracker : IDisposable + { + // Multi-Threading: only one thread at a time is allowed to call OnDataReceived or OnDataSent + // - but both methods can be called from different threads at the same time. + public abstract void OnDataReceived(); + public abstract void OnDataSent(); + public abstract void Dispose(); + public abstract void StartTimer(WebSocketBase webSocket); + public abstract void ResetTimer(); + public abstract bool ShouldSendKeepAlive(); + + public static KeepAliveTracker Create(TimeSpan keepAliveInterval) + { + if ((int)keepAliveInterval.TotalMilliseconds > 0) + { + return new DefaultKeepAliveTracker(keepAliveInterval); + } + + return new DisabledKeepAliveTracker(); + } + + private class DisabledKeepAliveTracker : KeepAliveTracker + { + public override void OnDataReceived() + { + } + + public override void OnDataSent() + { + } + + public override void ResetTimer() + { + } + + public override void StartTimer(WebSocketBase webSocket) + { + } + + public override bool ShouldSendKeepAlive() + { + return false; + } + + public override void Dispose() + { + } + } + + private class DefaultKeepAliveTracker : KeepAliveTracker + { + private static readonly TimerCallback _keepAliveTimerElapsedCallback = new TimerCallback(OnKeepAlive); + private readonly TimeSpan _keepAliveInterval; + private readonly Stopwatch _lastSendActivity; + private readonly Stopwatch _lastReceiveActivity; + private Timer _keepAliveTimer; + + public DefaultKeepAliveTracker(TimeSpan keepAliveInterval) + { + _keepAliveInterval = keepAliveInterval; + _lastSendActivity = new Stopwatch(); + _lastReceiveActivity = new Stopwatch(); + } + + public override void OnDataReceived() + { + _lastReceiveActivity.Restart(); + } + + public override void OnDataSent() + { + _lastSendActivity.Restart(); + } + + public override void ResetTimer() + { + ResetTimer((int)_keepAliveInterval.TotalMilliseconds); + } + + public override void StartTimer(WebSocketBase webSocket) + { + Contract.Assert(webSocket != null, "'webSocket' MUST NOT be NULL."); + Contract.Assert(webSocket._keepAliveTracker != null, + "'webSocket.m_KeepAliveTracker' MUST NOT be NULL at this point."); + int keepAliveIntervalMilliseconds = (int)_keepAliveInterval.TotalMilliseconds; + Contract.Assert(keepAliveIntervalMilliseconds > 0, "'keepAliveIntervalMilliseconds' MUST be POSITIVE."); +#if NET45 + if (ExecutionContext.IsFlowSuppressed()) + { + _keepAliveTimer = new Timer(_keepAliveTimerElapsedCallback, webSocket, keepAliveIntervalMilliseconds, Timeout.Infinite); + } + else + { + using (ExecutionContext.SuppressFlow()) + { + _keepAliveTimer = new Timer(_keepAliveTimerElapsedCallback, webSocket, keepAliveIntervalMilliseconds, Timeout.Infinite); + } + } +#else + _keepAliveTimer = new Timer(_keepAliveTimerElapsedCallback, webSocket, keepAliveIntervalMilliseconds, Timeout.Infinite); +#endif + } + + public override bool ShouldSendKeepAlive() + { + TimeSpan idleTime = GetIdleTime(); + if (idleTime >= _keepAliveInterval) + { + return true; + } + + ResetTimer((int)(_keepAliveInterval - idleTime).TotalMilliseconds); + return false; + } + + public override void Dispose() + { + _keepAliveTimer.Dispose(); + } + + private void ResetTimer(int dueInMilliseconds) + { + _keepAliveTimer.Change(dueInMilliseconds, Timeout.Infinite); + } + + private TimeSpan GetIdleTime() + { + TimeSpan sinceLastSendActivity = GetTimeElapsed(_lastSendActivity); + TimeSpan sinceLastReceiveActivity = GetTimeElapsed(_lastReceiveActivity); + + if (sinceLastReceiveActivity < sinceLastSendActivity) + { + return sinceLastReceiveActivity; + } + + return sinceLastSendActivity; + } + + private TimeSpan GetTimeElapsed(Stopwatch watch) + { + if (watch.IsRunning) + { + return watch.Elapsed; + } + + return _keepAliveInterval; + } + } + } + + private class OutstandingOperationHelper : IDisposable + { + private volatile int _operationsOutstanding; + private volatile CancellationTokenSource _cancellationTokenSource; + private volatile bool _isDisposed; + private readonly object _thisLock = new object(); + + public bool TryStartOperation(CancellationToken userCancellationToken, out CancellationToken linkedCancellationToken) + { + linkedCancellationToken = CancellationToken.None; + ThrowIfDisposed(); + + lock (_thisLock) + { + int operationsOutstanding = ++_operationsOutstanding; + + if (operationsOutstanding == 1) + { + linkedCancellationToken = CreateLinkedCancellationToken(userCancellationToken); + return true; + } + + Contract.Assert(operationsOutstanding >= 1, "'operationsOutstanding' must never be smaller than 1."); + return false; + } + } + + public void CompleteOperation(bool ownsCancellationTokenSource) + { + if (_isDisposed) + { + // no-op if the WebSocket is already aborted + return; + } + + CancellationTokenSource snapshot = null; + + lock (_thisLock) + { + --_operationsOutstanding; + Contract.Assert(_operationsOutstanding >= 0, "'m_OperationsOutstanding' must never be smaller than 0."); + + if (ownsCancellationTokenSource) + { + snapshot = _cancellationTokenSource; + _cancellationTokenSource = null; + } + } + + if (snapshot != null) + { + snapshot.Dispose(); + } + } + + // Has to be called under m_ThisLock lock + private CancellationToken CreateLinkedCancellationToken(CancellationToken cancellationToken) + { + CancellationTokenSource linkedCancellationTokenSource; + + if (cancellationToken == CancellationToken.None) + { + linkedCancellationTokenSource = new CancellationTokenSource(); + } + else + { + linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, + new CancellationTokenSource().Token); + } + + Contract.Assert(_cancellationTokenSource == null, "'m_CancellationTokenSource' MUST be NULL."); + _cancellationTokenSource = linkedCancellationTokenSource; + + return linkedCancellationTokenSource.Token; + } + + public void CancelIO() + { + CancellationTokenSource cancellationTokenSourceSnapshot = null; + + lock (_thisLock) + { + if (_operationsOutstanding == 0) + { + return; + } + + cancellationTokenSourceSnapshot = _cancellationTokenSource; + } + + if (cancellationTokenSourceSnapshot != null) + { + try + { + cancellationTokenSourceSnapshot.Cancel(); + } + catch (ObjectDisposedException) + { + // Simply ignore this exception - There is apparently a rare race condition + // where the cancellationTokensource is disposed before the Cancel method call completed. + } + } + } + + public void Dispose() + { + if (_isDisposed) + { + return; + } + + CancellationTokenSource snapshot = null; + lock (_thisLock) + { + if (_isDisposed) + { + return; + } + + _isDisposed = true; + snapshot = _cancellationTokenSource; + _cancellationTokenSource = null; + } + + if (snapshot != null) + { + snapshot.Dispose(); + } + } + + private void ThrowIfDisposed() + { + if (_isDisposed) + { + throw new ObjectDisposedException(GetType().FullName); + } + } + } + + internal interface IWebSocketStream + { + // Switching to opaque mode will change the behavior to use the knowledge that the WebSocketBase class + // is pinning all payloads already and that we will have at most one outstanding send and receive at any + // given time. This allows us to avoid creation of OverlappedData and pinning for each operation. + + void SwitchToOpaqueMode(WebSocketBase webSocket); + void Abort(); + bool SupportsMultipleWrite { get; } + Task MultipleWriteAsync(IList> buffers, CancellationToken cancellationToken); + + // Any implementation has to guarantee that no exception is thrown synchronously + // for example by enforcing a Task.Yield at the beginning of the method + // This is necessary to enforce an API contract (for WebSocketBase.StartOnCloseCompleted) that ensures + // that all locks have been released whenever an exception is thrown from it. + Task CloseNetworkConnectionAsync(CancellationToken cancellationToken); + } + + private static class ReceiveState + { + internal const int SendOperation = -1; + internal const int Idle = 0; + internal const int Application = 1; + internal const int PayloadAvailable = 2; + } + + internal static class Methods + { + internal const string ReceiveAsync = "ReceiveAsync"; + internal const string SendAsync = "SendAsync"; + internal const string CloseAsync = "CloseAsync"; + internal const string CloseOutputAsync = "CloseOutputAsync"; + internal const string Abort = "Abort"; + internal const string Initialize = "Initialize"; + internal const string Fault = "Fault"; + internal const string StartOnCloseCompleted = "StartOnCloseCompleted"; + internal const string FinishOnCloseReceived = "FinishOnCloseReceived"; + internal const string OnKeepAlive = "OnKeepAlive"; + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketBuffer.cs b/src/Microsoft.AspNet.WebSockets/WebSocketBuffer.cs new file mode 100644 index 0000000000..56b24a4fed --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketBuffer.cs @@ -0,0 +1,698 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics; +using System.Diagnostics.Contracts; +using System.Globalization; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; + +namespace Microsoft.AspNet.WebSockets +{ + // This class helps to abstract the internal WebSocket buffer, which is used to interact with the native WebSocket + // protocol component (WSPC). It helps to shield the details of the layout and the involved pointer arithmetic. + // The internal WebSocket buffer also contains a segment, which is used by the WebSocketBase class to buffer + // payload (parsed by WSPC already) for the application, if the application requested fewer bytes than the + // WSPC returned. The internal buffer is pinned for the whole lifetime if this class. + // LAYOUT: + // | Native buffer | PayloadReceiveBuffer | PropertyBuffer | + // | RBS + SBS + 144 | RBS | PBS | + // | Only WSPC may modify | Only WebSocketBase may modify | + // + // *RBS = ReceiveBufferSize, *SBS = SendBufferSize + // *PBS = PropertyBufferSize (32-bit: 16, 64 bit: 20 bytes) + internal class WebSocketBuffer : IDisposable + { + private const int NativeOverheadBufferSize = 144; + internal const int MinSendBufferSize = 16; + internal const int MinReceiveBufferSize = 256; + internal const int MaxBufferSize = 64 * 1024; +#if NET45 + private static readonly int SizeOfUInt = Marshal.SizeOf(typeof(uint)); + private static readonly int SizeOfBool = Marshal.SizeOf(typeof(bool)); +#else + private static readonly int SizeOfUInt = Marshal.SizeOf(); + private static readonly int SizeOfBool = Marshal.SizeOf(); +#endif + private static readonly int PropertyBufferSize = (2 * SizeOfUInt) + SizeOfBool + IntPtr.Size; + + private readonly int _ReceiveBufferSize; + + // Indicates the range of the pinned byte[] that can be used by the WSPC (nativeBuffer + pinnedSendBuffer) + private readonly long _StartAddress; + private readonly long _EndAddress; + private readonly GCHandle _GCHandle; + private readonly ArraySegment _InternalBuffer; + private readonly ArraySegment _NativeBuffer; + private readonly ArraySegment _PayloadBuffer; + private readonly ArraySegment _PropertyBuffer; + private readonly int _SendBufferSize; + private volatile int _PayloadOffset; + private volatile WebSocketReceiveResult _BufferedPayloadReceiveResult; + private long _PinnedSendBufferStartAddress; + private long _PinnedSendBufferEndAddress; + private ArraySegment _PinnedSendBuffer; + private GCHandle _PinnedSendBufferHandle; + private int _StateWhenDisposing = int.MinValue; + private int _SendBufferState; + + private WebSocketBuffer(ArraySegment internalBuffer, int receiveBufferSize, int sendBufferSize) + { + Contract.Assert(internalBuffer.Array != null, "'internalBuffer' MUST NOT be NULL."); + Contract.Assert(receiveBufferSize >= MinReceiveBufferSize, + "'receiveBufferSize' MUST be at least " + MinReceiveBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(sendBufferSize >= MinSendBufferSize, + "'sendBufferSize' MUST be at least " + MinSendBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(receiveBufferSize <= MaxBufferSize, + "'receiveBufferSize' MUST NOT exceed " + MaxBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(sendBufferSize <= MaxBufferSize, + "'sendBufferSize' MUST NOT exceed " + MaxBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + + _ReceiveBufferSize = receiveBufferSize; + _SendBufferSize = sendBufferSize; + _InternalBuffer = internalBuffer; + _GCHandle = GCHandle.Alloc(internalBuffer.Array, GCHandleType.Pinned); + // Size of the internal buffer owned exclusively by the WSPC. + int nativeBufferSize = _ReceiveBufferSize + _SendBufferSize + NativeOverheadBufferSize; + _StartAddress = Marshal.UnsafeAddrOfPinnedArrayElement(internalBuffer.Array, internalBuffer.Offset).ToInt64(); + _EndAddress = _StartAddress + nativeBufferSize; + _NativeBuffer = new ArraySegment(internalBuffer.Array, internalBuffer.Offset, nativeBufferSize); + _PayloadBuffer = new ArraySegment(internalBuffer.Array, + _NativeBuffer.Offset + _NativeBuffer.Count, + _ReceiveBufferSize); + _PropertyBuffer = new ArraySegment(internalBuffer.Array, + _PayloadBuffer.Offset + _PayloadBuffer.Count, + PropertyBufferSize); + _SendBufferState = SendBufferState.None; + } + + public int ReceiveBufferSize + { + get { return _ReceiveBufferSize; } + } + + public int SendBufferSize + { + get { return _SendBufferSize; } + } + + internal static WebSocketBuffer CreateClientBuffer(ArraySegment internalBuffer, int receiveBufferSize, int sendBufferSize) + { + Contract.Assert(internalBuffer.Count >= GetInternalBufferSize(receiveBufferSize, sendBufferSize, false), + "Array 'internalBuffer' is TOO SMALL. Call Validate before instantiating WebSocketBuffer."); + + return new WebSocketBuffer(internalBuffer, receiveBufferSize, GetNativeSendBufferSize(sendBufferSize, false)); + } + + internal static WebSocketBuffer CreateServerBuffer(ArraySegment internalBuffer, int receiveBufferSize) + { + int sendBufferSize = GetNativeSendBufferSize(MinSendBufferSize, true); + Contract.Assert(internalBuffer.Count >= GetInternalBufferSize(receiveBufferSize, sendBufferSize, true), + "Array 'internalBuffer' is TOO SMALL. Call Validate before instantiating WebSocketBuffer."); + + return new WebSocketBuffer(internalBuffer, receiveBufferSize, sendBufferSize); + } + + public void Dispose(WebSocketState webSocketState) + { + if (Interlocked.CompareExchange(ref _StateWhenDisposing, (int)webSocketState, int.MinValue) != int.MinValue) + { + return; + } + + this.CleanUp(); + } + + public void Dispose() + { + this.Dispose(WebSocketState.None); + } + + internal UnsafeNativeMethods.WebSocketProtocolComponent.Property[] CreateProperties(bool useZeroMaskingKey) + { + ThrowIfDisposed(); + // serialize marshaled property values in the property segment of the internal buffer + IntPtr internalBufferPtr = _GCHandle.AddrOfPinnedObject(); + int offset = _PropertyBuffer.Offset; + Marshal.WriteInt32(internalBufferPtr, offset, _ReceiveBufferSize); + offset += SizeOfUInt; + Marshal.WriteInt32(internalBufferPtr, offset, _SendBufferSize); + offset += SizeOfUInt; + Marshal.WriteIntPtr(internalBufferPtr, offset, internalBufferPtr); + offset += IntPtr.Size; + Marshal.WriteInt32(internalBufferPtr, offset, useZeroMaskingKey ? (int)1 : (int)0); + + int propertyCount = useZeroMaskingKey ? 4 : 3; + UnsafeNativeMethods.WebSocketProtocolComponent.Property[] properties = + new UnsafeNativeMethods.WebSocketProtocolComponent.Property[propertyCount]; + + // Calculate the pointers to the positions of the properties within the internal buffer + offset = _PropertyBuffer.Offset; + properties[0] = new UnsafeNativeMethods.WebSocketProtocolComponent.Property() + { + Type = UnsafeNativeMethods.WebSocketProtocolComponent.PropertyType.ReceiveBufferSize, + PropertySize = (uint)SizeOfUInt, + PropertyData = IntPtr.Add(internalBufferPtr, offset) + }; + offset += SizeOfUInt; + + properties[1] = new UnsafeNativeMethods.WebSocketProtocolComponent.Property() + { + Type = UnsafeNativeMethods.WebSocketProtocolComponent.PropertyType.SendBufferSize, + PropertySize = (uint)SizeOfUInt, + PropertyData = IntPtr.Add(internalBufferPtr, offset) + }; + offset += SizeOfUInt; + + properties[2] = new UnsafeNativeMethods.WebSocketProtocolComponent.Property() + { + Type = UnsafeNativeMethods.WebSocketProtocolComponent.PropertyType.AllocatedBuffer, + PropertySize = (uint)_NativeBuffer.Count, + PropertyData = IntPtr.Add(internalBufferPtr, offset) + }; + offset += IntPtr.Size; + + if (useZeroMaskingKey) + { + properties[3] = new UnsafeNativeMethods.WebSocketProtocolComponent.Property() + { + Type = UnsafeNativeMethods.WebSocketProtocolComponent.PropertyType.DisableMasking, + PropertySize = (uint)SizeOfBool, + PropertyData = IntPtr.Add(internalBufferPtr, offset) + }; + } + + return properties; + } + + // This method is not thread safe. It must only be called after enforcing at most 1 outstanding send operation + internal void PinSendBuffer(ArraySegment payload, out bool bufferHasBeenPinned) + { + bufferHasBeenPinned = false; + WebSocketHelpers.ValidateBuffer(payload.Array, payload.Offset, payload.Count); + int previousState = Interlocked.Exchange(ref _SendBufferState, SendBufferState.SendPayloadSpecified); + + if (previousState != SendBufferState.None) + { + Contract.Assert(false, "'m_SendBufferState' MUST BE 'None' at this point."); + // Indicates a violation in the API contract that could indicate + // memory corruption because the pinned sendbuffer is shared between managed and native code + throw new AccessViolationException(); + } + _PinnedSendBuffer = payload; + _PinnedSendBufferHandle = GCHandle.Alloc(_PinnedSendBuffer.Array, GCHandleType.Pinned); + bufferHasBeenPinned = true; + _PinnedSendBufferStartAddress = + Marshal.UnsafeAddrOfPinnedArrayElement(_PinnedSendBuffer.Array, _PinnedSendBuffer.Offset).ToInt64(); + _PinnedSendBufferEndAddress = _PinnedSendBufferStartAddress + _PinnedSendBuffer.Count; + } + + // This method is not thread safe. It must only be called after enforcing at most 1 outstanding send operation + internal IntPtr ConvertPinnedSendPayloadToNative(ArraySegment payload) + { + return ConvertPinnedSendPayloadToNative(payload.Array, payload.Offset, payload.Count); + } + + // This method is not thread safe. It must only be called after enforcing at most 1 outstanding send operation + internal IntPtr ConvertPinnedSendPayloadToNative(byte[] buffer, int offset, int count) + { + if (!IsPinnedSendPayloadBuffer(buffer, offset, count)) + { + // Indicates a violation in the API contract that could indicate + // memory corruption because the pinned sendbuffer is shared between managed and native code + throw new AccessViolationException(); + } + + Contract.Assert(Marshal.UnsafeAddrOfPinnedArrayElement(_PinnedSendBuffer.Array, + _PinnedSendBuffer.Offset).ToInt64() == _PinnedSendBufferStartAddress, + "'m_PinnedSendBuffer.Array' MUST be pinned during the entire send operation."); + + return new IntPtr(_PinnedSendBufferStartAddress + offset - _PinnedSendBuffer.Offset); + } + + // This method is not thread safe. It must only be called after enforcing at most 1 outstanding send operation + internal ArraySegment ConvertPinnedSendPayloadFromNative(UnsafeNativeMethods.WebSocketProtocolComponent.Buffer buffer, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType) + { + if (!IsPinnedSendPayloadBuffer(buffer, bufferType)) + { + // Indicates a violation in the API contract that could indicate + // memory corruption because the pinned sendbuffer is shared between managed and native code + throw new AccessViolationException(); + } + + Contract.Assert(Marshal.UnsafeAddrOfPinnedArrayElement(_PinnedSendBuffer.Array, + _PinnedSendBuffer.Offset).ToInt64() == _PinnedSendBufferStartAddress, + "'m_PinnedSendBuffer.Array' MUST be pinned during the entire send operation."); + + IntPtr bufferData; + uint bufferSize; + + UnwrapWebSocketBuffer(buffer, bufferType, out bufferData, out bufferSize); + + int internalOffset = (int)(bufferData.ToInt64() - _PinnedSendBufferStartAddress); + + return new ArraySegment(_PinnedSendBuffer.Array, _PinnedSendBuffer.Offset + internalOffset, (int)bufferSize); + } + + // This method is not thread safe. It must only be called after enforcing at most 1 outstanding send operation + private bool IsPinnedSendPayloadBuffer(byte[] buffer, int offset, int count) + { + if (_SendBufferState != SendBufferState.SendPayloadSpecified) + { + return false; + } + + return object.ReferenceEquals(buffer, _PinnedSendBuffer.Array) && + offset >= _PinnedSendBuffer.Offset && + offset + count <= _PinnedSendBuffer.Offset + _PinnedSendBuffer.Count; + } + + // This method is not thread safe. It must only be called after enforcing at most 1 outstanding send operation + private bool IsPinnedSendPayloadBuffer(UnsafeNativeMethods.WebSocketProtocolComponent.Buffer buffer, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType) + { + if (_SendBufferState != SendBufferState.SendPayloadSpecified) + { + return false; + } + + IntPtr bufferData; + uint bufferSize; + + UnwrapWebSocketBuffer(buffer, bufferType, out bufferData, out bufferSize); + + long nativeBufferStartAddress = bufferData.ToInt64(); + long nativeBufferEndAddress = nativeBufferStartAddress + bufferSize; + + return nativeBufferStartAddress >= _PinnedSendBufferStartAddress && + nativeBufferEndAddress >= _PinnedSendBufferStartAddress && + nativeBufferStartAddress <= _PinnedSendBufferEndAddress && + nativeBufferEndAddress <= _PinnedSendBufferEndAddress; + } + + // This method is only thread safe for races between Abort and at most 1 uncompleted send operation + internal void ReleasePinnedSendBuffer() + { + int previousState = Interlocked.Exchange(ref _SendBufferState, SendBufferState.None); + + if (previousState != SendBufferState.SendPayloadSpecified) + { + return; + } + + if (_PinnedSendBufferHandle.IsAllocated) + { + _PinnedSendBufferHandle.Free(); + } + + _PinnedSendBuffer = WebSocketHelpers.EmptyPayload; + } + + internal void BufferPayload(ArraySegment payload, + int unconsumedDataOffset, + WebSocketMessageType messageType, + bool endOfMessage) + { + ThrowIfDisposed(); + int bytesBuffered = payload.Count - unconsumedDataOffset; + + Contract.Assert(_PayloadOffset == 0, + "'m_PayloadOffset' MUST be '0' at this point."); + Contract.Assert(_BufferedPayloadReceiveResult == null || _BufferedPayloadReceiveResult.Count == 0, + "'m_BufferedPayloadReceiveResult.Count' MUST be '0' at this point."); + + Buffer.BlockCopy(payload.Array, + payload.Offset + unconsumedDataOffset, + _PayloadBuffer.Array, + _PayloadBuffer.Offset, + bytesBuffered); + + _BufferedPayloadReceiveResult = + new WebSocketReceiveResult(bytesBuffered, messageType, endOfMessage); + + this.ValidateBufferedPayload(); + } + + internal bool ReceiveFromBufferedPayload(ArraySegment buffer, out WebSocketReceiveResult receiveResult) + { + ThrowIfDisposed(); + ValidateBufferedPayload(); + + int bytesTransferred = Math.Min(buffer.Count, _BufferedPayloadReceiveResult.Count); + receiveResult = _BufferedPayloadReceiveResult.Copy(bytesTransferred); + + Buffer.BlockCopy(_PayloadBuffer.Array, + _PayloadBuffer.Offset + _PayloadOffset, + buffer.Array, + buffer.Offset, + bytesTransferred); + + bool morePayloadBuffered; + if (_BufferedPayloadReceiveResult.Count == 0) + { + _PayloadOffset = 0; + _BufferedPayloadReceiveResult = null; + morePayloadBuffered = false; + } + else + { + _PayloadOffset += bytesTransferred; + morePayloadBuffered = true; + this.ValidateBufferedPayload(); + } + + return morePayloadBuffered; + } + + internal ArraySegment ConvertNativeBuffer(UnsafeNativeMethods.WebSocketProtocolComponent.Action action, + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer buffer, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType) + { + ThrowIfDisposed(); + + IntPtr bufferData; + uint bufferLength; + + UnwrapWebSocketBuffer(buffer, bufferType, out bufferData, out bufferLength); + + if (bufferData == IntPtr.Zero) + { + return WebSocketHelpers.EmptyPayload; + } + + if (this.IsNativeBuffer(bufferData, bufferLength)) + { + return new ArraySegment(_InternalBuffer.Array, + this.GetOffset(bufferData), + (int)bufferLength); + } + + Contract.Assert(false, "'buffer' MUST reference a memory segment within the pinned InternalBuffer."); + // Indicates a violation in the contract with native Websocket.dll and could indicate + // memory corruption because the internal buffer is shared between managed and native code + throw new AccessViolationException(); + } + + internal void ConvertCloseBuffer(UnsafeNativeMethods.WebSocketProtocolComponent.Action action, + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer buffer, + out WebSocketCloseStatus closeStatus, + out string reason) + { + ThrowIfDisposed(); + IntPtr bufferData; + uint bufferLength; + closeStatus = (WebSocketCloseStatus)buffer.CloseStatus.CloseStatus; + + UnwrapWebSocketBuffer(buffer, UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close, out bufferData, out bufferLength); + + if (bufferData == IntPtr.Zero) + { + reason = null; + } + else + { + ArraySegment reasonBlob; + if (this.IsNativeBuffer(bufferData, bufferLength)) + { + reasonBlob = new ArraySegment(_InternalBuffer.Array, + this.GetOffset(bufferData), + (int)bufferLength); + } + else + { + Contract.Assert(false, "'buffer' MUST reference a memory segment within the pinned InternalBuffer."); + // Indicates a violation in the contract with native Websocket.dll and could indicate + // memory corruption because the internal buffer is shared between managed and native code + throw new AccessViolationException(); + } + + // No need to wrap DecoderFallbackException for invalid UTF8 chacters, because + // Encoding.UTF8 will not throw but replace invalid characters instead. + reason = Encoding.UTF8.GetString(reasonBlob.Array, reasonBlob.Offset, reasonBlob.Count); + } + } + + internal void ValidateNativeBuffers(UnsafeNativeMethods.WebSocketProtocolComponent.Action action, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType, + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer[] dataBuffers, + uint dataBufferCount) + { + Contract.Assert(dataBufferCount <= (uint)int.MaxValue, + "'dataBufferCount' MUST NOT be bigger than Int32.MaxValue."); + Contract.Assert(dataBuffers != null, "'dataBuffers' MUST NOT be NULL."); + + ThrowIfDisposed(); + if (dataBufferCount > dataBuffers.Length) + { + Contract.Assert(false, "'dataBufferCount' MUST NOT be bigger than 'dataBuffers.Length'."); + // Indicates a violation in the contract with native Websocket.dll and could indicate + // memory corruption because the internal buffer is shared between managed and native code + throw new AccessViolationException(); + } + + int count = dataBuffers.Length; + bool isSendActivity = action == UnsafeNativeMethods.WebSocketProtocolComponent.Action.IndicateSendComplete || + action == UnsafeNativeMethods.WebSocketProtocolComponent.Action.SendToNetwork; + + if (isSendActivity) + { + count = (int)dataBufferCount; + } + + bool nonZeroBufferFound = false; + for (int i = 0; i < count; i++) + { + UnsafeNativeMethods.WebSocketProtocolComponent.Buffer dataBuffer = dataBuffers[i]; + + IntPtr bufferData; + uint bufferLength; + UnwrapWebSocketBuffer(dataBuffer, bufferType, out bufferData, out bufferLength); + + if (bufferData == IntPtr.Zero) + { + continue; + } + + nonZeroBufferFound = true; + + bool isPinnedSendPayloadBuffer = IsPinnedSendPayloadBuffer(dataBuffer, bufferType); + + if (bufferLength > GetMaxBufferSize()) + { + if (!isSendActivity || !isPinnedSendPayloadBuffer) + { + Contract.Assert(false, + "'dataBuffer.BufferLength' MUST NOT be bigger than 'm_ReceiveBufferSize' and 'm_SendBufferSize'."); + // Indicates a violation in the contract with native Websocket.dll and could indicate + // memory corruption because the internal buffer is shared between managed and native code + throw new AccessViolationException(); + } + } + + if (!isPinnedSendPayloadBuffer && !IsNativeBuffer(bufferData, bufferLength)) + { + Contract.Assert(false, + "WebSocketGetAction MUST return a pointer within the pinned internal buffer."); + // Indicates a violation in the contract with native Websocket.dll and could indicate + // memory corruption because the internal buffer is shared between managed and native code + throw new AccessViolationException(); + } + } + + if (!nonZeroBufferFound && + action != UnsafeNativeMethods.WebSocketProtocolComponent.Action.NoAction && + action != UnsafeNativeMethods.WebSocketProtocolComponent.Action.IndicateReceiveComplete && + action != UnsafeNativeMethods.WebSocketProtocolComponent.Action.IndicateSendComplete) + { + Contract.Assert(false, "At least one 'dataBuffer.Buffer' MUST NOT be NULL."); + } + } + + private static int GetNativeSendBufferSize(int sendBufferSize, bool isServerBuffer) + { + return isServerBuffer ? MinSendBufferSize : sendBufferSize; + } + + internal static void UnwrapWebSocketBuffer(UnsafeNativeMethods.WebSocketProtocolComponent.Buffer buffer, + UnsafeNativeMethods.WebSocketProtocolComponent.BufferType bufferType, + out IntPtr bufferData, + out uint bufferLength) + { + bufferData = IntPtr.Zero; + bufferLength = 0; + + switch (bufferType) + { + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.Close: + bufferData = buffer.CloseStatus.ReasonData; + bufferLength = buffer.CloseStatus.ReasonLength; + break; + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.None: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryFragment: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.BinaryMessage: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Fragment: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UTF8Message: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.PingPong: + case UnsafeNativeMethods.WebSocketProtocolComponent.BufferType.UnsolicitedPong: + bufferData = buffer.Data.BufferData; + bufferLength = buffer.Data.BufferLength; + break; + default: + Contract.Assert(false, + string.Format(CultureInfo.InvariantCulture, + "BufferType '{0}' is invalid/unknown.", + bufferType)); + break; + } + } + + private void ThrowIfDisposed() + { + switch (_StateWhenDisposing) + { + case int.MinValue: + return; + case (int)WebSocketState.Closed: + case (int)WebSocketState.Aborted: + throw new WebSocketException(WebSocketError.InvalidState, + SR.GetString(SR.net_WebSockets_InvalidState_ClosedOrAborted, typeof(WebSocketBase), _StateWhenDisposing)); + default: + throw new ObjectDisposedException(GetType().FullName); + } + } + + [Conditional("DEBUG"), Conditional("CONTRACTS_FULL")] + private void ValidateBufferedPayload() + { + Contract.Assert(_BufferedPayloadReceiveResult != null, + "'m_BufferedPayloadReceiveResult' MUST NOT be NULL."); + Contract.Assert(_BufferedPayloadReceiveResult.Count >= 0, + "'m_BufferedPayloadReceiveResult.Count' MUST NOT be negative."); + Contract.Assert(_PayloadOffset >= 0, "'m_PayloadOffset' MUST NOT be smaller than 0."); + Contract.Assert(_PayloadOffset <= _PayloadBuffer.Count, + "'m_PayloadOffset' MUST NOT be bigger than 'm_PayloadBuffer.Count'."); + Contract.Assert(_PayloadOffset + _BufferedPayloadReceiveResult.Count <= _PayloadBuffer.Count, + "'m_PayloadOffset + m_PayloadBytesBuffered' MUST NOT be bigger than 'm_PayloadBuffer.Count'."); + } + + private int GetOffset(IntPtr pBuffer) + { + Contract.Assert(pBuffer != IntPtr.Zero, "'pBuffer' MUST NOT be IntPtr.Zero."); + int offset = (int)(pBuffer.ToInt64() - _StartAddress + _InternalBuffer.Offset); + + Contract.Assert(offset >= 0, "'offset' MUST NOT be negative."); + return offset; + } + + [Pure] + private int GetMaxBufferSize() + { + return Math.Max(_ReceiveBufferSize, _SendBufferSize); + } + + internal bool IsInternalBuffer(byte[] buffer, int offset, int count) + { + Contract.Assert(buffer != null, "'buffer' MUST NOT be NULL."); + Contract.Assert(_InternalBuffer.Array != null, "'m_InternalBuffer.Array' MUST NOT be NULL."); + Contract.Assert(offset >= 0, "'offset' MUST NOT be negative."); + Contract.Assert(count >= 0, "'count' MUST NOT be negative."); + Contract.Assert(offset + count <= buffer.Length, "'offset + count' MUST NOT exceed 'buffer.Length'."); + + return object.ReferenceEquals(buffer, _InternalBuffer.Array); + } + + internal IntPtr ToIntPtr(int offset) + { + Contract.Assert(offset >= 0, "'offset' MUST NOT be negative."); + Contract.Assert(_StartAddress + offset <= _EndAddress, "'offset' is TOO BIG."); + return new IntPtr(_StartAddress + offset); + } + + private bool IsNativeBuffer(IntPtr pBuffer, uint bufferSize) + { + Contract.Assert(pBuffer != IntPtr.Zero, "'pBuffer' MUST NOT be NULL."); + Contract.Assert(bufferSize <= GetMaxBufferSize(), + "'bufferSize' MUST NOT be bigger than 'm_ReceiveBufferSize' and 'm_SendBufferSize'."); + + long nativeBufferStartAddress = pBuffer.ToInt64(); + long nativeBufferEndAddress = bufferSize + nativeBufferStartAddress; + + Contract.Assert(Marshal.UnsafeAddrOfPinnedArrayElement(_InternalBuffer.Array, _InternalBuffer.Offset).ToInt64() == _StartAddress, + "'m_InternalBuffer.Array' MUST be pinned for the whole lifetime of a WebSocket."); + + if (nativeBufferStartAddress >= _StartAddress && + nativeBufferStartAddress <= _EndAddress && + nativeBufferEndAddress >= _StartAddress && + nativeBufferEndAddress <= _EndAddress) + { + return true; + } + + return false; + } + + private void CleanUp() + { + if (_GCHandle.IsAllocated) + { + _GCHandle.Free(); + } + + ReleasePinnedSendBuffer(); + } + + internal static ArraySegment CreateInternalBufferArraySegment(int receiveBufferSize, int sendBufferSize, bool isServerBuffer) + { + Contract.Assert(receiveBufferSize >= MinReceiveBufferSize, + "'receiveBufferSize' MUST be at least " + MinReceiveBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(sendBufferSize >= MinSendBufferSize, + "'sendBufferSize' MUST be at least " + MinSendBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + + int internalBufferSize = GetInternalBufferSize(receiveBufferSize, sendBufferSize, isServerBuffer); + return new ArraySegment(new byte[internalBufferSize]); + } + + internal static void Validate(int count, int receiveBufferSize, int sendBufferSize, bool isServerBuffer) + { + Contract.Assert(receiveBufferSize >= MinReceiveBufferSize, + "'receiveBufferSize' MUST be at least " + MinReceiveBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(sendBufferSize >= MinSendBufferSize, + "'sendBufferSize' MUST be at least " + MinSendBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + + int minBufferSize = GetInternalBufferSize(receiveBufferSize, sendBufferSize, isServerBuffer); + if (count < minBufferSize) + { + throw new ArgumentOutOfRangeException("internalBuffer", + SR.GetString(SR.net_WebSockets_ArgumentOutOfRange_InternalBuffer, minBufferSize)); + } + } + + private static int GetInternalBufferSize(int receiveBufferSize, int sendBufferSize, bool isServerBuffer) + { + Contract.Assert(receiveBufferSize >= MinReceiveBufferSize, + "'receiveBufferSize' MUST be at least " + MinReceiveBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(sendBufferSize >= MinSendBufferSize, + "'sendBufferSize' MUST be at least " + MinSendBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + + Contract.Assert(receiveBufferSize <= MaxBufferSize, + "'receiveBufferSize' MUST be less than or equal to " + MaxBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + Contract.Assert(sendBufferSize <= MaxBufferSize, + "'sendBufferSize' MUST be at less than or equal to " + MaxBufferSize.ToString(NumberFormatInfo.InvariantInfo) + "."); + + int nativeSendBufferSize = GetNativeSendBufferSize(sendBufferSize, isServerBuffer); + return (2 * receiveBufferSize) + nativeSendBufferSize + NativeOverheadBufferSize + PropertyBufferSize; + } + + private static class SendBufferState + { + public const int None = 0; + public const int SendPayloadSpecified = 1; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketCloseStatus.cs b/src/Microsoft.AspNet.WebSockets/WebSocketCloseStatus.cs new file mode 100644 index 0000000000..a773179e21 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketCloseStatus.cs @@ -0,0 +1,40 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.AspNet.WebSockets +{ + [SuppressMessage("Microsoft.Design", + "CA1008:EnumsShouldHaveZeroValue", + Justification = "This enum is reflecting the IETF's WebSocket specification. " + + "'0' is a disallowed value for the close status code")] + public enum WebSocketCloseStatus + { + NormalClosure = 1000, + EndpointUnavailable = 1001, + ProtocolError = 1002, + InvalidMessageType = 1003, + Empty = 1005, + // AbnormalClosure = 1006, // 1006 is reserved and should never be used by user + InvalidPayloadData = 1007, + PolicyViolation = 1008, + MessageTooBig = 1009, + MandatoryExtension = 1010, + InternalServerError = 1011 + // TLSHandshakeFailed = 1015, // 1015 is reserved and should never be used by user + + // 0 - 999 Status codes in the range 0-999 are not used. + // 1000 - 1999 Status codes in the range 1000-1999 are reserved for definition by this protocol. + // 2000 - 2999 Status codes in the range 2000-2999 are reserved for use by extensions. + // 3000 - 3999 Status codes in the range 3000-3999 MAY be used by libraries and frameworks. The + // interpretation of these codes is undefined by this protocol. End applications MUST + // NOT use status codes in this range. + // 4000 - 4999 Status codes in the range 4000-4999 MAY be used by application code. The interpretaion + // of these codes is undefined by this protocol. + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketError.cs b/src/Microsoft.AspNet.WebSockets/WebSocketError.cs new file mode 100644 index 0000000000..473bb56349 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketError.cs @@ -0,0 +1,22 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.WebSockets +{ + public enum WebSocketError + { + Success = 0, + InvalidMessageType = 1, + Faulted = 2, + NativeError = 3, + NotAWebSocket = 4, + UnsupportedVersion = 5, + UnsupportedProtocol = 6, + HeaderError = 7, + ConnectionClosedPrematurely = 8, + InvalidState = 9 + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketException.cs b/src/Microsoft.AspNet.WebSockets/WebSocketException.cs new file mode 100644 index 0000000000..64edc3d0cb --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketException.cs @@ -0,0 +1,155 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.ComponentModel; +using System.Runtime.InteropServices; + +namespace Microsoft.AspNet.WebSockets +{ +#if NET45 + [System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Usage", "CA2237:MarkISerializableTypesWithSerializable")] +#endif + internal sealed class WebSocketException : Win32Exception + { + private WebSocketError _WebSocketErrorCode; + + public WebSocketException() + : this(Marshal.GetLastWin32Error()) + { + } + + public WebSocketException(WebSocketError error) + : this(error, GetErrorMessage(error)) + { + } + + public WebSocketException(WebSocketError error, string message) : base(message) + { + _WebSocketErrorCode = error; + } + + public WebSocketException(WebSocketError error, Exception innerException) + : this(error, GetErrorMessage(error), innerException) + { + } + + public WebSocketException(WebSocketError error, string message, Exception innerException) + : base(message, innerException) + { + _WebSocketErrorCode = error; + } + + public WebSocketException(int nativeError) + : base(nativeError) + { + _WebSocketErrorCode = !UnsafeNativeMethods.WebSocketProtocolComponent.Succeeded(nativeError) ? WebSocketError.NativeError : WebSocketError.Success; + this.SetErrorCodeOnError(nativeError); + } + + public WebSocketException(int nativeError, string message) + : base(nativeError, message) + { + _WebSocketErrorCode = !UnsafeNativeMethods.WebSocketProtocolComponent.Succeeded(nativeError) ? WebSocketError.NativeError : WebSocketError.Success; + this.SetErrorCodeOnError(nativeError); + } + + public WebSocketException(int nativeError, Exception innerException) + : base(SR.GetString(SR.net_WebSockets_Generic), innerException) + { + _WebSocketErrorCode = !UnsafeNativeMethods.WebSocketProtocolComponent.Succeeded(nativeError) ? WebSocketError.NativeError : WebSocketError.Success; + this.SetErrorCodeOnError(nativeError); + } + + public WebSocketException(WebSocketError error, int nativeError) + : this(error, nativeError, GetErrorMessage(error)) + { + } + + public WebSocketException(WebSocketError error, int nativeError, string message) + : base(message) + { + _WebSocketErrorCode = error; + this.SetErrorCodeOnError(nativeError); + } + + public WebSocketException(WebSocketError error, int nativeError, Exception innerException) + : this(error, nativeError, GetErrorMessage(error), innerException) + { + } + + public WebSocketException(WebSocketError error, int nativeError, string message, Exception innerException) + : base(message, innerException) + { + _WebSocketErrorCode = error; + this.SetErrorCodeOnError(nativeError); + } + + public WebSocketException(string message) + : base(message) + { + } + + public WebSocketException(string message, Exception innerException) + : base(message, innerException) + { + } + + public override int ErrorCode + { + get + { + return base.NativeErrorCode; + } + } + + public WebSocketError WebSocketErrorCode + { + get + { + return _WebSocketErrorCode; + } + } + + private static string GetErrorMessage(WebSocketError error) + { + // provide a canned message for the error type + switch (error) + { + case WebSocketError.InvalidMessageType: + return SR.GetString(SR.net_WebSockets_InvalidMessageType_Generic, + typeof(WebSocket).Name + WebSocketBase.Methods.CloseAsync, + typeof(WebSocket).Name + WebSocketBase.Methods.CloseOutputAsync); + case WebSocketError.Faulted: + return SR.GetString(SR.net_Websockets_WebSocketBaseFaulted); + case WebSocketError.NotAWebSocket: + return SR.GetString(SR.net_WebSockets_NotAWebSocket_Generic); + case WebSocketError.UnsupportedVersion: + return SR.GetString(SR.net_WebSockets_UnsupportedWebSocketVersion_Generic); + case WebSocketError.UnsupportedProtocol: + return SR.GetString(SR.net_WebSockets_UnsupportedProtocol_Generic); + case WebSocketError.HeaderError: + return SR.GetString(SR.net_WebSockets_HeaderError_Generic); + case WebSocketError.ConnectionClosedPrematurely: + return SR.GetString(SR.net_WebSockets_ConnectionClosedPrematurely_Generic); + case WebSocketError.InvalidState: + return SR.GetString(SR.net_WebSockets_InvalidState_Generic); + default: + return SR.GetString(SR.net_WebSockets_Generic); + } + } + + // Set the error code only if there is an error (i.e. nativeError >= 0). Otherwise the code blows up on deserialization + // as the Exception..ctor() throws on setting HResult to 0. The default for HResult is -2147467259. + private void SetErrorCodeOnError(int nativeError) + { + if (!UnsafeNativeMethods.WebSocketProtocolComponent.Succeeded(nativeError)) + { + this.HResult = nativeError; + } + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketExtensions.cs b/src/Microsoft.AspNet.WebSockets/WebSocketExtensions.cs new file mode 100644 index 0000000000..abcbd12865 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketExtensions.cs @@ -0,0 +1,14 @@ +#if NET45 +using Microsoft.AspNet.WebSockets; + +namespace Owin +{ + public static class WebSocketExtensions + { + public static IAppBuilder UseWebSockets(this IAppBuilder app) + { + return app.Use(); + } + } +} +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketHelpers.cs b/src/Microsoft.AspNet.WebSockets/WebSocketHelpers.cs new file mode 100644 index 0000000000..3479e137fa --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketHelpers.cs @@ -0,0 +1,522 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.Globalization; +using System.IO; +using System.Runtime.CompilerServices; +#if NET45 +using System.Security.Cryptography; +#endif +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.WebSockets +{ + internal static class WebSocketHelpers + { + internal const string SecWebSocketKeyGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + internal const string WebSocketUpgradeToken = "websocket"; + internal const int DefaultReceiveBufferSize = 16 * 1024; + internal const int DefaultClientSendBufferSize = 16 * 1024; + internal const int MaxControlFramePayloadLength = 123; + + // RFC 6455 requests WebSocket clients to let the server initiate the TCP close to avoid that client sockets + // end up in TIME_WAIT-state + // + // After both sending and receiving a Close message, an endpoint considers the WebSocket connection closed and + // MUST close the underlying TCP connection. The server MUST close the underlying TCP connection immediately; + // the client SHOULD wait for the server to close the connection but MAY close the connection at any time after + // sending and receiving a Close message, e.g., if it has not received a TCP Close from the server in a + // reasonable time period. + internal const int ClientTcpCloseTimeout = 1000; // 1s + + private const int CloseStatusCodeAbort = 1006; + private const int CloseStatusCodeFailedTLSHandshake = 1015; + private const int InvalidCloseStatusCodesFrom = 0; + private const int InvalidCloseStatusCodesTo = 999; + private const string Separators = "()<>@,;:\\\"/[]?={} "; + + internal static readonly ArraySegment EmptyPayload = new ArraySegment(new byte[] { }, 0, 0); + private static readonly Random KeyGenerator = new Random(); + +/* + internal static Task AcceptWebSocketAsync(HttpListenerContext context, + string subProtocol, + int receiveBufferSize, + TimeSpan keepAliveInterval, + ArraySegment internalBuffer) + { + WebSocketHelpers.ValidateOptions(subProtocol, receiveBufferSize, WebSocketBuffer.MinSendBufferSize, keepAliveInterval); + WebSocketHelpers.ValidateArraySegment(internalBuffer, "internalBuffer"); + WebSocketBuffer.Validate(internalBuffer.Count, receiveBufferSize, WebSocketBuffer.MinSendBufferSize, true); + + return AcceptWebSocketAsyncCore(context, subProtocol, receiveBufferSize, keepAliveInterval, internalBuffer); + } + + private static async Task AcceptWebSocketAsyncCore(HttpListenerContext context, + string subProtocol, + int receiveBufferSize, + TimeSpan keepAliveInterval, + ArraySegment internalBuffer) + { + HttpListenerWebSocketContext webSocketContext = null; + /*if (Logging.On) + { + Logging.Enter(Logging.WebSockets, context, "AcceptWebSocketAsync", ""); + }* / + + try + { + // get property will create a new response if one doesn't exist. + HttpListenerResponse response = context.Response; + HttpListenerRequest request = context.Request; + ValidateWebSocketHeaders(context); + + string secWebSocketVersion = request.Headers[HttpKnownHeaderNames.SecWebSocketVersion]; + + // Optional for non-browser client + string origin = request.Headers[HttpKnownHeaderNames.Origin]; + + List secWebSocketProtocols = new List(); + string outgoingSecWebSocketProtocolString; + bool shouldSendSecWebSocketProtocolHeader = + WebSocketHelpers.ProcessWebSocketProtocolHeader( + request.Headers[HttpKnownHeaderNames.SecWebSocketProtocol], + subProtocol, + out outgoingSecWebSocketProtocolString); + + if (shouldSendSecWebSocketProtocolHeader) + { + secWebSocketProtocols.Add(outgoingSecWebSocketProtocolString); + response.Headers.Add(HttpKnownHeaderNames.SecWebSocketProtocol, + outgoingSecWebSocketProtocolString); + } + + // negotiate the websocket key return value + string secWebSocketKey = request.Headers[HttpKnownHeaderNames.SecWebSocketKey]; + string secWebSocketAccept = WebSocketHelpers.GetSecWebSocketAcceptString(secWebSocketKey); + + response.Headers.Add(HttpKnownHeaderNames.Connection, HttpKnownHeaderNames.Upgrade); + response.Headers.Add(HttpKnownHeaderNames.Upgrade, WebSocketHelpers.WebSocketUpgradeToken); + response.Headers.Add(HttpKnownHeaderNames.SecWebSocketAccept, secWebSocketAccept); + + response.StatusCode = (int)HttpStatusCode.SwitchingProtocols; // HTTP 101 + response.ComputeCoreHeaders(); + ulong hresult = SendWebSocketHeaders(response); + if (hresult != 0) + { + throw new WebSocketException((int)hresult, + SR.GetString(SR.net_WebSockets_NativeSendResponseHeaders, + WebSocketHelpers.MethodNames.AcceptWebSocketAsync, + hresult)); + } + + await response.OutputStream.FlushAsync().SuppressContextFlow(); // TODO:??? FlushAsync was never implemented + + HttpResponseStream responseStream = response.OutputStream as HttpResponseStream; + Contract.Assert(responseStream != null, "'responseStream' MUST be castable to System.Net.HttpResponseStream."); + ((HttpResponseStream)response.OutputStream).SwitchToOpaqueMode(); + HttpRequestStream requestStream = new HttpRequestStream(context); + requestStream.SwitchToOpaqueMode(); + WebSocketHttpListenerDuplexStream webSocketStream = + new WebSocketHttpListenerDuplexStream(requestStream, responseStream, context); + WebSocket webSocket = WebSocket.CreateServerWebSocket(webSocketStream, + subProtocol, + receiveBufferSize, + keepAliveInterval, + internalBuffer); + + webSocketContext = new HttpListenerWebSocketContext( + request.Url, + request.Headers, + request.Cookies, + context.User, + request.IsAuthenticated, + request.IsLocal, + request.IsSecureConnection, + origin, + secWebSocketProtocols.AsReadOnly(), + secWebSocketVersion, + secWebSocketKey, + webSocket); + + if (Logging.On) + { + Logging.Associate(Logging.WebSockets, context, webSocketContext); + Logging.Associate(Logging.WebSockets, webSocketContext, webSocket); + } + } + catch (Exception ex) + { + if (Logging.On) + { + Logging.Exception(Logging.WebSockets, context, "AcceptWebSocketAsync", ex); + } + throw; + } + finally + { + if (Logging.On) + { + Logging.Exit(Logging.WebSockets, context, "AcceptWebSocketAsync", ""); + } + } + return webSocketContext; + } + +*/ + [SuppressMessage("Microsoft.Cryptographic.Standard", "CA5354:SHA1CannotBeUsed", + Justification = "SHA1 used only for hashing purposes, not for crypto.")] + internal static string GetSecWebSocketAcceptString(string secWebSocketKey) + { + string retVal; +#if NET45 + // SHA1 used only for hashing purposes, not for crypto. Check here for FIPS compat. + using (SHA1 sha1 = SHA1.Create()) + { + string acceptString = string.Concat(secWebSocketKey, WebSocketHelpers.SecWebSocketKeyGuid); + byte[] toHash = Encoding.UTF8.GetBytes(acceptString); + retVal = Convert.ToBase64String(sha1.ComputeHash(toHash)); + } +#endif + return retVal; + } + + internal static string GetTraceMsgForParameters(int offset, int count, CancellationToken cancellationToken) + { + return string.Format(CultureInfo.InvariantCulture, + "offset: {0}, count: {1}, cancellationToken.CanBeCanceled: {2}", + offset, + count, + cancellationToken.CanBeCanceled); + } + + // return value here signifies if a Sec-WebSocket-Protocol header should be returned by the server. + internal static bool ProcessWebSocketProtocolHeader(string clientSecWebSocketProtocol, + string subProtocol, + out string acceptProtocol) + { + acceptProtocol = string.Empty; + if (string.IsNullOrEmpty(clientSecWebSocketProtocol)) + { + // client hasn't specified any Sec-WebSocket-Protocol header + if (subProtocol != null) + { + // If the server specified _anything_ this isn't valid. + throw new WebSocketException(WebSocketError.UnsupportedProtocol, + SR.GetString(SR.net_WebSockets_ClientAcceptingNoProtocols, subProtocol)); + } + // Treat empty and null from the server as the same thing here, server should not send headers. + return false; + } + + // here, we know the client specified something and it's non-empty. + + if (subProtocol == null) + { + // client specified some protocols, server specified 'null'. So server should send headers. + return true; + } + + // here, we know that the client has specified something, it's not empty + // and the server has specified exactly one protocol + + string[] requestProtocols = clientSecWebSocketProtocol.Split(new char[] { ',' }, + StringSplitOptions.RemoveEmptyEntries); + acceptProtocol = subProtocol; + + // client specified protocols, serverOptions has exactly 1 non-empty entry. Check that + // this exists in the list the client specified. + for (int i = 0; i < requestProtocols.Length; i++) + { + string currentRequestProtocol = requestProtocols[i].Trim(); + if (string.Compare(acceptProtocol, currentRequestProtocol, StringComparison.OrdinalIgnoreCase) == 0) + { + return true; + } + } + + throw new WebSocketException(WebSocketError.UnsupportedProtocol, + SR.GetString(SR.net_WebSockets_AcceptUnsupportedProtocol, + clientSecWebSocketProtocol, + subProtocol)); + } + + internal static ConfiguredTaskAwaitable SuppressContextFlow(this Task task) + { + // We don't flow the synchronization context within WebSocket.xxxAsync - but the calling application + // can decide whether the completion callback for the task returned from WebSocket.xxxAsync runs + // under the caller's synchronization context. + return task.ConfigureAwait(false); + } + + internal static ConfiguredTaskAwaitable SuppressContextFlow(this Task task) + { + // We don't flow the synchronization context within WebSocket.xxxAsync - but the calling application + // can decide whether the completion callback for the task returned from WebSocket.xxxAsync runs + // under the caller's synchronization context. + return task.ConfigureAwait(false); + } + + internal static void ValidateBuffer(byte[] buffer, int offset, int count) + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset"); + } + + if (count < 0 || count > (buffer.Length - offset)) + { + throw new ArgumentOutOfRangeException("count"); + } + } + /* + private static unsafe ulong SendWebSocketHeaders(HttpListenerResponse response) + { + return response.SendHeaders(null, null, + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_OPAQUE | + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_MORE_DATA | + UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.HTTP_SEND_RESPONSE_FLAG_BUFFER_DATA, + true); + } + private static void ValidateWebSocketHeaders(HttpListenerContext context) + { + EnsureHttpSysSupportsWebSockets(); + + if (!context.Request.IsWebSocketRequest) + { + throw new WebSocketException(WebSocketError.NotAWebSocket, + SR.GetString(SR.net_WebSockets_AcceptNotAWebSocket, + WebSocketHelpers.MethodNames.ValidateWebSocketHeaders, + HttpKnownHeaderNames.Connection, + HttpKnownHeaderNames.Upgrade, + WebSocketHelpers.WebSocketUpgradeToken, + context.Request.Headers[HttpKnownHeaderNames.Upgrade])); + } + + string secWebSocketVersion = context.Request.Headers[HttpKnownHeaderNames.SecWebSocketVersion]; + if (string.IsNullOrEmpty(secWebSocketVersion)) + { + throw new WebSocketException(WebSocketError.HeaderError, + SR.GetString(SR.net_WebSockets_AcceptHeaderNotFound, + WebSocketHelpers.MethodNames.ValidateWebSocketHeaders, + HttpKnownHeaderNames.SecWebSocketVersion)); + } + + if (string.Compare(secWebSocketVersion, WebSocketProtocolComponent.SupportedVersion, StringComparison.OrdinalIgnoreCase) != 0) + { + throw new WebSocketException(WebSocketError.UnsupportedVersion, + SR.GetString(SR.net_WebSockets_AcceptUnsupportedWebSocketVersion, + WebSocketHelpers.MethodNames.ValidateWebSocketHeaders, + secWebSocketVersion, + WebSocketProtocolComponent.SupportedVersion)); + } + + if (string.IsNullOrWhiteSpace(context.Request.Headers[HttpKnownHeaderNames.SecWebSocketKey])) + { + throw new WebSocketException(WebSocketError.HeaderError, + SR.GetString(SR.net_WebSockets_AcceptHeaderNotFound, + WebSocketHelpers.MethodNames.ValidateWebSocketHeaders, + HttpKnownHeaderNames.SecWebSocketKey)); + } + } + */ + + internal static void ValidateSubprotocol(string subProtocol) + { + if (string.IsNullOrWhiteSpace(subProtocol)) + { + throw new ArgumentException(SR.GetString(SR.net_WebSockets_InvalidEmptySubProtocol), "subProtocol"); + } + + char[] chars = subProtocol.ToCharArray(); + string invalidChar = null; + int i = 0; + while (i < chars.Length) + { + char ch = chars[i]; + if (ch < 0x21 || ch > 0x7e) + { + invalidChar = string.Format(CultureInfo.InvariantCulture, "[{0}]", (int)ch); + break; + } + + if (!char.IsLetterOrDigit(ch) && + Separators.IndexOf(ch) >= 0) + { + invalidChar = ch.ToString(); + break; + } + + i++; + } + + if (invalidChar != null) + { + throw new ArgumentException(SR.GetString(SR.net_WebSockets_InvalidCharInProtocolString, subProtocol, invalidChar), + "subProtocol"); + } + } + + internal static void ValidateCloseStatus(WebSocketCloseStatus closeStatus, string statusDescription) + { + if (closeStatus == WebSocketCloseStatus.Empty && !string.IsNullOrEmpty(statusDescription)) + { + throw new ArgumentException(SR.GetString(SR.net_WebSockets_ReasonNotNull, + statusDescription, + WebSocketCloseStatus.Empty), + "statusDescription"); + } + + int closeStatusCode = (int)closeStatus; + + if ((closeStatusCode >= InvalidCloseStatusCodesFrom && + closeStatusCode <= InvalidCloseStatusCodesTo) || + closeStatusCode == CloseStatusCodeAbort || + closeStatusCode == CloseStatusCodeFailedTLSHandshake) + { + // CloseStatus 1006 means Aborted - this will never appear on the wire and is reflected by calling WebSocket.Abort + throw new ArgumentException(SR.GetString(SR.net_WebSockets_InvalidCloseStatusCode, + closeStatusCode), + "closeStatus"); + } + + int length = 0; + if (!string.IsNullOrEmpty(statusDescription)) + { + length = UTF8Encoding.UTF8.GetByteCount(statusDescription); + } + + if (length > WebSocketHelpers.MaxControlFramePayloadLength) + { + throw new ArgumentException(SR.GetString(SR.net_WebSockets_InvalidCloseStatusDescription, + statusDescription, + WebSocketHelpers.MaxControlFramePayloadLength), + "statusDescription"); + } + } + + internal static void ValidateOptions(string subProtocol, + int receiveBufferSize, + int sendBufferSize, + TimeSpan keepAliveInterval) + { + // We allow the subProtocol to be null. Validate if it is not null. + if (subProtocol != null) + { + ValidateSubprotocol(subProtocol); + } + + ValidateBufferSizes(receiveBufferSize, sendBufferSize); + + // -1 + if (keepAliveInterval < Timeout.InfiniteTimeSpan) + { + throw new ArgumentOutOfRangeException("keepAliveInterval", keepAliveInterval, + SR.GetString(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, Timeout.InfiniteTimeSpan.ToString())); + } + } + + internal static void ValidateBufferSizes(int receiveBufferSize, int sendBufferSize) + { + if (receiveBufferSize < WebSocketBuffer.MinReceiveBufferSize) + { + throw new ArgumentOutOfRangeException("receiveBufferSize", receiveBufferSize, + SR.GetString(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, WebSocketBuffer.MinReceiveBufferSize)); + } + + if (sendBufferSize < WebSocketBuffer.MinSendBufferSize) + { + throw new ArgumentOutOfRangeException("sendBufferSize", sendBufferSize, + SR.GetString(SR.net_WebSockets_ArgumentOutOfRange_TooSmall, WebSocketBuffer.MinSendBufferSize)); + } + + if (receiveBufferSize > WebSocketBuffer.MaxBufferSize) + { + throw new ArgumentOutOfRangeException("receiveBufferSize", receiveBufferSize, + SR.GetString(SR.net_WebSockets_ArgumentOutOfRange_TooBig, + "receiveBufferSize", + receiveBufferSize, + WebSocketBuffer.MaxBufferSize)); + } + + if (sendBufferSize > WebSocketBuffer.MaxBufferSize) + { + throw new ArgumentOutOfRangeException("sendBufferSize", sendBufferSize, + SR.GetString(SR.net_WebSockets_ArgumentOutOfRange_TooBig, + "sendBufferSize", + sendBufferSize, + WebSocketBuffer.MaxBufferSize)); + } + } + + internal static void ValidateInnerStream(Stream innerStream) + { + if (innerStream == null) + { + throw new ArgumentNullException("innerStream"); + } + + if (!innerStream.CanRead) + { + throw new ArgumentException(SR.GetString(SR.NotReadableStream), "innerStream"); + } + + if (!innerStream.CanWrite) + { + throw new ArgumentException(SR.GetString(SR.NotWriteableStream), "innerStream"); + } + } + + internal static void ThrowIfConnectionAborted(Stream connection, bool read) + { + if ((!read && !connection.CanWrite) || + (read && !connection.CanRead)) + { + throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely); + } + } + + internal static void ThrowPlatformNotSupportedException_WSPC() + { + throw new PlatformNotSupportedException(SR.GetString(SR.net_WebSockets_UnsupportedPlatform)); + } + + internal static void ValidateArraySegment(ArraySegment arraySegment, string parameterName) + { + Contract.Requires(!string.IsNullOrEmpty(parameterName), "'parameterName' MUST NOT be NULL or string.Empty"); + + if (arraySegment.Array == null) + { + throw new ArgumentNullException(parameterName + ".Array"); + } + + if (arraySegment.Offset < 0 || arraySegment.Offset > arraySegment.Array.Length) + { + throw new ArgumentOutOfRangeException(parameterName + ".Offset"); + } + if (arraySegment.Count < 0 || arraySegment.Count > (arraySegment.Array.Length - arraySegment.Offset)) + { + throw new ArgumentOutOfRangeException(parameterName + ".Count"); + } + } + + internal static class MethodNames + { + internal const string AcceptWebSocketAsync = "AcceptWebSocketAsync"; + internal const string ValidateWebSocketHeaders = "ValidateWebSocketHeaders"; + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketMessageType.cs b/src/Microsoft.AspNet.WebSockets/WebSocketMessageType.cs new file mode 100644 index 0000000000..1661721a14 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketMessageType.cs @@ -0,0 +1,15 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.WebSockets +{ + public enum WebSocketMessageType + { + Text = 0x1, + Binary = 0x2, + Close = 0x8, + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketMiddleware.cs b/src/Microsoft.AspNet.WebSockets/WebSocketMiddleware.cs new file mode 100644 index 0000000000..7a77a6a4f6 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketMiddleware.cs @@ -0,0 +1,167 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Owin; + +namespace Microsoft.AspNet.WebSockets +{ + using AppFunc = Func, Task>; + using OpaqueUpgrade = + Action + < + IDictionary, // Parameters + Func // OpaqueFunc callback + < + IDictionary, // Opaque environment + Task // Complete + > + >; + using WebSocketAccept = + Action + < + IDictionary, // WebSocket Accept parameters + Func // WebSocketFunc callback + < + IDictionary, // WebSocket environment + Task // Complete + > + >; + using WebSocketFunc = + Func + < + IDictionary, // WebSocket Environment + Task // Complete + >; + + public class WebSocketMiddleware + { + private AppFunc _next; + + public WebSocketMiddleware(AppFunc next) + { + _next = next; + } + + public Task Invoke(IDictionary environment) + { + IOwinContext context = new OwinContext(environment); + // Detect if an opaque upgrade is available, and if websocket upgrade headers are present. + // If so, add a websocket upgrade. + OpaqueUpgrade upgrade = context.Get("opaque.Upgrade"); + if (upgrade != null) + { + // Headers and values: + // Connection: Upgrade + // Upgrade: WebSocket + // Sec-WebSocket-Version: (WebSocketProtocolComponent.SupportedVersion) + // Sec-WebSocket-Key: (hash, see WebSocketHelpers.GetSecWebSocketAcceptString) + // Sec-WebSocket-Protocol: (optional, list) + IList connectionHeaders = context.Request.Headers.GetCommaSeparatedValues(HttpKnownHeaderNames.Connection); // "Upgrade, KeepAlive" + string upgradeHeader = context.Request.Headers[HttpKnownHeaderNames.Upgrade]; + string versionHeader = context.Request.Headers[HttpKnownHeaderNames.SecWebSocketVersion]; + string keyHeader = context.Request.Headers[HttpKnownHeaderNames.SecWebSocketKey]; + + if (connectionHeaders != null && connectionHeaders.Count > 0 + && connectionHeaders.Contains(HttpKnownHeaderNames.Upgrade, StringComparer.OrdinalIgnoreCase) + && string.Equals(upgradeHeader, WebSocketHelpers.WebSocketUpgradeToken, StringComparison.OrdinalIgnoreCase) + && string.Equals(versionHeader, UnsafeNativeMethods.WebSocketProtocolComponent.SupportedVersion, StringComparison.OrdinalIgnoreCase) + && !string.IsNullOrWhiteSpace(keyHeader)) + { + environment["websocket.Accept"] = new WebSocketAccept(new UpgradeHandshake(context, upgrade).AcceptWebSocket); + } + } + + return _next(environment); + } + + private class UpgradeHandshake + { + private IOwinContext _context; + private OpaqueUpgrade _upgrade; + private WebSocketFunc _webSocketFunc; + + private string _subProtocol; + private int _receiveBufferSize = WebSocketHelpers.DefaultReceiveBufferSize; + private TimeSpan _keepAliveInterval = WebSocket.DefaultKeepAliveInterval; + private ArraySegment _internalBuffer; + + internal UpgradeHandshake(IOwinContext context, OpaqueUpgrade upgrade) + { + _context = context; + _upgrade = upgrade; + } + + internal void AcceptWebSocket(IDictionary options, WebSocketFunc webSocketFunc) + { + _webSocketFunc = webSocketFunc; + + // Get options + object temp; + if (options != null && options.TryGetValue("websocket.SubProtocol", out temp)) + { + _subProtocol = temp as string; + } + if (options != null && options.TryGetValue("websocket.ReceiveBufferSize", out temp)) + { + _receiveBufferSize = (int)temp; + } + if (options != null && options.TryGetValue("websocket.KeepAliveInterval", out temp)) + { + _keepAliveInterval = (TimeSpan)temp; + } + if (options != null && options.TryGetValue("websocket.Buffer", out temp)) + { + _internalBuffer = (ArraySegment)temp; + } + else + { + _internalBuffer = WebSocketBuffer.CreateInternalBufferArraySegment(_receiveBufferSize, WebSocketBuffer.MinSendBufferSize, true); + } + + // Set WebSocket upgrade response headers + + string outgoingSecWebSocketProtocolString; + bool shouldSendSecWebSocketProtocolHeader = + WebSocketHelpers.ProcessWebSocketProtocolHeader( + _context.Request.Headers[HttpKnownHeaderNames.SecWebSocketProtocol], + _subProtocol, + out outgoingSecWebSocketProtocolString); + + if (shouldSendSecWebSocketProtocolHeader) + { + _context.Response.Headers[HttpKnownHeaderNames.SecWebSocketProtocol] = outgoingSecWebSocketProtocolString; + } + + string secWebSocketKey = _context.Request.Headers[HttpKnownHeaderNames.SecWebSocketKey]; + string secWebSocketAccept = WebSocketHelpers.GetSecWebSocketAcceptString(secWebSocketKey); + + _context.Response.Headers[HttpKnownHeaderNames.Connection] = HttpKnownHeaderNames.Upgrade; + _context.Response.Headers[HttpKnownHeaderNames.Upgrade] = WebSocketHelpers.WebSocketUpgradeToken; + _context.Response.Headers[HttpKnownHeaderNames.SecWebSocketAccept] = secWebSocketAccept; + + _context.Response.StatusCode = 101; // Switching Protocols; + + _upgrade(options, OpaqueCallback); + } + + internal async Task OpaqueCallback(IDictionary opaqueEnv) + { + // Create WebSocket wrapper around the opaque env + WebSocket webSocket = CreateWebSocket(opaqueEnv); + OwinWebSocketWrapper wrapper = new OwinWebSocketWrapper(webSocket, (CancellationToken)opaqueEnv["opaque.CallCancelled"]); + await _webSocketFunc(wrapper.Environment); + // Close down the WebSocekt, gracefully if possible + await wrapper.CleanupAsync(); + } + + private WebSocket CreateWebSocket(IDictionary opaqueEnv) + { + Stream stream = (Stream)opaqueEnv["opaque.Stream"]; + return new ServerWebSocket(stream, _subProtocol, _receiveBufferSize, _keepAliveInterval, _internalBuffer); + } + } + } +} diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketReceiveResult.cs b/src/Microsoft.AspNet.WebSockets/WebSocketReceiveResult.cs new file mode 100644 index 0000000000..a48be62a2d --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketReceiveResult.cs @@ -0,0 +1,55 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +using System; +using System.Diagnostics.Contracts; + +namespace Microsoft.AspNet.WebSockets +{ + public class WebSocketReceiveResult + { + public WebSocketReceiveResult(int count, WebSocketMessageType messageType, bool endOfMessage) + : this(count, messageType, endOfMessage, null, null) + { + } + + public WebSocketReceiveResult(int count, + WebSocketMessageType messageType, + bool endOfMessage, + WebSocketCloseStatus? closeStatus, + string closeStatusDescription) + { + if (count < 0) + { + throw new ArgumentOutOfRangeException("count"); + } + + this.Count = count; + this.EndOfMessage = endOfMessage; + this.MessageType = messageType; + this.CloseStatus = closeStatus; + this.CloseStatusDescription = closeStatusDescription; + } + + public int Count { get; private set; } + public bool EndOfMessage { get; private set; } + public WebSocketMessageType MessageType { get; private set; } + public WebSocketCloseStatus? CloseStatus { get; private set; } + public string CloseStatusDescription { get; private set; } + + internal WebSocketReceiveResult Copy(int count) + { + Contract.Assert(count >= 0, "'count' MUST NOT be negative."); + Contract.Assert(count <= this.Count, "'count' MUST NOT be bigger than 'this.Count'."); + this.Count -= count; + return new WebSocketReceiveResult(count, + this.MessageType, + this.Count == 0 && this.EndOfMessage, + this.CloseStatus, + this.CloseStatusDescription); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/WebSocketState.cs b/src/Microsoft.AspNet.WebSockets/WebSocketState.cs new file mode 100644 index 0000000000..3c58d00a20 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/WebSocketState.cs @@ -0,0 +1,19 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.AspNet.WebSockets +{ + public enum WebSocketState + { + None = 0, + Connecting = 1, + Open = 2, + CloseSent = 3, // WebSocket close handshake started form local endpoint + CloseReceived = 4, // WebSocket close message received from remote endpoint. Waiting for app to call close + Closed = 5, + Aborted = 6, + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/build.cmd b/src/Microsoft.AspNet.WebSockets/build.cmd new file mode 100644 index 0000000000..110694d575 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/build.cmd @@ -0,0 +1,3 @@ +rem set TARGET_FRAMEWORK=k10 +@call ..\..\packages\ProjectK.0.0.1-pre-30121-096\tools\k build +pause \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs b/src/Microsoft.AspNet.WebSockets/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs new file mode 100644 index 0000000000..861c9e31ab --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/Microsoft/Win32/SafeHandles/SafeHandleZeroOrMinusOneIsInvalid.cs @@ -0,0 +1,31 @@ +// ==++== +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ==--== + +#if !NET45 + +namespace Microsoft.Win32.SafeHandles +{ + using System; + using System.Runtime.InteropServices; + using System.Runtime.CompilerServices; + + // Class of safe handle which uses 0 or -1 as an invalid handle. + [System.Security.SecurityCritical] // auto-generated_required + internal abstract class SafeHandleZeroOrMinusOneIsInvalid : SafeHandle + { + protected SafeHandleZeroOrMinusOneIsInvalid(bool ownsHandle) + : base(IntPtr.Zero, ownsHandle) + { + } + + public override bool IsInvalid + { + [System.Security.SecurityCritical] + get { return handle == new IntPtr(0) || handle == new IntPtr(-1); } + } + } +} +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/fx/System/AccessViolationException.cs b/src/Microsoft.AspNet.WebSockets/fx/System/AccessViolationException.cs new file mode 100644 index 0000000000..c32cafb4e1 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/System/AccessViolationException.cs @@ -0,0 +1,16 @@ +#if !NET45 + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace System +{ + internal class AccessViolationException : SystemException + { + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/fx/System/ComponentModel/Win32Exception.cs b/src/Microsoft.AspNet.WebSockets/fx/System/ComponentModel/Win32Exception.cs new file mode 100644 index 0000000000..108303ee2c --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/System/ComponentModel/Win32Exception.cs @@ -0,0 +1,112 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 + +using System.Runtime.InteropServices; +using System.Text; + +namespace System.ComponentModel +{ + internal class Win32Exception : ExternalException + { + /// + /// Represents the Win32 error code associated with this exception. This + /// field is read-only. + /// + private readonly int nativeErrorCode; + + /// + /// Initializes a new instance of the class with the last Win32 error + /// that occured. + /// + public Win32Exception() + : this(Marshal.GetLastWin32Error()) + { + } + /// + /// Initializes a new instance of the class with the specified error. + /// + public Win32Exception(int error) + : this(error, GetErrorMessage(error)) + { + } + /// + /// Initializes a new instance of the class with the specified error and the + /// specified detailed description. + /// + public Win32Exception(int error, string message) + : base(message) + { + nativeErrorCode = error; + } + + /// + /// Initializes a new instance of the Exception class with a specified error message. + /// FxCop CA1032: Multiple constructors are required to correctly implement a custom exception. + /// + public Win32Exception(string message) + : this(Marshal.GetLastWin32Error(), message) + { + } + + /// + /// Initializes a new instance of the Exception class with a specified error message and a + /// reference to the inner exception that is the cause of this exception. + /// FxCop CA1032: Multiple constructors are required to correctly implement a custom exception. + /// + public Win32Exception(string message, Exception innerException) + : base(message, innerException) + { + nativeErrorCode = Marshal.GetLastWin32Error(); + } + + + /// + /// Represents the Win32 error code associated with this exception. This + /// field is read-only. + /// + public int NativeErrorCode + { + get + { + return nativeErrorCode; + } + } + + private static string GetErrorMessage(int error) + { + //get the system error message... + string errorMsg = ""; + StringBuilder sb = new StringBuilder(256); + int result = SafeNativeMethods.FormatMessage( + SafeNativeMethods.FORMAT_MESSAGE_IGNORE_INSERTS | + SafeNativeMethods.FORMAT_MESSAGE_FROM_SYSTEM | + SafeNativeMethods.FORMAT_MESSAGE_ARGUMENT_ARRAY, + IntPtr.Zero, (uint)error, 0, sb, sb.Capacity + 1, + null); + if (result != 0) + { + int i = sb.Length; + while (i > 0) + { + char ch = sb[i - 1]; + if (ch > 32 && ch != '.') break; + i--; + } + errorMsg = sb.ToString(0, i); + } + else + { + errorMsg = "Unknown error (0x" + Convert.ToString(error, 16) + ")"; + } + + return errorMsg; + } + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/fx/System/ExternDll.cs b/src/Microsoft.AspNet.WebSockets/fx/System/ExternDll.cs new file mode 100644 index 0000000000..98303d48b8 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/System/ExternDll.cs @@ -0,0 +1,16 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 + +namespace System +{ + internal static class ExternDll + { + public const string Kernel32 = "kernel32.dll"; + } +} +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/fx/System/Runtime/InteropServices/ExternalException.cs b/src/Microsoft.AspNet.WebSockets/fx/System/Runtime/InteropServices/ExternalException.cs new file mode 100644 index 0000000000..21d82b1de4 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/System/Runtime/InteropServices/ExternalException.cs @@ -0,0 +1,98 @@ +// ==++== +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ==--== +/*============================================================================= +** +** Class: ExternalException +** +** +** Purpose: Exception base class for all errors from Interop or Structured +** Exception Handling code. +** +** +=============================================================================*/ + +#if !NET45 + +namespace System.Runtime.InteropServices +{ + using System; + using System.Globalization; + + // Base exception for COM Interop errors &; Structured Exception Handler + // exceptions. + // + internal class ExternalException : Exception + { + public ExternalException() + { + SetErrorCode(__HResults.E_FAIL); + } + + public ExternalException(String message) + : base(message) + { + SetErrorCode(__HResults.E_FAIL); + } + + public ExternalException(String message, Exception inner) + : base(message, inner) + { + SetErrorCode(__HResults.E_FAIL); + } + + public ExternalException(String message, int errorCode) + : base(message) + { + SetErrorCode(errorCode); + } + + private void SetErrorCode(int errorCode) + { + HResult = ErrorCode; + } + + private static class __HResults + { + internal const int E_FAIL = unchecked((int)0x80004005); + } + + public virtual int ErrorCode + { + get + { + return HResult; + } + } + + public override String ToString() + { + String message = Message; + String s; + String _className = GetType().ToString(); + s = _className + " (0x" + HResult.ToString("X8", CultureInfo.InvariantCulture) + ")"; + + if (!(String.IsNullOrEmpty(message))) + { + s = s + ": " + message; + } + + Exception _innerException = InnerException; + + if (_innerException != null) + { + s = s + " ---> " + _innerException.ToString(); + } + + + if (StackTrace != null) + s += Environment.NewLine + StackTrace; + + return s; + } + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/fx/System/SafeNativeMethods.cs b/src/Microsoft.AspNet.WebSockets/fx/System/SafeNativeMethods.cs new file mode 100644 index 0000000000..969d273fa8 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/System/SafeNativeMethods.cs @@ -0,0 +1,28 @@ +//------------------------------------------------------------------------------ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//------------------------------------------------------------------------------ + +#if !NET45 +using System.Runtime.InteropServices; +using System.Text; + +namespace System +{ + internal static class SafeNativeMethods + { + public const int + FORMAT_MESSAGE_ALLOCATE_BUFFER = 0x00000100, + FORMAT_MESSAGE_IGNORE_INSERTS = 0x00000200, + FORMAT_MESSAGE_FROM_STRING = 0x00000400, + FORMAT_MESSAGE_FROM_SYSTEM = 0x00001000, + FORMAT_MESSAGE_ARGUMENT_ARRAY = 0x00002000; + + [DllImport(ExternDll.Kernel32, CharSet = System.Runtime.InteropServices.CharSet.Unicode, SetLastError = true, BestFitMapping = true)] + public static unsafe extern int FormatMessage(int dwFlags, IntPtr lpSource_mustBeNull, uint dwMessageId, + int dwLanguageId, StringBuilder lpBuffer, int nSize, IntPtr[] arguments); + + } +} +#endif diff --git a/src/Microsoft.AspNet.WebSockets/fx/System/SystemException.cs b/src/Microsoft.AspNet.WebSockets/fx/System/SystemException.cs new file mode 100644 index 0000000000..4b6c66b4a4 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/fx/System/SystemException.cs @@ -0,0 +1,16 @@ +#if !NET45 + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace System +{ + internal class SystemException : Exception + { + } +} + +#endif \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/packages.config b/src/Microsoft.AspNet.WebSockets/packages.config new file mode 100644 index 0000000000..fb47f58ced --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/packages.config @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/src/Microsoft.AspNet.WebSockets/project.json b/src/Microsoft.AspNet.WebSockets/project.json new file mode 100644 index 0000000000..943893dbd2 --- /dev/null +++ b/src/Microsoft.AspNet.WebSockets/project.json @@ -0,0 +1,17 @@ +{ + "version": "0.1-alpha-*", + "dependencies": { + "Microsoft.AspNet.Abstractions" : "0.1-alpha-*", + "Microsoft.AspNet.HttpFeature" : "0.1-alpha-*" + }, + "compilationOptions" : { "allowUnsafe": true }, + "configurations": + { + "net45" : { + "dependencies": { + "Owin": "1.0", + "Microsoft.Owin": "2.1.0" + } + } + } +} diff --git a/test/Microsoft.AspNet.Security.Windows.Test/DenyAnonymous.cs b/test/Microsoft.AspNet.Security.Windows.Test/DenyAnonymous.cs new file mode 100644 index 0000000000..3bd1cc9f00 --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/DenyAnonymous.cs @@ -0,0 +1,47 @@ +// +// Copyright 2011-2012 Katana contributors +// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Security.Principal; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Security.Windows.Tests +{ + using AppFunc = Func, Task>; + + // This middleware can be placed at the end of a chain of pass-through auth schemes if at least one type of auth is required. + public class DenyAnonymous + { + private readonly AppFunc _nextApp; + + public DenyAnonymous(AppFunc nextApp) + { + _nextApp = nextApp; + } + + public async Task Invoke(IDictionary env) + { + if (env.Get("server.User") == null) + { + env["owin.ResponseStatusCode"] = 401; + return; + } + + await _nextApp(env); + } + } +} diff --git a/test/Microsoft.AspNet.Security.Windows.Test/DictionaryExtensions.cs b/test/Microsoft.AspNet.Security.Windows.Test/DictionaryExtensions.cs new file mode 100644 index 0000000000..a1a6e3577a --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/DictionaryExtensions.cs @@ -0,0 +1,51 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Linq; +using System.Text; + +namespace System.Collections.Generic +{ + internal static class DictionaryExtensions + { + internal static void Append(this IDictionary dictionary, string key, string value) + { + string[] orriginalValues; + if (dictionary.TryGetValue(key, out orriginalValues)) + { + string[] newValues = new string[orriginalValues.Length + 1]; + orriginalValues.CopyTo(newValues, 0); + newValues[newValues.Length - 1] = value; + dictionary[key] = newValues; + } + else + { + dictionary[key] = new string[] { value }; + } + } + + internal static string Get(this IDictionary dictionary, string key) + { + string[] values; + if (dictionary.TryGetValue(key, out values)) + { + return string.Join(", ", values); + } + return null; + } + + internal static T Get(this IDictionary dictionary, string key, T fallback = default(T)) + { + object values; + if (dictionary.TryGetValue(key, out values)) + { + return (T)values; + } + return fallback; + } + } +} diff --git a/test/Microsoft.AspNet.Security.Windows.Test/DigestTests.cs b/test/Microsoft.AspNet.Security.Windows.Test/DigestTests.cs new file mode 100644 index 0000000000..e097a7e65b --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/DigestTests.cs @@ -0,0 +1,244 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Security.Authentication.ExtendedProtection; +using System.Threading.Tasks; +using Microsoft.AspNet.Server.WebListener; +using Xunit; + +namespace Microsoft.AspNet.Security.Windows.Tests +{ + using AppFunc = Func, Task>; + + public class DigestTests + { + private const string Address = "http://localhost:8080/"; + private const string SecureAddress = "https://localhost:9090/"; + private const int DefaultStatusCode = 201; + + [Fact] + public async Task Digest_PartialMatch_PassedThrough() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp); + IDictionary emptyEnv = CreateEmptyRequest("Authorization", "Digestion blablabla"); + await windowsAuth.Invoke(emptyEnv); + + Assert.Equal(DefaultStatusCode, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + } + + [Fact] + public async Task Digest_BadData_400() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp); + IDictionary emptyEnv = CreateEmptyRequest("Authorization", "Digest blablabla"); + await windowsAuth.Invoke(emptyEnv); + + Assert.Equal(400, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + } + + [Fact] + public async Task Digest_AppSets401_401WithChallenge() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + IDictionary emptyEnv = CreateEmptyRequest(); + await windowsAuth.Invoke(emptyEnv); + FireOnSendingHeadersActions(emptyEnv); + + Assert.Equal(401, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(1, responseHeaders.Count); + Assert.NotNull(responseHeaders.Get("www-authenticate")); + Assert.True(responseHeaders.Get("www-authenticate").StartsWith("Digest ")); + } + + [Fact] + public async Task Digest_CbtOptionalButNotPresent_401WithChallenge() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + windowsAuth.ExtendedProtectionPolicy = new ExtendedProtectionPolicy(PolicyEnforcement.WhenSupported); + IDictionary emptyEnv = CreateEmptyRequest(); + emptyEnv["owin.RequestScheme"] = "https"; + await windowsAuth.Invoke(emptyEnv); + FireOnSendingHeadersActions(emptyEnv); + + Assert.Equal(401, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + Assert.Null(responseHeaders.Get("www-authenticate")); + } + + [Fact] + public async Task Digest_CbtRequiredButNotPresent_400() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + windowsAuth.ExtendedProtectionPolicy = new ExtendedProtectionPolicy(PolicyEnforcement.Always); + IDictionary emptyEnv = CreateEmptyRequest(); + emptyEnv["owin.RequestScheme"] = "https"; + await windowsAuth.Invoke(emptyEnv); + FireOnSendingHeadersActions(emptyEnv); + + Assert.Equal(401, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + Assert.Null(responseHeaders.Get("www-authenticate")); + } + + [Fact(Skip = "Broken")] + public async Task Digest_ClientAuthenticates_Success() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + + using (CreateServer(windowsAuth.Invoke)) + { + HttpResponseMessage response = await SendAuthRequestAsync(Address); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + + [Fact(Skip = "Broken")] + public async Task Digest_ClientAuthenticatesMultipleTimes_Success() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + + using (CreateServer(windowsAuth.Invoke)) + { + for (int i = 0; i < 10; i++) + { + HttpResponseMessage response = await SendAuthRequestAsync(Address); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + } + + [Fact] + public async Task Digest_AnonmousClient_401() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + + using (CreateServer(windowsAuth.Invoke)) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(401, (int)response.StatusCode); + Assert.True(response.Headers.WwwAuthenticate.ToString().StartsWith("Digest ")); + } + } + + [Fact(Skip = "Broken")] + public async Task Digest_ClientAuthenticatesWithCbt_Success() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Digest; + windowsAuth.ExtendedProtectionPolicy = new ExtendedProtectionPolicy(PolicyEnforcement.Always); + + using (CreateSecureServer(windowsAuth.Invoke)) + { + HttpResponseMessage response = await SendAuthRequestAsync(SecureAddress); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + + private IDictionary CreateEmptyRequest(string header = null, string value = null) + { + IDictionary env = new Dictionary(); + var requestHeaders = new Dictionary(StringComparer.OrdinalIgnoreCase); + env["owin.RequestHeaders"] = requestHeaders; + if (header != null) + { + requestHeaders[header] = new string[] { value }; + } + env["owin.ResponseHeaders"] = new Dictionary(StringComparer.OrdinalIgnoreCase); + + var onSendingHeadersActions = new List, object>>(); + env["server.OnSendingHeaders"] = new Action, object>( + (a, b) => onSendingHeadersActions.Add(new Tuple, object>(a, b))); + + env["test.OnSendingHeadersActions"] = onSendingHeadersActions; + return env; + } + + private void FireOnSendingHeadersActions(IDictionary env) + { + var onSendingHeadersActions = env.Get, object>>>("test.OnSendingHeadersActions"); + foreach (var actionPair in onSendingHeadersActions.Reverse()) + { + actionPair.Item1(actionPair.Item2); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private IDisposable CreateSecureServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "https"; + address["host"] = "localhost"; + address["port"] = "9090"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + + private async Task SendAuthRequestAsync(string uri) + { + WebRequestHandler handler = new WebRequestHandler(); + handler.UseDefaultCredentials = true; + handler.ServerCertificateValidationCallback = (a, b, c, d) => true; + using (HttpClient client = new HttpClient(handler)) + { + return await client.GetAsync(uri); + } + } + + private Task SimpleApp(IDictionary env) + { + env["owin.ResponseStatusCode"] = DefaultStatusCode; + return Task.FromResult(null); + } + } +} diff --git a/test/Microsoft.AspNet.Security.Windows.Test/NegotiateTests.cs b/test/Microsoft.AspNet.Security.Windows.Test/NegotiateTests.cs new file mode 100644 index 0000000000..572650b9b0 --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/NegotiateTests.cs @@ -0,0 +1,261 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Security.Authentication.ExtendedProtection; +using System.Threading.Tasks; +using Microsoft.AspNet.Server.WebListener; +using Xunit; +using Xunit.Extensions; + +namespace Microsoft.AspNet.Security.Windows.Tests +{ + using AppFunc = Func, Task>; + + public class NegotiateTests + { + private const string Address = "http://localhost:8080/"; + private const string SecureAddress = "https://localhost:9090/"; + private const int DefaultStatusCode = 201; + + [Theory] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_PartialMatch_PassedThrough(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp); + IDictionary emptyEnv = CreateEmptyRequest("Authorization", package + "ion blablabla"); + await windowsAuth.Invoke(emptyEnv); + + Assert.Equal(DefaultStatusCode, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + } + + [Theory] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_BadData_400(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp); + IDictionary emptyEnv = CreateEmptyRequest("Authorization", package + " blablabla"); + await windowsAuth.Invoke(emptyEnv); + + Assert.Equal(400, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + } + + [Theory] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_AppSets401_401WithChallenge(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp401); + windowsAuth.AuthenticationSchemes = (AuthTypes)Enum.Parse(typeof(AuthTypes), package, true); + IDictionary emptyEnv = CreateEmptyRequest(); + await windowsAuth.Invoke(emptyEnv); + FireOnSendingHeadersActions(emptyEnv); + + Assert.Equal(401, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(1, responseHeaders.Count); + Assert.NotNull(responseHeaders.Get("www-authenticate")); + Assert.Equal(package, responseHeaders.Get("www-authenticate")); + } + + [Theory(Skip = "Broken")] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_ClientAuthenticates_Success(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = (AuthTypes)Enum.Parse(typeof(AuthTypes), package, true); + + using (CreateServer(windowsAuth.Invoke)) + { + HttpResponseMessage response = await SendAuthRequestAsync(Address); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + + [Theory(Skip = "Broken")] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_ClientAuthenticatesMultipleTimes_Success(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = (AuthTypes)Enum.Parse(typeof(AuthTypes), package, true); + + using (CreateServer(windowsAuth.Invoke)) + { + for (int i = 0; i < 10; i++) + { + HttpResponseMessage response = await SendAuthRequestAsync(Address); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + } + + [Theory] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_AnonmousClient_401(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = (AuthTypes)Enum.Parse(typeof(AuthTypes), package, true); + + using (CreateServer(windowsAuth.Invoke)) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(401, (int)response.StatusCode); + Assert.Equal(package, response.Headers.WwwAuthenticate.ToString()); + } + } + + [Fact(Skip = "Broken")] + public async Task UnsafeSharedNTLM_AuthenticatedClient_Success() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = AuthTypes.Ntlm; + windowsAuth.UnsafeConnectionNtlmAuthentication = true; + + using (CreateServer(windowsAuth.Invoke)) + { + WebRequestHandler handler = new WebRequestHandler(); + CredentialCache cache = new CredentialCache(); + cache.Add(new Uri(Address), "NTLM", CredentialCache.DefaultNetworkCredentials); + handler.Credentials = cache; + handler.UnsafeAuthenticatedConnectionSharing = true; + using (HttpClient client = new HttpClient(handler)) + { + HttpResponseMessage response = await client.GetAsync(Address); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + response.EnsureSuccessStatusCode(); + + // Remove the credentials before try two just to prove they aren't used. + cache.Remove(new Uri(Address), "NTLM"); + response = await client.GetAsync(Address); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + } + + [Theory(Skip = "Broken")] + [InlineData("Negotiate")] + [InlineData("NTLM")] + public async Task Negotiate_ClientAuthenticatesWithCbt_Success(string package) + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(new DenyAnonymous(SimpleApp).Invoke); + windowsAuth.AuthenticationSchemes = (AuthTypes)Enum.Parse(typeof(AuthTypes), package, true); + windowsAuth.ExtendedProtectionPolicy = new ExtendedProtectionPolicy(PolicyEnforcement.Always); + + using (CreateSecureServer(windowsAuth.Invoke)) + { + HttpResponseMessage response = await SendAuthRequestAsync(SecureAddress); + Assert.Equal(DefaultStatusCode, (int)response.StatusCode); + } + } + + private IDictionary CreateEmptyRequest(string header = null, string value = null, string connectionId = "Random") + { + IDictionary env = new Dictionary(); + var requestHeaders = new Dictionary(StringComparer.OrdinalIgnoreCase); + env["owin.RequestHeaders"] = requestHeaders; + if (header != null) + { + requestHeaders[header] = new string[] { value }; + } + env["owin.ResponseHeaders"] = new Dictionary(StringComparer.OrdinalIgnoreCase); + + var onSendingHeadersActions = new List, object>>(); + env["server.OnSendingHeaders"] = new Action, object>( + (a, b) => onSendingHeadersActions.Add(new Tuple, object>(a, b))); + + env["test.OnSendingHeadersActions"] = onSendingHeadersActions; + env["server.ConnectionId"] = connectionId; + return env; + } + + private void FireOnSendingHeadersActions(IDictionary env) + { + var onSendingHeadersActions = env.Get, object>>>("test.OnSendingHeadersActions"); + foreach (var actionPair in onSendingHeadersActions.Reverse()) + { + actionPair.Item1(actionPair.Item2); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private IDisposable CreateSecureServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "https"; + address["host"] = "localhost"; + address["port"] = "9090"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + + private async Task SendAuthRequestAsync(string uri) + { + WebRequestHandler handler = new WebRequestHandler(); + handler.UseDefaultCredentials = true; + handler.ServerCertificateValidationCallback = (a, b, c, d) => true; + using (HttpClient client = new HttpClient(handler)) + { + return await client.GetAsync(uri); + } + } + + private Task SimpleApp(IDictionary env) + { + env["owin.ResponseStatusCode"] = DefaultStatusCode; + return Task.FromResult(null); + } + + private Task SimpleApp401(IDictionary env) + { + env["owin.ResponseStatusCode"] = 401; + return Task.FromResult(null); + } + } +} diff --git a/test/Microsoft.AspNet.Security.Windows.Test/PassThroughTests.cs b/test/Microsoft.AspNet.Security.Windows.Test/PassThroughTests.cs new file mode 100644 index 0000000000..6ffcf2b8a8 --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/PassThroughTests.cs @@ -0,0 +1,64 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Security.Windows.Tests +{ + public class PassThroughTests + { + private const int DefaultStatusCode = 201; + + [Fact] + public async Task PassThrough_EmptyRequest_Success() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp); + IDictionary emptyEnv = CreateEmptyRequest(); + await windowsAuth.Invoke(emptyEnv); + + Assert.Equal(DefaultStatusCode, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + } + + [Fact] + public async Task PassThrough_BasicAuth_Success() + { + WindowsAuthMiddleware windowsAuth = new WindowsAuthMiddleware(SimpleApp); + IDictionary emptyEnv = CreateEmptyRequest("Authorization", "Basic blablabla"); + await windowsAuth.Invoke(emptyEnv); + + Assert.Equal(DefaultStatusCode, emptyEnv.Get("owin.ResponseStatusCode")); + var responseHeaders = emptyEnv.Get>("owin.ResponseHeaders"); + Assert.Equal(0, responseHeaders.Count); + } + + private IDictionary CreateEmptyRequest(string header = null, string value = null) + { + IDictionary env = new Dictionary(); + var requestHeaders = new Dictionary(StringComparer.OrdinalIgnoreCase); + env["owin.RequestHeaders"] = requestHeaders; + if (header != null) + { + requestHeaders[header] = new string[] { value }; + } + env["owin.ResponseHeaders"] = new Dictionary(StringComparer.OrdinalIgnoreCase); + env["server.OnSendingHeaders"] = new Action, object>((a, b) => { }); + return env; + } + + private Task SimpleApp(IDictionary env) + { + env["owin.ResponseStatusCode"] = DefaultStatusCode; + return Task.FromResult(null); + } + } +} diff --git a/test/Microsoft.AspNet.Security.Windows.Test/Properties/AssemblyInfo.cs b/test/Microsoft.AspNet.Security.Windows.Test/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..681f7d66cd --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/Properties/AssemblyInfo.cs @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("Microsoft.AspNet.Security.Windows.Tests")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.AspNet.Security.Windows.Tests")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("334c99b0-a718-4cda-9ca0-d5a45c3a32b0")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] diff --git a/test/Microsoft.AspNet.Security.Windows.Test/packages.config b/test/Microsoft.AspNet.Security.Windows.Test/packages.config new file mode 100644 index 0000000000..67a23e70da --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/packages.config @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/test/Microsoft.AspNet.Security.Windows.Test/project.json b/test/Microsoft.AspNet.Security.Windows.Test/project.json new file mode 100644 index 0000000000..ecbdca9ad7 --- /dev/null +++ b/test/Microsoft.AspNet.Security.Windows.Test/project.json @@ -0,0 +1,17 @@ +{ + "version" : "0.1-alpha-*", + "dependencies": { + "Microsoft.AspNet.Server.WebListener" : "", + "Microsoft.AspNet.Security.Windows" : "" + }, + "configurations": { + "net45": { + "dependencies": { + "XUnit": "1.9.2", + "XUnit.Extensions": "1.9.2", + "System.Net.Http": "", + "System.Net.Http.WebRequest": "" + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/AuthenticationTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/AuthenticationTests.cs new file mode 100644 index 0000000000..d85855d66a --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/AuthenticationTests.cs @@ -0,0 +1,136 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Threading.Tasks; +using Xunit; +using Xunit.Extensions; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class AuthenticationTests + { + private const string Address = "http://localhost:8080/"; + + [Theory] + [InlineData(AuthenticationType.Kerberos)] + [InlineData(AuthenticationType.Negotiate)] + [InlineData(AuthenticationType.Ntlm)] + [InlineData(AuthenticationType.Digest)] + [InlineData(AuthenticationType.Basic)] + [InlineData(AuthenticationType.Kerberos | AuthenticationType.Negotiate | AuthenticationType.Ntlm | AuthenticationType.Digest | AuthenticationType.Basic)] + public async Task AuthTypes_EnabledButNotChalleneged_PassThrough(AuthenticationType authType) + { + using (CreateServer(authType, env => + { + return Task.FromResult(0); + })) + { + var response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + } + } + + [Theory] + [InlineData(AuthenticationType.Kerberos)] + [InlineData(AuthenticationType.Negotiate)] + [InlineData(AuthenticationType.Ntlm)] + // [InlineData(AuthenticationType.Digest)] // TODO: Not implemented + [InlineData(AuthenticationType.Basic)] + public async Task AuthType_Specify401_ChallengesAdded(AuthenticationType authType) + { + using (CreateServer(authType, env => + { + env["owin.ResponseStatusCode"] = 401; + return Task.FromResult(0); + })) + { + var response = await SendRequestAsync(Address); + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + Assert.Equal(authType.ToString(), response.Headers.WwwAuthenticate.ToString(), StringComparer.OrdinalIgnoreCase); + } + } + + [Fact] + public async Task MultipleAuthTypes_Specify401_ChallengesAdded() + { + // TODO: Not implemented - Digest + using (CreateServer(AuthenticationType.Kerberos | AuthenticationType.Negotiate | AuthenticationType.Ntlm | /*AuthenticationType.Digest |*/ AuthenticationType.Basic, env => + { + env["owin.ResponseStatusCode"] = 401; + return Task.FromResult(0); + })) + { + var response = await SendRequestAsync(Address); + Assert.Equal(HttpStatusCode.Unauthorized, response.StatusCode); + Assert.Equal("Kerberos, Negotiate, NTLM, basic", response.Headers.WwwAuthenticate.ToString(), StringComparer.OrdinalIgnoreCase); + } + } + + [Theory] + [InlineData(AuthenticationType.Kerberos)] + [InlineData(AuthenticationType.Negotiate)] + [InlineData(AuthenticationType.Ntlm)] + // [InlineData(AuthenticationType.Digest)] // TODO: Not implemented + // [InlineData(AuthenticationType.Basic)] // Doesn't work with default creds + [InlineData(AuthenticationType.Kerberos | AuthenticationType.Negotiate | AuthenticationType.Ntlm | /*AuthenticationType.Digest |*/ AuthenticationType.Basic)] + public async Task AuthTypes_Login_Success(AuthenticationType authType) + { + int requestCount = 0; + using (CreateServer(authType, env => + { + requestCount++; + object obj; + if (env.TryGetValue("server.User", out obj) && obj != null) + { + return Task.FromResult(0); + } + env["owin.ResponseStatusCode"] = 401; + return Task.FromResult(0); + })) + { + var response = await SendRequestAsync(Address, useDefaultCredentials: true); + response.EnsureSuccessStatusCode(); + } + } + + private IDisposable CreateServer(AuthenticationType authType, AppFunc app) + { + IDictionary properties = new Dictionary(); + OwinServerFactory.Initialize(properties); + OwinWebListener listener = (OwinWebListener)properties[typeof(OwinWebListener).FullName]; + listener.AuthenticationManager.AuthenticationTypes = authType; + + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri, bool useDefaultCredentials = false) + { + HttpClientHandler handler = new HttpClientHandler(); + handler.UseDefaultCredentials = useDefaultCredentials; + using (HttpClient client = new HttpClient(handler)) + { + return await client.GetAsync(uri); + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/DictionaryExtensions.cs b/test/Microsoft.AspNet.Server.WebListener.Test/DictionaryExtensions.cs new file mode 100644 index 0000000000..b999b1fc4f --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/DictionaryExtensions.cs @@ -0,0 +1,31 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +namespace System.Collections.Generic +{ + internal static class DictionaryExtensions + { + internal static string Get(this IDictionary dictionary, string key) + { + string[] values; + if (dictionary.TryGetValue(key, out values)) + { + return string.Join(", ", values); + } + return null; + } + + internal static T Get(this IDictionary dictionary, string key) + { + object values; + if (dictionary.TryGetValue(key, out values)) + { + return (T)values; + } + return default(T); + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/HttpsTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/HttpsTests.cs new file mode 100644 index 0000000000..98661a8380 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/HttpsTests.cs @@ -0,0 +1,191 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.Http; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class HttpsTests + { + private const string Address = "https://localhost:9090/"; + + [Fact] + public async Task Https_200OK_Success() + { + using (CreateServer(env => + { + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal(string.Empty, response); + } + } + + [Fact] + public async Task Https_SendHelloWorld_Success() + { + using (CreateServer(env => + { + byte[] body = Encoding.UTF8.GetBytes("Hello World"); + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { body.Length.ToString() }; + return env.Get("owin.ResponseBody").WriteAsync(body, 0, body.Length); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public async Task Https_EchoHelloWorld_Success() + { + using (CreateServer(env => + { + string input = new StreamReader(env.Get("owin.RequestBody")).ReadToEnd(); + Assert.Equal("Hello World", input); + byte[] body = Encoding.UTF8.GetBytes("Hello World"); + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { body.Length.ToString() }; + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address, "Hello World"); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public async Task Https_ClientCertNotSent_ClientCertNotPresent() + { + X509Certificate clientCert = null; + using (CreateServer(env => + { + var loadAsync = env.Get>("ssl.LoadClientCertAsync"); + loadAsync().Wait(); + clientCert = env.Get("ssl.ClientCertificate"); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal(string.Empty, response); + Assert.Null(clientCert); + } + } + + [Fact] + public async Task Https_ClientCertRequested_ClientCertPresent() + { + X509Certificate clientCert = null; + using (CreateServer(env => + { + var loadAsync = env.Get>("ssl.LoadClientCertAsync"); + loadAsync().Wait(); + clientCert = env.Get("ssl.ClientCertificate"); + return Task.FromResult(0); + })) + { + X509Certificate2 cert = FindClientCert(); + Assert.NotNull(cert); + string response = await SendRequestAsync(Address, cert); + Assert.Equal(string.Empty, response); + Assert.NotNull(clientCert); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "https"; + address["host"] = "localhost"; + address["port"] = "9090"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri, + X509Certificate cert = null) + { + WebRequestHandler handler = new WebRequestHandler(); + handler.ServerCertificateValidationCallback = (a, b, c, d) => true; + if (cert != null) + { + handler.ClientCertificates.Add(cert); + } + using (HttpClient client = new HttpClient(handler)) + { + return await client.GetStringAsync(uri); + } + } + + private async Task SendRequestAsync(string uri, string upload) + { + WebRequestHandler handler = new WebRequestHandler(); + handler.ServerCertificateValidationCallback = (a, b, c, d) => true; + using (HttpClient client = new HttpClient(handler)) + { + HttpResponseMessage response = await client.PostAsync(uri, new StringContent(upload)); + response.EnsureSuccessStatusCode(); + return await response.Content.ReadAsStringAsync(); + } + } + + private X509Certificate2 FindClientCert() + { + var store = new X509Store(); + store.Open(OpenFlags.ReadOnly); + + foreach (var cert in store.Certificates) + { + bool isClientAuth = false; + bool isSmartCard = false; + foreach (var extension in cert.Extensions) + { + var eku = extension as X509EnhancedKeyUsageExtension; + if (eku != null) + { + foreach (var oid in eku.EnhancedKeyUsages) + { + if (oid.FriendlyName == "Client Authentication") + { + isClientAuth = true; + } + else if (oid.FriendlyName == "Smart Card Logon") + { + isSmartCard = true; + break; + } + } + } + } + + if (isClientAuth && !isSmartCard) + { + return cert; + } + } + return null; + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/OpaqueUpgradeTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/OpaqueUpgradeTests.cs new file mode 100644 index 0000000000..07c8610feb --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/OpaqueUpgradeTests.cs @@ -0,0 +1,324 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.Http; +using System.Net.Sockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; +using Xunit.Extensions; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + using OpaqueUpgrade = Action, Func, Task>>; + + public class OpaqueUpgradeTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task OpaqueUpgrade_SupportKeys_Present() + { + using (CreateServer(env => + { + try + { + IDictionary capabilities = env.Get>("server.Capabilities"); + Assert.NotNull(capabilities); + + Assert.Equal("1.0", capabilities.Get("opaque.Version")); + + OpaqueUpgrade opaqueUpgrade = env.Get("opaque.Upgrade"); + Assert.NotNull(opaqueUpgrade); + } + catch (Exception ex) + { + byte[] body = Encoding.UTF8.GetBytes(ex.ToString()); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + } + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.False(response.Headers.TransferEncodingChunked.HasValue, "Chunked"); + Assert.Equal(0, response.Content.Headers.ContentLength); + Assert.Equal(string.Empty, response.Content.ReadAsStringAsync().Result); + } + } + + [Fact] + public async Task OpaqueUpgrade_NullCallback_Throws() + { + using (CreateServer(env => + { + try + { + OpaqueUpgrade opaqueUpgrade = env.Get("opaque.Upgrade"); + opaqueUpgrade(new Dictionary(), null); + } + catch (Exception ex) + { + byte[] body = Encoding.UTF8.GetBytes(ex.ToString()); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + } + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.True(response.Headers.TransferEncodingChunked.Value, "Chunked"); + Assert.Contains("callback", response.Content.ReadAsStringAsync().Result); + } + } + + [Fact] + public async Task OpaqueUpgrade_AfterHeadersSent_Throws() + { + bool? upgradeThrew = null; + using (CreateServer(env => + { + byte[] body = Encoding.UTF8.GetBytes("Hello World"); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + OpaqueUpgrade opaqueUpgrade = env.Get("opaque.Upgrade"); + try + { + opaqueUpgrade(null, _ => Task.FromResult(0)); + upgradeThrew = false; + } + catch (InvalidOperationException) + { + upgradeThrew = true; + } + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.True(response.Headers.TransferEncodingChunked.Value, "Chunked"); + Assert.True(upgradeThrew.Value); + } + } + + [Fact] + public async Task OpaqueUpgrade_GetUpgrade_Success() + { + ManualResetEvent waitHandle = new ManualResetEvent(false); + bool? callbackInvoked = null; + using (CreateServer(env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Upgrade"] = new string[] { "websocket" }; // Win8.1 blocks anything but WebSockets + OpaqueUpgrade opaqueUpgrade = env.Get("opaque.Upgrade"); + opaqueUpgrade(null, opqEnv => + { + callbackInvoked = true; + waitHandle.Set(); + return Task.FromResult(0); + }); + return Task.FromResult(0); + })) + { + using (Stream stream = await SendOpaqueRequestAsync("GET", Address)) + { + Assert.True(waitHandle.WaitOne(TimeSpan.FromSeconds(1)), "Timed out"); + Assert.True(callbackInvoked.HasValue, "CallbackInvoked not set"); + Assert.True(callbackInvoked.Value, "Callback not invoked"); + } + } + } + + [Theory] + // See HTTP_VERB for known verbs + [InlineData("UNKNOWN", null)] + [InlineData("INVALID", null)] + [InlineData("OPTIONS", null)] + [InlineData("GET", null)] + [InlineData("HEAD", null)] + [InlineData("DELETE", null)] + [InlineData("TRACE", null)] + [InlineData("CONNECT", null)] + [InlineData("TRACK", null)] + [InlineData("MOVE", null)] + [InlineData("COPY", null)] + [InlineData("PROPFIND", null)] + [InlineData("PROPPATCH", null)] + [InlineData("MKCOL", null)] + [InlineData("LOCK", null)] + [InlineData("UNLOCK", null)] + [InlineData("SEARCH", null)] + [InlineData("CUSTOMVERB", null)] + [InlineData("PATCH", null)] + [InlineData("POST", "Content-Length: 0")] + [InlineData("PUT", "Content-Length: 0")] + public async Task OpaqueUpgrade_VariousMethodsUpgradeSendAndReceive_Success(string method, string extraHeader) + { + using (CreateServer(env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Upgrade"] = new string[] { "WebSocket" }; // Win8.1 blocks anything but WebSockets + OpaqueUpgrade opaqueUpgrade = env.Get("opaque.Upgrade"); + opaqueUpgrade(null, async opqEnv => + { + Stream opaqueStream = opqEnv.Get("opaque.Stream"); + + byte[] buffer = new byte[100]; + int read = await opaqueStream.ReadAsync(buffer, 0, buffer.Length); + + await opaqueStream.WriteAsync(buffer, 0, read); + }); + return Task.FromResult(0); + })) + { + using (Stream stream = await SendOpaqueRequestAsync(method, Address, extraHeader)) + { + byte[] data = new byte[100]; + stream.WriteAsync(data, 0, 49).Wait(); + int read = stream.ReadAsync(data, 0, data.Length).Result; + Assert.Equal(49, read); + } + } + } + + [Theory] + // Http.Sys returns a 411 Length Required if PUT or POST does not specify content-length or chunked. + [InlineData("POST", "Content-Length: 10")] + [InlineData("POST", "Transfer-Encoding: chunked")] + [InlineData("PUT", "Content-Length: 10")] + [InlineData("PUT", "Transfer-Encoding: chunked")] + [InlineData("CUSTOMVERB", "Content-Length: 10")] + [InlineData("CUSTOMVERB", "Transfer-Encoding: chunked")] + public void OpaqueUpgrade_InvalidMethodUpgrade_Disconnected(string method, string extraHeader) + { + OpaqueUpgrade opaqueUpgrade = null; + using (CreateServer(env => + { + opaqueUpgrade = env.Get("opaque.Upgrade"); + if (opaqueUpgrade == null) + { + throw new NotImplementedException(); + } + opaqueUpgrade(null, opqEnv => Task.FromResult(0)); + return Task.FromResult(0); + })) + { + Assert.Throws(() => + { + try + { + return SendOpaqueRequestAsync(method, Address, extraHeader).Result; + } + catch (AggregateException ag) + { + throw ag.GetBaseException(); + } + }); + Assert.Null(opaqueUpgrade); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + OwinServerFactory.Initialize(properties); + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + + // Returns a bidirectional opaque stream or throws if the upgrade fails + private async Task SendOpaqueRequestAsync(string method, string address, string extraHeader = null) + { + // Connect with a socket + Uri uri = new Uri(address); + TcpClient client = new TcpClient(); + try + { + await client.ConnectAsync(uri.Host, uri.Port); + NetworkStream stream = client.GetStream(); + + // Send an HTTP GET request + byte[] requestBytes = BuildGetRequest(method, uri, extraHeader); + await stream.WriteAsync(requestBytes, 0, requestBytes.Length); + + // Read the response headers, fail if it's not a 101 + await ParseResponseAsync(stream); + + // Return the opaque network stream + return stream; + } + catch (Exception) + { + client.Close(); + throw; + } + } + + private byte[] BuildGetRequest(string method, Uri uri, string extraHeader) + { + StringBuilder builder = new StringBuilder(); + builder.Append(method); + builder.Append(" "); + builder.Append(uri.PathAndQuery); + builder.Append(" HTTP/1.1"); + builder.AppendLine(); + + builder.Append("Host: "); + builder.Append(uri.Host); + builder.Append(':'); + builder.Append(uri.Port); + builder.AppendLine(); + + if (!string.IsNullOrEmpty(extraHeader)) + { + builder.AppendLine(extraHeader); + } + + builder.AppendLine(); + return Encoding.ASCII.GetBytes(builder.ToString()); + } + + // Read the response headers, fail if it's not a 101 + private async Task ParseResponseAsync(NetworkStream stream) + { + StreamReader reader = new StreamReader(stream); + string statusLine = await reader.ReadLineAsync(); + string[] parts = statusLine.Split(' '); + if (int.Parse(parts[1]) != 101) + { + throw new InvalidOperationException("The response status code was incorrect: " + statusLine); + } + + // Scan to the end of the headers + while (!string.IsNullOrEmpty(reader.ReadLine())) + { + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/Properties/AssemblyInfo.cs b/test/Microsoft.AspNet.Server.WebListener.Test/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..486c5d20cd --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/Properties/AssemblyInfo.cs @@ -0,0 +1,42 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("Microsoft.AspNet.Server.WebListener.Tests")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("Microsoft.AspNet.Server.WebListener.Tests")] +[assembly: AssemblyCopyright("Copyright © 2012")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("a265fcd6-3542-4f59-a1dd-ad423d40ddde")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("0.5")] +[assembly: AssemblyVersion("0.5")] +[assembly: AssemblyFileVersion("0.5.40117.0")] diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/RequestBodyTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/RequestBodyTests.cs new file mode 100644 index 0000000000..3687766ca1 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/RequestBodyTests.cs @@ -0,0 +1,176 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class RequestBodyTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task RequestBody_ReadSync_Success() + { + using (CreateServer(env => + { + byte[] input = new byte[100]; + int read = env.Get("owin.RequestBody").Read(input, 0, input.Length); + + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { read.ToString() }; + env.Get("owin.ResponseBody").Write(input, 0, read); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address, "Hello World"); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public async Task RequestBody_ReadAync_Success() + { + using (CreateServer(async env => + { + byte[] input = new byte[100]; + int read = await env.Get("owin.RequestBody").ReadAsync(input, 0, input.Length); + + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { read.ToString() }; + await env.Get("owin.ResponseBody").WriteAsync(input, 0, read); + })) + { + string response = await SendRequestAsync(Address, "Hello World"); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public async Task RequestBody_ReadBeginEnd_Success() + { + using (CreateServer(env => + { + Stream requestStream = env.Get("owin.RequestBody"); + byte[] input = new byte[100]; + int read = requestStream.EndRead(requestStream.BeginRead(input, 0, input.Length, null, null)); + + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { read.ToString() }; + Stream responseStream = env.Get("owin.ResponseBody"); + responseStream.EndWrite(responseStream.BeginWrite(input, 0, read, null, null)); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address, "Hello World"); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public async Task RequestBody_ReadSyncPartialBody_Success() + { + StaggardContent content = new StaggardContent(); + using (CreateServer(env => + { + byte[] input = new byte[10]; + int read = env.Get("owin.RequestBody").Read(input, 0, input.Length); + Assert.Equal(5, read); + content.Block.Release(); + read = env.Get("owin.RequestBody").Read(input, 0, input.Length); + Assert.Equal(5, read); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address, content); + Assert.Equal(string.Empty, response); + } + } + + [Fact] + public async Task RequestBody_ReadAsyncPartialBody_Success() + { + StaggardContent content = new StaggardContent(); + using (CreateServer(async env => + { + byte[] input = new byte[10]; + int read = await env.Get("owin.RequestBody").ReadAsync(input, 0, input.Length); + Assert.Equal(5, read); + content.Block.Release(); + read = await env.Get("owin.RequestBody").ReadAsync(input, 0, input.Length); + Assert.Equal(5, read); + })) + { + string response = await SendRequestAsync(Address, content); + Assert.Equal(string.Empty, response); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private Task SendRequestAsync(string uri, string upload) + { + return SendRequestAsync(uri, new StringContent(upload)); + } + + private async Task SendRequestAsync(string uri, HttpContent content) + { + using (HttpClient client = new HttpClient()) + { + HttpResponseMessage response = await client.PostAsync(uri, content); + response.EnsureSuccessStatusCode(); + return await response.Content.ReadAsStringAsync(); + } + } + + private class StaggardContent : HttpContent + { + public StaggardContent() + { + Block = new SemaphoreSlim(0, 1); + } + + public SemaphoreSlim Block { get; private set; } + + protected async override Task SerializeToStreamAsync(Stream stream, TransportContext context) + { + await stream.WriteAsync(new byte[5], 0, 5); + await Block.WaitAsync(); + await stream.WriteAsync(new byte[5], 0, 5); + } + + protected override bool TryComputeLength(out long length) + { + length = 10; + return true; + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/RequestHeaderTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/RequestHeaderTests.cs new file mode 100644 index 0000000000..a322ea06d0 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/RequestHeaderTests.cs @@ -0,0 +1,120 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Net.Sockets; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class RequestHeaderTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task RequestHeaders_ClientSendsDefaultHeaders_Success() + { + using (CreateServer(env => + { + var requestHeaders = env.Get>("owin.RequestHeaders"); + // NOTE: The System.Net client only sends the Connection: keep-alive header on the first connection per service-point. + // Assert.Equal(2, requestHeaders.Count); + // Assert.Equal("Keep-Alive", requestHeaders.Get("Connection")); + Assert.Equal("localhost:8080", requestHeaders.Get("Host")); + Assert.Equal(null, requestHeaders.Get("Accept")); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal(string.Empty, response); + } + } + + [Fact] + public async Task RequestHeaders_ClientSendsCustomHeaders_Success() + { + using (CreateServer(env => + { + var requestHeaders = env.Get>("owin.RequestHeaders"); + Assert.Equal(4, requestHeaders.Count); + Assert.Equal("localhost:8080", requestHeaders.Get("Host")); + Assert.Equal("close", requestHeaders.Get("Connection")); + Assert.Equal(1, requestHeaders["Custom-Header"].Length); + // Apparently Http.Sys squashes request headers together. + Assert.Equal("custom1, and custom2, custom3", requestHeaders.Get("Custom-Header")); + Assert.Equal(1, requestHeaders["Spacer-Header"].Length); + Assert.Equal("spacervalue, spacervalue", requestHeaders.Get("Spacer-Header")); + return Task.FromResult(0); + })) + { + string[] customValues = new string[] { "custom1, and custom2", "custom3" }; + + await SendRequestAsync("localhost", 8080, "Custom-Header", customValues); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetStringAsync(uri); + } + } + + private async Task SendRequestAsync(string host, int port, string customHeader, string[] customValues) + { + StringBuilder builder = new StringBuilder(); + builder.AppendLine("GET / HTTP/1.1"); + builder.AppendLine("Connection: close"); + builder.Append("HOST: "); + builder.Append(host); + builder.Append(':'); + builder.AppendLine(port.ToString()); + foreach (string value in customValues) + { + builder.Append(customHeader); + builder.Append(": "); + builder.AppendLine(value); + builder.AppendLine("Spacer-Header: spacervalue"); + } + builder.AppendLine(); + + byte[] request = Encoding.ASCII.GetBytes(builder.ToString()); + + Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + socket.Connect(host, port); + + socket.Send(request); + + byte[] response = new byte[1024 * 5]; + await Task.Run(() => socket.Receive(response)); + socket.Close(); + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/RequestTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/RequestTests.cs new file mode 100644 index 0000000000..a99bdafa86 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/RequestTests.cs @@ -0,0 +1,185 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; +using Xunit.Extensions; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class RequestTests + { + private const string Address = "http://localhost:8080"; + + [Fact] + public async Task Request_SimpleGet_Success() + { + using (CreateServer(env => + { + try + { + // General keys + Assert.Equal("1.0", env.Get("owin.Version")); + Assert.True(env.Get("owin.CallCancelled").CanBeCanceled); + + // Request Keys + Assert.Equal("GET", env.Get("owin.RequestMethod")); + Assert.Equal(Stream.Null, env.Get("owin.RequestBody")); + Assert.NotNull(env.Get>("owin.RequestHeaders")); + Assert.Equal("http", env.Get("owin.RequestScheme")); + Assert.Equal("/basepath", env.Get("owin.RequestPathBase")); + Assert.Equal("/SomePath", env.Get("owin.RequestPath")); + Assert.Equal("SomeQuery", env.Get("owin.RequestQueryString")); + Assert.Equal("HTTP/1.1", env.Get("owin.RequestProtocol")); + + // Server Keys + Assert.NotNull(env.Get>("server.Capabilities")); + Assert.Equal("::1", env.Get("server.RemoteIpAddress")); + Assert.NotNull(env.Get("server.RemotePort")); + Assert.Equal("::1", env.Get("server.LocalIpAddress")); + Assert.Equal("8080", env.Get("server.LocalPort")); + Assert.True(env.Get("server.IsLocal")); + + // Note: Response keys are validated in the ResponseTests + } + catch (Exception ex) + { + byte[] body = Encoding.ASCII.GetBytes(ex.ToString()); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + } + return Task.FromResult(0); + }, "http", "localhost", "8080", "/basepath")) + { + string response = await SendRequestAsync(Address + "/basepath/SomePath?SomeQuery"); + Assert.Equal(string.Empty, response); + } + } + + [Theory] + [InlineData("http", "localhost", "8080", "/", "http://localhost:8080/", "", "/")] + [InlineData("http", "localhost", "8080", "/basepath/", "http://localhost:8080/basepath", "/basepath", "")] + [InlineData("http", "localhost", "8080", "/basepath/", "http://localhost:8080/basepath/", "/basepath", "/")] + [InlineData("http", "localhost", "8080", "/basepath/", "http://localhost:8080/basepath/subpath", "/basepath", "/subpath")] + [InlineData("http", "localhost", "8080", "/base path/", "http://localhost:8080/base%20path/sub path", "/base path", "/sub path")] + [InlineData("http", "localhost", "8080", "/base葉path/", "http://localhost:8080/base%E8%91%89path/sub%E8%91%89path", "/base葉path", "/sub葉path")] + public async Task Request_PathSplitting(string scheme, string host, string port, string pathBase, string requestUri, + string expectedPathBase, string expectedPath) + { + using (CreateServer(env => + { + try + { + Uri uri = new Uri(requestUri); + string expectedQuery = uri.Query.Length > 0 ? uri.Query.Substring(1) : string.Empty; + // Request Keys + Assert.Equal(scheme, env.Get("owin.RequestScheme")); + Assert.Equal(expectedPath, env.Get("owin.RequestPath")); + Assert.Equal(expectedPathBase, env.Get("owin.RequestPathBase")); + Assert.Equal(expectedQuery, env.Get("owin.RequestQueryString")); + Assert.Equal(port, env.Get("server.LocalPort")); + } + catch (Exception ex) + { + byte[] body = Encoding.ASCII.GetBytes(ex.ToString()); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + } + return Task.FromResult(0); + }, scheme, host, port, pathBase)) + { + string response = await SendRequestAsync(requestUri); + Assert.Equal(string.Empty, response); + } + } + + [Theory] + // The test server defines these prefixes: "/", "/11", "/2/3", "/2", "/11/2" + [InlineData("/", "", "/")] + [InlineData("/random", "", "/random")] + [InlineData("/11", "/11", "")] + [InlineData("/11/", "/11", "/")] + [InlineData("/11/random", "/11", "/random")] + [InlineData("/2", "/2", "")] + [InlineData("/2/", "/2", "/")] + [InlineData("/2/random", "/2", "/random")] + [InlineData("/2/3", "/2/3", "")] + [InlineData("/2/3/", "/2/3", "/")] + [InlineData("/2/3/random", "/2/3", "/random")] + public async Task Request_MultiplePrefixes(string requestUri, string expectedPathBase, string expectedPath) + { + using (CreateServer(env => + { + try + { + Assert.Equal(expectedPath, env.Get("owin.RequestPath")); + Assert.Equal(expectedPathBase, env.Get("owin.RequestPathBase")); + } + catch (Exception ex) + { + byte[] body = Encoding.ASCII.GetBytes(ex.ToString()); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + } + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address + requestUri); + Assert.Equal(string.Empty, response); + } + } + + private IDisposable CreateServer(AppFunc app, string scheme, string host, string port, string path) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = scheme; + address["host"] = host; + address["port"] = port; + address["path"] = path; + + return OwinServerFactory.Create(app, properties); + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + foreach (string path in new[] { "/", "/11", "/2/3", "/2", "/11/2" }) + { + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = path; + } + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetStringAsync(uri); + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/ResponseBodyTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseBodyTests.cs new file mode 100644 index 0000000000..bc95cb305f --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseBodyTests.cs @@ -0,0 +1,213 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class ResponseBodyTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task ResponseBody_WriteNoHeaders_DefaultsToChunked() + { + using (CreateServer(env => + { + env.Get("owin.ResponseBody").Write(new byte[10], 0, 10); + return env.Get("owin.ResponseBody").WriteAsync(new byte[10], 0, 10); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new Version(1, 1), response.Version); + IEnumerable ignored; + Assert.False(response.Content.Headers.TryGetValues("content-length", out ignored), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value, "Chunked"); + Assert.Equal(new byte[20], await response.Content.ReadAsByteArrayAsync()); + } + } + + [Fact] + public async Task ResponseBody_WriteChunked_Chunked() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["transfeR-Encoding"] = new string[] { " CHunked " }; + Stream stream = env.Get("owin.ResponseBody"); + stream.EndWrite(stream.BeginWrite(new byte[10], 0, 10, null, null)); + stream.Write(new byte[10], 0, 10); + return stream.WriteAsync(new byte[10], 0, 10); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new Version(1, 1), response.Version); + IEnumerable ignored; + Assert.False(response.Content.Headers.TryGetValues("content-length", out ignored), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value, "Chunked"); + Assert.Equal(new byte[30], await response.Content.ReadAsByteArrayAsync()); + } + } + + [Fact] + public async Task ResponseBody_WriteContentLength_PassedThrough() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { " 30 " }; + Stream stream = env.Get("owin.ResponseBody"); + stream.EndWrite(stream.BeginWrite(new byte[10], 0, 10, null, null)); + stream.Write(new byte[10], 0, 10); + return stream.WriteAsync(new byte[10], 0, 10); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new Version(1, 1), response.Version); + IEnumerable contentLength; + Assert.True(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.Equal("30", contentLength.First()); + Assert.Null(response.Headers.TransferEncodingChunked); + Assert.Equal(new byte[30], await response.Content.ReadAsByteArrayAsync()); + } + } + + [Fact] + public async Task ResponseBody_Http10WriteNoHeaders_DefaultsConnectionClose() + { + using (CreateServer(env => + { + env["owin.ResponseProtocol"] = "HTTP/1.0"; + env.Get("owin.ResponseBody").Write(new byte[10], 0, 10); + return env.Get("owin.ResponseBody").WriteAsync(new byte[10], 0, 10); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new Version(1, 1), response.Version); // Http.Sys won't transmit 1.0 + IEnumerable ignored; + Assert.False(response.Content.Headers.TryGetValues("content-length", out ignored), "Content-Length"); + Assert.Null(response.Headers.TransferEncodingChunked); + Assert.Equal(new byte[20], await response.Content.ReadAsByteArrayAsync()); + } + } + + [Fact] + public void ResponseBody_WriteContentLengthNoneWritten_Throws() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { " 20 " }; + return Task.FromResult(0); + })) + { + Assert.Throws(() => SendRequestAsync(Address).Result); + } + } + + [Fact] + public void ResponseBody_WriteContentLengthNotEnoughWritten_Throws() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { " 20 " }; + env.Get("owin.ResponseBody").Write(new byte[5], 0, 5); + return Task.FromResult(0); + })) + { + Assert.Throws(() => SendRequestAsync(Address).Result); + } + } + + [Fact] + public void ResponseBody_WriteContentLengthTooMuchWritten_Throws() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { " 10 " }; + env.Get("owin.ResponseBody").Write(new byte[5], 0, 5); + env.Get("owin.ResponseBody").Write(new byte[6], 0, 6); + return Task.FromResult(0); + })) + { + Assert.Throws(() => SendRequestAsync(Address).Result); + } + } + + [Fact] + public async Task ResponseBody_WriteContentLengthExtraWritten_Throws() + { + ManualResetEvent waitHandle = new ManualResetEvent(false); + bool? appThrew = null; + using (CreateServer(env => + { + try + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { " 10 " }; + env.Get("owin.ResponseBody").Write(new byte[10], 0, 10); + env.Get("owin.ResponseBody").Write(new byte[9], 0, 9); + appThrew = false; + } + catch (Exception) + { + appThrew = true; + } + waitHandle.Set(); + return Task.FromResult(0); + })) + { + // The full response is received. + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal(new Version(1, 1), response.Version); + IEnumerable contentLength; + Assert.True(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.Equal("10", contentLength.First()); + Assert.Null(response.Headers.TransferEncodingChunked); + Assert.Equal(new byte[10], await response.Content.ReadAsByteArrayAsync()); + + Assert.True(waitHandle.WaitOne(100)); + Assert.True(appThrew.HasValue, "appThrew.HasValue"); + Assert.True(appThrew.Value, "appThrew.Value"); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/ResponseHeaderTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseHeaderTests.cs new file mode 100644 index 0000000000..56b3bb48f4 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseHeaderTests.cs @@ -0,0 +1,243 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class ResponseHeaderTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task ResponseHeaders_ServerSendsDefaultHeaders_Success() + { + using (CreateServer(env => + { + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.Equal(2, response.Headers.Count()); + Assert.False(response.Headers.TransferEncodingChunked.HasValue); + Assert.True(response.Headers.Date.HasValue); + Assert.Equal("Microsoft-HTTPAPI/2.0", response.Headers.Server.ToString()); + Assert.Equal(1, response.Content.Headers.Count()); + Assert.Equal(0, response.Content.Headers.ContentLength); + } + } + + [Fact] + public async Task ResponseHeaders_ServerSendsCustomHeaders_Success() + { + using (CreateServer(env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Custom-Header1"] = new string[] { "custom1, and custom2", "custom3" }; + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.Equal(3, response.Headers.Count()); + Assert.False(response.Headers.TransferEncodingChunked.HasValue); + Assert.True(response.Headers.Date.HasValue); + Assert.Equal("Microsoft-HTTPAPI/2.0", response.Headers.Server.ToString()); + Assert.Equal(new string[] { "custom1, and custom2", "custom3" }, response.Headers.GetValues("Custom-Header1")); + Assert.Equal(1, response.Content.Headers.Count()); + Assert.Equal(0, response.Content.Headers.ContentLength); + } + } + + [Fact] + public async Task ResponseHeaders_ServerSendsConnectionClose_Closed() + { + using (CreateServer(env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Connection"] = new string[] { "Close" }; + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.True(response.Headers.ConnectionClose.Value); + Assert.Equal(new string[] { "close" }, response.Headers.GetValues("Connection")); + } + } + + [Fact] + public async Task ResponseHeaders_SendsHttp10_Gets11Close() + { + using (CreateServer(env => + { + env["owin.ResponseProtocol"] = "HTTP/1.0"; + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.Equal(new Version(1, 1), response.Version); + Assert.True(response.Headers.ConnectionClose.Value); + Assert.Equal(new string[] { "close" }, response.Headers.GetValues("Connection")); + } + } + + [Fact] + public async Task ResponseHeaders_SendsHttp10WithBody_Gets11Close() + { + using (CreateServer(env => + { + env["owin.ResponseProtocol"] = "HTTP/1.0"; + return env.Get("owin.ResponseBody").WriteAsync(new byte[10], 0, 10); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.Equal(new Version(1, 1), response.Version); + Assert.False(response.Headers.TransferEncodingChunked.HasValue); + Assert.False(response.Content.Headers.Contains("Content-Length")); + Assert.True(response.Headers.ConnectionClose.Value); + Assert.Equal(new string[] { "close" }, response.Headers.GetValues("Connection")); + } + } + + [Fact] + public async Task ResponseHeaders_HTTP10Request_Gets11Close() + { + using (CreateServer(env => + { + return Task.FromResult(0); + })) + { + using (HttpClient client = new HttpClient()) + { + HttpRequestMessage request = new HttpRequestMessage(HttpMethod.Get, Address); + request.Version = new Version(1, 0); + HttpResponseMessage response = await client.SendAsync(request); + response.EnsureSuccessStatusCode(); + Assert.Equal(new Version(1, 1), response.Version); + Assert.True(response.Headers.ConnectionClose.Value); + Assert.Equal(new string[] { "close" }, response.Headers.GetValues("Connection")); + } + } + } + + [Fact] + public async Task ResponseHeaders_HTTP10Request_RemovesChunkedHeader() + { + using (CreateServer(env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Transfer-Encoding"] = new string[] { "chunked" }; + return env.Get("owin.ResponseBody").WriteAsync(new byte[10], 0, 10); + })) + { + using (HttpClient client = new HttpClient()) + { + HttpRequestMessage request = new HttpRequestMessage(HttpMethod.Get, Address); + request.Version = new Version(1, 0); + HttpResponseMessage response = await client.SendAsync(request); + response.EnsureSuccessStatusCode(); + Assert.Equal(new Version(1, 1), response.Version); + Assert.False(response.Headers.TransferEncodingChunked.HasValue); + Assert.False(response.Content.Headers.Contains("Content-Length")); + Assert.True(response.Headers.ConnectionClose.Value); + Assert.Equal(new string[] { "close" }, response.Headers.GetValues("Connection")); + } + } + } + + [Fact] + public async Task Headers_FlushSendsHeaders_Success() + { + using (CreateServer( + env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders.Add("Custom1", new string[] { "value1a", "value1b" }); + responseHeaders.Add("Custom2", new string[] { "value2a, value2b" }); + var body = env.Get("owin.ResponseBody"); + body.Flush(); + env["owin.ResponseStatusCode"] = 404; // Ignored + responseHeaders.Add("Custom3", new string[] { "value3a, value3b", "value3c" }); // Ignored + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.Equal(5, response.Headers.Count()); // Date, Server, Chunked + + Assert.Equal(2, response.Headers.GetValues("Custom1").Count()); + Assert.Equal("value1a", response.Headers.GetValues("Custom1").First()); + Assert.Equal("value1b", response.Headers.GetValues("Custom1").Skip(1).First()); + Assert.Equal(1, response.Headers.GetValues("Custom2").Count()); + Assert.Equal("value2a, value2b", response.Headers.GetValues("Custom2").First()); + } + } + + [Fact] + public async Task Headers_FlushAsyncSendsHeaders_Success() + { + using (CreateServer( + async env => + { + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders.Add("Custom1", new string[] { "value1a", "value1b" }); + responseHeaders.Add("Custom2", new string[] { "value2a, value2b" }); + var body = env.Get("owin.ResponseBody"); + await body.FlushAsync(); + env["owin.ResponseStatusCode"] = 404; // Ignored + responseHeaders.Add("Custom3", new string[] { "value3a, value3b", "value3c" }); // Ignored + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + response.EnsureSuccessStatusCode(); + Assert.Equal(5, response.Headers.Count()); // Date, Server, Chunked + + Assert.Equal(2, response.Headers.GetValues("Custom1").Count()); + Assert.Equal("value1a", response.Headers.GetValues("Custom1").First()); + Assert.Equal("value1b", response.Headers.GetValues("Custom1").Skip(1).First()); + Assert.Equal(1, response.Headers.GetValues("Custom2").Count()); + Assert.Equal("value2a, value2b", response.Headers.GetValues("Custom2").First()); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/ResponseSendFileTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseSendFileTests.cs new file mode 100644 index 0000000000..d176dc2a1a --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseSendFileTests.cs @@ -0,0 +1,327 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + using SendFileFunc = Func; + + public class ResponseSendFileTests + { + private const string Address = "http://localhost:8080/"; + private static readonly string AbsoluteFilePath = Environment.CurrentDirectory + "\\Microsoft.AspNet.Server.WebListener.dll"; + private static readonly string RelativeFilePath = "Microsoft.AspNet.Server.WebListener.dll"; + private static readonly long FileLength = new FileInfo(AbsoluteFilePath).Length; + + [Fact] + public async Task ResponseSendFile_SupportKeys_Present() + { + using (CreateServer(env => + { + try + { + IDictionary capabilities = env.Get>("server.Capabilities"); + Assert.NotNull(capabilities); + + Assert.Equal("1.0", capabilities.Get("sendfile.Version")); + + IDictionary support = capabilities.Get>("sendfile.Support"); + Assert.NotNull(support); + + Assert.Equal("Overlapped", support.Get("sendfile.Concurrency")); + + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + Assert.NotNull(sendFileAsync); + } + catch (Exception ex) + { + byte[] body = Encoding.UTF8.GetBytes(ex.ToString()); + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + } + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable ignored; + Assert.True(response.Content.Headers.TryGetValues("content-length", out ignored), "Content-Length"); + Assert.False(response.Headers.TransferEncodingChunked.HasValue, "Chunked"); + Assert.Equal(0, response.Content.Headers.ContentLength); + Assert.Equal(string.Empty, await response.Content.ReadAsStringAsync()); + } + } + + [Fact] + public async Task ResponseSendFile_MissingFile_Throws() + { + ManualResetEvent waitHandle = new ManualResetEvent(false); + bool? appThrew = null; + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + try + { + sendFileAsync(string.Empty, 0, null, CancellationToken.None).Wait(); + appThrew = false; + } + catch (Exception) + { + appThrew = true; + throw; + } + finally + { + waitHandle.Set(); + } + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + Assert.True(waitHandle.WaitOne(100)); + Assert.True(appThrew.HasValue, "appThrew.HasValue"); + Assert.True(appThrew.Value, "appThrew.Value"); + } + } + + [Fact] + public async Task ResponseSendFile_NoHeaders_DefaultsToChunked() + { + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable ignored; + Assert.False(response.Content.Headers.TryGetValues("content-length", out ignored), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value, "Chunked"); + Assert.Equal(FileLength, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_RelativeFile_Success() + { + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(RelativeFilePath, 0, null, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable ignored; + Assert.False(response.Content.Headers.TryGetValues("content-length", out ignored), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value, "Chunked"); + Assert.Equal(FileLength, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_Chunked_Chunked() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Transfer-EncodinG"] = new string[] { "CHUNKED" }; + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.False(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value); + Assert.Equal(FileLength, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_MultipleChunks_Chunked() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Transfer-EncodinG"] = new string[] { "CHUNKED" }; + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + sendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None).Wait(); + return sendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.False(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value); + Assert.Equal(FileLength * 2, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_ChunkedHalfOfFile_Chunked() + { + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, FileLength / 2, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.False(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value); + Assert.Equal(FileLength / 2, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_ChunkedOffsetOutOfRange_Throws() + { + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 1234567, null, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(500, (int)response.StatusCode); + } + } + + [Fact] + public async Task ResponseSendFile_ChunkedCountOutOfRange_Throws() + { + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, 1234567, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(500, (int)response.StatusCode); + } + } + + [Fact] + public async Task ResponseSendFile_ChunkedCount0_Chunked() + { + using (CreateServer(env => + { + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, 0, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.False(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.True(response.Headers.TransferEncodingChunked.Value); + Assert.Equal(0, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_ContentLength_PassedThrough() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { FileLength.ToString() }; + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.True(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.Equal(FileLength.ToString(), contentLength.First()); + Assert.Null(response.Headers.TransferEncodingChunked); + Assert.Equal(FileLength, response.Content.ReadAsByteArrayAsync().Result.Length); + } + } + + [Fact] + public async Task ResponseSendFile_ContentLengthSpecific_PassedThrough() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { "10" }; + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, 10, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.True(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.Equal("10", contentLength.First()); + Assert.Null(response.Headers.TransferEncodingChunked); + Assert.Equal(10, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + [Fact] + public async Task ResponseSendFile_ContentLength0_PassedThrough() + { + using (CreateServer(env => + { + env.Get>("owin.ResponseHeaders")["Content-lenGth"] = new string[] { "0" }; + SendFileFunc sendFileAsync = env.Get("sendfile.SendAsync"); + return sendFileAsync(AbsoluteFilePath, 0, 0, CancellationToken.None); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + IEnumerable contentLength; + Assert.True(response.Content.Headers.TryGetValues("content-length", out contentLength), "Content-Length"); + Assert.Equal("0", contentLength.First()); + Assert.Null(response.Headers.TransferEncodingChunked); + Assert.Equal(0, (await response.Content.ReadAsByteArrayAsync()).Length); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IList> addresses = new List>(); + IDictionary properties = new Dictionary(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + OwinServerFactory.Initialize(properties); + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/ResponseTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseTests.cs new file mode 100644 index 0000000000..537ad8c6c9 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/ResponseTests.cs @@ -0,0 +1,142 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class ResponseTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task Response_ServerSendsDefaultResponse_ServerProvidesStatusCodeAndReasonPhrase() + { + using (CreateServer(env => + { + Assert.Equal(200, env["owin.ResponseStatusCode"]); + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(200, (int)response.StatusCode); + Assert.Equal("OK", response.ReasonPhrase); + Assert.Equal(new Version(1, 1), response.Version); + Assert.Equal(string.Empty, await response.Content.ReadAsStringAsync()); + } + } + + [Fact] + public async Task Response_ServerSendsSpecificStatus_ServerProvidesReasonPhrase() + { + using (CreateServer(env => + { + env["owin.ResponseStatusCode"] = 201; + env["owin.ResponseProtocol"] = "HTTP/1.0"; // Http.Sys ignores this value + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(201, (int)response.StatusCode); + Assert.Equal("Created", response.ReasonPhrase); + Assert.Equal(new Version(1, 1), response.Version); + Assert.Equal(string.Empty, await response.Content.ReadAsStringAsync()); + } + } + + [Fact] + public async Task Response_ServerSendsSpecificStatusAndReasonPhrase_PassedThrough() + { + using (CreateServer(env => + { + env["owin.ResponseStatusCode"] = 201; + env["owin.ResponseReasonPhrase"] = "CustomReasonPhrase"; + env["owin.ResponseProtocol"] = "HTTP/1.0"; // Http.Sys ignores this value + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(201, (int)response.StatusCode); + Assert.Equal("CustomReasonPhrase", response.ReasonPhrase); + Assert.Equal(new Version(1, 1), response.Version); + Assert.Equal(string.Empty, await response.Content.ReadAsStringAsync()); + } + } + + [Fact] + public async Task Response_ServerSendsCustomStatus_NoReasonPhrase() + { + using (CreateServer(env => + { + env["owin.ResponseStatusCode"] = 901; + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(Address); + Assert.Equal(901, (int)response.StatusCode); + Assert.Equal(string.Empty, response.ReasonPhrase); + Assert.Equal(string.Empty, await response.Content.ReadAsStringAsync()); + } + } + + [Fact] + public void Response_100_Throws() + { + using (CreateServer(env => + { + env["owin.ResponseStatusCode"] = 100; + return Task.FromResult(0); + })) + { + Assert.Throws(() => SendRequestAsync(Address).Result); + } + } + + [Fact] + public void Response_0_Throws() + { + using (CreateServer(env => + { + env["owin.ResponseStatusCode"] = 0; + return Task.FromResult(0); + })) + { + Assert.Throws(() => SendRequestAsync(Address).Result); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + using (HttpClient client = new HttpClient()) + { + return await client.GetAsync(uri); + } + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/ServerTests.cs b/test/Microsoft.AspNet.Server.WebListener.Test/ServerTests.cs new file mode 100644 index 0000000000..84f76d9a52 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/ServerTests.cs @@ -0,0 +1,287 @@ +// ----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// ----------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Net.Http; +using System.Net.Sockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNet.Server.WebListener.Tests +{ + using AppFunc = Func, Task>; + + public class ServerTests + { + private const string Address = "http://localhost:8080/"; + + [Fact] + public async Task Server_200OK_Success() + { + using (CreateServer(env => + { + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal(string.Empty, response); + } + } + + [Fact] + public async Task Server_SendHelloWorld_Success() + { + using (CreateServer(env => + { + byte[] body = Encoding.UTF8.GetBytes("Hello World"); + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { body.Length.ToString() }; + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public async Task Server_EchoHelloWorld_Success() + { + using (CreateServer(env => + { + string input = new StreamReader(env.Get("owin.RequestBody")).ReadToEnd(); + Assert.Equal("Hello World", input); + byte[] body = Encoding.UTF8.GetBytes("Hello World"); + var responseHeaders = env.Get>("owin.ResponseHeaders"); + responseHeaders["Content-Length"] = new string[] { body.Length.ToString() }; + env.Get("owin.ResponseBody").Write(body, 0, body.Length); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address, "Hello World"); + Assert.Equal("Hello World", response); + } + } + + [Fact] + public void Server_AppException_ClientReset() + { + using (CreateServer(env => + { + throw new InvalidOperationException(); + })) + { + Task requestTask = SendRequestAsync(Address); + Assert.Throws(() => requestTask.Result); + + // Do it again to make sure the server didn't crash + requestTask = SendRequestAsync(Address); + Assert.Throws(() => requestTask.Result); + } + } + + [Fact] + public void Server_MultipleOutstandingSyncRequests_Success() + { + int requestLimit = 10; + int requestCount = 0; + TaskCompletionSource tcs = new TaskCompletionSource(); + + using (CreateServer(env => + { + if (Interlocked.Increment(ref requestCount) == requestLimit) + { + tcs.TrySetResult(null); + } + else + { + tcs.Task.Wait(); + } + + return Task.FromResult(0); + })) + { + List requestTasks = new List(); + for (int i = 0; i < requestLimit; i++) + { + Task requestTask = SendRequestAsync(Address); + requestTasks.Add(requestTask); + } + + bool success = Task.WaitAll(requestTasks.ToArray(), TimeSpan.FromSeconds(5)); + if (!success) + { + Console.WriteLine(); + } + Assert.True(success, "Timed out"); + } + } + + [Fact] + public void Server_MultipleOutstandingAsyncRequests_Success() + { + int requestLimit = 10; + int requestCount = 0; + TaskCompletionSource tcs = new TaskCompletionSource(); + + using (CreateServer(async env => + { + if (Interlocked.Increment(ref requestCount) == requestLimit) + { + tcs.TrySetResult(null); + } + else + { + await tcs.Task; + } + })) + { + List requestTasks = new List(); + for (int i = 0; i < requestLimit; i++) + { + Task requestTask = SendRequestAsync(Address); + requestTasks.Add(requestTask); + } + Assert.True(Task.WaitAll(requestTasks.ToArray(), TimeSpan.FromSeconds(2)), "Timed out"); + } + } + + [Fact] + public async Task Server_ClientDisconnects_CallCancelled() + { + TimeSpan interval = TimeSpan.FromSeconds(1); + ManualResetEvent received = new ManualResetEvent(false); + ManualResetEvent aborted = new ManualResetEvent(false); + ManualResetEvent canceled = new ManualResetEvent(false); + + using (CreateServer(env => + { + CancellationToken ct = env.Get("owin.CallCancelled"); + Assert.True(ct.CanBeCanceled, "CanBeCanceled"); + Assert.False(ct.IsCancellationRequested, "IsCancellationRequested"); + ct.Register(() => canceled.Set()); + received.Set(); + Assert.True(aborted.WaitOne(interval), "Aborted"); + Assert.True(ct.WaitHandle.WaitOne(interval), "CT Wait"); + Assert.True(ct.IsCancellationRequested, "IsCancellationRequested"); + return Task.FromResult(0); + })) + { + // Note: System.Net.Sockets does not RST the connection by default, it just FINs. + // Http.Sys's disconnect notice requires a RST. + using (Socket socket = await SendHungRequestAsync("GET", Address)) + { + Assert.True(received.WaitOne(interval), "Receive Timeout"); + socket.Close(0); // Force a RST + aborted.Set(); + } + Assert.True(canceled.WaitOne(interval), "canceled"); + } + } + + [Fact] + public async Task Server_SetQueueLimit_Success() + { + using (CreateServer(env => + { + // There's no good way to validate this in code. Just execute it to make sure it doesn't crash. + // Run "netsh http show servicestate" to see the current value + var listener = env.Get("Microsoft.AspNet.Server.WebListener.OwinWebListener"); + listener.SetRequestQueueLimit(1001); + return Task.FromResult(0); + })) + { + string response = await SendRequestAsync(Address); + Assert.Equal(string.Empty, response); + } + } + + private IDisposable CreateServer(AppFunc app) + { + IDictionary properties = new Dictionary(); + IList> addresses = new List>(); + properties["host.Addresses"] = addresses; + + IDictionary address = new Dictionary(); + addresses.Add(address); + + address["scheme"] = "http"; + address["host"] = "localhost"; + address["port"] = "8080"; + address["path"] = string.Empty; + + return OwinServerFactory.Create(app, properties); + } + + private async Task SendRequestAsync(string uri) + { + ServicePointManager.DefaultConnectionLimit = 100; + using (HttpClient client = new HttpClient()) + { + return await client.GetStringAsync(uri); + } + } + + private async Task SendRequestAsync(string uri, string upload) + { + using (HttpClient client = new HttpClient()) + { + HttpResponseMessage response = await client.PostAsync(uri, new StringContent(upload)); + response.EnsureSuccessStatusCode(); + return await response.Content.ReadAsStringAsync(); + } + } + + private async Task SendHungRequestAsync(string method, string address) + { + // Connect with a socket + Uri uri = new Uri(address); + TcpClient client = new TcpClient(); + try + { + await client.ConnectAsync(uri.Host, uri.Port); + NetworkStream stream = client.GetStream(); + + // Send an HTTP GET request + byte[] requestBytes = BuildGetRequest(method, uri); + await stream.WriteAsync(requestBytes, 0, requestBytes.Length); + + // Return the opaque network stream + return client.Client; + } + catch (Exception) + { + client.Close(); + throw; + } + } + + private byte[] BuildGetRequest(string method, Uri uri) + { + StringBuilder builder = new StringBuilder(); + builder.Append(method); + builder.Append(" "); + builder.Append(uri.PathAndQuery); + builder.Append(" HTTP/1.1"); + builder.AppendLine(); + + builder.Append("Host: "); + builder.Append(uri.Host); + builder.Append(':'); + builder.Append(uri.Port); + builder.AppendLine(); + + builder.AppendLine(); + return Encoding.ASCII.GetBytes(builder.ToString()); + } + } +} diff --git a/test/Microsoft.AspNet.Server.WebListener.Test/project.json b/test/Microsoft.AspNet.Server.WebListener.Test/project.json new file mode 100644 index 0000000000..b11449f836 --- /dev/null +++ b/test/Microsoft.AspNet.Server.WebListener.Test/project.json @@ -0,0 +1,16 @@ +{ + "version" : "0.1-alpha-*", + "dependencies": { + "Microsoft.AspNet.Server.WebListener" : "" + }, + "configurations": { + "net45": { + "dependencies": { + "XUnit": "1.9.2", + "XUnit.Extensions": "1.9.2", + "System.Net.Http": "", + "System.Net.Http.WebRequest": "" + } + } + } +}