diff --git a/src/Scrutor/ServiceCollectionExtensions.Decoration.cs b/src/Scrutor/ServiceCollectionExtensions.Decoration.cs index 83c5dee7..bc34c450 100644 --- a/src/Scrutor/ServiceCollectionExtensions.Decoration.cs +++ b/src/Scrutor/ServiceCollectionExtensions.Decoration.cs @@ -78,7 +78,8 @@ public static bool TryDecorate(this IServiceCollection services, Type serviceTyp if (serviceType.IsOpenGeneric() && decoratorType.IsOpenGeneric()) { - return services.TryDecorateOpenGeneric(serviceType, decoratorType); + var openTypeTryDecorator = OpenTypeTryDecorator(services, serviceType, decoratorType); + return services.TryDecorateOpenGeneric(serviceType, openTypeTryDecorator); } return services.TryDecorateDescriptors(serviceType, x => x.Decorate(decoratorType)); @@ -170,6 +171,11 @@ public static IServiceCollection Decorate(this IServiceCollection services, Type Preconditions.NotNull(serviceType, nameof(serviceType)); Preconditions.NotNull(decorator, nameof(decorator)); + if (serviceType.IsOpenGeneric()) + { + return services.DecorateOpenGeneric(serviceType, decorator); + } + return services.DecorateDescriptors(serviceType, x => x.Decorate(decorator)); } @@ -230,7 +236,8 @@ public static bool TryDecorate(this IServiceCollection services, Type serviceTyp private static IServiceCollection DecorateOpenGeneric(this IServiceCollection services, Type serviceType, Type decoratorType) { - if (services.TryDecorateOpenGeneric(serviceType, decoratorType)) + var openTypeTryDecorator = OpenTypeTryDecorator(services, serviceType, decoratorType); + if (services.TryDecorateOpenGeneric(serviceType, openTypeTryDecorator)) { return services; } @@ -243,16 +250,19 @@ private static bool IsSameGenericType(Type t1, Type t2) return t1.IsGenericType && t2.IsGenericType && t1.GetGenericTypeDefinition() == t2.GetGenericTypeDefinition(); } - private static bool TryDecorateOpenGeneric(this IServiceCollection services, Type serviceType, Type decoratorType) + private static IServiceCollection DecorateOpenGeneric(this IServiceCollection services, Type serviceType, Func decorator) { - bool TryDecorate(Type[] typeArguments) + var openTypeTryDecorator = OpenTypeTryDecorator(services, serviceType, decorator); + if (services.TryDecorateOpenGeneric(serviceType, openTypeTryDecorator)) { - var closedServiceType = serviceType.MakeGenericType(typeArguments); - var closedDecoratorType = decoratorType.MakeGenericType(typeArguments); - - return services.TryDecorateDescriptors(closedServiceType, x => x.Decorate(closedDecoratorType)); + return services; } + throw new MissingTypeRegistrationException(serviceType); + } + + private static bool TryDecorateOpenGeneric(this IServiceCollection services, Type serviceType, Func openTypeTryDecorator) + { var arguments = services .Where(descriptor => IsSameGenericType(descriptor.ServiceType, serviceType)) .Select(descriptor => descriptor.ServiceType.GenericTypeArguments) @@ -263,7 +273,27 @@ bool TryDecorate(Type[] typeArguments) return false; } - return arguments.Aggregate(true, (result, args) => result && TryDecorate(args)); + return arguments.Aggregate(true, (result, args) => result && openTypeTryDecorator(args)); + } + + private static Func OpenTypeTryDecorator(IServiceCollection services, Type serviceType, Type decoratorType) + { + return typeArguments => + { + var closedServiceType = serviceType.MakeGenericType(typeArguments); + var closedDecoratorType = decoratorType.MakeGenericType(typeArguments); + + return services.TryDecorateDescriptors(closedServiceType, x => x.Decorate(closedDecoratorType)); + }; + } + + private static Func OpenTypeTryDecorator(IServiceCollection services, Type serviceType, Func decorator) + { + return typeArguments => + { + var closedServiceType = serviceType.MakeGenericType(typeArguments); + return services.TryDecorateDescriptors(closedServiceType, x => x.Decorate(decorator)); + }; } private static IServiceCollection DecorateDescriptors(this IServiceCollection services, Type serviceType, Func decorator) diff --git a/test/Scrutor.Tests/OpenGenericDecorationTests.cs b/test/Scrutor.Tests/OpenGenericDecorationTests.cs index c090766f..98e8e5ee 100644 --- a/test/Scrutor.Tests/OpenGenericDecorationTests.cs +++ b/test/Scrutor.Tests/OpenGenericDecorationTests.cs @@ -41,6 +41,28 @@ public void CanDecorateOpenGenericTypeBasedOnInterface() Assert.IsType(loggingDecorator.Inner); } + [Fact] + public void CanDecorateOpenGenericTypeBasedOnInterfaceByDecoratorFunc() + { + var provider = ConfigureProvider(services => + { + services.AddSingleton, MySpecialQueryHandler>(); + services.Decorate(typeof(IQueryHandler<,>), (handlerObj, serviceProvider) => + { + if (handlerObj is ISpecialInterface specialInterface) + { + specialInterface.InitSomeField(); + } + + return handlerObj; + }); + }); + + var instance = provider.GetRequiredService>(); + var myQueryHandler = Assert.IsType(instance); + Assert.True(myQueryHandler.GetSomeField()); + } + [Fact] public void DecoratingNonRegisteredOpenGenericServiceThrows() { @@ -79,6 +101,22 @@ public void DecoratingOpenGenericTypeBasedOnGrandparentInterfaceDoesNotDecorateP } } + public interface ISpecialInterface + { + void InitSomeField(); + } + + public class MySpecialQueryHandler : QueryHandler, ISpecialInterface + { + private bool _someField = false; + public void InitSomeField() + { + _someField = true; + } + + public bool GetSomeField() => _someField; + } + public class MyQuery { } public class MyResult { }