diff --git a/src/Scrutor/ServiceCollectionExtensions.Decoration.cs b/src/Scrutor/ServiceCollectionExtensions.Decoration.cs index 205c1fa8..d2dbdde8 100644 --- a/src/Scrutor/ServiceCollectionExtensions.Decoration.cs +++ b/src/Scrutor/ServiceCollectionExtensions.Decoration.cs @@ -21,7 +21,35 @@ public static IServiceCollection Decorate(this IServiceCol { Preconditions.NotNull(services, nameof(services)); - return services.DecorateDescriptors(typeof(TService), x => x.Decorate(typeof(TDecorator))); + if (typeof(TDecorator).IsInterface) + { + return DecorateUsingInterface(services); + } + else + { + return services.DecorateDescriptors(typeof(TService), x => x.Decorate(typeof(TDecorator))); + } + } + + private static IServiceCollection DecorateUsingInterface(IServiceCollection services) where TDecorator : TService + { + if (typeof(TDecorator).IsGenericType) + { + var decoratorDescriptor = services.Where(service => HasSameTypeDefinition(service.ServiceType, typeof(TDecorator))).FirstOrDefault(); + if (decoratorDescriptor == null) + throw new MissingTypeRegistrationException(typeof(TDecorator).IsGenericType ? typeof(TDecorator).GetGenericTypeDefinition() : typeof(TDecorator)); + + return services.DecorateDescriptors(typeof(TService), x => x.Decorate(decoratorDescriptor.ImplementationType.MakeGenericType(typeof(TDecorator).GetGenericArguments().First()))); + } + else + { + var decoratorDescriptor = services.Where(service => service.ServiceType == typeof(TDecorator)).FirstOrDefault(); + if (decoratorDescriptor == null) + throw new MissingTypeRegistrationException(typeof(TDecorator).IsGenericType ? typeof(TDecorator).GetGenericTypeDefinition() : typeof(TDecorator)); + + return services.DecorateDescriptors(typeof(TService), x => x.Decorate(decoratorDescriptor.ImplementationType)); + } + } ///