From 14f0ee691e9aca8cd3932c2f66028d875630e81d Mon Sep 17 00:00:00 2001 From: Chris Gillum Date: Mon, 22 Dec 2025 13:50:21 -0800 Subject: [PATCH] Added overloads for simplified func-based activity and orchestrator registration --- misc/misc.csproj | 1 + src/Abstractions/DurableTaskAttribute.cs | 6 +- .../DurableTaskRegistry.Activities.cs | 176 ++++++++++++++++- .../DurableTaskRegistry.Orchestrators.cs | 181 +++++++++++++++++- src/Abstractions/Entities/TaskEntity.cs | 6 + .../DurableTaskRegistryTests.Activities.cs | 103 ++++++++++ .../DurableTaskRegistryTests.Orchestrators.cs | 103 ++++++++++ 7 files changed, 572 insertions(+), 4 deletions(-) diff --git a/misc/misc.csproj b/misc/misc.csproj index 0e70dd0f2..b2c1120e9 100644 --- a/misc/misc.csproj +++ b/misc/misc.csproj @@ -18,6 +18,7 @@ + diff --git a/src/Abstractions/DurableTaskAttribute.cs b/src/Abstractions/DurableTaskAttribute.cs index 1b7caa4f8..38cc18e68 100644 --- a/src/Abstractions/DurableTaskAttribute.cs +++ b/src/Abstractions/DurableTaskAttribute.cs @@ -4,16 +4,18 @@ namespace Microsoft.DurableTask; /// -/// Indicates that the attributed class represents a durable task. +/// Indicates that the attributed class or method represents a durable task. /// /// /// This attribute is meant to be used on class definitions that derive from /// , , /// or TaskEntity{TState} from the Microsoft.DurableTask.Entities namespace. +/// It can also be applied to methods used with +/// or similar overloads to specify a custom name for the orchestrator. /// It is used specifically by build-time source generators to generate type-safe methods for invoking /// orchestrations, activities, or registering entities. /// -[AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = false)] +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = false)] public sealed class DurableTaskAttribute : Attribute { /// diff --git a/src/Abstractions/DurableTaskRegistry.Activities.cs b/src/Abstractions/DurableTaskRegistry.Activities.cs index ac525147a..8f559cdb5 100644 --- a/src/Abstractions/DurableTaskRegistry.Activities.cs +++ b/src/Abstractions/DurableTaskRegistry.Activities.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Reflection; using Microsoft.Extensions.DependencyInjection; namespace Microsoft.DurableTask; @@ -21,7 +22,17 @@ TaskName and TActivity generic parameter ITaskActivity singleton TaskName ITaskActivity singleton - by func/action: + by func/action (with explicit name): + Func{Context, Input, Task{Output}} + Func{Context, Input, Task} + Func{Context, Input, Output} + Func{Context, Task{Output}} + Func{Context, Task} + Func{Context, Output} + Action{Context, TInput} + Action{Context} + + by func/action (name inferred from method or [DurableTask] attribute): Func{Context, Input, Task{Output}} Func{Context, Input, Task} Func{Context, Input, Output} @@ -219,4 +230,167 @@ public DurableTaskRegistry AddActivityFunc(TaskName name, Action + /// Registers an activity factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The activity input type. + /// The activity output type. + /// The activity implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddActivityFunc( + Func> activity) + { + Check.NotNull(activity); + return this.AddActivityFunc(GetActivityNameFromDelegate(activity), activity); + } + + /// + /// Registers an activity factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The activity input type. + /// The activity output type. + /// The activity implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddActivityFunc( + Func activity) + { + Check.NotNull(activity); + return this.AddActivityFunc(GetActivityNameFromDelegate(activity), activity); + } + + /// + /// Registers an activity factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The activity input type. + /// The activity implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddActivityFunc(Func activity) + { + Check.NotNull(activity); + return this.AddActivityFunc(GetActivityNameFromDelegate(activity), activity); + } + + /// + /// Registers an activity factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The activity output type. + /// The activity implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddActivityFunc(Func> activity) + { + Check.NotNull(activity); + return this.AddActivityFunc(GetActivityNameFromDelegate(activity), activity); + } + + /// + /// Registers an activity factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The activity implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddActivityFunc(Func activity) + { + Check.NotNull(activity); + return this.AddActivityFunc(GetActivityNameFromDelegate(activity), activity); + } + + /// + /// Registers an activity factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The activity output type. + /// The activity implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddActivityFunc(Func activity) + { + Check.NotNull(activity); + return this.AddActivityFunc(GetActivityNameFromDelegate(activity), activity); + } + + /// + /// Registers an activity factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The activity input type. + /// The activity implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddActivityFunc(Action activity) + { + Check.NotNull(activity); + return this.AddActivityFunc(GetActivityNameFromDelegate(activity), activity); + } + + /// + /// Registers an activity factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The activity implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddActivityFunc(Action activity) + { + Check.NotNull(activity); + return this.AddActivityFunc(GetActivityNameFromDelegate(activity), activity); + } + + /// + /// Gets the task name from a delegate by checking for a + /// or falling back to the method name. + /// + /// The delegate to extract the name from. + /// The task name. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + static TaskName GetActivityNameFromDelegate(Delegate @delegate) + { + MethodInfo method = @delegate.Method; + + // Check for DurableTaskAttribute on the method + DurableTaskAttribute? attribute = method.GetCustomAttribute(); + if (attribute?.Name.Name is not null and not "") + { + return attribute.Name; + } + + // Fall back to method name + string? methodName = method.Name; + if (string.IsNullOrEmpty(methodName) || methodName.StartsWith("<", StringComparison.Ordinal)) + { + throw new ArgumentException( + "Cannot infer activity name from the delegate. The delegate must either have a " + + "[DurableTask] attribute with a name, or be a named method (not a lambda or anonymous delegate).", + nameof(@delegate)); + } + + return new TaskName(methodName); + } } diff --git a/src/Abstractions/DurableTaskRegistry.Orchestrators.cs b/src/Abstractions/DurableTaskRegistry.Orchestrators.cs index 7ad7583f0..1686fcf54 100644 --- a/src/Abstractions/DurableTaskRegistry.Orchestrators.cs +++ b/src/Abstractions/DurableTaskRegistry.Orchestrators.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Reflection; + namespace Microsoft.DurableTask; /// @@ -19,7 +21,17 @@ TaskName and TOrchestrator generic parameter ITaskOrchestrator singleton TaskName and ITaskOrchestrator singleton - by func/action: + by func/action (with explicit name): + Func{Context, Input, Task{Output}} + Func{Context, Input, Task} + Func{Context, Input, Output} + Func{Context, Task{Output}} + Func{Context, Task} + Func{Context, Output} + Action{Context, TInput} + Action{Context} + + by func/action (name inferred from method or [DurableTask] attribute): Func{Context, Input, Task{Output}} Func{Context, Input, Task} Func{Context, Input, Output} @@ -220,4 +232,171 @@ public DurableTaskRegistry AddOrchestratorFunc(TaskName name, Action + /// Registers an orchestrator factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The orchestrator input type. + /// The orchestrator output type. + /// The orchestrator implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddOrchestratorFunc( + Func> orchestrator) + { + Check.NotNull(orchestrator); + return this.AddOrchestratorFunc(GetTaskNameFromDelegate(orchestrator), orchestrator); + } + + /// + /// Registers an orchestrator factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The orchestrator input type. + /// The orchestrator output type. + /// The orchestrator implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddOrchestratorFunc( + Func orchestrator) + { + Check.NotNull(orchestrator); + return this.AddOrchestratorFunc(GetTaskNameFromDelegate(orchestrator), orchestrator); + } + + /// + /// Registers an orchestrator factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The orchestrator input type. + /// The orchestrator implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddOrchestratorFunc( + Func orchestrator) + { + Check.NotNull(orchestrator); + return this.AddOrchestratorFunc(GetTaskNameFromDelegate(orchestrator), orchestrator); + } + + /// + /// Registers an orchestrator factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The orchestrator output type. + /// The orchestrator implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddOrchestratorFunc( + Func> orchestrator) + { + Check.NotNull(orchestrator); + return this.AddOrchestratorFunc(GetTaskNameFromDelegate(orchestrator), orchestrator); + } + + /// + /// Registers an orchestrator factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The orchestrator implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddOrchestratorFunc(Func orchestrator) + { + Check.NotNull(orchestrator); + return this.AddOrchestratorFunc(GetTaskNameFromDelegate(orchestrator), orchestrator); + } + + /// + /// Registers an orchestrator factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The orchestrator output type. + /// The orchestrator implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddOrchestratorFunc( + Func orchestrator) + { + Check.NotNull(orchestrator); + return this.AddOrchestratorFunc(GetTaskNameFromDelegate(orchestrator), orchestrator); + } + + /// + /// Registers an orchestrator factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The orchestrator input type. + /// The orchestrator implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddOrchestratorFunc( + Action orchestrator) + { + Check.NotNull(orchestrator); + return this.AddOrchestratorFunc(GetTaskNameFromDelegate(orchestrator), orchestrator); + } + + /// + /// Registers an orchestrator factory, where the implementation is . + /// The name is inferred from a on the method, or the method name. + /// + /// The orchestrator implementation. + /// The same registry, for call chaining. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + public DurableTaskRegistry AddOrchestratorFunc(Action orchestrator) + { + Check.NotNull(orchestrator); + return this.AddOrchestratorFunc(GetTaskNameFromDelegate(orchestrator), orchestrator); + } + + /// + /// Gets the task name from a delegate by checking for a + /// or falling back to the method name. + /// + /// The delegate to extract the name from. + /// The task name. + /// + /// Thrown if the name cannot be inferred from the delegate. + /// + static TaskName GetTaskNameFromDelegate(Delegate @delegate) + { + MethodInfo method = @delegate.Method; + + // Check for DurableTaskAttribute on the method + DurableTaskAttribute? attribute = method.GetCustomAttribute(); + if (attribute?.Name.Name is not null and not "") + { + return attribute.Name; + } + + // Fall back to method name + string? methodName = method.Name; + if (string.IsNullOrEmpty(methodName) || methodName.StartsWith("<", StringComparison.Ordinal)) + { + throw new ArgumentException( + "Cannot infer orchestrator name from the delegate. The delegate must either have a " + + "[DurableTask] attribute with a name, or be a named method (not a lambda or anonymous delegate).", + nameof(@delegate)); + } + + return new TaskName(methodName); + } } diff --git a/src/Abstractions/Entities/TaskEntity.cs b/src/Abstractions/Entities/TaskEntity.cs index d8fee1ee2..27624a5fd 100644 --- a/src/Abstractions/Entities/TaskEntity.cs +++ b/src/Abstractions/Entities/TaskEntity.cs @@ -148,6 +148,12 @@ public abstract class TaskEntity : ITaskEntity /// The default implementation uses . protected virtual TState InitializeState(TaskEntityOperation entityOperation) { + // Throw if TState is a string, since strings are immutable and don't support dynamic creation of new instances. + if (typeof(TState) == typeof(string)) + { + throw new InvalidOperationException("Entity state cannot be a string. Use a class or struct instead."); + } + if (Nullable.GetUnderlyingType(typeof(TState)) is Type t) { // Activator.CreateInstance>() returns null. To avoid this, we will instantiate via underlying diff --git a/test/Worker/Core.Tests/DurableTaskRegistryTests.Activities.cs b/test/Worker/Core.Tests/DurableTaskRegistryTests.Activities.cs index bae370d84..f5182602d 100644 --- a/test/Worker/Core.Tests/DurableTaskRegistryTests.Activities.cs +++ b/test/Worker/Core.Tests/DurableTaskRegistryTests.Activities.cs @@ -163,6 +163,96 @@ public void AddActivity_Action2_Success() => RunAddActivityTest( r => r.AddActivityFunc(nameof(TestActivity), (TaskActivityContext ctx) => { })); + [Fact] + public void AddActivity_FuncNoName1_Success() + => RunAddActivityTest( + r => r.AddActivityFunc(NamedActivityFunc1), + nameof(NamedActivityFunc1)); + + [Fact] + public void AddActivity_FuncNoName2_Success() + => RunAddActivityTest( + r => r.AddActivityFunc(NamedActivityFunc2), + nameof(NamedActivityFunc2)); + + [Fact] + public void AddActivity_FuncNoName3_Success() + => RunAddActivityTest( + r => r.AddActivityFunc(NamedActivityFunc3), + nameof(NamedActivityFunc3)); + + [Fact] + public void AddActivity_FuncNoName4_Success() + => RunAddActivityTest( + r => r.AddActivityFunc(NamedActivityFunc4), + nameof(NamedActivityFunc4)); + + [Fact] + public void AddActivity_FuncNoName5_Success() + => RunAddActivityTest( + r => r.AddActivityFunc(NamedActivityFunc5), + nameof(NamedActivityFunc5)); + + [Fact] + public void AddActivity_FuncNoName6_Success() + => RunAddActivityTest( + r => r.AddActivityFunc(NamedActivityFunc6), + nameof(NamedActivityFunc6)); + + [Fact] + public void AddActivity_ActionNoName1_Success() + => RunAddActivityTest( + r => r.AddActivityFunc(NamedActivityAction1), + nameof(NamedActivityAction1)); + + [Fact] + public void AddActivity_ActionNoName2_Success() + => RunAddActivityTest( + r => r.AddActivityFunc(NamedActivityAction2), + nameof(NamedActivityAction2)); + + [Fact] + public void AddActivity_FuncNoName_WithAttribute_Success() + => RunAddActivityTest( + r => r.AddActivityFunc(AttributedActivityFunc), + "CustomActivityName"); + + [Fact] + public void AddActivity_FuncNoName_Lambda_Throws() + { + DurableTaskRegistry registry = new(); + Action act = () => registry.AddActivityFunc((TaskActivityContext ctx) => Task.CompletedTask); + act.Should().ThrowExactly(); + } + + [Fact] + public void AddActivity_FuncNoName_LambdaWithInput_Throws() + { + DurableTaskRegistry registry = new(); + Action act = () => registry.AddActivityFunc( + (TaskActivityContext ctx, string input) => Task.FromResult(input)); + act.Should().ThrowExactly(); + } + + static Task NamedActivityFunc1(TaskActivityContext ctx, string input) => Task.FromResult(input); + + static string NamedActivityFunc2(TaskActivityContext ctx, string input) => input; + + static Task NamedActivityFunc3(TaskActivityContext ctx, string input) => Task.CompletedTask; + + static Task NamedActivityFunc4(TaskActivityContext ctx) => Task.FromResult(string.Empty); + + static Task NamedActivityFunc5(TaskActivityContext ctx) => Task.CompletedTask; + + static string NamedActivityFunc6(TaskActivityContext ctx) => string.Empty; + + static void NamedActivityAction1(TaskActivityContext ctx, string input) { } + + static void NamedActivityAction2(TaskActivityContext ctx) { } + + [DurableTask("CustomActivityName")] + static Task AttributedActivityFunc(TaskActivityContext ctx) => Task.CompletedTask; + static ITaskActivity RunAddActivityTest(Action callback) { DurableTaskRegistry registry = new(); @@ -176,6 +266,19 @@ static ITaskActivity RunAddActivityTest(Action callback) return actual!; } + static ITaskActivity RunAddActivityTest(Action callback, string expectedName) + { + DurableTaskRegistry registry = new(); + callback(registry); + IDurableTaskFactory factory = registry.BuildFactory(); + + bool found = factory.TryCreateActivity( + expectedName, Mock.Of(), out ITaskActivity? actual); + found.Should().BeTrue(); + actual.Should().NotBeNull(); + return actual!; + } + abstract class InvalidActivity : TaskActivity { } diff --git a/test/Worker/Core.Tests/DurableTaskRegistryTests.Orchestrators.cs b/test/Worker/Core.Tests/DurableTaskRegistryTests.Orchestrators.cs index ebddb3175..149b4df21 100644 --- a/test/Worker/Core.Tests/DurableTaskRegistryTests.Orchestrators.cs +++ b/test/Worker/Core.Tests/DurableTaskRegistryTests.Orchestrators.cs @@ -168,6 +168,96 @@ public void AddOrchestrator_Action2_Success() => RunAddOrchestratorTest( r => r.AddOrchestratorFunc(nameof(TestOrchestrator), (TaskOrchestrationContext ctx) => { })); + [Fact] + public void AddOrchestrator_FuncNoName1_Success() + => RunAddOrchestratorTest( + r => r.AddOrchestratorFunc(NamedOrchestratorFunc1), + nameof(NamedOrchestratorFunc1)); + + [Fact] + public void AddOrchestrator_FuncNoName2_Success() + => RunAddOrchestratorTest( + r => r.AddOrchestratorFunc(NamedOrchestratorFunc2), + nameof(NamedOrchestratorFunc2)); + + [Fact] + public void AddOrchestrator_FuncNoName3_Success() + => RunAddOrchestratorTest( + r => r.AddOrchestratorFunc(NamedOrchestratorFunc3), + nameof(NamedOrchestratorFunc3)); + + [Fact] + public void AddOrchestrator_FuncNoName4_Success() + => RunAddOrchestratorTest( + r => r.AddOrchestratorFunc(NamedOrchestratorFunc4), + nameof(NamedOrchestratorFunc4)); + + [Fact] + public void AddOrchestrator_FuncNoName5_Success() + => RunAddOrchestratorTest( + r => r.AddOrchestratorFunc(NamedOrchestratorFunc5), + nameof(NamedOrchestratorFunc5)); + + [Fact] + public void AddOrchestrator_FuncNoName6_Success() + => RunAddOrchestratorTest( + r => r.AddOrchestratorFunc(NamedOrchestratorFunc6), + nameof(NamedOrchestratorFunc6)); + + [Fact] + public void AddOrchestrator_ActionNoName1_Success() + => RunAddOrchestratorTest( + r => r.AddOrchestratorFunc(NamedOrchestratorAction1), + nameof(NamedOrchestratorAction1)); + + [Fact] + public void AddOrchestrator_ActionNoName2_Success() + => RunAddOrchestratorTest( + r => r.AddOrchestratorFunc(NamedOrchestratorAction2), + nameof(NamedOrchestratorAction2)); + + [Fact] + public void AddOrchestrator_FuncNoName_WithAttribute_Success() + => RunAddOrchestratorTest( + r => r.AddOrchestratorFunc(AttributedOrchestratorFunc), + "CustomOrchestratorName"); + + [Fact] + public void AddOrchestrator_FuncNoName_Lambda_Throws() + { + DurableTaskRegistry registry = new(); + Action act = () => registry.AddOrchestratorFunc((TaskOrchestrationContext ctx) => Task.CompletedTask); + act.Should().ThrowExactly(); + } + + [Fact] + public void AddOrchestrator_FuncNoName_LambdaWithInput_Throws() + { + DurableTaskRegistry registry = new(); + Action act = () => registry.AddOrchestratorFunc( + (TaskOrchestrationContext ctx, string input) => Task.FromResult(input)); + act.Should().ThrowExactly(); + } + + static Task NamedOrchestratorFunc1(TaskOrchestrationContext ctx, string input) => Task.FromResult(input); + + static string NamedOrchestratorFunc2(TaskOrchestrationContext ctx, string input) => input; + + static Task NamedOrchestratorFunc3(TaskOrchestrationContext ctx, string input) => Task.CompletedTask; + + static Task NamedOrchestratorFunc4(TaskOrchestrationContext ctx) => Task.FromResult(string.Empty); + + static Task NamedOrchestratorFunc5(TaskOrchestrationContext ctx) => Task.CompletedTask; + + static string NamedOrchestratorFunc6(TaskOrchestrationContext ctx) => string.Empty; + + static void NamedOrchestratorAction1(TaskOrchestrationContext ctx, string input) { } + + static void NamedOrchestratorAction2(TaskOrchestrationContext ctx) { } + + [DurableTask("CustomOrchestratorName")] + static Task AttributedOrchestratorFunc(TaskOrchestrationContext ctx) => Task.CompletedTask; + static ITaskOrchestrator RunAddOrchestratorTest(Action callback) { DurableTaskRegistry registry = new(); @@ -181,6 +271,19 @@ static ITaskOrchestrator RunAddOrchestratorTest(Action call return actual!; } + static ITaskOrchestrator RunAddOrchestratorTest(Action callback, string expectedName) + { + DurableTaskRegistry registry = new(); + callback(registry); + IDurableTaskFactory factory = registry.BuildFactory(); + + bool found = factory.TryCreateOrchestrator( + expectedName, Mock.Of(), out ITaskOrchestrator? actual); + found.Should().BeTrue(); + actual.Should().NotBeNull(); + return actual!; + } + abstract class InvalidOrchestrator: TaskOrchestrator { }