diff --git a/src/Scrutor/RegistrationStrategy.cs b/src/Scrutor/RegistrationStrategy.cs index 48943849..ab5ac286 100644 --- a/src/Scrutor/RegistrationStrategy.cs +++ b/src/Scrutor/RegistrationStrategy.cs @@ -6,10 +6,15 @@ namespace Scrutor; public abstract class RegistrationStrategy { /// - /// Skips registrations for services that already exists. + /// Appends a new registration when no registration exists for the same Service type. /// public static readonly RegistrationStrategy Skip = new SkipRegistrationStrategy(); + /// + /// Appends a new registration when no registration exists for the same Service and Implementation type. + /// + public static readonly RegistrationStrategy Distinct = new DistinctRegistrationStrategy(); + /// /// Appends a new registration for existing services. /// @@ -49,6 +54,26 @@ private sealed class SkipRegistrationStrategy : RegistrationStrategy public override void Apply(IServiceCollection services, ServiceDescriptor descriptor) => services.TryAdd(descriptor); } + private sealed class DistinctRegistrationStrategy : RegistrationStrategy { + /// + /// Adds the service descriptor if the service collection does not contain a desriptor with the same Service and Implementation type. + /// + /// The service collection. + /// The descriptor to apply. + /// + /// Unable to use + /// TryAddEnumerable() + /// since it would throw an ArgumentException when used with AsSelf(). + /// + public override void Apply(IServiceCollection services, ServiceDescriptor descriptor) + { + if (services.HasRegistration(descriptor)) { + return; + } + services.Add(descriptor); + } + } + private sealed class AppendRegistrationStrategy : RegistrationStrategy { public override void Apply(IServiceCollection services, ServiceDescriptor descriptor) => services.Add(descriptor); @@ -85,24 +110,38 @@ public override void Apply(IServiceCollection services, ServiceDescriptor descri behavior = ReplacementBehavior.ServiceType; } - if (behavior.HasFlag(ReplacementBehavior.ServiceType)) + if (behavior == ReplacementBehavior.Both) { + var implementationType = descriptor.GetImplementationType(); for (var i = services.Count - 1; i >= 0; i--) { - if (services[i].ServiceType == descriptor.ServiceType) + if (services[i].ServiceType == descriptor.ServiceType && services[i].GetImplementationType() == implementationType) { services.RemoveAt(i); } } } + else { + if (behavior.HasFlag(ReplacementBehavior.ServiceType)) + { + for (var i = services.Count - 1; i >= 0; i--) + { + if (services[i].ServiceType == descriptor.ServiceType) + { + services.RemoveAt(i); + } + } + } - if (behavior.HasFlag(ReplacementBehavior.ImplementationType)) - { - for (var i = services.Count - 1; i >= 0; i--) + if (behavior.HasFlag(ReplacementBehavior.ImplementationType)) { - if (services[i].ImplementationType == descriptor.ImplementationType) + var implementationType = descriptor.GetImplementationType(); + for (var i = services.Count - 1; i >= 0; i--) { - services.RemoveAt(i); + if (services[i].GetImplementationType() == implementationType) + { + services.RemoveAt(i); + } } } } @@ -110,4 +149,4 @@ public override void Apply(IServiceCollection services, ServiceDescriptor descri services.Add(descriptor); } } -} \ No newline at end of file +} diff --git a/src/Scrutor/ReplacementBehavior.cs b/src/Scrutor/ReplacementBehavior.cs index 23aaff4f..8db62725 100644 --- a/src/Scrutor/ReplacementBehavior.cs +++ b/src/Scrutor/ReplacementBehavior.cs @@ -21,7 +21,12 @@ public enum ReplacementBehavior ImplementationType = 2, /// - /// Replace existing services by either service- or implementation type. + /// Replace existing services with the same service or implementation type. /// - All = ServiceType | ImplementationType -} \ No newline at end of file + Either = ServiceType | ImplementationType, + + /// + /// Replace existing services with the same service and implementation type. + /// + Both = 4 +} diff --git a/src/Scrutor/ServiceCollectionExtensions.cs b/src/Scrutor/ServiceCollectionExtensions.cs index d4002c61..638ec415 100644 --- a/src/Scrutor/ServiceCollectionExtensions.cs +++ b/src/Scrutor/ServiceCollectionExtensions.cs @@ -10,4 +10,16 @@ public static bool HasRegistration(this IServiceCollection services, Type servic { return services.Any(x => x.ServiceType == serviceType); } -} \ No newline at end of file + + /// + /// Determines whether the service collection has a descriptor with the same Service and Implementation types. + /// + /// The service collection. + /// The service descriptor. + /// true if the service collection contains the specified service descriptor; otherwise, false. + public static bool HasRegistration(this IServiceCollection services, ServiceDescriptor descriptor) + { + var implementationType = descriptor.GetImplementationType(); + return services.Any(x => x.ServiceType == descriptor.ServiceType && x.ImplementationType == implementationType); + } +} diff --git a/src/Scrutor/ServiceDescriptorExtensions.cs b/src/Scrutor/ServiceDescriptorExtensions.cs index f734fe1d..2d433536 100644 --- a/src/Scrutor/ServiceDescriptorExtensions.cs +++ b/src/Scrutor/ServiceDescriptorExtensions.cs @@ -15,4 +15,37 @@ public static ServiceDescriptor WithImplementationFactory(this ServiceDescriptor { ImplementationInstance: not null } => new ServiceDescriptor(serviceType, descriptor.ImplementationInstance), _ => throw new ArgumentException($"No implementation factory or instance or type found for {descriptor.ServiceType}.", nameof(descriptor)) }; + + /// + /// Gets the service descriptor's implementation type. + /// + /// The service descriptor. + /// System.Type?. + /// + /// Mostly replicates ServiceDescriptor.GetImplementationType() + /// + public static Type? GetImplementationType(this ServiceDescriptor descriptor) + { + if (descriptor.ImplementationType != null) + { + return descriptor.ImplementationType; + } + else if (descriptor.ImplementationInstance != null) + { + return descriptor.ImplementationInstance.GetType(); + } + else if (descriptor.ImplementationFactory != null) + { + Type[]? typeArguments = descriptor.ImplementationFactory.GetType().GenericTypeArguments; + if (typeArguments[1] == typeof(object)) + { + return descriptor.ImplementationFactory.Method.ReturnType; + } + else + { + return typeArguments[1]; + } + } + return null; + } } diff --git a/test/Scrutor.Tests/ScanningTests.cs b/test/Scrutor.Tests/ScanningTests.cs index 541f198f..b77a3523 100644 --- a/test/Scrutor.Tests/ScanningTests.cs +++ b/test/Scrutor.Tests/ScanningTests.cs @@ -48,16 +48,69 @@ public void UsingRegistrationStrategy_None() } [Fact] - public void UsingRegistrationStrategy_SkipIfExists() + public void UsingRegistrationStrategy_Skip() { Collection.Scan(scan => scan .FromAssemblyOf() + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Skip) + .AsImplementedInterfaces() + .WithTransientLifetime()); + + var services = Collection.GetDescriptors(); + + Assert.Equal(1, services.Count(x => x.ServiceType == typeof(ITransientService))); + } + + [Fact] + public void UsingRegistrationStrategy_SkipAfterNone() + { + Collection.Scan(scan => scan + .FromAssemblyOf() + // registers 4 .AddClasses(classes => classes.AssignableTo()) .AsImplementedInterfaces() .WithTransientLifetime() + // no new registrations + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Skip) + .AsImplementedInterfaces() + .WithSingletonLifetime()); + + var services = Collection.GetDescriptors(); + + Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService))); + } + + [Fact] + public void UsingRegistrationStrategy_Distinct() + { + Collection.Scan(scan => scan + .FromAssemblyOf() + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Distinct) + .AsImplementedInterfaces() + .WithTransientLifetime()); + + var services = Collection.GetDescriptors(); + + Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService))); + } + + [Fact] + public void UsingRegistrationStrategy_DistinctAfterSkip() + { + Collection.Scan(scan => scan + .FromAssemblyOf() + // registers 1 .AddClasses(classes => classes.AssignableTo()) .UsingRegistrationStrategy(RegistrationStrategy.Skip) .AsImplementedInterfaces() + .WithTransientLifetime() + // registers the other three + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Distinct) + .AsImplementedInterfaces() .WithSingletonLifetime()); var services = Collection.GetDescriptors(); @@ -65,6 +118,46 @@ public void UsingRegistrationStrategy_SkipIfExists() Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService))); } + [Fact] + public void UsingRegistrationStrategy_DistinctAfterNone() + { + Collection.Scan(scan => scan + .FromAssemblyOf() + // register 4 + .AddClasses(classes => classes.AssignableTo()) + .AsImplementedInterfaces() + .WithTransientLifetime() + // no new registrations + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Distinct) + .AsImplementedInterfaces() + .WithSingletonLifetime()); + + var services = Collection.GetDescriptors(); + + Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService))); + } + + [Fact] + public void UsingRegistrationStrategy_DistinctWithSelf() + { + Collection.Scan(scan => scan + .FromAssemblyOf() + // registers 9 + .AddClasses(classes => classes.AssignableTo()) + .AsImplementedInterfaces() + .AsSelf() + .WithTransientLifetime() + // no new registrations, and does not throw due to not using TryAddEnumerable() with AsSelf() + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Distinct) + .AsImplementedInterfaces() + .AsSelf() + .WithTransientLifetime()); + + Assert.Equal(9, Collection.Count); + } + [Fact] public void UsingRegistrationStrategy_ReplaceDefault() { @@ -119,6 +212,52 @@ public void UsingRegistrationStrategy_ReplaceImplementationTypes() Assert.Equal(3, services.Count(x => x.ServiceType == typeof(ITransientService))); } + [Theory] + [InlineData(ReplacementBehavior.ServiceType)] + [InlineData(ReplacementBehavior.ImplementationType)] + [InlineData(ReplacementBehavior.Both)] + [InlineData(ReplacementBehavior.Either)] + public void UsingRegistrationStrategy_Replace_ReplacesInstances(ReplacementBehavior behavior) + { + var instanceToReplace = new Replacement1(); + Collection.Add(new(typeof(IReplacement), instanceToReplace)); + + Collection.Scan(scan => scan + .FromAssemblyOf() + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Replace(behavior)) + .AsImplementedInterfaces() + .WithSingletonLifetime()); + + var services = Collection.GetDescriptors(); + + Assert.Equal(1, services.Count(x => x.ServiceType == typeof(IReplacement))); + Assert.Equal(0, services.Count(x => x.ImplementationInstance == instanceToReplace)); + } + + [Theory] + [InlineData(ReplacementBehavior.ServiceType)] + [InlineData(ReplacementBehavior.ImplementationType)] + [InlineData(ReplacementBehavior.Both)] + [InlineData(ReplacementBehavior.Either)] + public void UsingRegistrationStrategy_Replace_ReplacesFactories(ReplacementBehavior behavior) + { + Replacement1 factory(IServiceProvider _) => new(); + Collection.Add(new(typeof(IReplacement), factory, ServiceLifetime.Transient)); + + Collection.Scan(scan => scan + .FromAssemblyOf() + .AddClasses(classes => classes.AssignableTo()) + .UsingRegistrationStrategy(RegistrationStrategy.Replace(behavior)) + .AsImplementedInterfaces() + .WithSingletonLifetime()); + + var services = Collection.GetDescriptors(); + + Assert.Equal(1, services.Count(x => x.ServiceType == typeof(IReplacement))); + Assert.Equal(0, services.Count(x => x.ImplementationFactory is not null)); + } + [Fact] public void UsingRegistrationStrategy_Throw() { @@ -614,6 +753,10 @@ public interface IMixedAttribute { } [ServiceDescriptor(typeof(IMixedAttribute), ServiceLifetime.Scoped)] [ServiceDescriptor(ServiceLifetime.Singleton)] public class MixedAttribute : IMixedAttribute { } + + public interface IReplacement { } + + public class Replacement1 : IReplacement { } } namespace Scrutor.Tests.ChildNamespace