From 1d6c09ab3182d325c9b2a356623ce8670e8310d4 Mon Sep 17 00:00:00 2001 From: Pranav K Date: Thu, 22 Mar 2018 16:09:03 -0700 Subject: [PATCH] Make the use of Assembly.CodeBase more robust --- .../RelatedAssemblyAttribute.cs | 23 ++++++-- .../RelatedAssemblyPartTest.cs | 58 ++++++++++++++----- 2 files changed, 64 insertions(+), 17 deletions(-) diff --git a/src/Microsoft.AspNetCore.Mvc.Core/ApplicationParts/RelatedAssemblyAttribute.cs b/src/Microsoft.AspNetCore.Mvc.Core/ApplicationParts/RelatedAssemblyAttribute.cs index aa796cdb55..b9f6cdb2d7 100644 --- a/src/Microsoft.AspNetCore.Mvc.Core/ApplicationParts/RelatedAssemblyAttribute.cs +++ b/src/Microsoft.AspNetCore.Mvc.Core/ApplicationParts/RelatedAssemblyAttribute.cs @@ -50,10 +50,14 @@ namespace Microsoft.AspNetCore.Mvc.ApplicationParts throw new ArgumentNullException(nameof(assembly)); } - return GetRelatedAssemblies(assembly, throwOnError, AssemblyLoadFileDelegate); + return GetRelatedAssemblies(assembly, throwOnError, File.Exists, AssemblyLoadFileDelegate); } - internal static IReadOnlyList GetRelatedAssemblies(Assembly assembly, bool throwOnError, Func loadFile) + internal static IReadOnlyList GetRelatedAssemblies( + Assembly assembly, + bool throwOnError, + Func fileExists, + Func loadFile) { if (assembly == null) { @@ -74,7 +78,8 @@ namespace Microsoft.AspNetCore.Mvc.ApplicationParts } var assemblyName = assembly.GetName().Name; - var assemblyDirectory = Path.GetDirectoryName(assembly.CodeBase); + var assemblyLocation = GetAssemblyLocation(assembly); + var assemblyDirectory = Path.GetDirectoryName(assemblyLocation); var relatedAssemblies = new List(); for (var i = 0; i < attributes.Length; i++) @@ -87,7 +92,7 @@ namespace Microsoft.AspNetCore.Mvc.ApplicationParts } var relatedAssemblyLocation = Path.Combine(assemblyDirectory, attribute.AssemblyFileName + ".dll"); - if (!File.Exists(relatedAssemblyLocation)) + if (!fileExists(relatedAssemblyLocation)) { if (throwOnError) { @@ -107,5 +112,15 @@ namespace Microsoft.AspNetCore.Mvc.ApplicationParts return relatedAssemblies; } + + internal static string GetAssemblyLocation(Assembly assembly) + { + if (Uri.TryCreate(assembly.CodeBase, UriKind.Absolute, out var result) && result.IsFile) + { + return result.LocalPath; + } + + return assembly.Location; + } } } diff --git a/test/Microsoft.AspNetCore.Mvc.Core.Test/ApplicationParts/RelatedAssemblyPartTest.cs b/test/Microsoft.AspNetCore.Mvc.Core.Test/ApplicationParts/RelatedAssemblyPartTest.cs index 8f3bbdb47c..efce8dfaad 100644 --- a/test/Microsoft.AspNetCore.Mvc.Core.Test/ApplicationParts/RelatedAssemblyPartTest.cs +++ b/test/Microsoft.AspNetCore.Mvc.Core.Test/ApplicationParts/RelatedAssemblyPartTest.cs @@ -67,20 +67,46 @@ namespace Microsoft.AspNetCore.Mvc.ApplicationParts }; var relatedAssembly = typeof(RelatedAssemblyPartTest).Assembly; - try + var result = RelatedAssemblyAttribute.GetRelatedAssemblies(assembly, throwOnError: true, file => true, file => { - File.WriteAllBytes(destination, new byte[0]); - var result = RelatedAssemblyAttribute.GetRelatedAssemblies(assembly, throwOnError: true, file => - { - Assert.Equal(file, destination); - return relatedAssembly; - }); - Assert.Equal(new[] { relatedAssembly }, result); - } - finally + Assert.Equal(file, destination); + return relatedAssembly; + }); + Assert.Equal(new[] { relatedAssembly }, result); + } + + [Fact] + public void GetAssemblyLocation_UsesCodeBase() + { + // Arrange + var destination = Path.Combine(AssemblyDirectory, "RelatedAssembly.dll"); + var codeBase = "file://x/file/Assembly.dll"; + var expected = new Uri(codeBase).LocalPath; + var assembly = new TestAssembly { - File.Delete(destination); - } + CodeBaseSettable = codeBase, + }; + + // Act + var actual = RelatedAssemblyAttribute.GetAssemblyLocation(assembly); + Assert.Equal(expected, actual); + } + + [Fact] + public void GetAssemblyLocation_UsesLocation_IfCodeBaseIsNotLocal() + { + // Arrange + var destination = Path.Combine(AssemblyDirectory, "RelatedAssembly.dll"); + var expected = Path.Combine(AssemblyDirectory, "Some-Dir", "Assembly.dll"); + var assembly = new TestAssembly + { + CodeBaseSettable = "https://www.microsoft.com/test.dll", + LocationSettable = expected, + }; + + // Act + var actual = RelatedAssemblyAttribute.GetAssemblyLocation(assembly); + Assert.Equal(expected, actual); } private class TestAssembly : Assembly @@ -92,7 +118,13 @@ namespace Microsoft.AspNetCore.Mvc.ApplicationParts public string AttributeAssembly { get; set; } - public override string CodeBase => Path.Combine(AssemblyDirectory, "MyAssembly.dll"); + public string CodeBaseSettable { get; set; } = Path.Combine(AssemblyDirectory, "MyAssembly.dll"); + + public override string CodeBase => CodeBaseSettable; + + public string LocationSettable { get; set; } = Path.Combine(AssemblyDirectory, "MyAssembly.dll"); + + public override string Location => LocationSettable; public override object[] GetCustomAttributes(Type attributeType, bool inherit) {