diff --git a/src/Scrutor/RegistrationStrategy.cs b/src/Scrutor/RegistrationStrategy.cs
index 48943849..ab5ac286 100644
--- a/src/Scrutor/RegistrationStrategy.cs
+++ b/src/Scrutor/RegistrationStrategy.cs
@@ -6,10 +6,15 @@ namespace Scrutor;
public abstract class RegistrationStrategy
{
///
- /// Skips registrations for services that already exists.
+ /// Appends a new registration when no registration exists for the same Service type.
///
public static readonly RegistrationStrategy Skip = new SkipRegistrationStrategy();
+ ///
+ /// Appends a new registration when no registration exists for the same Service and Implementation type.
+ ///
+ public static readonly RegistrationStrategy Distinct = new DistinctRegistrationStrategy();
+
///
/// Appends a new registration for existing services.
///
@@ -49,6 +54,26 @@ private sealed class SkipRegistrationStrategy : RegistrationStrategy
public override void Apply(IServiceCollection services, ServiceDescriptor descriptor) => services.TryAdd(descriptor);
}
+ private sealed class DistinctRegistrationStrategy : RegistrationStrategy {
+ ///
+ /// Adds the service descriptor if the service collection does not contain a desriptor with the same Service and Implementation type.
+ ///
+ /// The service collection.
+ /// The descriptor to apply.
+ ///
+ /// Unable to use
+ /// TryAddEnumerable()
+ /// since it would throw an ArgumentException when used with AsSelf().
+ ///
+ public override void Apply(IServiceCollection services, ServiceDescriptor descriptor)
+ {
+ if (services.HasRegistration(descriptor)) {
+ return;
+ }
+ services.Add(descriptor);
+ }
+ }
+
private sealed class AppendRegistrationStrategy : RegistrationStrategy
{
public override void Apply(IServiceCollection services, ServiceDescriptor descriptor) => services.Add(descriptor);
@@ -85,24 +110,38 @@ public override void Apply(IServiceCollection services, ServiceDescriptor descri
behavior = ReplacementBehavior.ServiceType;
}
- if (behavior.HasFlag(ReplacementBehavior.ServiceType))
+ if (behavior == ReplacementBehavior.Both)
{
+ var implementationType = descriptor.GetImplementationType();
for (var i = services.Count - 1; i >= 0; i--)
{
- if (services[i].ServiceType == descriptor.ServiceType)
+ if (services[i].ServiceType == descriptor.ServiceType && services[i].GetImplementationType() == implementationType)
{
services.RemoveAt(i);
}
}
}
+ else {
+ if (behavior.HasFlag(ReplacementBehavior.ServiceType))
+ {
+ for (var i = services.Count - 1; i >= 0; i--)
+ {
+ if (services[i].ServiceType == descriptor.ServiceType)
+ {
+ services.RemoveAt(i);
+ }
+ }
+ }
- if (behavior.HasFlag(ReplacementBehavior.ImplementationType))
- {
- for (var i = services.Count - 1; i >= 0; i--)
+ if (behavior.HasFlag(ReplacementBehavior.ImplementationType))
{
- if (services[i].ImplementationType == descriptor.ImplementationType)
+ var implementationType = descriptor.GetImplementationType();
+ for (var i = services.Count - 1; i >= 0; i--)
{
- services.RemoveAt(i);
+ if (services[i].GetImplementationType() == implementationType)
+ {
+ services.RemoveAt(i);
+ }
}
}
}
@@ -110,4 +149,4 @@ public override void Apply(IServiceCollection services, ServiceDescriptor descri
services.Add(descriptor);
}
}
-}
\ No newline at end of file
+}
diff --git a/src/Scrutor/ReplacementBehavior.cs b/src/Scrutor/ReplacementBehavior.cs
index 23aaff4f..8db62725 100644
--- a/src/Scrutor/ReplacementBehavior.cs
+++ b/src/Scrutor/ReplacementBehavior.cs
@@ -21,7 +21,12 @@ public enum ReplacementBehavior
ImplementationType = 2,
///
- /// Replace existing services by either service- or implementation type.
+ /// Replace existing services with the same service or implementation type.
///
- All = ServiceType | ImplementationType
-}
\ No newline at end of file
+ Either = ServiceType | ImplementationType,
+
+ ///
+ /// Replace existing services with the same service and implementation type.
+ ///
+ Both = 4
+}
diff --git a/src/Scrutor/ServiceCollectionExtensions.cs b/src/Scrutor/ServiceCollectionExtensions.cs
index d4002c61..638ec415 100644
--- a/src/Scrutor/ServiceCollectionExtensions.cs
+++ b/src/Scrutor/ServiceCollectionExtensions.cs
@@ -10,4 +10,16 @@ public static bool HasRegistration(this IServiceCollection services, Type servic
{
return services.Any(x => x.ServiceType == serviceType);
}
-}
\ No newline at end of file
+
+ ///
+ /// Determines whether the service collection has a descriptor with the same Service and Implementation types.
+ ///
+ /// The service collection.
+ /// The service descriptor.
+ /// true if the service collection contains the specified service descriptor; otherwise, false.
+ public static bool HasRegistration(this IServiceCollection services, ServiceDescriptor descriptor)
+ {
+ var implementationType = descriptor.GetImplementationType();
+ return services.Any(x => x.ServiceType == descriptor.ServiceType && x.ImplementationType == implementationType);
+ }
+}
diff --git a/src/Scrutor/ServiceDescriptorExtensions.cs b/src/Scrutor/ServiceDescriptorExtensions.cs
index f734fe1d..2d433536 100644
--- a/src/Scrutor/ServiceDescriptorExtensions.cs
+++ b/src/Scrutor/ServiceDescriptorExtensions.cs
@@ -15,4 +15,37 @@ public static ServiceDescriptor WithImplementationFactory(this ServiceDescriptor
{ ImplementationInstance: not null } => new ServiceDescriptor(serviceType, descriptor.ImplementationInstance),
_ => throw new ArgumentException($"No implementation factory or instance or type found for {descriptor.ServiceType}.", nameof(descriptor))
};
+
+ ///
+ /// Gets the service descriptor's implementation type.
+ ///
+ /// The service descriptor.
+ /// System.Type?.
+ ///
+ /// Mostly replicates ServiceDescriptor.GetImplementationType()
+ ///
+ public static Type? GetImplementationType(this ServiceDescriptor descriptor)
+ {
+ if (descriptor.ImplementationType != null)
+ {
+ return descriptor.ImplementationType;
+ }
+ else if (descriptor.ImplementationInstance != null)
+ {
+ return descriptor.ImplementationInstance.GetType();
+ }
+ else if (descriptor.ImplementationFactory != null)
+ {
+ Type[]? typeArguments = descriptor.ImplementationFactory.GetType().GenericTypeArguments;
+ if (typeArguments[1] == typeof(object))
+ {
+ return descriptor.ImplementationFactory.Method.ReturnType;
+ }
+ else
+ {
+ return typeArguments[1];
+ }
+ }
+ return null;
+ }
}
diff --git a/test/Scrutor.Tests/ScanningTests.cs b/test/Scrutor.Tests/ScanningTests.cs
index 541f198f..b77a3523 100644
--- a/test/Scrutor.Tests/ScanningTests.cs
+++ b/test/Scrutor.Tests/ScanningTests.cs
@@ -48,16 +48,69 @@ public void UsingRegistrationStrategy_None()
}
[Fact]
- public void UsingRegistrationStrategy_SkipIfExists()
+ public void UsingRegistrationStrategy_Skip()
{
Collection.Scan(scan => scan
.FromAssemblyOf()
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Skip)
+ .AsImplementedInterfaces()
+ .WithTransientLifetime());
+
+ var services = Collection.GetDescriptors();
+
+ Assert.Equal(1, services.Count(x => x.ServiceType == typeof(ITransientService)));
+ }
+
+ [Fact]
+ public void UsingRegistrationStrategy_SkipAfterNone()
+ {
+ Collection.Scan(scan => scan
+ .FromAssemblyOf()
+ // registers 4
.AddClasses(classes => classes.AssignableTo())
.AsImplementedInterfaces()
.WithTransientLifetime()
+ // no new registrations
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Skip)
+ .AsImplementedInterfaces()
+ .WithSingletonLifetime());
+
+ var services = Collection.GetDescriptors();
+
+ Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService)));
+ }
+
+ [Fact]
+ public void UsingRegistrationStrategy_Distinct()
+ {
+ Collection.Scan(scan => scan
+ .FromAssemblyOf()
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Distinct)
+ .AsImplementedInterfaces()
+ .WithTransientLifetime());
+
+ var services = Collection.GetDescriptors();
+
+ Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService)));
+ }
+
+ [Fact]
+ public void UsingRegistrationStrategy_DistinctAfterSkip()
+ {
+ Collection.Scan(scan => scan
+ .FromAssemblyOf()
+ // registers 1
.AddClasses(classes => classes.AssignableTo())
.UsingRegistrationStrategy(RegistrationStrategy.Skip)
.AsImplementedInterfaces()
+ .WithTransientLifetime()
+ // registers the other three
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Distinct)
+ .AsImplementedInterfaces()
.WithSingletonLifetime());
var services = Collection.GetDescriptors();
@@ -65,6 +118,46 @@ public void UsingRegistrationStrategy_SkipIfExists()
Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService)));
}
+ [Fact]
+ public void UsingRegistrationStrategy_DistinctAfterNone()
+ {
+ Collection.Scan(scan => scan
+ .FromAssemblyOf()
+ // register 4
+ .AddClasses(classes => classes.AssignableTo())
+ .AsImplementedInterfaces()
+ .WithTransientLifetime()
+ // no new registrations
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Distinct)
+ .AsImplementedInterfaces()
+ .WithSingletonLifetime());
+
+ var services = Collection.GetDescriptors();
+
+ Assert.Equal(4, services.Count(x => x.ServiceType == typeof(ITransientService)));
+ }
+
+ [Fact]
+ public void UsingRegistrationStrategy_DistinctWithSelf()
+ {
+ Collection.Scan(scan => scan
+ .FromAssemblyOf()
+ // registers 9
+ .AddClasses(classes => classes.AssignableTo())
+ .AsImplementedInterfaces()
+ .AsSelf()
+ .WithTransientLifetime()
+ // no new registrations, and does not throw due to not using TryAddEnumerable() with AsSelf()
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Distinct)
+ .AsImplementedInterfaces()
+ .AsSelf()
+ .WithTransientLifetime());
+
+ Assert.Equal(9, Collection.Count);
+ }
+
[Fact]
public void UsingRegistrationStrategy_ReplaceDefault()
{
@@ -119,6 +212,52 @@ public void UsingRegistrationStrategy_ReplaceImplementationTypes()
Assert.Equal(3, services.Count(x => x.ServiceType == typeof(ITransientService)));
}
+ [Theory]
+ [InlineData(ReplacementBehavior.ServiceType)]
+ [InlineData(ReplacementBehavior.ImplementationType)]
+ [InlineData(ReplacementBehavior.Both)]
+ [InlineData(ReplacementBehavior.Either)]
+ public void UsingRegistrationStrategy_Replace_ReplacesInstances(ReplacementBehavior behavior)
+ {
+ var instanceToReplace = new Replacement1();
+ Collection.Add(new(typeof(IReplacement), instanceToReplace));
+
+ Collection.Scan(scan => scan
+ .FromAssemblyOf()
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Replace(behavior))
+ .AsImplementedInterfaces()
+ .WithSingletonLifetime());
+
+ var services = Collection.GetDescriptors();
+
+ Assert.Equal(1, services.Count(x => x.ServiceType == typeof(IReplacement)));
+ Assert.Equal(0, services.Count(x => x.ImplementationInstance == instanceToReplace));
+ }
+
+ [Theory]
+ [InlineData(ReplacementBehavior.ServiceType)]
+ [InlineData(ReplacementBehavior.ImplementationType)]
+ [InlineData(ReplacementBehavior.Both)]
+ [InlineData(ReplacementBehavior.Either)]
+ public void UsingRegistrationStrategy_Replace_ReplacesFactories(ReplacementBehavior behavior)
+ {
+ Replacement1 factory(IServiceProvider _) => new();
+ Collection.Add(new(typeof(IReplacement), factory, ServiceLifetime.Transient));
+
+ Collection.Scan(scan => scan
+ .FromAssemblyOf()
+ .AddClasses(classes => classes.AssignableTo())
+ .UsingRegistrationStrategy(RegistrationStrategy.Replace(behavior))
+ .AsImplementedInterfaces()
+ .WithSingletonLifetime());
+
+ var services = Collection.GetDescriptors();
+
+ Assert.Equal(1, services.Count(x => x.ServiceType == typeof(IReplacement)));
+ Assert.Equal(0, services.Count(x => x.ImplementationFactory is not null));
+ }
+
[Fact]
public void UsingRegistrationStrategy_Throw()
{
@@ -614,6 +753,10 @@ public interface IMixedAttribute { }
[ServiceDescriptor(typeof(IMixedAttribute), ServiceLifetime.Scoped)]
[ServiceDescriptor(ServiceLifetime.Singleton)]
public class MixedAttribute : IMixedAttribute { }
+
+ public interface IReplacement { }
+
+ public class Replacement1 : IReplacement { }
}
namespace Scrutor.Tests.ChildNamespace