diff --git a/Razor.sln b/Razor.sln index 52132dc17a..f184c02441 100644 --- a/Razor.sln +++ b/Razor.sln @@ -94,6 +94,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.Razor. EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.Razor.Sdk", "src\Microsoft.AspNetCore.Razor.Sdk\Microsoft.AspNetCore.Razor.Sdk.csproj", "{7D9ECCEE-71D1-4A42-ABEE-876AFA1B4FC9}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.Razor.Tools.Test", "test\Microsoft.AspNetCore.Razor.Tools.Test\Microsoft.AspNetCore.Razor.Tools.Test.csproj", "{6EA56B2B-89EC-4C38-A384-97D203375B06}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -388,6 +390,14 @@ Global {7D9ECCEE-71D1-4A42-ABEE-876AFA1B4FC9}.Release|Any CPU.Build.0 = Release|Any CPU {7D9ECCEE-71D1-4A42-ABEE-876AFA1B4FC9}.ReleaseNoVSIX|Any CPU.ActiveCfg = Debug|Any CPU {7D9ECCEE-71D1-4A42-ABEE-876AFA1B4FC9}.ReleaseNoVSIX|Any CPU.Build.0 = Debug|Any CPU + {6EA56B2B-89EC-4C38-A384-97D203375B06}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {6EA56B2B-89EC-4C38-A384-97D203375B06}.Debug|Any CPU.Build.0 = Debug|Any CPU + {6EA56B2B-89EC-4C38-A384-97D203375B06}.DebugNoVSIX|Any CPU.ActiveCfg = Debug|Any CPU + {6EA56B2B-89EC-4C38-A384-97D203375B06}.DebugNoVSIX|Any CPU.Build.0 = Debug|Any CPU + {6EA56B2B-89EC-4C38-A384-97D203375B06}.Release|Any CPU.ActiveCfg = Release|Any CPU + {6EA56B2B-89EC-4C38-A384-97D203375B06}.Release|Any CPU.Build.0 = Release|Any CPU + {6EA56B2B-89EC-4C38-A384-97D203375B06}.ReleaseNoVSIX|Any CPU.ActiveCfg = Release|Any CPU + {6EA56B2B-89EC-4C38-A384-97D203375B06}.ReleaseNoVSIX|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -429,6 +439,7 @@ Global {933101DA-C4CC-401A-AA01-2784E1025B7F} = {92463391-81BE-462B-AC3C-78C6C760741F} {3E7F2D49-3B45-45A8-9893-F73EC1EEBAAB} = {3C0D6505-79B3-49D0-B4C3-176F0F1836ED} {7D9ECCEE-71D1-4A42-ABEE-876AFA1B4FC9} = {3C0D6505-79B3-49D0-B4C3-176F0F1836ED} + {6EA56B2B-89EC-4C38-A384-97D203375B06} = {92463391-81BE-462B-AC3C-78C6C760741F} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {0035341D-175A-4D05-95E6-F1C2785A1E26} diff --git a/src/Microsoft.AspNetCore.Razor.Tasks/DotnetToolTask.cs b/src/Microsoft.AspNetCore.Razor.Tasks/DotnetToolTask.cs index 03f55808d1..29f3b2ec62 100644 --- a/src/Microsoft.AspNetCore.Razor.Tasks/DotnetToolTask.cs +++ b/src/Microsoft.AspNetCore.Razor.Tasks/DotnetToolTask.cs @@ -7,6 +7,7 @@ using System.Diagnostics; using System.IO; using System.Linq; using System.Threading; +using Microsoft.AspNetCore.Razor.Tools; using Microsoft.Build.Framework; using Microsoft.Build.Utilities; using Microsoft.CodeAnalysis.CommandLine; @@ -106,37 +107,35 @@ namespace Microsoft.AspNetCore.Razor.Tasks _razorServerCts?.Cancel(); } - protected virtual bool TryExecuteOnServer(string pathToTool, string responseFileCommands, string commandLineCommands, out int result) + protected virtual bool TryExecuteOnServer( + string pathToTool, + string responseFileCommands, + string commandLineCommands, + out int result) { CompilerServerLogger.Log("Server execution started."); using (_razorServerCts = new CancellationTokenSource()) { CompilerServerLogger.Log($"CommandLine = '{commandLineCommands}'"); - CompilerServerLogger.Log($"BuildResponseFile = '{responseFileCommands}'"); + CompilerServerLogger.Log($"ServerResponseFile = '{responseFileCommands}'"); // The server contains the tools for discovering tag helpers and generating Razor code. var clientDir = Path.GetDirectoryName(ToolAssembly); var workingDir = CurrentDirectoryToUse(); - var tempDir = BuildServerConnection.GetTempPath(workingDir); - - var buildPaths = new BuildPathsAlt( + var tempDir = ServerConnection.GetTempPath(workingDir); + var serverPaths = new ServerPaths( clientDir, - // MSBuild doesn't need the .NET SDK directory - sdkDir: null, workingDir: workingDir, tempDir: tempDir); - var responseTask = BuildServerConnection.RunServerCompilation( - GetArguments(responseFileCommands), - buildPaths, - keepAlive: null, - cancellationToken: _razorServerCts.Token); + var arguments = GetArguments(responseFileCommands); + var responseTask = ServerConnection.RunOnServer(arguments, serverPaths, _razorServerCts.Token); responseTask.Wait(_razorServerCts.Token); var response = responseTask.Result; - if (response.Type == BuildResponse.ResponseType.Completed && - response is CompletedBuildResponse completedResponse) + if (response.Type == ServerResponse.ResponseType.Completed && + response is CompletedServerResponse completedResponse) { result = completedResponse.ReturnCode; @@ -159,7 +158,7 @@ namespace Microsoft.AspNetCore.Razor.Tasks { // ToolTask has a method for this. But it may return null. Use the process directory // if ToolTask didn't override. MSBuild uses the process directory. - string workingDirectory = GetWorkingDirectory(); + var workingDirectory = GetWorkingDirectory(); if (string.IsNullOrEmpty(workingDirectory)) { workingDirectory = Directory.GetCurrentDirectory(); diff --git a/src/Microsoft.AspNetCore.Razor.Tasks/Microsoft.AspNetCore.Razor.Tasks.csproj b/src/Microsoft.AspNetCore.Razor.Tasks/Microsoft.AspNetCore.Razor.Tasks.csproj index 70b5027083..ebe30d5d84 100644 --- a/src/Microsoft.AspNetCore.Razor.Tasks/Microsoft.AspNetCore.Razor.Tasks.csproj +++ b/src/Microsoft.AspNetCore.Razor.Tasks/Microsoft.AspNetCore.Razor.Tasks.csproj @@ -16,29 +16,26 @@ - - Shared\BuildServerConnection.cs - - - Shared\NativeMethods.cs - Shared\CompilerServerLogger.cs Shared\PlatformInformation.cs - - Shared\BuildProtocol.cs - Shared\CommandLineUtilities.cs + + Shared\ServerProtocol\%(FileName) + Shared\PipeName.cs Shared\MutexName.cs + + Shared\Client.cs + diff --git a/src/Microsoft.AspNetCore.Razor.Tools/Client.cs b/src/Microsoft.AspNetCore.Razor.Tools/Client.cs index b542d653ca..9a0334b8db 100644 --- a/src/Microsoft.AspNetCore.Razor.Tools/Client.cs +++ b/src/Microsoft.AspNetCore.Razor.Tools/Client.cs @@ -12,6 +12,23 @@ namespace Microsoft.AspNetCore.Razor.Tools { internal abstract class Client : IDisposable { + private static int counter; + + public abstract Stream Stream { get; } + + public abstract string Identifier { get; } + + public void Dispose() + { + Dispose(disposing: true); + } + + public abstract Task WaitForDisconnectAsync(CancellationToken cancellationToken); + + protected virtual void Dispose(bool disposing) + { + } + // Based on: https://github.com/dotnet/roslyn/blob/14aed138a01c448143b9acf0fe77a662e3dfe2f4/src/Compilers/Shared/BuildServerConnection.cs#L290 public static async Task ConnectAsync(string pipeName, TimeSpan? timeout, CancellationToken cancellationToken) { @@ -49,7 +66,7 @@ namespace Microsoft.AspNetCore.Razor.Tools // We plan to rely on the BCL for this but it's not yet implemented: // See https://github.com/dotnet/corefx/issues/25427 - return new NamedPipeClient(stream); + return new NamedPipeClient(stream, GetNextIdentifier()); } catch (Exception e) when (!(e is TaskCanceledException || e is OperationCanceledException)) { @@ -58,25 +75,55 @@ namespace Microsoft.AspNetCore.Razor.Tools } } - public abstract Stream Stream { get; } - - public void Dispose() + private static string GetNextIdentifier() { - Dispose(disposing: true); + var id = Interlocked.Increment(ref counter); + return "clientconnection-" + id; } - protected virtual void Dispose(bool disposing) - { - } private class NamedPipeClient : Client { - public NamedPipeClient(NamedPipeClientStream stream) + public NamedPipeClient(NamedPipeClientStream stream, string identifier) { Stream = stream; + Identifier = identifier; } public override Stream Stream { get; } + public override string Identifier { get; } + + public async override Task WaitForDisconnectAsync(CancellationToken cancellationToken) + { + if (!(Stream is PipeStream pipeStream)) + { + return; + } + + // We have to poll for disconnection by reading, PipeStream.IsConnected isn't reliable unless you + // actually do a read - which will cause it to update its state. + while (!cancellationToken.IsCancellationRequested && pipeStream.IsConnected) + { + await Task.Delay(TimeSpan.FromMilliseconds(100), cancellationToken); + + try + { + CompilerServerLogger.Log($"Before poking pipe {Identifier}."); + await Stream.ReadAsync(Array.Empty(), 0, 0, cancellationToken); + CompilerServerLogger.Log($"After poking pipe {Identifier}."); + } + catch (OperationCanceledException) + { + } + catch (Exception e) + { + // It is okay for this call to fail. Errors will be reflected in the + // IsConnected property which will be read on the next iteration. + CompilerServerLogger.LogException(e, $"Error poking pipe {Identifier}."); + } + } + } + protected override void Dispose(bool disposing) { if (disposing) diff --git a/src/Microsoft.AspNetCore.Razor.Tools/CompilerHost.cs b/src/Microsoft.AspNetCore.Razor.Tools/CompilerHost.cs index e551b656ea..a00ad04aa4 100644 --- a/src/Microsoft.AspNetCore.Razor.Tools/CompilerHost.cs +++ b/src/Microsoft.AspNetCore.Razor.Tools/CompilerHost.cs @@ -5,7 +5,6 @@ using System.Collections.Generic; using System.Linq; using System.Threading; using Microsoft.CodeAnalysis.CommandLine; -using Microsoft.Extensions.CommandLineUtils; namespace Microsoft.AspNetCore.Razor.Tools { @@ -16,15 +15,15 @@ namespace Microsoft.AspNetCore.Razor.Tools return new DefaultCompilerHost(); } - public abstract BuildResponse Execute(BuildRequest request, CancellationToken cancellationToken); + public abstract ServerResponse Execute(ServerRequest request, CancellationToken cancellationToken); private class DefaultCompilerHost : CompilerHost { - public override BuildResponse Execute(BuildRequest request, CancellationToken cancellationToken) + public override ServerResponse Execute(ServerRequest request, CancellationToken cancellationToken) { if (!TryParseArguments(request, out var parsed)) { - return new RejectedBuildResponse(); + return new RejectedServerResponse(); } var app = new Application(cancellationToken); @@ -33,10 +32,10 @@ namespace Microsoft.AspNetCore.Razor.Tools var exitCode = app.Execute(commandArgs); var output = app.Out.ToString() ?? string.Empty; - return new CompletedBuildResponse(exitCode, utf8output: false, output: output); + return new CompletedServerResponse(exitCode, utf8output: false, output: output); } - private bool TryParseArguments(BuildRequest request, out (string workingDirectory, string tempDirectory, string[] args) parsed) + private bool TryParseArguments(ServerRequest request, out (string workingDirectory, string tempDirectory, string[] args) parsed) { string workingDirectory = null; string tempDirectory = null; @@ -46,15 +45,15 @@ namespace Microsoft.AspNetCore.Razor.Tools for (var i = 0; i < request.Arguments.Count; i++) { var argument = request.Arguments[i]; - if (argument.ArgumentId == BuildProtocolConstants.ArgumentId.CurrentDirectory) + if (argument.Id == RequestArgument.ArgumentId.CurrentDirectory) { workingDirectory = argument.Value; } - else if (argument.ArgumentId == BuildProtocolConstants.ArgumentId.TempDirectory) + else if (argument.Id == RequestArgument.ArgumentId.TempDirectory) { tempDirectory = argument.Value; } - else if (argument.ArgumentId == BuildProtocolConstants.ArgumentId.CommandLineArgument) + else if (argument.Id == RequestArgument.ArgumentId.CommandLineArgument) { args.Add(argument.Value); } diff --git a/src/Microsoft.AspNetCore.Razor.Tools/ConnectionHost.cs b/src/Microsoft.AspNetCore.Razor.Tools/ConnectionHost.cs index e08073f80d..7da0a05497 100644 --- a/src/Microsoft.AspNetCore.Razor.Tools/ConnectionHost.cs +++ b/src/Microsoft.AspNetCore.Razor.Tools/ConnectionHost.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO; using System.IO.Pipes; using System.Threading; using System.Threading.Tasks; @@ -13,23 +14,23 @@ namespace Microsoft.AspNetCore.Razor.Tools // https://github.com/dotnet/roslyn/blob/14aed138a01c448143b9acf0fe77a662e3dfe2f4/src/Compilers/Server/VBCSCompiler/NamedPipeClientConnection.cs#L17 internal abstract class ConnectionHost { - private static int counter; - - private static string GetNextIdentifier() - { - var id = Interlocked.Increment(ref counter); - return "connection-" + id; - } - // Size of the buffers to use: 64K private const int PipeBufferSize = 0x10000; + private static int counter; + + public abstract Task WaitForConnectionAsync(CancellationToken cancellationToken); + public static ConnectionHost Create(string pipeName) { return new NamedPipeConnectionHost(pipeName); } - public abstract Task WaitForConnectionAsync(CancellationToken cancellationToken); + private static string GetNextIdentifier() + { + var id = Interlocked.Increment(ref counter); + return "connection-" + id; + } private class NamedPipeConnectionHost : ConnectionHost { @@ -76,17 +77,20 @@ namespace Microsoft.AspNetCore.Razor.Tools { public NamedPipeConnection(NamedPipeServerStream stream, string identifier) { - base.Stream = stream; + Stream = stream; Identifier = identifier; } - public new NamedPipeServerStream Stream => (NamedPipeServerStream)base.Stream; - public async override Task WaitForDisconnectAsync(CancellationToken cancellationToken) { + if (!(Stream is PipeStream pipeStream)) + { + return; + } + // We have to poll for disconnection by reading, PipeStream.IsConnected isn't reliable unless you // actually do a read - which will cause it to update its state. - while (!cancellationToken.IsCancellationRequested && Stream.IsConnected) + while (!cancellationToken.IsCancellationRequested && pipeStream.IsConnected) { await Task.Delay(TimeSpan.FromMilliseconds(100), cancellationToken); @@ -102,7 +106,7 @@ namespace Microsoft.AspNetCore.Razor.Tools catch (Exception e) { // It is okay for this call to fail. Errors will be reflected in the - // IsConnected property which will be read on the next iteration of the + // IsConnected property which will be read on the next iteration. CompilerServerLogger.LogException(e, $"Error poking pipe {Identifier}."); } } diff --git a/src/Microsoft.AspNetCore.Razor.Tools/DefaultRequestDispatcher.cs b/src/Microsoft.AspNetCore.Razor.Tools/DefaultRequestDispatcher.cs new file mode 100644 index 0000000000..7725a9b3cb --- /dev/null +++ b/src/Microsoft.AspNetCore.Razor.Tools/DefaultRequestDispatcher.cs @@ -0,0 +1,471 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis.CommandLine; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + // Heavily influenced by: + // https://github.com/dotnet/roslyn/blob/14aed138a01c448143b9acf0fe77a662e3dfe2f4/src/Compilers/Server/ServerShared/ServerDispatcher.cs#L15 + internal class DefaultRequestDispatcher : RequestDispatcher + { + private readonly CancellationToken _cancellationToken; + private readonly CompilerHost _compilerHost; + private readonly ConnectionHost _connectionHost; + private readonly EventBus _eventBus; + + private KeepAlive _keepAlive; + private State _state; + private Task _timeoutTask; + private Task _gcTask; + private Task _listenTask; + private CancellationTokenSource _listenCancellationTokenSource; + private List> _connections = new List>(); + + public DefaultRequestDispatcher( + ConnectionHost connectionHost, + CompilerHost compilerHost, + CancellationToken cancellationToken, + EventBus eventBus = null, + TimeSpan? keepAlive = null) + { + _connectionHost = connectionHost; + _compilerHost = compilerHost; + _cancellationToken = cancellationToken; + + _eventBus = eventBus ?? EventBus.Default; + + var keepAliveTimeout = DefaultServerKeepAlive; + if (keepAlive.HasValue) + { + keepAliveTimeout = keepAlive.Value; + } + _keepAlive = new KeepAlive(keepAliveTimeout, isDefault: true); + } + + // The server accepts connections until we reach a state that requires a shutdown. At that + // time no new connections will be accepted and the server will drain existing connections. + // + // The idea is that it's better to let clients fallback to in-proc (and slow down) than it is to keep + // running in an undesired state. + public override void Run() + { + _state = State.Running; + + try + { + Listen(); + + do + { + Debug.Assert(_listenTask != null); + + MaybeCreateTimeoutTask(); + MaybeCreateGCTask(); + WaitForAnyCompletion(_cancellationToken); + CheckCompletedTasks(_cancellationToken); + } + while (_connections.Count > 0 || _state == State.Running); + } + finally + { + _state = State.Completed; + _gcTask = null; + _timeoutTask = null; + + if (_listenTask != null) + { + CloseListenTask(); + } + } + } + + + private void CheckCompletedTasks(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + HandleCancellation(); + return; + } + + if (_listenTask.IsCompleted) + { + HandleCompletedListenTask(cancellationToken); + } + + if (_timeoutTask?.IsCompleted == true) + { + HandleCompletedTimeoutTask(); + } + + if (_gcTask?.IsCompleted == true) + { + HandleCompletedGCTask(); + } + + HandleCompletedConnections(); + } + + private void HandleCancellation() + { + Debug.Assert(_listenTask != null); + + // If cancellation has been requested then the server needs to be in the process + // of shutting down. + _state = State.ShuttingDown; + + CloseListenTask(); + + try + { + Task.WaitAll(_connections.ToArray()); + } + catch + { + // It's expected that some will throw exceptions, in particular OperationCanceledException. It's + // okay for them to throw so long as they complete. + } + + HandleCompletedConnections(); + Debug.Assert(_connections.Count == 0); + } + + /// + /// The server farms out work to Task values and this method needs to wait until at least one of them + /// has completed. + /// + private void WaitForAnyCompletion(CancellationToken cancellationToken) + { + var all = new List(); + all.AddRange(_connections); + all.Add(_timeoutTask); + all.Add(_listenTask); + all.Add(_gcTask); + + try + { + var waitArray = all.Where(x => x != null).ToArray(); + Task.WaitAny(waitArray, cancellationToken); + } + catch (OperationCanceledException) + { + // Thrown when the provided cancellationToken is cancelled. This is handled in the caller, + // here it just serves to break out of the WaitAny call. + } + } + + private void Listen() + { + Debug.Assert(_listenTask == null); + Debug.Assert(_timeoutTask == null); + + _listenCancellationTokenSource = new CancellationTokenSource(); + _listenTask = _connectionHost.WaitForConnectionAsync(_listenCancellationTokenSource.Token); + _eventBus.ConnectionListening(); + } + + private void CloseListenTask() + { + Debug.Assert(_listenTask != null); + + _listenCancellationTokenSource.Cancel(); + _listenCancellationTokenSource = null; + _listenTask = null; + } + + private void HandleCompletedListenTask(CancellationToken cancellationToken) + { + _eventBus.ConnectionReceived(); + + // Don't accept any new connections once we're in shutdown mode, instead gracefully reject the request. + // This should cause the client to run in process. + var accept = _state == State.Running; + var connectionTask = AcceptConnection(_listenTask, accept, cancellationToken); + _connections.Add(connectionTask); + + // Timeout and GC are only done when there are no active connections. Now that we have a new + // connection cancel out these tasks. + _timeoutTask = null; + _gcTask = null; + + // Begin listening again for new connections. + _listenTask = null; + Listen(); + } + + private void HandleCompletedTimeoutTask() + { + _eventBus.KeepAliveReached(); + _listenCancellationTokenSource.Cancel(); + _timeoutTask = null; + _state = State.ShuttingDown; + } + + private void HandleCompletedGCTask() + { + _gcTask = null; + for (var i = 0; i < 10; i++) + { + GC.Collect(); + GC.WaitForPendingFinalizers(); + } + + GCSettings.LargeObjectHeapCompactionMode = GCLargeObjectHeapCompactionMode.CompactOnce; + GC.Collect(); + } + + private void MaybeCreateTimeoutTask() + { + // If there are no active clients running then the server needs to be in a timeout mode. + if (_connections.Count == 0 && _timeoutTask == null) + { + Debug.Assert(_listenTask != null); + _timeoutTask = Task.Delay(_keepAlive.TimeSpan); + } + } + + private void MaybeCreateGCTask() + { + if (_connections.Count == 0 && _gcTask == null) + { + _gcTask = Task.Delay(GCTimeout); + } + } + + /// + /// Checks the completed connection objects. + /// + /// False if the server needs to begin shutting down + private void HandleCompletedConnections() + { + var shutdown = false; + var processedCount = 0; + var i = 0; + while (i < _connections.Count) + { + var current = _connections[i]; + if (!current.IsCompleted) + { + i++; + continue; + } + + _connections.RemoveAt(i); + processedCount++; + + var result = current.Result; + if (result.KeepAlive.HasValue) + { + var updated = _keepAlive.Update(result.KeepAlive.Value); + if (updated.Equals(_keepAlive)) + { + _eventBus.UpdateKeepAlive(updated.TimeSpan); + } + } + + switch (result.CloseReason) + { + case ConnectionResult.Reason.CompilationCompleted: + case ConnectionResult.Reason.CompilationNotStarted: + // These are all normal end states. Nothing to do here. + break; + + case ConnectionResult.Reason.ClientDisconnect: + // Have to assume the worst here which is user pressing Ctrl+C at the command line and + // hence wanting all compilation to end. + _eventBus.ConnectionRudelyEnded(); + shutdown = true; + break; + + case ConnectionResult.Reason.ClientException: + case ConnectionResult.Reason.ClientShutdownRequest: + _eventBus.ConnectionRudelyEnded(); + shutdown = true; + break; + + default: + throw new InvalidOperationException($"Unexpected enum value {result.CloseReason}"); + } + } + + if (processedCount > 0) + { + _eventBus.ConnectionCompleted(processedCount); + } + + if (shutdown) + { + _state = State.ShuttingDown; + } + } + + internal async Task AcceptConnection(Task task, bool accept, CancellationToken cancellationToken) + { + Connection connection; + try + { + connection = await task; + } + catch (Exception ex) + { + // Unable to establish a connection with the client. The client is responsible for + // handling this case. Nothing else for us to do here. + CompilerServerLogger.LogException(ex, "Error creating client named pipe"); + return new ConnectionResult(ConnectionResult.Reason.CompilationNotStarted); + } + + try + { + using (connection) + { + ServerRequest request; + try + { + CompilerServerLogger.Log("Begin reading request."); + request = await ServerRequest.ReadAsync(connection.Stream, cancellationToken).ConfigureAwait(false); + CompilerServerLogger.Log("End reading request."); + } + catch (Exception e) + { + CompilerServerLogger.LogException(e, "Error reading build request."); + return new ConnectionResult(ConnectionResult.Reason.CompilationNotStarted); + } + + if (request.IsShutdownRequest()) + { + // Reply with the PID of this process so that the client can wait for it to exit. + var response = new ShutdownServerResponse(Process.GetCurrentProcess().Id); + await response.WriteAsync(connection.Stream, cancellationToken); + + // We can safely disconnect the client, then when this connection gets cleaned up by the event loop + // the server will go to a shutdown state. + return new ConnectionResult(ConnectionResult.Reason.ClientShutdownRequest); + } + else if (!accept) + { + // We're already in shutdown mode, respond gracefully so the client can run in-process. + var response = new RejectedServerResponse(); + await response.WriteAsync(connection.Stream, cancellationToken).ConfigureAwait(false); + + return new ConnectionResult(ConnectionResult.Reason.CompilationNotStarted); + } + else + { + // If we get here then this is a real request that we will accept and process. + // + // Kick off both the compilation and a task to monitor the pipe for closing. + var buildCancelled = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + + var watcher = connection.WaitForDisconnectAsync(buildCancelled.Token); + var worker = ExecuteRequestAsync(request, buildCancelled.Token); + + // await will end when either the work is complete or the connection is closed. + await Task.WhenAny(worker, watcher); + + // Do an 'await' on the completed task, preference being compilation, to force + // any exceptions to be realized in this method for logging. + ConnectionResult.Reason reason; + if (worker.IsCompleted) + { + var response = await worker; + + try + { + CompilerServerLogger.Log("Begin writing response."); + await response.WriteAsync(connection.Stream, cancellationToken); + CompilerServerLogger.Log("End writing response."); + + reason = ConnectionResult.Reason.CompilationCompleted; + } + catch + { + reason = ConnectionResult.Reason.ClientDisconnect; + } + } + else + { + await watcher; + reason = ConnectionResult.Reason.ClientDisconnect; + } + + // Begin the tear down of the Task which didn't complete. + buildCancelled.Cancel(); + + return new ConnectionResult(reason, request.KeepAlive); + } + } + } + catch (Exception ex) + { + CompilerServerLogger.LogException(ex, "Error handling connection"); + return new ConnectionResult(ConnectionResult.Reason.ClientException); + } + } + + private Task ExecuteRequestAsync(ServerRequest buildRequest, CancellationToken cancellationToken) + { + Func func = () => + { + CompilerServerLogger.Log("Begin processing request"); + + var response = _compilerHost.Execute(buildRequest, cancellationToken); + + CompilerServerLogger.Log("End processing request"); + return response; + }; + + var task = new Task(func, cancellationToken, TaskCreationOptions.LongRunning); + task.Start(); + return task; + } + + private enum State + { + /// + /// Server running and accepting all requests + /// + Running, + + /// + /// Server processing existing requests, responding to shutdown commands but is not accepting + /// new build requests. + /// + ShuttingDown, + + /// + /// Server is done. + /// + Completed, + } + + private struct KeepAlive + { + public TimeSpan TimeSpan; + public bool IsDefault; + + public KeepAlive(TimeSpan timeSpan, bool isDefault) + { + TimeSpan = timeSpan; + IsDefault = isDefault; + } + + public KeepAlive Update(TimeSpan timeSpan) + { + if (IsDefault || timeSpan > TimeSpan) + { + return new KeepAlive(timeSpan, isDefault: false); + } + + return this; + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Razor.Tools/PipeName.cs b/src/Microsoft.AspNetCore.Razor.Tools/PipeName.cs index 7a575cb6b9..0b897bf05c 100644 --- a/src/Microsoft.AspNetCore.Razor.Tools/PipeName.cs +++ b/src/Microsoft.AspNetCore.Razor.Tools/PipeName.cs @@ -30,7 +30,7 @@ namespace Microsoft.AspNetCore.Razor.Tools var baseName = ComputeBaseName("Razor:" + AppDomain.CurrentDomain.BaseDirectory); // Prefix with username and elevation - bool isAdmin = false; + var isAdmin = false; if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { #if WINDOWS_HACK_LOL diff --git a/src/Microsoft.AspNetCore.Razor.Tools/Properties/AssemblyInfo.cs b/src/Microsoft.AspNetCore.Razor.Tools/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..96dc897437 --- /dev/null +++ b/src/Microsoft.AspNetCore.Razor.Tools/Properties/AssemblyInfo.cs @@ -0,0 +1,8 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.Razor.Tools.Test, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] +[assembly: InternalsVisibleTo("DynamicProxyGenAssembly2, PublicKey=0024000004800000940000000602000000240000525341310004000001000100c547cac37abd99c8db225ef2f6c8a3602f3b3606cc9891605d02baa56104f4cfc0734aa39b93bf7852f7d9266654753cc297e7d2edfe0bac1cdcf9f717241550e0a7b191195b7667bb4f64bcb8e2121380fd1d9d46ad2d92d2d15605093924cceaf74c4861eff62abf69b9291ed0a340e113be11e6a7d3113e92484cf7045cc7")] + diff --git a/src/Microsoft.AspNetCore.Razor.Tools/RequestDispatcher.cs b/src/Microsoft.AspNetCore.Razor.Tools/RequestDispatcher.cs index ea488b52d9..1b46d1cbf7 100644 --- a/src/Microsoft.AspNetCore.Razor.Tools/RequestDispatcher.cs +++ b/src/Microsoft.AspNetCore.Razor.Tools/RequestDispatcher.cs @@ -2,26 +2,12 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Runtime; using System.Threading; -using System.Threading.Tasks; -using Microsoft.CodeAnalysis.CommandLine; -using Microsoft.CodeAnalysis.CompilerServer; namespace Microsoft.AspNetCore.Razor.Tools { - // Heavily influenced by: - // https://github.com/dotnet/roslyn/blob/14aed138a01c448143b9acf0fe77a662e3dfe2f4/src/Compilers/Server/ServerShared/ServerDispatcher.cs#L15 internal abstract class RequestDispatcher { - public static RequestDispatcher Create(ConnectionHost connectionHost, CompilerHost compilerHost, CancellationToken cancellationToken) - { - return new DefaultRequestDispatcher(connectionHost, compilerHost, cancellationToken); - } - /// /// Default time the server will stay alive after the last request disconnects. /// @@ -35,448 +21,9 @@ namespace Microsoft.AspNetCore.Razor.Tools public abstract void Run(); - private enum State + public static RequestDispatcher Create(ConnectionHost connectionHost, CompilerHost compilerHost, CancellationToken cancellationToken, EventBus eventBus, TimeSpan? keepAlive = null) { - /// - /// Server running and accepting all requests - /// - Running, - - /// - /// Server processing existing requests, responding to shutdown commands but is not accepting - /// new build requests. - /// - ShuttingDown, - - /// - /// Server is done. - /// - Completed, - } - - private class DefaultRequestDispatcher : RequestDispatcher - { - private readonly CancellationToken _cancellationToken; - private readonly CompilerHost _compilerHost; - private readonly ConnectionHost _connectionHost; - private readonly EventBus _eventBus; - - private KeepAlive _keepAlive; - private State _state; - private Task _timeoutTask; - private Task _gcTask; - private Task _listenTask; - private CancellationTokenSource _listenCancellationTokenSource; - private List> _connections = new List>(); - - public DefaultRequestDispatcher(ConnectionHost connectionHost, CompilerHost compilerHost, CancellationToken cancellationToken) - { - _connectionHost = connectionHost; - _compilerHost = compilerHost; - _cancellationToken = cancellationToken; - - _eventBus = EventBus.Default; - _keepAlive = new KeepAlive(DefaultServerKeepAlive, isDefault: true); - } - - // The server accepts connections until we reach a state that requires a shutdown. At that - // time no new connections will be accepted and the server will drain existing connections. - // - // The idea is that it's better to let clients fallback to in-proc (and slow down) than it is to keep - // running in an undesired state. - public override void Run() - { - _state = State.Running; - - try - { - Listen(); - - do - { - Debug.Assert(_listenTask != null); - - MaybeCreateTimeoutTask(); - MaybeCreateGCTask(); - WaitForAnyCompletion(_cancellationToken); - CheckCompletedTasks(_cancellationToken); - } - while (_connections.Count > 0 || _state == State.Running); - } - finally - { - _state = State.Completed; - _gcTask = null; - _timeoutTask = null; - - if (_listenTask != null) - { - CloseListenTask(); - } - } - } - - - private void CheckCompletedTasks(CancellationToken cancellationToken) - { - if (cancellationToken.IsCancellationRequested) - { - HandleCancellation(); - return; - } - - if (_listenTask.IsCompleted) - { - HandleCompletedListenTask(cancellationToken); - } - - if (_timeoutTask?.IsCompleted == true) - { - HandleCompletedTimeoutTask(); - } - - if (_gcTask?.IsCompleted == true) - { - HandleCompletedGCTask(); - } - - HandleCompletedConnections(); - } - - private void HandleCancellation() - { - Debug.Assert(_listenTask != null); - - // If cancellation has been requested then the server needs to be in the process - // of shutting down. - _state = State.ShuttingDown; - - CloseListenTask(); - - try - { - Task.WaitAll(_connections.ToArray()); - } - catch - { - // It's expected that some will throw exceptions, in particular OperationCanceledException. It's - // okay for them to throw so long as they complete. - } - - HandleCompletedConnections(); - Debug.Assert(_connections.Count == 0); - } - - /// - /// The server farms out work to Task values and this method needs to wait until at least one of them - /// has completed. - /// - private void WaitForAnyCompletion(CancellationToken cancellationToken) - { - var all = new List(); - all.AddRange(_connections); - all.Add(_timeoutTask); - all.Add(_listenTask); - all.Add(_gcTask); - - try - { - var waitArray = all.Where(x => x != null).ToArray(); - Task.WaitAny(waitArray, cancellationToken); - } - catch (OperationCanceledException) - { - // Thrown when the provided cancellationToken is cancelled. This is handled in the caller, - // here it just serves to break out of the WaitAny call. - } - } - - private void Listen() - { - Debug.Assert(_listenTask == null); - Debug.Assert(_timeoutTask == null); - - _listenCancellationTokenSource = new CancellationTokenSource(); - _listenTask = _connectionHost.WaitForConnectionAsync(_listenCancellationTokenSource.Token); - _eventBus.ConnectionListening(); - } - - private void CloseListenTask() - { - Debug.Assert(_listenTask != null); - - _listenCancellationTokenSource.Cancel(); - _listenCancellationTokenSource = null; - _listenTask = null; - } - - private void HandleCompletedListenTask(CancellationToken cancellationToken) - { - _eventBus.ConnectionReceived(); - - // Don't accept any new connections once we're in shutdown mode, instead gracefully reject the request. - // This should cause the client to run in process. - var accept = _state == State.Running; - var connectionTask = AcceptConnection(_listenTask, accept, cancellationToken); - _connections.Add(connectionTask); - - // Timeout and GC are only done when there are no active connections. Now that we have a new - // connection cancel out these tasks. - _timeoutTask = null; - _gcTask = null; - - // Begin listening again for new connections. - _listenTask = null; - Listen(); - } - - private void HandleCompletedTimeoutTask() - { - _eventBus.KeepAliveReached(); - _listenCancellationTokenSource.Cancel(); - _timeoutTask = null; - _state = State.ShuttingDown; - } - - private void HandleCompletedGCTask() - { - _gcTask = null; - for (int i = 0; i < 10; i++) - { - GC.Collect(); - GC.WaitForPendingFinalizers(); - } - - GCSettings.LargeObjectHeapCompactionMode = GCLargeObjectHeapCompactionMode.CompactOnce; - GC.Collect(); - } - - private void MaybeCreateTimeoutTask() - { - // If there are no active clients running then the server needs to be in a timeout mode. - if (_connections.Count == 0 && _timeoutTask == null) - { - Debug.Assert(_listenTask != null); - _timeoutTask = Task.Delay(_keepAlive.TimeSpan); - } - } - - private void MaybeCreateGCTask() - { - if (_connections.Count == 0 && _gcTask == null) - { - _gcTask = Task.Delay(GCTimeout); - } - } - - /// - /// Checks the completed connection objects. - /// - /// False if the server needs to begin shutting down - private void HandleCompletedConnections() - { - var shutdown = false; - var processedCount = 0; - var i = 0; - while (i < _connections.Count) - { - var current = _connections[i]; - if (!current.IsCompleted) - { - i++; - continue; - } - - _connections.RemoveAt(i); - processedCount++; - - var result = current.Result; - if (result.KeepAlive.HasValue) - { - var updated = _keepAlive.Update(result.KeepAlive.Value); - if (updated.Equals(_keepAlive)) - { - _eventBus.UpdateKeepAlive(updated.TimeSpan); - } - } - - switch (result.CloseReason) - { - case ConnectionResult.Reason.CompilationCompleted: - case ConnectionResult.Reason.CompilationNotStarted: - // These are all normal end states. Nothing to do here. - break; - - case ConnectionResult.Reason.ClientDisconnect: - // Have to assume the worst here which is user pressing Ctrl+C at the command line and - // hence wanting all compilation to end. - _eventBus.ConnectionRudelyEnded(); - shutdown = true; - break; - - case ConnectionResult.Reason.ClientException: - case ConnectionResult.Reason.ClientShutdownRequest: - _eventBus.ConnectionRudelyEnded(); - shutdown = true; - break; - - default: - throw new InvalidOperationException($"Unexpected enum value {result.CloseReason}"); - } - } - - if (processedCount > 0) - { - _eventBus.ConnectionCompleted(processedCount); - } - - if (shutdown) - { - _state = State.ShuttingDown; - } - } - - internal async Task AcceptConnection(Task task, bool accept, CancellationToken cancellationToken) - { - Connection connection; - try - { - connection = await task; - } - catch (Exception ex) - { - // Unable to establish a connection with the client. The client is responsible for - // handling this case. Nothing else for us to do here. - CompilerServerLogger.LogException(ex, "Error creating client named pipe"); - return new ConnectionResult(ConnectionResult.Reason.CompilationNotStarted); - } - - try - { - using (connection) - { - BuildRequest request; - try - { - CompilerServerLogger.Log("Begin reading request."); - request = await BuildRequest.ReadAsync(connection.Stream, cancellationToken).ConfigureAwait(false); - CompilerServerLogger.Log("End reading request."); - } - catch (Exception e) - { - CompilerServerLogger.LogException(e, "Error reading build request."); - return new ConnectionResult(ConnectionResult.Reason.CompilationNotStarted); - } - - if (request.IsShutdownRequest()) - { - // Reply with the PID of this process so that the client can wait for it to exit. - var response = new ShutdownBuildResponse(Process.GetCurrentProcess().Id); - await response.WriteAsync(connection.Stream, cancellationToken); - - // We can safely disconnect the client, then when this connection gets cleaned up by the event loop - // the server will go to a shutdown state. - return new ConnectionResult(ConnectionResult.Reason.ClientShutdownRequest); - } - else if (!accept) - { - // We're already in shutdown mode, respond gracefully so the client can run in-process. - var response = new RejectedBuildResponse(); - await response.WriteAsync(connection.Stream, cancellationToken).ConfigureAwait(false); - - return new ConnectionResult(ConnectionResult.Reason.CompilationNotStarted); - } - else - { - // If we get here then this is a real request that we will accept and process. - // - // Kick off both the compilation and a task to monitor the pipe for closing. - var buildCancelled = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - - var watcher = connection.WaitForDisconnectAsync(buildCancelled.Token); - var worker = ExecuteRequestAsync(request, buildCancelled.Token); - - // await will end when either the work is complete or the connection is closed. - await Task.WhenAny(worker, watcher); - - // Do an 'await' on the completed task, preference being compilation, to force - // any exceptions to be realized in this method for logging. - ConnectionResult.Reason reason; - if (worker.IsCompleted) - { - var response = await worker; - - try - { - CompilerServerLogger.Log("Begin writing response."); - await response.WriteAsync(connection.Stream, cancellationToken); - CompilerServerLogger.Log("End writing response."); - - reason = ConnectionResult.Reason.CompilationCompleted; - } - catch - { - reason = ConnectionResult.Reason.ClientDisconnect; - } - } - else - { - await watcher; - reason = ConnectionResult.Reason.ClientDisconnect; - } - - // Begin the tear down of the Task which didn't complete. - buildCancelled.Cancel(); - - return new ConnectionResult(reason, request.KeepAlive); - } - } - } - catch (Exception ex) - { - CompilerServerLogger.LogException(ex, "Error handling connection"); - return new ConnectionResult(ConnectionResult.Reason.ClientException); - } - } - - private Task ExecuteRequestAsync(BuildRequest buildRequest, CancellationToken cancellationToken) - { - Func func = () => - { - CompilerServerLogger.Log("Begin processing request"); - - var response = _compilerHost.Execute(buildRequest, cancellationToken); - - CompilerServerLogger.Log("End processing request"); - return response; - }; - - var task = new Task(func, cancellationToken, TaskCreationOptions.LongRunning); - task.Start(); - return task; - } - } - - private struct KeepAlive - { - public TimeSpan TimeSpan; - public bool IsDefault; - - public KeepAlive(TimeSpan timeSpan, bool isDefault) - { - TimeSpan = timeSpan; - IsDefault = isDefault; - } - - public KeepAlive Update(TimeSpan timeSpan) - { - if (IsDefault || timeSpan > TimeSpan) - { - return new KeepAlive(timeSpan, isDefault: false); - } - - return this; - } + return new DefaultRequestDispatcher(connectionHost, compilerHost, cancellationToken, eventBus, keepAlive); } } } diff --git a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/BuildProtocol.cs b/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/BuildProtocol.cs deleted file mode 100644 index f2745e3e58..0000000000 --- a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/BuildProtocol.cs +++ /dev/null @@ -1,586 +0,0 @@ -// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Collections.ObjectModel; -using System.IO; -using System.Linq; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using static Microsoft.CodeAnalysis.CommandLine.BuildProtocolConstants; -using static Microsoft.CodeAnalysis.CommandLine.CompilerServerLogger; - -// This file describes data structures about the protocol from client program to server that is -// used. The basic protocol is this. -// -// After the server pipe is connected, it forks off a thread to handle the connection, and creates -// a new instance of the pipe to listen for new clients. When it gets a request, it validates -// the security and elevation level of the client. If that fails, it disconnects the client. Otherwise, -// it handles the request, sends a response (described by Response class) back to the client, then -// disconnects the pipe and ends the thread. - -namespace Microsoft.CodeAnalysis.CommandLine -{ - /// - /// Represents a request from the client. A request is as follows. - /// - /// Field Name Type Size (bytes) - /// ---------------------------------------------------- - /// Length Integer 4 - /// Argument Count UInteger 4 - /// Arguments Argument[] Variable - /// - /// See for the format of an - /// Argument. - /// - /// - internal class BuildRequest - { - public readonly uint ProtocolVersion; - public readonly ReadOnlyCollection Arguments; - - public BuildRequest(uint protocolVersion, IEnumerable arguments) - { - ProtocolVersion = protocolVersion; - Arguments = new ReadOnlyCollection(arguments.ToList()); - - if (Arguments.Count > ushort.MaxValue) - { - throw new ArgumentOutOfRangeException( - nameof(arguments), - $"Too many arguments: maximum of {ushort.MaxValue} arguments allowed."); - } - } - - public TimeSpan? KeepAlive - { - get - { - TimeSpan? keepAlive = null; - foreach (var argument in Arguments) - { - if (argument.ArgumentId == BuildProtocolConstants.ArgumentId.KeepAlive) - { - // If the value is not a valid integer for any reason,ignore it and continue with the current timeout. - // The client is responsible for validating the argument. - if (int.TryParse(argument.Value, out var result)) - { - // Keep alive times are specified in seconds - keepAlive = TimeSpan.FromSeconds(result); - } - } - } - - return keepAlive; - } - } - - public static BuildRequest Create( - string workingDirectory, - string tempDirectory, - IList args, - string keepAlive = null, - string libDirectory = null) - { - Log("Creating BuildRequest"); - Log($"Working directory: {workingDirectory}"); - Log($"Temp directory: {tempDirectory}"); - Log($"Lib directory: {libDirectory ?? "null"}"); - - var requestLength = args.Count + 1 + (libDirectory == null ? 0 : 1); - var requestArgs = new List(requestLength); - - requestArgs.Add(new Argument(ArgumentId.CurrentDirectory, 0, workingDirectory)); - requestArgs.Add(new Argument(ArgumentId.TempDirectory, 0, tempDirectory)); - - if (keepAlive != null) - { - requestArgs.Add(new Argument(ArgumentId.KeepAlive, 0, keepAlive)); - } - - if (libDirectory != null) - { - requestArgs.Add(new Argument(ArgumentId.LibEnvVariable, 0, libDirectory)); - } - - for (int i = 0; i < args.Count; ++i) - { - var arg = args[i]; - Log($"argument[{i}] = {arg}"); - requestArgs.Add(new Argument(ArgumentId.CommandLineArgument, i, arg)); - } - - return new BuildRequest(BuildProtocolConstants.ProtocolVersion, requestArgs); - } - - public static BuildRequest CreateShutdown() - { - var requestArgs = new[] - { - new Argument(ArgumentId.Shutdown, argumentIndex: 0, value: ""), - new Argument(ArgumentId.CommandLineArgument, argumentIndex: 1, value: "shutdown"), - }; - return new BuildRequest(BuildProtocolConstants.ProtocolVersion, requestArgs); - } - - public bool IsShutdownRequest() - { - return Arguments.Count >= 1 && Arguments[0].ArgumentId == ArgumentId.Shutdown; - } - - /// - /// Read a Request from the given stream. - /// - /// The total request size must be less than 1MB. - /// - /// null if the Request was too large, the Request otherwise. - public static async Task ReadAsync(Stream inStream, CancellationToken cancellationToken) - { - // Read the length of the request - var lengthBuffer = new byte[4]; - Log("Reading length of request"); - await ReadAllAsync(inStream, lengthBuffer, 4, cancellationToken).ConfigureAwait(false); - var length = BitConverter.ToInt32(lengthBuffer, 0); - - // Back out if the request is > 1MB - if (length > 0x100000) - { - Log("Request is over 1MB in length, cancelling read."); - return null; - } - - cancellationToken.ThrowIfCancellationRequested(); - - // Read the full request - var requestBuffer = new byte[length]; - await ReadAllAsync(inStream, requestBuffer, length, cancellationToken).ConfigureAwait(false); - - cancellationToken.ThrowIfCancellationRequested(); - - Log("Parsing request"); - // Parse the request into the Request data structure. - using (var reader = new BinaryReader(new MemoryStream(requestBuffer), Encoding.Unicode)) - { - var protocolVersion = reader.ReadUInt32(); - uint argumentCount = reader.ReadUInt32(); - - var argumentsBuilder = new List((int)argumentCount); - - for (int i = 0; i < argumentCount; i++) - { - cancellationToken.ThrowIfCancellationRequested(); - argumentsBuilder.Add(BuildRequest.Argument.ReadFromBinaryReader(reader)); - } - - return new BuildRequest(protocolVersion, argumentsBuilder); - } - } - - /// - /// Write a Request to the stream. - /// - public async Task WriteAsync(Stream outStream, CancellationToken cancellationToken = default(CancellationToken)) - { - using (var memoryStream = new MemoryStream()) - using (var writer = new BinaryWriter(memoryStream, Encoding.Unicode)) - { - // Format the request. - Log("Formatting request"); - writer.Write(ProtocolVersion); - writer.Write(Arguments.Count); - foreach (Argument arg in Arguments) - { - cancellationToken.ThrowIfCancellationRequested(); - arg.WriteToBinaryWriter(writer); - } - writer.Flush(); - - cancellationToken.ThrowIfCancellationRequested(); - - // Write the length of the request - int length = checked((int)memoryStream.Length); - - // Back out if the request is > 1 MB - if (memoryStream.Length > 0x100000) - { - Log("Request is over 1MB in length, cancelling write"); - throw new ArgumentOutOfRangeException(); - } - - // Send the request to the server - Log("Writing length of request."); - await outStream.WriteAsync(BitConverter.GetBytes(length), 0, 4, - cancellationToken).ConfigureAwait(false); - - Log("Writing request of size {0}", length); - // Write the request - memoryStream.Position = 0; - await memoryStream.CopyToAsync(outStream, bufferSize: length, cancellationToken: cancellationToken).ConfigureAwait(false); - } - } - - /// - /// A command line argument to the compilation. - /// An argument is formatted as follows: - /// - /// Field Name Type Size (bytes) - /// -------------------------------------------------- - /// ID UInteger 4 - /// Index UInteger 4 - /// Value String Variable - /// - /// Strings are encoded via a length prefix as a signed - /// 32-bit integer, followed by an array of characters. - /// - public struct Argument - { - public readonly ArgumentId ArgumentId; - public readonly int ArgumentIndex; - public readonly string Value; - - public Argument(ArgumentId argumentId, - int argumentIndex, - string value) - { - ArgumentId = argumentId; - ArgumentIndex = argumentIndex; - Value = value; - } - - public static Argument ReadFromBinaryReader(BinaryReader reader) - { - var argId = (ArgumentId)reader.ReadInt32(); - var argIndex = reader.ReadInt32(); - string value = ReadLengthPrefixedString(reader); - return new Argument(argId, argIndex, value); - } - - public void WriteToBinaryWriter(BinaryWriter writer) - { - writer.Write((int)ArgumentId); - writer.Write(ArgumentIndex); - WriteLengthPrefixedString(writer, Value); - } - } - } - - /// - /// Base class for all possible responses to a request. - /// The ResponseType enum should list all possible response types - /// and ReadResponse creates the appropriate response subclass based - /// on the response type sent by the client. - /// The format of a response is: - /// - /// Field Name Field Type Size (bytes) - /// ------------------------------------------------- - /// responseLength int (positive) 4 - /// responseType enum ResponseType 4 - /// responseBody Response subclass variable - /// - internal abstract class BuildResponse - { - public enum ResponseType - { - // The client and server are using incompatible protocol versions. - MismatchedVersion, - - // The build request completed on the server and the results are contained - // in the message. - Completed, - - // The build request could not be run on the server due because it created - // an unresolvable inconsistency with analyzers. - AnalyzerInconsistency, - - // The shutdown request completed and the server process information is - // contained in the message. - Shutdown, - - // The request was rejected by the server. - Rejected, - } - - public abstract ResponseType Type { get; } - - public async Task WriteAsync(Stream outStream, - CancellationToken cancellationToken) - { - using (var memoryStream = new MemoryStream()) - using (var writer = new BinaryWriter(memoryStream, Encoding.Unicode)) - { - // Format the response - Log("Formatting Response"); - writer.Write((int)Type); - - AddResponseBody(writer); - writer.Flush(); - - cancellationToken.ThrowIfCancellationRequested(); - - // Send the response to the client - - // Write the length of the response - int length = checked((int)memoryStream.Length); - - Log("Writing response length"); - // There is no way to know the number of bytes written to - // the pipe stream. We just have to assume all of them are written. - await outStream.WriteAsync(BitConverter.GetBytes(length), - 0, - 4, - cancellationToken).ConfigureAwait(false); - - // Write the response - Log("Writing response of size {0}", length); - memoryStream.Position = 0; - await memoryStream.CopyToAsync(outStream, bufferSize: length, cancellationToken: cancellationToken).ConfigureAwait(false); - } - } - - protected abstract void AddResponseBody(BinaryWriter writer); - - /// - /// May throw exceptions if there are pipe problems. - /// - /// - /// - /// - public static async Task ReadAsync(Stream stream, CancellationToken cancellationToken = default(CancellationToken)) - { - Log("Reading response length"); - // Read the response length - var lengthBuffer = new byte[4]; - await ReadAllAsync(stream, lengthBuffer, 4, cancellationToken).ConfigureAwait(false); - var length = BitConverter.ToUInt32(lengthBuffer, 0); - - // Read the response - Log("Reading response of length {0}", length); - var responseBuffer = new byte[length]; - await ReadAllAsync(stream, - responseBuffer, - responseBuffer.Length, - cancellationToken).ConfigureAwait(false); - - using (var reader = new BinaryReader(new MemoryStream(responseBuffer), Encoding.Unicode)) - { - var responseType = (ResponseType)reader.ReadInt32(); - - switch (responseType) - { - case ResponseType.Completed: - return CompletedBuildResponse.Create(reader); - case ResponseType.MismatchedVersion: - return new MismatchedVersionBuildResponse(); - case ResponseType.AnalyzerInconsistency: - return new AnalyzerInconsistencyBuildResponse(); - case ResponseType.Shutdown: - return ShutdownBuildResponse.Create(reader); - case ResponseType.Rejected: - return new RejectedBuildResponse(); - default: - throw new InvalidOperationException("Received invalid response type from server."); - } - } - } - } - - /// - /// Represents a Response from the server. A response is as follows. - /// - /// Field Name Type Size (bytes) - /// -------------------------------------------------- - /// Length UInteger 4 - /// ReturnCode Integer 4 - /// Output String Variable - /// ErrorOutput String Variable - /// - /// Strings are encoded via a character count prefix as a - /// 32-bit integer, followed by an array of characters. - /// - /// - internal sealed class CompletedBuildResponse : BuildResponse - { - public readonly int ReturnCode; - public readonly bool Utf8Output; - public readonly string Output; - public readonly string ErrorOutput; - - public CompletedBuildResponse(int returnCode, - bool utf8output, - string output) - { - ReturnCode = returnCode; - Utf8Output = utf8output; - Output = output; - - // This field existed to support writing to Console.Error. The compiler doesn't ever write to - // this field or Console.Error. This field is only kept around in order to maintain the existing - // protocol semantics. - ErrorOutput = string.Empty; - } - - public override ResponseType Type => ResponseType.Completed; - - public static CompletedBuildResponse Create(BinaryReader reader) - { - var returnCode = reader.ReadInt32(); - var utf8Output = reader.ReadBoolean(); - var output = ReadLengthPrefixedString(reader); - var errorOutput = ReadLengthPrefixedString(reader); - if (!string.IsNullOrEmpty(errorOutput)) - { - throw new InvalidOperationException(); - } - - return new CompletedBuildResponse(returnCode, utf8Output, output); - } - - protected override void AddResponseBody(BinaryWriter writer) - { - writer.Write(ReturnCode); - writer.Write(Utf8Output); - WriteLengthPrefixedString(writer, Output); - WriteLengthPrefixedString(writer, ErrorOutput); - } - } - - internal sealed class ShutdownBuildResponse : BuildResponse - { - public readonly int ServerProcessId; - - public ShutdownBuildResponse(int serverProcessId) - { - ServerProcessId = serverProcessId; - } - - public override ResponseType Type => ResponseType.Shutdown; - - protected override void AddResponseBody(BinaryWriter writer) - { - writer.Write(ServerProcessId); - } - - public static ShutdownBuildResponse Create(BinaryReader reader) - { - var serverProcessId = reader.ReadInt32(); - return new ShutdownBuildResponse(serverProcessId); - } - } - - internal sealed class MismatchedVersionBuildResponse : BuildResponse - { - public override ResponseType Type => ResponseType.MismatchedVersion; - - /// - /// MismatchedVersion has no body. - /// - protected override void AddResponseBody(BinaryWriter writer) { } - } - - internal sealed class AnalyzerInconsistencyBuildResponse : BuildResponse - { - public override ResponseType Type => ResponseType.AnalyzerInconsistency; - - /// - /// AnalyzerInconsistency has no body. - /// - /// - protected override void AddResponseBody(BinaryWriter writer) { } - } - - internal sealed class RejectedBuildResponse : BuildResponse - { - public override ResponseType Type => ResponseType.Rejected; - - /// - /// AnalyzerInconsistency has no body. - /// - /// - protected override void AddResponseBody(BinaryWriter writer) { } - } - - /// - /// Constants about the protocol. - /// - internal static class BuildProtocolConstants - { - /// - /// The version number for this protocol. - /// - public const uint ProtocolVersion = 2; - - // Arguments for CSharp and VB Compiler - public enum ArgumentId - { - // The current directory of the client - CurrentDirectory = 0x51147221, - - // A comment line argument. The argument index indicates which one (0 .. N) - CommandLineArgument, - - // The "LIB" environment variable of the client - LibEnvVariable, - - // Request a longer keep alive time for the server - KeepAlive, - - // Request a server shutdown from the client - Shutdown, - - // The directory to use for temporary operations. - TempDirectory, - } - - /// - /// Read a string from the Reader where the string is encoded - /// as a length prefix (signed 32-bit integer) followed by - /// a sequence of characters. - /// - public static string ReadLengthPrefixedString(BinaryReader reader) - { - var length = reader.ReadInt32(); - return new String(reader.ReadChars(length)); - } - - /// - /// Write a string to the Writer where the string is encoded - /// as a length prefix (signed 32-bit integer) follows by - /// a sequence of characters. - /// - public static void WriteLengthPrefixedString(BinaryWriter writer, string value) - { - writer.Write(value.Length); - writer.Write(value.ToCharArray()); - } - - /// - /// This task does not complete until we are completely done reading. - /// - internal static async Task ReadAllAsync( - Stream stream, - byte[] buffer, - int count, - CancellationToken cancellationToken) - { - int totalBytesRead = 0; - do - { - Log("Attempting to read {0} bytes from the stream", - count - totalBytesRead); - int bytesRead = await stream.ReadAsync(buffer, - totalBytesRead, - count - totalBytesRead, - cancellationToken).ConfigureAwait(false); - if (bytesRead == 0) - { - Log("Unexpected -- read 0 bytes from the stream."); - throw new EndOfStreamException("Reached end of stream before end of read."); - } - Log("Read {0} bytes", bytesRead); - totalBytesRead += bytesRead; - } while (totalBytesRead < count); - Log("Finished read"); - } - } -} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/BuildServerConnection.cs b/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/BuildServerConnection.cs deleted file mode 100644 index 2a431d27e8..0000000000 --- a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/BuildServerConnection.cs +++ /dev/null @@ -1,520 +0,0 @@ -// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.ComponentModel; -using System.Diagnostics; -using System.IO; -using System.IO.Pipes; -using System.Reflection; -using System.Runtime.InteropServices; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Razor.Tools; -using Microsoft.Win32.SafeHandles; -using Roslyn.Utilities; -using static Microsoft.CodeAnalysis.CommandLine.CompilerServerLogger; -using static Microsoft.CodeAnalysis.CommandLine.NativeMethods; - -namespace Microsoft.CodeAnalysis.CommandLine -{ - internal struct BuildPathsAlt - { - /// - /// The path which contains the compiler binaries and response files. - /// - internal string ClientDirectory { get; } - - /// - /// The path in which the compilation takes place. - /// - internal string WorkingDirectory { get; } - - /// - /// The path which contains mscorlib. This can be null when specified by the user or running in a - /// CoreClr environment. - /// - internal string SdkDirectory { get; } - - /// - /// The temporary directory a compilation should use instead of . The latter - /// relies on global state individual compilations should ignore. - /// - internal string TempDirectory { get; } - - internal BuildPathsAlt(string clientDir, string workingDir, string sdkDir, string tempDir) - { - ClientDirectory = clientDir; - WorkingDirectory = workingDir; - SdkDirectory = sdkDir; - TempDirectory = tempDir; - } - } - - internal sealed class BuildServerConnection - { - internal const string ServerNameCoreClr = "rzc.dll"; - - // Spend up to 1s connecting to existing process (existing processes should be always responsive). - internal const int TimeOutMsExistingProcess = 1000; - - // Spend up to 20s connecting to a new process, to allow time for it to start. - internal const int TimeOutMsNewProcess = 20000; - - public static Task RunServerCompilation( - List arguments, - BuildPathsAlt buildPaths, - string keepAlive, - CancellationToken cancellationToken) - { - var pipeName = PipeName.ComputeDefault(); - - return RunServerCompilationCore( - arguments, - buildPaths, - pipeName: pipeName, - keepAlive: keepAlive, - libEnvVariable: null, - timeoutOverride: null, - tryCreateServerFunc: TryCreateServerCore, - cancellationToken: cancellationToken); - } - - internal static async Task RunServerCompilationCore( - List arguments, - BuildPathsAlt buildPaths, - string pipeName, - string keepAlive, - string libEnvVariable, - int? timeoutOverride, - Func tryCreateServerFunc, - CancellationToken cancellationToken) - { - if (pipeName == null) - { - return new RejectedBuildResponse(); - } - - if (buildPaths.TempDirectory == null) - { - return new RejectedBuildResponse(); - } - - var clientDir = buildPaths.ClientDirectory; - var timeoutNewProcess = timeoutOverride ?? TimeOutMsNewProcess; - var timeoutExistingProcess = timeoutOverride ?? TimeOutMsExistingProcess; - var clientMutexName = MutexName.GetClientMutexName(pipeName); - Task pipeTask = null; - using (var clientMutex = new Mutex(initiallyOwned: true, - name: clientMutexName, - createdNew: out var holdsMutex)) - { - try - { - if (!holdsMutex) - { - try - { - holdsMutex = clientMutex.WaitOne(timeoutNewProcess); - - if (!holdsMutex) - { - return new RejectedBuildResponse(); - } - } - catch (AbandonedMutexException) - { - holdsMutex = true; - } - } - - // Check for an already running server - var serverMutexName = MutexName.GetServerMutexName(pipeName); - bool wasServerRunning = WasServerMutexOpen(serverMutexName); - var timeout = wasServerRunning ? timeoutExistingProcess : timeoutNewProcess; - - if (wasServerRunning || tryCreateServerFunc(clientDir, pipeName)) - { - pipeTask = TryConnectToServerAsync(pipeName, timeout, cancellationToken); - } - } - finally - { - if (holdsMutex) - { - clientMutex.ReleaseMutex(); - } - } - } - - if (pipeTask != null) - { - var pipe = await pipeTask.ConfigureAwait(false); - if (pipe != null) - { - var request = BuildRequest.Create( - buildPaths.WorkingDirectory, - buildPaths.TempDirectory, - arguments, - keepAlive, - libEnvVariable); - - return await TryCompile(pipe, request, cancellationToken).ConfigureAwait(false); - } - } - - return new RejectedBuildResponse(); - } - - internal static bool WasServerMutexOpen(string mutexName) - { - Mutex mutex; - var open = Mutex.TryOpenExisting(mutexName, out mutex); - if (open) - { - mutex.Dispose(); - return true; - } - return false; - } - - /// - /// Try to compile using the server. Returns a null-containing Task if a response - /// from the server cannot be retrieved. - /// - private static async Task TryCompile(NamedPipeClientStream pipeStream, - BuildRequest request, - CancellationToken cancellationToken) - { - BuildResponse response; - using (pipeStream) - { - // Write the request - try - { - Log("Begin writing request"); - await request.WriteAsync(pipeStream, cancellationToken).ConfigureAwait(false); - Log("End writing request"); - } - catch (Exception e) - { - LogException(e, "Error writing build request."); - return new RejectedBuildResponse(); - } - - // Wait for the compilation and a monitor to detect if the server disconnects - var serverCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - - Log("Begin reading response"); - - var responseTask = BuildResponse.ReadAsync(pipeStream, serverCts.Token); - var monitorTask = CreateMonitorDisconnectTask(pipeStream, "client", serverCts.Token); - await Task.WhenAny(responseTask, monitorTask).ConfigureAwait(false); - - Log("End reading response"); - - if (responseTask.IsCompleted) - { - // await the task to log any exceptions - try - { - response = await responseTask.ConfigureAwait(false); - } - catch (Exception e) - { - LogException(e, "Error reading response"); - response = new RejectedBuildResponse(); - } - } - else - { - Log("Server disconnect"); - response = new RejectedBuildResponse(); - } - - // Cancel whatever task is still around - serverCts.Cancel(); - Debug.Assert(response != null); - return response; - } - } - - /// - /// The IsConnected property on named pipes does not detect when the client has disconnected - /// if we don't attempt any new I/O after the client disconnects. We start an async I/O here - /// which serves to check the pipe for disconnection. - /// - internal static async Task CreateMonitorDisconnectTask( - PipeStream pipeStream, - string identifier = null, - CancellationToken cancellationToken = default(CancellationToken)) - { - var buffer = Array.Empty(); - - while (!cancellationToken.IsCancellationRequested && pipeStream.IsConnected) - { - // Wait a tenth of a second before trying again - await Task.Delay(100, cancellationToken).ConfigureAwait(false); - - try - { - Log($"Before poking pipe {identifier}."); - await pipeStream.ReadAsync(buffer, 0, 0, cancellationToken).ConfigureAwait(false); - Log($"After poking pipe {identifier}."); - } - catch (OperationCanceledException) - { - } - catch (Exception e) - { - // It is okay for this call to fail. Errors will be reflected in the - // IsConnected property which will be read on the next iteration of the - LogException(e, $"Error poking pipe {identifier}."); - } - } - } - - /// - /// Connect to the pipe for a given directory and return it. - /// Throws on cancellation. - /// - /// Name of the named pipe to connect to. - /// Timeout to allow in connecting to process. - /// Cancellation token to cancel connection to server. - /// - /// An open to the server process or null on failure. - /// - internal static async Task TryConnectToServerAsync( - string pipeName, - int timeoutMs, - CancellationToken cancellationToken) - { - NamedPipeClientStream pipeStream; - try - { - // Machine-local named pipes are named "\\.\pipe\". - // We use the SHA1 of the directory the compiler exes live in as the pipe name. - // The NamedPipeClientStream class handles the "\\.\pipe\" part for us. - Log("Attempt to open named pipe '{0}'", pipeName); - - pipeStream = new NamedPipeClientStream(".", pipeName, PipeDirection.InOut, PipeOptions.Asynchronous); - cancellationToken.ThrowIfCancellationRequested(); - - Log("Attempt to connect named pipe '{0}'", pipeName); - try - { - await pipeStream.ConnectAsync(timeoutMs, cancellationToken).ConfigureAwait(false); - } - catch (Exception e) when (e is IOException || e is TimeoutException) - { - // Note: IOException can also indicate timeout. From docs: - // TimeoutException: Could not connect to the server within the - // specified timeout period. - // IOException: The server is connected to another client and the - // time-out period has expired. - - Log($"Connecting to server timed out after {timeoutMs} ms"); - return null; - } - Log("Named pipe '{0}' connected", pipeName); - - cancellationToken.ThrowIfCancellationRequested(); - - // Verify that we own the pipe. - if (!CheckPipeConnectionOwnership(pipeStream)) - { - Log("Owner of named pipe is incorrect"); - return null; - } - - return pipeStream; - } - catch (Exception e) when (!(e is TaskCanceledException || e is OperationCanceledException)) - { - LogException(e, "Exception while connecting to process"); - return null; - } - } - - internal static bool TryCreateServerCore(string clientDir, string pipeName) - { - string expectedPath; - string processArguments; - - // The server should be in the same directory as the client - var expectedCompilerPath = Path.Combine(clientDir, ServerNameCoreClr); - expectedPath = Environment.GetEnvironmentVariable("DOTNET_HOST_PATH") ?? "dotnet"; - processArguments = $@"""{expectedCompilerPath}"" server -p {pipeName}"; - - if (!File.Exists(expectedCompilerPath)) - { - return false; - } - - if (PlatformInformation.IsWindows) - { - // As far as I can tell, there isn't a way to use the Process class to - // create a process with no stdin/stdout/stderr, so we use P/Invoke. - // This code was taken from MSBuild task starting code. - - STARTUPINFO startInfo = new STARTUPINFO(); - startInfo.cb = Marshal.SizeOf(startInfo); - startInfo.hStdError = InvalidIntPtr; - startInfo.hStdInput = InvalidIntPtr; - startInfo.hStdOutput = InvalidIntPtr; - startInfo.dwFlags = STARTF_USESTDHANDLES; - uint dwCreationFlags = NORMAL_PRIORITY_CLASS | CREATE_NO_WINDOW; - - PROCESS_INFORMATION processInfo; - - Log("Attempting to create process '{0}'", expectedPath); - - var builder = new StringBuilder($@"""{expectedPath}"" {processArguments}"); - - bool success = CreateProcess( - lpApplicationName: null, - lpCommandLine: builder, - lpProcessAttributes: NullPtr, - lpThreadAttributes: NullPtr, - bInheritHandles: false, - dwCreationFlags: dwCreationFlags, - lpEnvironment: NullPtr, // Inherit environment - lpCurrentDirectory: clientDir, - lpStartupInfo: ref startInfo, - lpProcessInformation: out processInfo); - - if (success) - { - Log("Successfully created process with process id {0}", processInfo.dwProcessId); - CloseHandle(processInfo.hProcess); - CloseHandle(processInfo.hThread); - } - else - { - Log("Failed to create process. GetLastError={0}", Marshal.GetLastWin32Error()); - } - return success; - } - else - { - try - { - var startInfo = new ProcessStartInfo() - { - FileName = expectedPath, - Arguments = processArguments, - UseShellExecute = false, - WorkingDirectory = clientDir, - RedirectStandardInput = true, - RedirectStandardOutput = true, - RedirectStandardError = true, - CreateNoWindow = true - }; - - Process.Start(startInfo); - return true; - } - catch - { - return false; - } - } - } - - /// - /// Check to ensure that the named pipe server we connected to is owned by the same - /// user. - /// - /// - /// The type is embedded in assemblies that need to run cross platform. While this particular - /// code will never be hit when running on non-Windows platforms it does need to work when - /// on Windows. To facilitate that we use reflection to make the check here to enable it to - /// compile into our cross plat assemblies. - /// - private static bool CheckPipeConnectionOwnership(NamedPipeClientStream pipeStream) - { - return true; - } - -#if NETSTANDARD1_3 - internal static bool CheckIdentityUnix(PipeStream stream) - { - // Identity verification is unavailable in the MSBuild task, - // but verification is not needed client-side so that's okay. - // (unavailable due to lack of internal reflection capabilities in netstandard1.3) - return true; - } -#else - [DllImport("System.Native", EntryPoint = "SystemNative_GetEUid")] - private static extern uint GetEUid(); - - [DllImport("System.Native", EntryPoint = "SystemNative_GetPeerID", SetLastError = true)] - private static extern int GetPeerID(SafeHandle socket, out uint euid); - - internal static bool CheckIdentityUnix(PipeStream stream) - { - var flags = BindingFlags.Instance | BindingFlags.NonPublic; - var handle = (SafePipeHandle)typeof(PipeStream).GetField("_handle", flags).GetValue(stream); - var handle2 = (SafeHandle)typeof(SafePipeHandle).GetField("_namedPipeSocketHandle", flags).GetValue(handle); - - uint myID = GetEUid(); - - if (GetPeerID(handle, out uint peerID) == -1) - { - throw new Win32Exception(Marshal.GetLastWin32Error()); - } - - return myID == peerID; - } -#endif - - /// - /// Gets the value of the temporary path for the current environment assuming the working directory - /// is . This function must emulate as - /// closely as possible. - /// - public static string GetTempPath(string workingDir) - { - if (PlatformInformation.IsUnix) - { - // Unix temp path is fine: it does not use the working directory - // (it uses ${TMPDIR} if set, otherwise, it returns /tmp) - return Path.GetTempPath(); - } - - var tmp = Environment.GetEnvironmentVariable("TMP"); - if (Path.IsPathRooted(tmp)) - { - return tmp; - } - - var temp = Environment.GetEnvironmentVariable("TEMP"); - if (Path.IsPathRooted(temp)) - { - return temp; - } - - if (!string.IsNullOrEmpty(workingDir)) - { - if (!string.IsNullOrEmpty(tmp)) - { - return Path.Combine(workingDir, tmp); - } - - if (!string.IsNullOrEmpty(temp)) - { - return Path.Combine(workingDir, temp); - } - } - - var userProfile = Environment.GetEnvironmentVariable("USERPROFILE"); - if (Path.IsPathRooted(userProfile)) - { - return userProfile; - } - - return Environment.GetEnvironmentVariable("SYSTEMROOT"); - } - } -} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CommandLineUtilities.cs b/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CommandLineUtilities.cs index caaa38d52f..441322aaab 100644 --- a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CommandLineUtilities.cs +++ b/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CommandLineUtilities.cs @@ -48,8 +48,7 @@ namespace Roslyn.Utilities /// public static IEnumerable SplitCommandLineIntoArguments(string commandLine, bool removeHashComments) { - char? unused; - return SplitCommandLineIntoArguments(commandLine, removeHashComments, out unused); + return SplitCommandLineIntoArguments(commandLine, removeHashComments, out var unused); } public static IEnumerable SplitCommandLineIntoArguments(string commandLine, bool removeHashComments, out char? illegalChar) diff --git a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CommonCompiler.cs b/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CommonCompiler.cs deleted file mode 100644 index ccfbddc826..0000000000 --- a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CommonCompiler.cs +++ /dev/null @@ -1,12 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Microsoft.CodeAnalysis -{ - class CommonCompiler - { - internal const int Failed = 1; - internal const int Succeeded = 0; - } -} diff --git a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CompilerRequestHandler.cs b/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CompilerRequestHandler.cs deleted file mode 100644 index 855f5e5bb2..0000000000 --- a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CompilerRequestHandler.cs +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Diagnostics; -using System.Globalization; -using System.IO; -using System.Runtime.InteropServices; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.CodeAnalysis.CommandLine; - -using static Microsoft.CodeAnalysis.CommandLine.CompilerServerLogger; - -namespace Microsoft.CodeAnalysis.CompilerServer -{ - internal struct RunRequest - { - public string Language { get; } - public string CurrentDirectory { get; } - public string TempDirectory { get; } - public string LibDirectory { get; } - public string[] Arguments { get; } - - public RunRequest(string language, string currentDirectory, string tempDirectory, string libDirectory, string[] arguments) - { - Language = language; - CurrentDirectory = currentDirectory; - TempDirectory = tempDirectory; - LibDirectory = libDirectory; - Arguments = arguments; - } - } - - internal abstract class CompilerServerHost : ICompilerServerHost - { - public abstract IAnalyzerAssemblyLoader AnalyzerAssemblyLoader { get; } - - public abstract Func AssemblyReferenceProvider { get; } - - /// - /// Directory that contains the compiler executables and the response files. - /// - public string ClientDirectory { get; } - - /// - /// Directory that contains mscorlib. Can be null when the host is executing in a CoreCLR context. - /// - public string SdkDirectory { get; } - - protected CompilerServerHost(string clientDirectory, string sdkDirectory) - { - ClientDirectory = clientDirectory; - SdkDirectory = sdkDirectory; - } - - public abstract bool CheckAnalyzers(string baseDirectory, ImmutableArray analyzers); - - public bool TryCreateCompiler(RunRequest request, out CommonCompiler compiler) - { - compiler = null; - return false; - } - - public BuildResponse RunCompilation(RunRequest request, CancellationToken cancellationToken) - { - return null; - } - } -} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CompilerServerLogger.cs b/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CompilerServerLogger.cs index d2871a8ba1..c2c95f00bf 100644 --- a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CompilerServerLogger.cs +++ b/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/CompilerServerLogger.cs @@ -22,7 +22,7 @@ namespace Microsoft.CodeAnalysis.CommandLine internal class CompilerServerLogger { // Environment variable, if set, to enable logging and set the file to log to. - private const string environmentVariable = "RoslynCommandLineLogFile"; + private const string EnvironmentVariable = "RAZORBUILDSERVER_LOG"; private static readonly Stream s_loggingStream; private static string s_prefix = "---"; @@ -37,7 +37,7 @@ namespace Microsoft.CodeAnalysis.CommandLine try { // Check if the environment - string loggingFileName = Environment.GetEnvironmentVariable(environmentVariable); + string loggingFileName = Environment.GetEnvironmentVariable(EnvironmentVariable); if (loggingFileName != null) { diff --git a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/DiagnosticListener.cs b/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/DiagnosticListener.cs deleted file mode 100644 index a5c6bf3126..0000000000 --- a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/DiagnosticListener.cs +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace Microsoft.CodeAnalysis.CompilerServer -{ - internal interface IDiagnosticListener - { - /// - /// Called when the server updates the keep alive value. - /// - void UpdateKeepAlive(TimeSpan timeSpan); - - /// - /// Called each time the server listens for new connections. - /// - void ConnectionListening(); - - /// - /// Called when a connection to the server occurs. - /// - void ConnectionReceived(); - - /// - /// Called when one or more connections have completed processing. The number of connections - /// processed is provided in . - /// - void ConnectionCompleted(int count); - - /// - /// Called when a bad client connection was detected and the server will be shutting down as a - /// result. - /// - void ConnectionRudelyEnded(); - - /// - /// Called when the server is shutting down because the keep alive timeout was reached. - /// - void KeepAliveReached(); - } - - internal sealed class EmptyDiagnosticListener : IDiagnosticListener - { - public void UpdateKeepAlive(TimeSpan timeSpan) - { - } - - public void ConnectionListening() - { - } - - public void ConnectionReceived() - { - } - - public void ConnectionCompleted(int count) - { - } - - public void ConnectionRudelyEnded() - { - } - - public void KeepAliveReached() - { - } - } -} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/ICompilerServerHost.cs b/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/ICompilerServerHost.cs deleted file mode 100644 index 710a86ba96..0000000000 --- a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/ICompilerServerHost.cs +++ /dev/null @@ -1,10 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Microsoft.CodeAnalysis.CompilerServer -{ - class ICompilerServerHost - { - } -} diff --git a/src/Microsoft.AspNetCore.Razor.Tools/ServerCommand.cs b/src/Microsoft.AspNetCore.Razor.Tools/ServerCommand.cs index 3e8f6e3602..f44487ab25 100644 --- a/src/Microsoft.AspNetCore.Razor.Tools/ServerCommand.cs +++ b/src/Microsoft.AspNetCore.Razor.Tools/ServerCommand.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.CommandLineUtils; @@ -13,10 +14,13 @@ namespace Microsoft.AspNetCore.Razor.Tools : base(parent, "server") { Pipe = Option("-p|--pipe", "name of named pipe", CommandOptionType.SingleValue); + KeepAlive = Option("-k|--keep-alive", "sets the default idle timeout for the server in seconds", CommandOptionType.SingleValue); } public CommandOption Pipe { get; } + public CommandOption KeepAlive { get; } + protected override bool ValidateArguments() { if (string.IsNullOrEmpty(Pipe.Value())) @@ -41,10 +45,20 @@ namespace Microsoft.AspNetCore.Razor.Tools try { + TimeSpan? keepAlive = null; + if (KeepAlive.HasValue()) + { + var value = KeepAlive.Value(); + if (int.TryParse(value, out var result)) + { + // Keep alive times are specified in seconds + keepAlive = TimeSpan.FromSeconds(result); + } + } + var host = ConnectionHost.Create(Pipe.Value()); var compilerHost = CompilerHost.Create(); - var dispatcher = RequestDispatcher.Create(host, compilerHost, Cancelled); - dispatcher.Run(); + ExecuteServerCore(host, compilerHost, Cancelled, eventBus: null, keepAlive: keepAlive); } finally { @@ -54,5 +68,11 @@ namespace Microsoft.AspNetCore.Razor.Tools return Task.FromResult(0); } + + protected virtual void ExecuteServerCore(ConnectionHost host, CompilerHost compilerHost, CancellationToken cancellationToken, EventBus eventBus, TimeSpan? keepAlive) + { + var dispatcher = RequestDispatcher.Create(host, compilerHost, cancellationToken, eventBus, keepAlive); + dispatcher.Run(); + } } } diff --git a/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/CompletedServerResponse.cs b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/CompletedServerResponse.cs new file mode 100644 index 0000000000..49a9c5a33a --- /dev/null +++ b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/CompletedServerResponse.cs @@ -0,0 +1,66 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + /// + /// Represents a Response from the server. A response is as follows. + /// + /// Field Name Type Size (bytes) + /// -------------------------------------------------- + /// Length UInteger 4 + /// ReturnCode Integer 4 + /// Output String Variable + /// ErrorOutput String Variable + /// + /// Strings are encoded via a character count prefix as a + /// 32-bit integer, followed by an array of characters. + /// + /// + internal sealed class CompletedServerResponse : ServerResponse + { + public readonly int ReturnCode; + public readonly bool Utf8Output; + public readonly string Output; + public readonly string ErrorOutput; + + public CompletedServerResponse(int returnCode, bool utf8output, string output) + { + ReturnCode = returnCode; + Utf8Output = utf8output; + Output = output; + + // This field existed to support writing to Console.Error. The compiler doesn't ever write to + // this field or Console.Error. This field is only kept around in order to maintain the existing + // protocol semantics. + ErrorOutput = string.Empty; + } + + public override ResponseType Type => ResponseType.Completed; + + public static CompletedServerResponse Create(BinaryReader reader) + { + var returnCode = reader.ReadInt32(); + var utf8Output = reader.ReadBoolean(); + var output = ServerProtocol.ReadLengthPrefixedString(reader); + var errorOutput = ServerProtocol.ReadLengthPrefixedString(reader); + if (!string.IsNullOrEmpty(errorOutput)) + { + throw new InvalidOperationException(); + } + + return new CompletedServerResponse(returnCode, utf8Output, output); + } + + protected override void AddResponseBody(BinaryWriter writer) + { + writer.Write(ReturnCode); + writer.Write(Utf8Output); + ServerProtocol.WriteLengthPrefixedString(writer, Output); + ServerProtocol.WriteLengthPrefixedString(writer, ErrorOutput); + } + } +} diff --git a/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/MismatchedVersionServerResponse.cs b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/MismatchedVersionServerResponse.cs new file mode 100644 index 0000000000..57293797e9 --- /dev/null +++ b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/MismatchedVersionServerResponse.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + internal sealed class MismatchedVersionServerResponse : ServerResponse + { + public override ResponseType Type => ResponseType.MismatchedVersion; + + /// + /// MismatchedVersion has no body. + /// + protected override void AddResponseBody(BinaryWriter writer) { } + } +} diff --git a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/NativeMethods.cs b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/NativeMethods.cs similarity index 98% rename from src/Microsoft.AspNetCore.Razor.Tools/Roslyn/NativeMethods.cs rename to src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/NativeMethods.cs index f76afa8580..00f0610842 100644 --- a/src/Microsoft.AspNetCore.Razor.Tools/Roslyn/NativeMethods.cs +++ b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/NativeMethods.cs @@ -4,7 +4,7 @@ using System; using System.Runtime.InteropServices; using System.Text; -namespace Microsoft.CodeAnalysis.CommandLine +namespace Microsoft.AspNetCore.Razor.Tools { [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] internal struct STARTUPINFO diff --git a/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/RejectedServerResponse.cs b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/RejectedServerResponse.cs new file mode 100644 index 0000000000..2b6e3e894b --- /dev/null +++ b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/RejectedServerResponse.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + internal sealed class RejectedServerResponse : ServerResponse + { + public override ResponseType Type => ResponseType.Rejected; + + /// + /// RejectedResponse has no body. + /// + protected override void AddResponseBody(BinaryWriter writer) { } + } +} diff --git a/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/RequestArgument.cs b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/RequestArgument.cs new file mode 100644 index 0000000000..fd15c40686 --- /dev/null +++ b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/RequestArgument.cs @@ -0,0 +1,67 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + /// + /// A command line argument to the compilation. + /// An argument is formatted as follows: + /// + /// Field Name Type Size (bytes) + /// -------------------------------------------------- + /// ID UInteger 4 + /// Index UInteger 4 + /// Value String Variable + /// + /// Strings are encoded via a length prefix as a signed + /// 32-bit integer, followed by an array of characters. + /// + internal struct RequestArgument + { + public readonly ArgumentId Id; + public readonly int ArgumentIndex; + public readonly string Value; + + public RequestArgument(ArgumentId argumentId, int argumentIndex, string value) + { + Id = argumentId; + ArgumentIndex = argumentIndex; + Value = value; + } + + public static RequestArgument ReadFromBinaryReader(BinaryReader reader) + { + var argId = (ArgumentId)reader.ReadInt32(); + var argIndex = reader.ReadInt32(); + var value = ServerProtocol.ReadLengthPrefixedString(reader); + return new RequestArgument(argId, argIndex, value); + } + + public void WriteToBinaryWriter(BinaryWriter writer) + { + writer.Write((int)Id); + writer.Write(ArgumentIndex); + ServerProtocol.WriteLengthPrefixedString(writer, Value); + } + + public enum ArgumentId + { + // The current directory of the client + CurrentDirectory = 0x51147221, + + // A comment line argument. The argument index indicates which one (0 .. N) + CommandLineArgument, + + // Request a longer keep alive time for the server + KeepAlive, + + // Request a server shutdown from the client + Shutdown, + + // The directory to use for temporary operations. + TempDirectory, + } + } +} diff --git a/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerConnection.cs b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerConnection.cs new file mode 100644 index 0000000000..aec7e72be9 --- /dev/null +++ b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerConnection.cs @@ -0,0 +1,330 @@ +// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Roslyn.Utilities; +using static Microsoft.CodeAnalysis.CommandLine.CompilerServerLogger; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + internal static class ServerConnection + { + private const string ServerName = "rzc.dll"; + + // Spend up to 1s connecting to existing process (existing processes should be always responsive). + private const int TimeOutMsExistingProcess = 1000; + + // Spend up to 20s connecting to a new process, to allow time for it to start. + private const int TimeOutMsNewProcess = 20000; + + public static bool WasServerMutexOpen(string mutexName) + { + var open = Mutex.TryOpenExisting(mutexName, out var mutex); + if (open) + { + mutex.Dispose(); + return true; + } + return false; + } + + /// + /// Gets the value of the temporary path for the current environment assuming the working directory + /// is . This function must emulate as + /// closely as possible. + /// + public static string GetTempPath(string workingDir) + { + if (PlatformInformation.IsUnix) + { + // Unix temp path is fine: it does not use the working directory + // (it uses ${TMPDIR} if set, otherwise, it returns /tmp) + return Path.GetTempPath(); + } + + var tmp = Environment.GetEnvironmentVariable("TMP"); + if (Path.IsPathRooted(tmp)) + { + return tmp; + } + + var temp = Environment.GetEnvironmentVariable("TEMP"); + if (Path.IsPathRooted(temp)) + { + return temp; + } + + if (!string.IsNullOrEmpty(workingDir)) + { + if (!string.IsNullOrEmpty(tmp)) + { + return Path.Combine(workingDir, tmp); + } + + if (!string.IsNullOrEmpty(temp)) + { + return Path.Combine(workingDir, temp); + } + } + + var userProfile = Environment.GetEnvironmentVariable("USERPROFILE"); + if (Path.IsPathRooted(userProfile)) + { + return userProfile; + } + + return Environment.GetEnvironmentVariable("SYSTEMROOT"); + } + + public static Task RunOnServer( + List arguments, + ServerPaths buildPaths, + CancellationToken cancellationToken, + string keepAlive = null) + { + var pipeName = PipeName.ComputeDefault(); + + return RunOnServerCore( + arguments, + buildPaths, + pipeName: pipeName, + keepAlive: keepAlive, + timeoutOverride: null, + tryCreateServerFunc: TryCreateServerCore, + cancellationToken: cancellationToken); + } + + private static async Task RunOnServerCore( + List arguments, + ServerPaths buildPaths, + string pipeName, + string keepAlive, + int? timeoutOverride, + Func tryCreateServerFunc, + CancellationToken cancellationToken) + { + if (pipeName == null) + { + return new RejectedServerResponse(); + } + + if (buildPaths.TempDirectory == null) + { + return new RejectedServerResponse(); + } + + var clientDir = buildPaths.ClientDirectory; + var timeoutNewProcess = timeoutOverride ?? TimeOutMsNewProcess; + var timeoutExistingProcess = timeoutOverride ?? TimeOutMsExistingProcess; + var clientMutexName = MutexName.GetClientMutexName(pipeName); + Task pipeTask = null; + using (var clientMutex = new Mutex(initiallyOwned: true, name: clientMutexName, createdNew: out var holdsMutex)) + { + try + { + if (!holdsMutex) + { + try + { + holdsMutex = clientMutex.WaitOne(timeoutNewProcess); + + if (!holdsMutex) + { + return new RejectedServerResponse(); + } + } + catch (AbandonedMutexException) + { + holdsMutex = true; + } + } + + // Check for an already running server + var serverMutexName = MutexName.GetServerMutexName(pipeName); + var wasServerRunning = WasServerMutexOpen(serverMutexName); + var timeout = wasServerRunning ? timeoutExistingProcess : timeoutNewProcess; + + if (wasServerRunning || tryCreateServerFunc(clientDir, pipeName)) + { + pipeTask = Client.ConnectAsync(pipeName, TimeSpan.FromMilliseconds(timeout), cancellationToken); + } + } + finally + { + if (holdsMutex) + { + clientMutex.ReleaseMutex(); + } + } + } + + if (pipeTask != null) + { + var client = await pipeTask.ConfigureAwait(false); + if (client != null) + { + var request = ServerRequest.Create( + buildPaths.WorkingDirectory, + buildPaths.TempDirectory, + arguments, + keepAlive); + + return await TryProcessRequest(client, request, cancellationToken).ConfigureAwait(false); + } + } + + return new RejectedServerResponse(); + } + + /// + /// Try to process the request using the server. Returns a null-containing Task if a response + /// from the server cannot be retrieved. + /// + private static async Task TryProcessRequest( + Client client, + ServerRequest request, + CancellationToken cancellationToken) + { + ServerResponse response; + using (client) + { + // Write the request + try + { + Log("Begin writing request"); + await request.WriteAsync(client.Stream, cancellationToken).ConfigureAwait(false); + Log("End writing request"); + } + catch (Exception e) + { + LogException(e, "Error writing build request."); + return new RejectedServerResponse(); + } + + // Wait for the compilation and a monitor to detect if the server disconnects + var serverCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + + Log("Begin reading response"); + + var responseTask = ServerResponse.ReadAsync(client.Stream, serverCts.Token); + var monitorTask = client.WaitForDisconnectAsync(serverCts.Token); + await Task.WhenAny(responseTask, monitorTask).ConfigureAwait(false); + + Log("End reading response"); + + if (responseTask.IsCompleted) + { + // await the task to log any exceptions + try + { + response = await responseTask.ConfigureAwait(false); + } + catch (Exception e) + { + LogException(e, "Error reading response"); + response = new RejectedServerResponse(); + } + } + else + { + Log("Server disconnect"); + response = new RejectedServerResponse(); + } + + // Cancel whatever task is still around + serverCts.Cancel(); + Debug.Assert(response != null); + return response; + } + } + + private static bool TryCreateServerCore(string clientDir, string pipeName) + { + string expectedPath; + string processArguments; + + // The server should be in the same directory as the client + var expectedCompilerPath = Path.Combine(clientDir, ServerName); + expectedPath = Environment.GetEnvironmentVariable("DOTNET_HOST_PATH") ?? "dotnet"; + processArguments = $@"""{expectedCompilerPath}"" server -p {pipeName}"; + + if (!File.Exists(expectedCompilerPath)) + { + return false; + } + + if (PlatformInformation.IsWindows) + { + // As far as I can tell, there isn't a way to use the Process class to + // create a process with no stdin/stdout/stderr, so we use P/Invoke. + // This code was taken from MSBuild task starting code. + + var startInfo = new STARTUPINFO(); + startInfo.cb = Marshal.SizeOf(startInfo); + startInfo.hStdError = NativeMethods.InvalidIntPtr; + startInfo.hStdInput = NativeMethods.InvalidIntPtr; + startInfo.hStdOutput = NativeMethods.InvalidIntPtr; + startInfo.dwFlags = NativeMethods.STARTF_USESTDHANDLES; + var dwCreationFlags = NativeMethods.NORMAL_PRIORITY_CLASS | NativeMethods.CREATE_NO_WINDOW; + + Log("Attempting to create process '{0}'", expectedPath); + + var builder = new StringBuilder($@"""{expectedPath}"" {processArguments}"); + + var success = NativeMethods.CreateProcess( + lpApplicationName: null, + lpCommandLine: builder, + lpProcessAttributes: NativeMethods.NullPtr, + lpThreadAttributes: NativeMethods.NullPtr, + bInheritHandles: false, + dwCreationFlags: dwCreationFlags, + lpEnvironment: NativeMethods.NullPtr, // Inherit environment + lpCurrentDirectory: clientDir, + lpStartupInfo: ref startInfo, + lpProcessInformation: out var processInfo); + + if (success) + { + Log("Successfully created process with process id {0}", processInfo.dwProcessId); + NativeMethods.CloseHandle(processInfo.hProcess); + NativeMethods.CloseHandle(processInfo.hThread); + } + else + { + Log("Failed to create process. GetLastError={0}", Marshal.GetLastWin32Error()); + } + return success; + } + else + { + try + { + var startInfo = new ProcessStartInfo() + { + FileName = expectedPath, + Arguments = processArguments, + UseShellExecute = false, + WorkingDirectory = clientDir, + RedirectStandardInput = true, + RedirectStandardOutput = true, + RedirectStandardError = true, + CreateNoWindow = true + }; + + Process.Start(startInfo); + return true; + } + catch + { + return false; + } + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerPaths.cs b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerPaths.cs new file mode 100644 index 0000000000..37cd4a58c4 --- /dev/null +++ b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerPaths.cs @@ -0,0 +1,33 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + internal struct ServerPaths + { + internal ServerPaths(string clientDir, string workingDir, string tempDir) + { + ClientDirectory = clientDir; + WorkingDirectory = workingDir; + TempDirectory = tempDir; + } + + /// + /// The path which contains the Razor compiler binaries and response files. + /// + internal string ClientDirectory { get; } + + /// + /// The path in which the Razor compilation takes place. + /// + internal string WorkingDirectory { get; } + + /// + /// The temporary directory a compilation should use instead of . The latter + /// relies on global state individual compilations should ignore. + /// + internal string TempDirectory { get; } + } +} diff --git a/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerProtocol.cs b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerProtocol.cs new file mode 100644 index 0000000000..807690d7f4 --- /dev/null +++ b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerProtocol.cs @@ -0,0 +1,71 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis.CommandLine; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + internal static class ServerProtocol + { + /// + /// The version number for this protocol. + /// + public static readonly uint ProtocolVersion = 2; + + /// + /// Read a string from the Reader where the string is encoded + /// as a length prefix (signed 32-bit integer) followed by + /// a sequence of characters. + /// + public static string ReadLengthPrefixedString(BinaryReader reader) + { + var length = reader.ReadInt32(); + return new string(reader.ReadChars(length)); + } + + /// + /// Write a string to the Writer where the string is encoded + /// as a length prefix (signed 32-bit integer) follows by + /// a sequence of characters. + /// + public static void WriteLengthPrefixedString(BinaryWriter writer, string value) + { + writer.Write(value.Length); + writer.Write(value.ToCharArray()); + } + + /// + /// This task does not complete until we are completely done reading. + /// + internal static async Task ReadAllAsync( + Stream stream, + byte[] buffer, + int count, + CancellationToken cancellationToken) + { + var totalBytesRead = 0; + do + { + CompilerServerLogger.Log("Attempting to read {0} bytes from the stream", count - totalBytesRead); + var bytesRead = await stream.ReadAsync( + buffer, + totalBytesRead, + count - totalBytesRead, + cancellationToken) + .ConfigureAwait(false); + + if (bytesRead == 0) + { + CompilerServerLogger.Log("Unexpected -- read 0 bytes from the stream."); + throw new EndOfStreamException("Reached end of stream before end of read."); + } + CompilerServerLogger.Log("Read {0} bytes", bytesRead); + totalBytesRead += bytesRead; + } while (totalBytesRead < count); + CompilerServerLogger.Log("Finished read"); + } + } +} diff --git a/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerRequest.cs b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerRequest.cs new file mode 100644 index 0000000000..5fa861717c --- /dev/null +++ b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerRequest.cs @@ -0,0 +1,217 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using static Microsoft.CodeAnalysis.CommandLine.CompilerServerLogger; + +// After the server pipe is connected, it forks off a thread to handle the connection, and creates +// a new instance of the pipe to listen for new clients. When it gets a request, it validates +// the security and elevation level of the client. If that fails, it disconnects the client. Otherwise, +// it handles the request, sends a response (described by Response class) back to the client, then +// disconnects the pipe and ends the thread. +namespace Microsoft.AspNetCore.Razor.Tools +{ + /// + /// Represents a request from the client. A request is as follows. + /// + /// Field Name Type Size (bytes) + /// ---------------------------------------------------- + /// Length Integer 4 + /// Argument Count UInteger 4 + /// Arguments Argument[] Variable + /// + /// See for the format of an + /// Argument. + /// + /// + internal class ServerRequest + { + public ServerRequest(uint protocolVersion, IEnumerable arguments) + { + ProtocolVersion = protocolVersion; + Arguments = new ReadOnlyCollection(arguments.ToList()); + + if (Arguments.Count > ushort.MaxValue) + { + throw new ArgumentOutOfRangeException( + nameof(arguments), + $"Too many arguments: maximum of {ushort.MaxValue} arguments allowed."); + } + } + + public uint ProtocolVersion { get; } + + public ReadOnlyCollection Arguments { get; } + + public TimeSpan? KeepAlive + { + get + { + TimeSpan? keepAlive = null; + foreach (var argument in Arguments) + { + if (argument.Id == RequestArgument.ArgumentId.KeepAlive) + { + // If the value is not a valid integer for any reason, ignore it and continue with the current timeout. + // The client is responsible for validating the argument. + if (int.TryParse(argument.Value, out var result)) + { + // Keep alive times are specified in seconds + keepAlive = TimeSpan.FromSeconds(result); + } + } + } + + return keepAlive; + } + } + + public bool IsShutdownRequest() + { + return Arguments.Count >= 1 && Arguments[0].Id == RequestArgument.ArgumentId.Shutdown; + } + + public static ServerRequest Create( + string workingDirectory, + string tempDirectory, + IList args, + string keepAlive = null) + { + Log("Creating ServerRequest"); + Log($"Working directory: {workingDirectory}"); + Log($"Temp directory: {tempDirectory}"); + + var requestLength = args.Count + 1; + var requestArgs = new List(requestLength) + { + new RequestArgument(RequestArgument.ArgumentId.CurrentDirectory, 0, workingDirectory), + new RequestArgument(RequestArgument.ArgumentId.TempDirectory, 0, tempDirectory) + }; + + if (keepAlive != null) + { + requestArgs.Add(new RequestArgument(RequestArgument.ArgumentId.KeepAlive, 0, keepAlive)); + } + + for (var i = 0; i < args.Count; ++i) + { + var arg = args[i]; + Log($"argument[{i}] = {arg}"); + requestArgs.Add(new RequestArgument(RequestArgument.ArgumentId.CommandLineArgument, i, arg)); + } + + return new ServerRequest(ServerProtocol.ProtocolVersion, requestArgs); + } + + public static ServerRequest CreateShutdown() + { + var requestArgs = new[] + { + new RequestArgument(RequestArgument.ArgumentId.Shutdown, argumentIndex: 0, value: ""), + new RequestArgument(RequestArgument.ArgumentId.CommandLineArgument, argumentIndex: 1, value: "shutdown"), + }; + return new ServerRequest(ServerProtocol.ProtocolVersion, requestArgs); + } + + /// + /// Read a Request from the given stream. + /// + /// The total request size must be less than 1MB. + /// + /// null if the Request was too large, the Request otherwise. + public static async Task ReadAsync(Stream inStream, CancellationToken cancellationToken) + { + // Read the length of the request + var lengthBuffer = new byte[4]; + Log("Reading length of request"); + await ServerProtocol.ReadAllAsync(inStream, lengthBuffer, 4, cancellationToken).ConfigureAwait(false); + var length = BitConverter.ToInt32(lengthBuffer, 0); + + // Back out if the request is > 1MB + if (length > 0x100000) + { + Log("Request is over 1MB in length, cancelling read."); + return null; + } + + cancellationToken.ThrowIfCancellationRequested(); + + // Read the full request + var requestBuffer = new byte[length]; + await ServerProtocol.ReadAllAsync(inStream, requestBuffer, length, cancellationToken).ConfigureAwait(false); + + cancellationToken.ThrowIfCancellationRequested(); + + Log("Parsing request"); + // Parse the request into the Request data structure. + using (var reader = new BinaryReader(new MemoryStream(requestBuffer), Encoding.Unicode)) + { + var protocolVersion = reader.ReadUInt32(); + var argumentCount = reader.ReadUInt32(); + + var argumentsBuilder = new List((int)argumentCount); + + for (var i = 0; i < argumentCount; i++) + { + cancellationToken.ThrowIfCancellationRequested(); + argumentsBuilder.Add(RequestArgument.ReadFromBinaryReader(reader)); + } + + return new ServerRequest(protocolVersion, argumentsBuilder); + } + } + + /// + /// Write a Request to the stream. + /// + public async Task WriteAsync(Stream outStream, CancellationToken cancellationToken = default) + { + using (var memoryStream = new MemoryStream()) + using (var writer = new BinaryWriter(memoryStream, Encoding.Unicode)) + { + // Format the request. + Log("Formatting request"); + writer.Write(ProtocolVersion); + writer.Write(Arguments.Count); + foreach (var arg in Arguments) + { + cancellationToken.ThrowIfCancellationRequested(); + arg.WriteToBinaryWriter(writer); + } + writer.Flush(); + + cancellationToken.ThrowIfCancellationRequested(); + + // Write the length of the request + var length = checked((int)memoryStream.Length); + + // Back out if the request is > 1 MB + if (memoryStream.Length > 0x100000) + { + Log("Request is over 1MB in length, cancelling write"); + throw new ArgumentOutOfRangeException(); + } + + // Send the request to the server + Log("Writing length of request."); + await outStream + .WriteAsync(BitConverter.GetBytes(length), 0, 4, cancellationToken) + .ConfigureAwait(false); + + Log("Writing request of size {0}", length); + // Write the request + memoryStream.Position = 0; + await memoryStream + .CopyToAsync(outStream, bufferSize: length, cancellationToken: cancellationToken) + .ConfigureAwait(false); + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerResponse.cs b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerResponse.cs new file mode 100644 index 0000000000..ea2c32ce44 --- /dev/null +++ b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ServerResponse.cs @@ -0,0 +1,133 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using static Microsoft.CodeAnalysis.CommandLine.CompilerServerLogger; + +// After the server pipe is connected, it forks off a thread to handle the connection, and creates +// a new instance of the pipe to listen for new clients. When it gets a request, it validates +// the security and elevation level of the client. If that fails, it disconnects the client. Otherwise, +// it handles the request, sends a response (described by Response class) back to the client, then +// disconnects the pipe and ends the thread. +namespace Microsoft.AspNetCore.Razor.Tools +{ + /// + /// Base class for all possible responses to a request. + /// The ResponseType enum should list all possible response types + /// and ReadResponse creates the appropriate response subclass based + /// on the response type sent by the client. + /// The format of a response is: + /// + /// Field Name Field Type Size (bytes) + /// ------------------------------------------------- + /// responseLength int (positive) 4 + /// responseType enum ResponseType 4 + /// responseBody Response subclass variable + /// + internal abstract class ServerResponse + { + public enum ResponseType + { + // The client and server are using incompatible protocol versions. + MismatchedVersion, + + // The build request completed on the server and the results are contained + // in the message. + Completed, + + // The shutdown request completed and the server process information is + // contained in the message. + Shutdown, + + // The request was rejected by the server. + Rejected, + } + + public abstract ResponseType Type { get; } + + public async Task WriteAsync(Stream outStream, CancellationToken cancellationToken) + { + using (var memoryStream = new MemoryStream()) + using (var writer = new BinaryWriter(memoryStream, Encoding.Unicode)) + { + // Format the response + Log("Formatting Response"); + writer.Write((int)Type); + + AddResponseBody(writer); + writer.Flush(); + + cancellationToken.ThrowIfCancellationRequested(); + + // Send the response to the client + + // Write the length of the response + var length = checked((int)memoryStream.Length); + + Log("Writing response length"); + // There is no way to know the number of bytes written to + // the pipe stream. We just have to assume all of them are written. + await outStream + .WriteAsync(BitConverter.GetBytes(length), 0, 4, cancellationToken) + .ConfigureAwait(false); + + // Write the response + Log("Writing response of size {0}", length); + memoryStream.Position = 0; + await memoryStream + .CopyToAsync(outStream, bufferSize: length, cancellationToken: cancellationToken) + .ConfigureAwait(false); + } + } + + protected abstract void AddResponseBody(BinaryWriter writer); + + /// + /// May throw exceptions if there are pipe problems. + /// + /// + /// + /// + public static async Task ReadAsync(Stream stream, CancellationToken cancellationToken = default(CancellationToken)) + { + Log("Reading response length"); + // Read the response length + var lengthBuffer = new byte[4]; + await ServerProtocol.ReadAllAsync(stream, lengthBuffer, 4, cancellationToken).ConfigureAwait(false); + var length = BitConverter.ToUInt32(lengthBuffer, 0); + + // Read the response + Log("Reading response of length {0}", length); + var responseBuffer = new byte[length]; + await ServerProtocol.ReadAllAsync( + stream, + responseBuffer, + responseBuffer.Length, + cancellationToken) + .ConfigureAwait(false); + + using (var reader = new BinaryReader(new MemoryStream(responseBuffer), Encoding.Unicode)) + { + var responseType = (ResponseType)reader.ReadInt32(); + + switch (responseType) + { + case ResponseType.Completed: + return CompletedServerResponse.Create(reader); + case ResponseType.MismatchedVersion: + return new MismatchedVersionServerResponse(); + case ResponseType.Shutdown: + return ShutdownServerResponse.Create(reader); + case ResponseType.Rejected: + return new RejectedServerResponse(); + default: + throw new InvalidOperationException("Received invalid response type from server."); + } + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ShutdownServerResponse.cs b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ShutdownServerResponse.cs new file mode 100644 index 0000000000..dbf94e1628 --- /dev/null +++ b/src/Microsoft.AspNetCore.Razor.Tools/ServerProtocol/ShutdownServerResponse.cs @@ -0,0 +1,30 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + internal sealed class ShutdownServerResponse : ServerResponse + { + public readonly int ServerProcessId; + + public ShutdownServerResponse(int serverProcessId) + { + ServerProcessId = serverProcessId; + } + + public override ResponseType Type => ResponseType.Shutdown; + + protected override void AddResponseBody(BinaryWriter writer) + { + writer.Write(ServerProcessId); + } + + public static ShutdownServerResponse Create(BinaryReader reader) + { + var serverProcessId = reader.ReadInt32(); + return new ShutdownServerResponse(serverProcessId); + } + } +} diff --git a/src/Microsoft.AspNetCore.Razor.Tools/ShutdownCommand.cs b/src/Microsoft.AspNetCore.Razor.Tools/ShutdownCommand.cs index 114795331e..f77b30f5b2 100644 --- a/src/Microsoft.AspNetCore.Razor.Tools/ShutdownCommand.cs +++ b/src/Microsoft.AspNetCore.Razor.Tools/ShutdownCommand.cs @@ -5,7 +5,6 @@ using System; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; -using Microsoft.CodeAnalysis.CommandLine; using Microsoft.Extensions.CommandLineUtils; namespace Microsoft.AspNetCore.Razor.Tools @@ -46,10 +45,10 @@ namespace Microsoft.AspNetCore.Razor.Tools { using (var client = await Client.ConnectAsync(Pipe.Value(), timeout: null, cancellationToken: Cancelled)) { - var request = BuildRequest.CreateShutdown(); + var request = ServerRequest.CreateShutdown(); await request.WriteAsync(client.Stream, Cancelled).ConfigureAwait(false); - var response = ((ShutdownBuildResponse)await BuildResponse.ReadAsync(client.Stream, Cancelled)); + var response = ((ShutdownServerResponse)await ServerResponse.ReadAsync(client.Stream, Cancelled)); if (Wait.HasValue()) { diff --git a/test/Microsoft.AspNetCore.Razor.Tools.Test/DefaultRequestDispatcherTest.cs b/test/Microsoft.AspNetCore.Razor.Tools.Test/DefaultRequestDispatcherTest.cs new file mode 100644 index 0000000000..0cca396bf8 --- /dev/null +++ b/test/Microsoft.AspNetCore.Razor.Tools.Test/DefaultRequestDispatcherTest.cs @@ -0,0 +1,551 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + public class DefaultRequestDispatcherTest + { + private static ServerRequest EmptyServerRequest => new ServerRequest(1, Array.Empty()); + + private static ServerResponse EmptyServerResponse => new CompletedServerResponse( + returnCode: 0, + utf8output: false, + output: string.Empty); + + [Fact] + public async Task AcceptConnection_ReadingRequestFails_ClosesConnection() + { + // Arrange + var stream = Mock.Of(); + var compilerHost = CreateCompilerHost(); + var connectionHost = CreateConnectionHost(); + var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None); + var connection = CreateConnection(stream); + + // Act + var result = await dispatcher.AcceptConnection( + Task.FromResult(connection), accept: true, cancellationToken: CancellationToken.None); + + // Assert + Assert.Equal(ConnectionResult.Reason.CompilationNotStarted, result.CloseReason); + } + + /// + /// A failure to write the results to the client is considered a client disconnection. Any error + /// from when the build starts to when the write completes should be handled this way. + /// + [Fact] + public async Task AcceptConnection_WritingResultsFails_ClosesConnection() + { + // Arrange + var memoryStream = new MemoryStream(); + await EmptyServerRequest.WriteAsync(memoryStream, CancellationToken.None).ConfigureAwait(true); + memoryStream.Position = 0; + + var stream = new Mock(MockBehavior.Strict); + stream + .Setup(x => x.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((byte[] array, int start, int length, CancellationToken ct) => memoryStream.ReadAsync(array, start, length, ct)); + + var connection = CreateConnection(stream.Object); + var compilerHost = CreateCompilerHost(c => + { + c.ExecuteFunc = (req, ct) => + { + return EmptyServerResponse; + }; + }); + var connectionHost = CreateConnectionHost(); + var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None); + + // Act + // We expect WriteAsync to fail because the mock stream doesn't have a corresponding setup. + var connectionResult = await dispatcher.AcceptConnection( + Task.FromResult(connection), accept: true, cancellationToken: CancellationToken.None); + + // Assert + Assert.Equal(ConnectionResult.Reason.ClientDisconnect, connectionResult.CloseReason); + Assert.Null(connectionResult.KeepAlive); + } + + /// + /// Ensure the Connection correctly handles the case where a client disconnects while in the + /// middle of executing a request. + /// + [Fact] + public async Task AcceptConnection_ClientDisconnectsWhenExecutingRequest_ClosesConnection() + { + // Arrange + var connectionHost = Mock.Of(); + + // Fake a long running task here that we can validate later on. + var buildTaskSource = new TaskCompletionSource(); + var buildTaskCancellationToken = default(CancellationToken); + var compilerHost = CreateCompilerHost(c => + { + c.ExecuteFunc = (req, ct) => + { + Task.WaitAll(buildTaskSource.Task); + return EmptyServerResponse; + }; + }); + + var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None); + var readyTaskSource = new TaskCompletionSource(); + var disconnectTaskSource = new TaskCompletionSource(); + var connectionTask = CreateConnectionWithEmptyServerRequest(c => + { + c.WaitForDisconnectAsyncFunc = (ct) => + { + buildTaskCancellationToken = ct; + readyTaskSource.SetResult(true); + return disconnectTaskSource.Task; + }; + }); + + var handleTask = dispatcher.AcceptConnection( + connectionTask, accept: true, cancellationToken: CancellationToken.None); + + // Wait until WaitForDisconnectAsync task is actually created and running. + await readyTaskSource.Task.ConfigureAwait(false); + + // Act + // Now simulate a disconnect by the client. + disconnectTaskSource.SetResult(true); + var connectionResult = await handleTask; + buildTaskSource.SetResult(true); + + // Assert + Assert.Equal(ConnectionResult.Reason.ClientDisconnect, connectionResult.CloseReason); + Assert.Null(connectionResult.KeepAlive); + Assert.True(buildTaskCancellationToken.IsCancellationRequested); + } + + [Fact] + public async Task AcceptConnection_AcceptFalse_RejectsBuildRequest() + { + // Arrange + var stream = new TestableStream(); + await EmptyServerRequest.WriteAsync(stream.ReadStream, CancellationToken.None); + stream.ReadStream.Position = 0; + + var connection = CreateConnection(stream); + var connectionHost = CreateConnectionHost(); + var compilerHost = CreateCompilerHost(); + var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None); + + // Act + var connectionResult = await dispatcher.AcceptConnection( + Task.FromResult(connection), accept: false, cancellationToken: CancellationToken.None); + + // Assert + Assert.Equal(ConnectionResult.Reason.CompilationNotStarted, connectionResult.CloseReason); + stream.WriteStream.Position = 0; + var response = await ServerResponse.ReadAsync(stream.WriteStream).ConfigureAwait(false); + Assert.Equal(ServerResponse.ResponseType.Rejected, response.Type); + } + + [Fact] + public async Task AcceptConnection_ShutdownRequest_ReturnsShutdownResponse() + { + // Arrange + var stream = new TestableStream(); + await ServerRequest.CreateShutdown().WriteAsync(stream.ReadStream, CancellationToken.None); + stream.ReadStream.Position = 0; + + var connection = CreateConnection(stream); + var connectionHost = CreateConnectionHost(); + var compilerHost = CreateCompilerHost(); + var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None); + + // Act + var connectionResult = await dispatcher.AcceptConnection( + Task.FromResult(connection), accept: true, cancellationToken: CancellationToken.None); + + // Assert + Assert.Equal(ConnectionResult.Reason.ClientShutdownRequest, connectionResult.CloseReason); + stream.WriteStream.Position = 0; + var response = await ServerResponse.ReadAsync(stream.WriteStream).ConfigureAwait(false); + Assert.Equal(ServerResponse.ResponseType.Shutdown, response.Type); + } + + [Fact] + public async Task AcceptConnection_ConnectionHostThrowsWhenConnecting_ClosesConnection() + { + // Arrange + var connectionHost = new Mock(MockBehavior.Strict); + connectionHost.Setup(c => c.WaitForConnectionAsync(It.IsAny())).Throws(new Exception()); + var compilerHost = CreateCompilerHost(); + var dispatcher = new DefaultRequestDispatcher(connectionHost.Object, compilerHost, CancellationToken.None); + var connection = CreateConnection(Mock.Of()); + + // Act + var connectionResult = await dispatcher.AcceptConnection( + Task.FromResult(connection), accept: true, cancellationToken: CancellationToken.None); + + // Assert + Assert.Equal(ConnectionResult.Reason.CompilationNotStarted, connectionResult.CloseReason); + Assert.Null(connectionResult.KeepAlive); + } + + [Fact] + public async Task AcceptConnection_ClientConnectionThrowsWhenConnecting_ClosesConnection() + { + // Arrange + var compilerHost = CreateCompilerHost(); + var connectionHost = CreateConnectionHost(); + var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None); + var connectionTask = Task.FromException(new Exception()); + + // Act + var connectionResult = await dispatcher.AcceptConnection( + connectionTask, accept: true, cancellationToken: CancellationToken.None); + + // Assert + Assert.Equal(ConnectionResult.Reason.CompilationNotStarted, connectionResult.CloseReason); + Assert.Null(connectionResult.KeepAlive); + } + + [Fact] + public async Task Dispatcher_ClientConnectionThrowsWhenExecutingRequest_ClosesConnection() + { + // Arrange + var called = false; + var connectionTask = CreateConnectionWithEmptyServerRequest(c => + { + c.WaitForDisconnectAsyncFunc = (ct) => + { + called = true; + throw new Exception(); + }; + }); + + var compilerHost = CreateCompilerHost(); + var connectionHost = CreateConnectionHost(); + var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None); + + // Act + var connectionResult = await dispatcher.AcceptConnection( + connectionTask, accept: true, cancellationToken: CancellationToken.None); + + // Assert + Assert.True(called); + Assert.Equal(ConnectionResult.Reason.ClientException, connectionResult.CloseReason); + Assert.Null(connectionResult.KeepAlive); + } + + [Fact] + public void Dispatcher_NoConnections_HitsKeepAliveTimeout() + { + // Arrange + var keepAlive = TimeSpan.FromSeconds(3); + var compilerHost = CreateCompilerHost(); + var connectionHost = new Mock(); + connectionHost + .Setup(x => x.WaitForConnectionAsync(It.IsAny())) + .Returns(new TaskCompletionSource().Task); + + var eventBus = new TestableEventBus(); + var dispatcher = new DefaultRequestDispatcher(connectionHost.Object, compilerHost, CancellationToken.None, eventBus, keepAlive); + var startTime = DateTime.Now; + + // Act + dispatcher.Run(); + + // Assert + Assert.True(eventBus.HitKeepAliveTimeout); + } + + /// + /// Ensure server respects keep alive and shuts down after processing a single connection. + /// + [Fact] + public void Dispatcher_ProcessSingleConnection_HitsKeepAliveTimeout() + { + // Arrange + var connectionTask = CreateConnectionWithEmptyServerRequest(); + var keepAlive = TimeSpan.FromSeconds(1); + var compilerHost = CreateCompilerHost(c => + { + c.ExecuteFunc = (req, ct) => + { + return EmptyServerResponse; + }; + }); + var connectionHost = CreateConnectionHost(connectionTask, new TaskCompletionSource().Task); + + var eventBus = new TestableEventBus(); + var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None, eventBus, keepAlive); + + // Act + dispatcher.Run(); + + // Assert + Assert.Equal(1, eventBus.CompletedCount); + Assert.True(eventBus.LastProcessedTime.HasValue); + Assert.True(eventBus.HitKeepAliveTimeout); + } + + /// + /// Ensure server respects keep alive and shuts down after processing multiple connections. + /// + [Fact] + public void Dispatcher_ProcessMultipleConnections_HitsKeepAliveTimeout() + { + // Arrange + var count = 5; + var list = new List>(); + for (var i = 0; i < count; i++) + { + var connectionTask = CreateConnectionWithEmptyServerRequest(); + list.Add(connectionTask); + } + + list.Add(new TaskCompletionSource().Task); + var connectionHost = CreateConnectionHost(list.ToArray()); + var compilerHost = CreateCompilerHost(c => + { + c.ExecuteFunc = (req, ct) => + { + return EmptyServerResponse; + }; + }); + + var keepAlive = TimeSpan.FromSeconds(1); + var eventBus = new TestableEventBus(); + var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None, eventBus, keepAlive); + + // Act + dispatcher.Run(); + + // Assert + Assert.Equal(count, eventBus.CompletedCount); + Assert.True(eventBus.LastProcessedTime.HasValue); + Assert.True(eventBus.HitKeepAliveTimeout); + } + + /// + /// Ensure server respects keep alive and shuts down after processing simultaneous connections. + /// + [Fact] + public async Task Dispatcher_ProcessSimultaneousConnections_HitsKeepAliveTimeout() + { + // Arrange + var totalCount = 2; + var readySource = new TaskCompletionSource(); + var list = new List>(); + var connectionHost = new Mock(); + connectionHost + .Setup(x => x.WaitForConnectionAsync(It.IsAny())) + .Returns((CancellationToken ct) => + { + if (list.Count < totalCount) + { + var source = new TaskCompletionSource(); + var connectionTask = CreateConnectionWithEmptyServerRequest(c => + { + c.WaitForDisconnectAsyncFunc = _ => source.Task; + }); + list.Add(source); + return connectionTask; + } + + readySource.SetResult(true); + return new TaskCompletionSource().Task; + }); + + var compilerHost = CreateCompilerHost(c => + { + c.ExecuteFunc = (req, ct) => + { + return EmptyServerResponse; + }; + }); + + var keepAlive = TimeSpan.FromSeconds(1); + var eventBus = new TestableEventBus(); + var dispatcherTask = Task.Run(() => + { + var dispatcher = new DefaultRequestDispatcher(connectionHost.Object, compilerHost, CancellationToken.None, eventBus, keepAlive); + dispatcher.Run(); + }); + + await readySource.Task; + foreach (var source in list) + { + source.SetResult(true); + } + + // Act + await dispatcherTask; + + // Assert + Assert.Equal(totalCount, eventBus.CompletedCount); + Assert.True(eventBus.LastProcessedTime.HasValue); + Assert.True(eventBus.HitKeepAliveTimeout); + } + + [Fact] + public void Dispatcher_ClientConnectionThrows_BeginsShutdown() + { + // Arrange + var listenCancellationToken = default(CancellationToken); + var firstConnectionTask = CreateConnectionWithEmptyServerRequest(c => + { + c.WaitForDisconnectAsyncFunc = (ct) => + { + listenCancellationToken = ct; + return Task.Delay(Timeout.Infinite, ct).ContinueWith(_ => null); + }; + }); + var secondConnectionTask = CreateConnectionWithEmptyServerRequest(c => + { + c.WaitForDisconnectAsyncFunc = (ct) => throw new Exception(); + }); + + var compilerHost = CreateCompilerHost(); + var connectionHost = CreateConnectionHost( + firstConnectionTask, + secondConnectionTask, + new TaskCompletionSource().Task); + var keepAlive = TimeSpan.FromSeconds(10); + var eventBus = new TestableEventBus(); + var dispatcher = new DefaultRequestDispatcher(connectionHost, compilerHost, CancellationToken.None, eventBus, keepAlive); + + // Act + dispatcher.Run(); + + // Assert + Assert.True(eventBus.HasDetectedBadConnection); + Assert.True(listenCancellationToken.IsCancellationRequested); + } + + private static TestableConnection CreateConnection(Stream stream, string identifier = null) + { + return new TestableConnection(stream, identifier ?? "identifier"); + } + + private static async Task CreateConnectionWithEmptyServerRequest(Action configureConnection = null) + { + var memoryStream = new MemoryStream(); + await EmptyServerRequest.WriteAsync(memoryStream, CancellationToken.None); + memoryStream.Position = 0; + var connection = CreateConnection(memoryStream); + configureConnection?.Invoke(connection); + + return connection; + } + + private static ConnectionHost CreateConnectionHost(params Task[] connections) + { + var host = new Mock(); + if (connections.Length > 0) + { + var index = 0; + host + .Setup(x => x.WaitForConnectionAsync(It.IsAny())) + .Returns((CancellationToken ct) => connections[index++]); + } + + return host.Object; + } + + private static TestableCompilerHost CreateCompilerHost(Action configureCompilerHost = null) + { + var compilerHost = new TestableCompilerHost(); + configureCompilerHost?.Invoke(compilerHost); + + return compilerHost; + } + + private class TestableCompilerHost : CompilerHost + { + internal Func ExecuteFunc; + + public override ServerResponse Execute(ServerRequest request, CancellationToken cancellationToken) + { + if (ExecuteFunc != null) + { + return ExecuteFunc(request, cancellationToken); + } + + return EmptyServerResponse; + } + } + + private class TestableConnection : Connection + { + internal Func WaitForDisconnectAsyncFunc; + + public TestableConnection(Stream stream, string identifier) + { + Stream = stream; + Identifier = identifier; + WaitForDisconnectAsyncFunc = ct => Task.Delay(Timeout.Infinite, ct); + } + + public override Task WaitForDisconnectAsync(CancellationToken cancellationToken) + { + return WaitForDisconnectAsyncFunc(cancellationToken); + } + } + + private class TestableStream : Stream + { + internal readonly MemoryStream ReadStream = new MemoryStream(); + internal readonly MemoryStream WriteStream = new MemoryStream(); + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => true; + public override long Length { get { throw new NotImplementedException(); } } + public override long Position + { + get { throw new NotImplementedException(); } + set { throw new NotImplementedException(); } + } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) + { + return ReadStream.Read(buffer, offset, count); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return ReadStream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + WriteStream.Write(buffer, offset, count); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return WriteStream.WriteAsync(buffer, offset, count, cancellationToken); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Razor.Tools.Test/Infrastructure/ServerData.cs b/test/Microsoft.AspNetCore.Razor.Tools.Test/Infrastructure/ServerData.cs new file mode 100644 index 0000000000..37a9d0180a --- /dev/null +++ b/test/Microsoft.AspNetCore.Razor.Tools.Test/Infrastructure/ServerData.cs @@ -0,0 +1,49 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + internal sealed class ServerData : IDisposable + { + internal CancellationTokenSource CancellationTokenSource { get; } + internal Task ServerTask { get; } + internal Task ListenTask { get; } + internal string PipeName { get; } + + internal ServerData(CancellationTokenSource cancellationTokenSource, string pipeName, Task serverTask, Task listenTask) + { + CancellationTokenSource = cancellationTokenSource; + PipeName = pipeName; + ServerTask = serverTask; + ListenTask = listenTask; + } + + internal async Task CancelAndCompleteAsync() + { + CancellationTokenSource.Cancel(); + return await ServerTask; + } + + internal async Task Verify(int connections, int completed) + { + var stats = await CancelAndCompleteAsync().ConfigureAwait(false); + Assert.Equal(connections, stats.Connections); + Assert.Equal(completed, stats.CompletedConnections); + } + + public void Dispose() + { + if (!CancellationTokenSource.IsCancellationRequested) + { + CancellationTokenSource.Cancel(); + } + + ServerTask.Wait(); + } + } +} diff --git a/test/Microsoft.AspNetCore.Razor.Tools.Test/Infrastructure/ServerStats.cs b/test/Microsoft.AspNetCore.Razor.Tools.Test/Infrastructure/ServerStats.cs new file mode 100644 index 0000000000..ce4064528f --- /dev/null +++ b/test/Microsoft.AspNetCore.Razor.Tools.Test/Infrastructure/ServerStats.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Razor.Tools +{ + internal struct ServerStats + { + internal readonly int Connections; + internal readonly int CompletedConnections; + + internal ServerStats(int connections, int completedConnections) + { + Connections = connections; + CompletedConnections = completedConnections; + } + } +} diff --git a/test/Microsoft.AspNetCore.Razor.Tools.Test/Infrastructure/ServerUtilities.cs b/test/Microsoft.AspNetCore.Razor.Tools.Test/Infrastructure/ServerUtilities.cs new file mode 100644 index 0000000000..77c30cd60a --- /dev/null +++ b/test/Microsoft.AspNetCore.Razor.Tools.Test/Infrastructure/ServerUtilities.cs @@ -0,0 +1,145 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.CodeAnalysis.CommandLine; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + internal static class ServerUtilities + { + internal static string DefaultClientDirectory { get; } = Path.GetDirectoryName(typeof(ServerUtilities).Assembly.Location); + + internal static ServerPaths CreateBuildPaths(string workingDir, string tempDir) + { + return new ServerPaths( + clientDir: DefaultClientDirectory, + workingDir: workingDir, + tempDir: tempDir); + } + + internal static ServerData CreateServer( + string pipeName = null, + CompilerHost compilerHost = null, + ConnectionHost connectionHost = null) + { + pipeName = pipeName ?? Guid.NewGuid().ToString(); + compilerHost = compilerHost ?? CompilerHost.Create(); + connectionHost = connectionHost ?? ConnectionHost.Create(pipeName); + + var serverStatsSource = new TaskCompletionSource(); + var serverListenSource = new TaskCompletionSource(); + var cts = new CancellationTokenSource(); + var mutexName = MutexName.GetServerMutexName(pipeName); + var thread = new Thread(_ => + { + var eventBus = new TestableEventBus(); + eventBus.Listening += (sender, e) => { serverListenSource.TrySetResult(true); }; + try + { + RunServer( + pipeName, + connectionHost, + compilerHost, + cts.Token, + eventBus, + Timeout.InfiniteTimeSpan); + } + finally + { + var serverStats = new ServerStats(connections: eventBus.ConnectionCount, completedConnections: eventBus.CompletedCount); + serverStatsSource.SetResult(serverStats); + } + }); + + thread.Start(); + + // The contract of this function is that it will return once the server has started. Spin here until + // we can verify the server has started or simply failed to start. + while (ServerConnection.WasServerMutexOpen(mutexName) != true && thread.IsAlive) + { + Thread.Yield(); + } + + return new ServerData(cts, pipeName, serverStatsSource.Task, serverListenSource.Task); + } + + internal static async Task Send(string pipeName, ServerRequest request) + { + using (var client = await Client.ConnectAsync(pipeName, timeout: null, cancellationToken: default).ConfigureAwait(false)) + { + await request.WriteAsync(client.Stream).ConfigureAwait(false); + return await ServerResponse.ReadAsync(client.Stream).ConfigureAwait(false); + } + } + + internal static async Task SendShutdown(string pipeName) + { + var response = await Send(pipeName, ServerRequest.CreateShutdown()); + return ((ShutdownServerResponse)response).ServerProcessId; + } + + internal static int RunServer( + string pipeName, + ConnectionHost host, + CompilerHost compilerHost, + CancellationToken cancellationToken = default, + EventBus eventBus = null, + TimeSpan? keepAlive = null) + { + var command = new TestableServerCommand(host, compilerHost, cancellationToken, eventBus, keepAlive); + var args = new List + { + "-p", + pipeName + }; + + var result = command.Execute(args.ToArray()); + return result; + } + + private class TestableServerCommand : ServerCommand + { + private readonly ConnectionHost _host; + private readonly CompilerHost _compilerHost; + private readonly EventBus _eventBus; + private readonly CancellationToken _cancellationToken; + private readonly TimeSpan? _keepAlive; + + + public TestableServerCommand( + ConnectionHost host, + CompilerHost compilerHost, + CancellationToken ct, + EventBus eventBus, + TimeSpan? keepAlive) + : base(new Application(ct)) + { + _host = host; + _compilerHost = compilerHost; + _cancellationToken = ct; + _eventBus = eventBus; + _keepAlive = keepAlive; + } + + protected override void ExecuteServerCore( + ConnectionHost host, + CompilerHost compilerHost, + CancellationToken cancellationToken, + EventBus eventBus, + TimeSpan? keepAlive = null) + { + base.ExecuteServerCore( + _host ?? host, + _compilerHost ?? compilerHost, + _cancellationToken, + _eventBus ?? eventBus, + _keepAlive ?? keepAlive); + } + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Razor.Tools.Test/Infrastructure/TestableEventBus.cs b/test/Microsoft.AspNetCore.Razor.Tools.Test/Infrastructure/TestableEventBus.cs new file mode 100644 index 0000000000..4bc2465476 --- /dev/null +++ b/test/Microsoft.AspNetCore.Razor.Tools.Test/Infrastructure/TestableEventBus.cs @@ -0,0 +1,53 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + internal class TestableEventBus : EventBus + { + public int ListeningCount; + public int ConnectionCount; + public int CompletedCount; + public DateTime? LastProcessedTime; + public TimeSpan? KeepAlive; + public bool HasDetectedBadConnection; + public bool HitKeepAliveTimeout; + public event EventHandler Listening; + + public override void ConnectionListening() + { + ListeningCount++; + Listening?.Invoke(this, EventArgs.Empty); + } + + public override void ConnectionReceived() + { + ConnectionCount++; + } + + public override void ConnectionCompleted(int count) + { + CompletedCount += count; + LastProcessedTime = DateTime.Now; + } + + public override void UpdateKeepAlive(TimeSpan timeSpan) + { + KeepAlive = timeSpan; + } + + public override void ConnectionRudelyEnded() + { + HasDetectedBadConnection = true; + } + + public override void KeepAliveReached() + { + HitKeepAliveTimeout = true; + } + } +} diff --git a/test/Microsoft.AspNetCore.Razor.Tools.Test/Microsoft.AspNetCore.Razor.Tools.Test.csproj b/test/Microsoft.AspNetCore.Razor.Tools.Test/Microsoft.AspNetCore.Razor.Tools.Test.csproj new file mode 100644 index 0000000000..d2fdc4f4a1 --- /dev/null +++ b/test/Microsoft.AspNetCore.Razor.Tools.Test/Microsoft.AspNetCore.Razor.Tools.Test.csproj @@ -0,0 +1,18 @@ + + + + netcoreapp2.0 + + + + + + + + + + + + + + diff --git a/test/Microsoft.AspNetCore.Razor.Tools.Test/ServerLifecycleTest.cs b/test/Microsoft.AspNetCore.Razor.Tools.Test/ServerLifecycleTest.cs new file mode 100644 index 0000000000..e617299f76 --- /dev/null +++ b/test/Microsoft.AspNetCore.Razor.Tools.Test/ServerLifecycleTest.cs @@ -0,0 +1,260 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + public class ServerLifecycleTest + { + private static ServerRequest EmptyServerRequest => new ServerRequest(1, Array.Empty()); + + private static ServerResponse EmptyServerResponse => new CompletedServerResponse( + returnCode: 0, + utf8output: false, + output: string.Empty); + + [Fact] + public void ServerStartup_MutexAlreadyAcquired_Fails() + { + // Arrange + var pipeName = Guid.NewGuid().ToString("N"); + var mutexName = MutexName.GetServerMutexName(pipeName); + var compilerHost = new Mock(MockBehavior.Strict); + var host = new Mock(MockBehavior.Strict); + + // Act & Assert + using (var mutex = new Mutex(initiallyOwned: true, name: mutexName, createdNew: out var holdsMutex)) + { + Assert.True(holdsMutex); + try + { + var result = ServerUtilities.RunServer(pipeName, host.Object, compilerHost.Object); + + // Assert failure + Assert.Equal(1, result); + } + finally + { + mutex.ReleaseMutex(); + } + } + } + + [Fact] + public void ServerStartup_SuccessfullyAcquiredMutex() + { + // Arrange, Act & Assert + var pipeName = Guid.NewGuid().ToString("N"); + var mutexName = MutexName.GetServerMutexName(pipeName); + var compilerHost = new Mock(MockBehavior.Strict); + var host = new Mock(MockBehavior.Strict); + host + .Setup(x => x.WaitForConnectionAsync(It.IsAny())) + .Returns(() => + { + // Use a thread instead of Task to guarantee this code runs on a different + // thread and we can validate the mutex state. + var source = new TaskCompletionSource(); + var thread = new Thread(_ => + { + Mutex mutex = null; + try + { + Assert.True(Mutex.TryOpenExisting(mutexName, out mutex)); + Assert.False(mutex.WaitOne(millisecondsTimeout: 0)); + source.SetResult(true); + } + catch (Exception ex) + { + source.SetException(ex); + throw; + } + finally + { + mutex?.Dispose(); + } + }); + + // Synchronously wait here. Don't returned a Task value because we need to + // ensure the above check completes before the server hits a timeout and + // releases the mutex. + thread.Start(); + source.Task.Wait(); + + return new TaskCompletionSource().Task; + }); + + var result = ServerUtilities.RunServer(pipeName, host.Object, compilerHost.Object, keepAlive: TimeSpan.FromSeconds(1)); + Assert.Equal(0, result); + } + + [Fact] + public async Task ServerRunning_ShutdownRequest_processesSuccessfully() + { + // Arrange + using (var serverData = ServerUtilities.CreateServer()) + { + // Act + var serverProcessId = await ServerUtilities.SendShutdown(serverData.PipeName); + + // Assert + Assert.Equal(Process.GetCurrentProcess().Id, serverProcessId); + await serverData.Verify(connections: 1, completed: 1); + } + } + + /// + /// A shutdown request should not abort an existing compilation. It should be allowed to run to + /// completion. + /// + [Fact] + public async Task ServerRunning_ShutdownRequest_DoesNotAbortCompilation() + { + // Arrange + var completionSource = new TaskCompletionSource(); + var host = CreateCompilerHost(c => c.ExecuteFunc = (req, ct) => + { + // We want this to keep running even after the shutdown is seen. + completionSource.Task.Wait(); + return EmptyServerResponse; + }); + + using (var serverData = ServerUtilities.CreateServer(compilerHost: host)) + { + var compileTask = ServerUtilities.Send(serverData.PipeName, EmptyServerRequest); + + // Act + // The compilation is now in progress, send the shutdown. + await ServerUtilities.SendShutdown(serverData.PipeName); + Assert.False(compileTask.IsCompleted); + + // Now let the task complete. + completionSource.SetResult(true); + + // Assert + var response = await compileTask; + Assert.Equal(ServerResponse.ResponseType.Completed, response.Type); + Assert.Equal(0, ((CompletedServerResponse)response).ReturnCode); + + await serverData.Verify(connections: 2, completed: 2); + } + } + + /// + /// Multiple clients should be able to send shutdown requests to the server. + /// + [Fact] + public async Task ServerRunning_MultipleShutdownRequests_HandlesSuccessfully() + { + // Arrange + var completionSource = new TaskCompletionSource(); + var host = CreateCompilerHost(c => c.ExecuteFunc = (req, ct) => + { + // We want this to keep running even after the shutdown is seen. + completionSource.Task.Wait(); + return EmptyServerResponse; + }); + + using (var serverData = ServerUtilities.CreateServer(compilerHost: host)) + { + var compileTask = ServerUtilities.Send(serverData.PipeName, EmptyServerRequest); + + // Act + for (var i = 0; i < 10; i++) + { + // The compilation is now in progress, send the shutdown. + var processId = await ServerUtilities.SendShutdown(serverData.PipeName); + Assert.Equal(Process.GetCurrentProcess().Id, processId); + Assert.False(compileTask.IsCompleted); + } + + // Now let the task complete. + completionSource.SetResult(true); + + // Assert + var response = await compileTask; + Assert.Equal(ServerResponse.ResponseType.Completed, response.Type); + Assert.Equal(0, ((CompletedServerResponse)response).ReturnCode); + + await serverData.Verify(connections: 11, completed: 11); + } + } + + [Fact] + public async Task ServerRunning_CancelCompilation_CancelsSuccessfully() + { + // Arrange + const int requestCount = 5; + var count = 0; + var completionSource = new TaskCompletionSource(); + var host = CreateCompilerHost(c => c.ExecuteFunc = (req, ct) => + { + if (Interlocked.Increment(ref count) == requestCount) + { + completionSource.SetResult(true); + } + + ct.WaitHandle.WaitOne(); + return new RejectedServerResponse(); + }); + + using (var serverData = ServerUtilities.CreateServer(compilerHost: host)) + { + var tasks = new List>(); + for (var i = 0; i < requestCount; i++) + { + var task = ServerUtilities.Send(serverData.PipeName, EmptyServerRequest); + tasks.Add(task); + } + + // Act + // Wait until all of the connections are being processed by the server. + completionSource.Task.Wait(); + + // Now cancel + var stats = await serverData.CancelAndCompleteAsync(); + + // Assert + Assert.Equal(requestCount, stats.Connections); + Assert.Equal(requestCount, count); + + foreach (var task in tasks) + { + // We expect this to throw because the stream is already closed. + await Assert.ThrowsAsync(() => task); + } + } + } + + private static TestableCompilerHost CreateCompilerHost(Action configureCompilerHost = null) + { + var compilerHost = new TestableCompilerHost(); + configureCompilerHost?.Invoke(compilerHost); + + return compilerHost; + } + + private class TestableCompilerHost : CompilerHost + { + internal Func ExecuteFunc; + + public override ServerResponse Execute(ServerRequest request, CancellationToken cancellationToken) + { + if (ExecuteFunc != null) + { + return ExecuteFunc(request, cancellationToken); + } + + return EmptyServerResponse; + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Razor.Tools.Test/ServerProtocol/ServerProtocolTest.cs b/test/Microsoft.AspNetCore.Razor.Tools.Test/ServerProtocol/ServerProtocolTest.cs new file mode 100644 index 0000000000..9ce69a6f54 --- /dev/null +++ b/test/Microsoft.AspNetCore.Razor.Tools.Test/ServerProtocol/ServerProtocolTest.cs @@ -0,0 +1,128 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Immutable; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNetCore.Razor.Tools +{ + public class ServerProtocolTest + { + [Fact] + public async Task ServerResponse_WriteRead_RoundtripsProperly() + { + // Arrange + var response = new CompletedServerResponse(42, utf8output: false, output: "a string"); + var memoryStream = new MemoryStream(); + + // Act + await response.WriteAsync(memoryStream, CancellationToken.None); + + // Assert + Assert.True(memoryStream.Position > 0); + memoryStream.Position = 0; + var result = (CompletedServerResponse)await ServerResponse.ReadAsync(memoryStream, CancellationToken.None); + Assert.Equal(42, result.ReturnCode); + Assert.False(result.Utf8Output); + Assert.Equal("a string", result.Output); + Assert.Equal("", result.ErrorOutput); + } + + [Fact] + public async Task ServerRequest_WriteRead_RoundtripsProperly() + { + // Arrange + var request = new ServerRequest( + ServerProtocol.ProtocolVersion, + ImmutableArray.Create( + new RequestArgument(RequestArgument.ArgumentId.CurrentDirectory, argumentIndex: 0, value: "directory"), + new RequestArgument(RequestArgument.ArgumentId.CommandLineArgument, argumentIndex: 1, value: "file"))); + var memoryStream = new MemoryStream(); + + // Act + await request.WriteAsync(memoryStream, CancellationToken.None); + + // Assert + Assert.True(memoryStream.Position > 0); + memoryStream.Position = 0; + var read = await ServerRequest.ReadAsync(memoryStream, CancellationToken.None); + Assert.Equal(ServerProtocol.ProtocolVersion, read.ProtocolVersion); + Assert.Equal(2, read.Arguments.Count); + Assert.Equal(RequestArgument.ArgumentId.CurrentDirectory, read.Arguments[0].Id); + Assert.Equal(0, read.Arguments[0].ArgumentIndex); + Assert.Equal("directory", read.Arguments[0].Value); + Assert.Equal(RequestArgument.ArgumentId.CommandLineArgument, read.Arguments[1].Id); + Assert.Equal(1, read.Arguments[1].ArgumentIndex); + Assert.Equal("file", read.Arguments[1].Value); + } + + [Fact] + public void CreateShutdown_CreatesCorrectShutdownRequest() + { + // Arrange & Act + var request = ServerRequest.CreateShutdown(); + + // Assert + Assert.Equal(2, request.Arguments.Count); + + var argument1 = request.Arguments[0]; + Assert.Equal(RequestArgument.ArgumentId.Shutdown, argument1.Id); + Assert.Equal(0, argument1.ArgumentIndex); + Assert.Equal("", argument1.Value); + + var argument2 = request.Arguments[1]; + Assert.Equal(RequestArgument.ArgumentId.CommandLineArgument, argument2.Id); + Assert.Equal(1, argument2.ArgumentIndex); + Assert.Equal("shutdown", argument2.Value); + } + + [Fact] + public async Task ShutdownRequest_WriteRead_RoundtripsProperly() + { + // Arrange + var memoryStream = new MemoryStream(); + var request = ServerRequest.CreateShutdown(); + + // Act + await request.WriteAsync(memoryStream, CancellationToken.None); + + // Assert + memoryStream.Position = 0; + var read = await ServerRequest.ReadAsync(memoryStream, CancellationToken.None); + + var argument1 = request.Arguments[0]; + Assert.Equal(RequestArgument.ArgumentId.Shutdown, argument1.Id); + Assert.Equal(0, argument1.ArgumentIndex); + Assert.Equal("", argument1.Value); + + var argument2 = request.Arguments[1]; + Assert.Equal(RequestArgument.ArgumentId.CommandLineArgument, argument2.Id); + Assert.Equal(1, argument2.ArgumentIndex); + Assert.Equal("shutdown", argument2.Value); + } + + [Fact] + public async Task ShutdownResponse_WriteRead_RoundtripsProperly() + { + // Arrange & Act 1 + var memoryStream = new MemoryStream(); + var response = new ShutdownServerResponse(42); + + // Assert 1 + Assert.Equal(ServerResponse.ResponseType.Shutdown, response.Type); + + // Act 2 + await response.WriteAsync(memoryStream, CancellationToken.None); + + // Assert 2 + memoryStream.Position = 0; + var read = await ServerResponse.ReadAsync(memoryStream, CancellationToken.None); + Assert.Equal(ServerResponse.ResponseType.Shutdown, read.Type); + var typed = (ShutdownServerResponse)read; + Assert.Equal(42, typed.ServerProcessId); + } + } +}