Make the use of Assembly.CodeBase more robust

This commit is contained in:
Pranav K 2018-03-22 16:09:03 -07:00
parent 56501cb8a0
commit 1d6c09ab31
2 changed files with 64 additions and 17 deletions

View File

@ -50,10 +50,14 @@ namespace Microsoft.AspNetCore.Mvc.ApplicationParts
throw new ArgumentNullException(nameof(assembly));
}
return GetRelatedAssemblies(assembly, throwOnError, AssemblyLoadFileDelegate);
return GetRelatedAssemblies(assembly, throwOnError, File.Exists, AssemblyLoadFileDelegate);
}
internal static IReadOnlyList<Assembly> GetRelatedAssemblies(Assembly assembly, bool throwOnError, Func<string, Assembly> loadFile)
internal static IReadOnlyList<Assembly> GetRelatedAssemblies(
Assembly assembly,
bool throwOnError,
Func<string, bool> fileExists,
Func<string, Assembly> loadFile)
{
if (assembly == null)
{
@ -74,7 +78,8 @@ namespace Microsoft.AspNetCore.Mvc.ApplicationParts
}
var assemblyName = assembly.GetName().Name;
var assemblyDirectory = Path.GetDirectoryName(assembly.CodeBase);
var assemblyLocation = GetAssemblyLocation(assembly);
var assemblyDirectory = Path.GetDirectoryName(assemblyLocation);
var relatedAssemblies = new List<Assembly>();
for (var i = 0; i < attributes.Length; i++)
@ -87,7 +92,7 @@ namespace Microsoft.AspNetCore.Mvc.ApplicationParts
}
var relatedAssemblyLocation = Path.Combine(assemblyDirectory, attribute.AssemblyFileName + ".dll");
if (!File.Exists(relatedAssemblyLocation))
if (!fileExists(relatedAssemblyLocation))
{
if (throwOnError)
{
@ -107,5 +112,15 @@ namespace Microsoft.AspNetCore.Mvc.ApplicationParts
return relatedAssemblies;
}
internal static string GetAssemblyLocation(Assembly assembly)
{
if (Uri.TryCreate(assembly.CodeBase, UriKind.Absolute, out var result) && result.IsFile)
{
return result.LocalPath;
}
return assembly.Location;
}
}
}

View File

@ -67,20 +67,46 @@ namespace Microsoft.AspNetCore.Mvc.ApplicationParts
};
var relatedAssembly = typeof(RelatedAssemblyPartTest).Assembly;
try
var result = RelatedAssemblyAttribute.GetRelatedAssemblies(assembly, throwOnError: true, file => true, file =>
{
File.WriteAllBytes(destination, new byte[0]);
var result = RelatedAssemblyAttribute.GetRelatedAssemblies(assembly, throwOnError: true, file =>
{
Assert.Equal(file, destination);
return relatedAssembly;
});
Assert.Equal(new[] { relatedAssembly }, result);
}
finally
Assert.Equal(file, destination);
return relatedAssembly;
});
Assert.Equal(new[] { relatedAssembly }, result);
}
[Fact]
public void GetAssemblyLocation_UsesCodeBase()
{
// Arrange
var destination = Path.Combine(AssemblyDirectory, "RelatedAssembly.dll");
var codeBase = "file://x/file/Assembly.dll";
var expected = new Uri(codeBase).LocalPath;
var assembly = new TestAssembly
{
File.Delete(destination);
}
CodeBaseSettable = codeBase,
};
// Act
var actual = RelatedAssemblyAttribute.GetAssemblyLocation(assembly);
Assert.Equal(expected, actual);
}
[Fact]
public void GetAssemblyLocation_UsesLocation_IfCodeBaseIsNotLocal()
{
// Arrange
var destination = Path.Combine(AssemblyDirectory, "RelatedAssembly.dll");
var expected = Path.Combine(AssemblyDirectory, "Some-Dir", "Assembly.dll");
var assembly = new TestAssembly
{
CodeBaseSettable = "https://www.microsoft.com/test.dll",
LocationSettable = expected,
};
// Act
var actual = RelatedAssemblyAttribute.GetAssemblyLocation(assembly);
Assert.Equal(expected, actual);
}
private class TestAssembly : Assembly
@ -92,7 +118,13 @@ namespace Microsoft.AspNetCore.Mvc.ApplicationParts
public string AttributeAssembly { get; set; }
public override string CodeBase => Path.Combine(AssemblyDirectory, "MyAssembly.dll");
public string CodeBaseSettable { get; set; } = Path.Combine(AssemblyDirectory, "MyAssembly.dll");
public override string CodeBase => CodeBaseSettable;
public string LocationSettable { get; set; } = Path.Combine(AssemblyDirectory, "MyAssembly.dll");
public override string Location => LocationSettable;
public override object[] GetCustomAttributes(Type attributeType, bool inherit)
{