diff --git a/src/Worker/Core/Shims/TaskActivityShim.cs b/src/Worker/Core/Shims/TaskActivityShim.cs index ae1f3bb06..b40386173 100644 --- a/src/Worker/Core/Shims/TaskActivityShim.cs +++ b/src/Worker/Core/Shims/TaskActivityShim.cs @@ -44,6 +44,10 @@ public TaskActivityShim( TaskActivityContextWrapper contextWrapper = new(coreContext, this.name); string instanceId = coreContext.OrchestrationInstance.InstanceId; + using IDisposable? scope = this.logger.BeginScope(new Dictionary + { + ["InstanceId"] = instanceId, + }); this.logger.ActivityStarted(instanceId, this.name); try diff --git a/src/Worker/Core/Shims/TaskOrchestrationShim.cs b/src/Worker/Core/Shims/TaskOrchestrationShim.cs index eb7a179be..2c68148d4 100644 --- a/src/Worker/Core/Shims/TaskOrchestrationShim.cs +++ b/src/Worker/Core/Shims/TaskOrchestrationShim.cs @@ -57,55 +57,59 @@ public TaskOrchestrationShim( /// public override async Task Execute(OrchestrationContext innerContext, string rawInput) - { - Check.NotNull(innerContext); - JsonDataConverterShim converterShim = new(this.invocationContext.Options.DataConverter); - innerContext.MessageDataConverter = converterShim; - innerContext.ErrorDataConverter = converterShim; - - object? input = this.DataConverter.Deserialize(rawInput, this.implementation.InputType); - this.wrapperContext = new(innerContext, this.invocationContext, input, this.properties); - - string instanceId = innerContext.OrchestrationInstance.InstanceId; - if (!innerContext.IsReplaying) { - this.logger.OrchestrationStarted(instanceId, this.invocationContext.Name); - } + Check.NotNull(innerContext); + JsonDataConverterShim converterShim = new(this.invocationContext.Options.DataConverter); + innerContext.MessageDataConverter = converterShim; + innerContext.ErrorDataConverter = converterShim; - try - { - object? output = await this.implementation.RunAsync(this.wrapperContext, input); + object? input = this.DataConverter.Deserialize(rawInput, this.implementation.InputType); + this.wrapperContext = new(innerContext, this.invocationContext, input, this.properties); + string instanceId = innerContext.OrchestrationInstance.InstanceId; + using IDisposable? scope = this.logger.BeginScope(new Dictionary + { + ["InstanceId"] = instanceId, + }); if (!innerContext.IsReplaying) { - this.logger.OrchestrationCompleted(instanceId, this.invocationContext.Name); + this.logger.OrchestrationStarted(instanceId, this.invocationContext.Name); } - // Return the output (if any) as a serialized string. - return this.DataConverter.Serialize(output); - } - catch (TaskFailedException e) - { - if (!innerContext.IsReplaying) + try { - this.logger.OrchestrationFailed(e, instanceId, this.invocationContext.Name); - } + object? output = await this.implementation.RunAsync(this.wrapperContext, input); + + if (!innerContext.IsReplaying) + { + this.logger.OrchestrationCompleted(instanceId, this.invocationContext.Name); + } - // Convert back to something the Durable Task Framework natively understands so that - // failure details are correctly propagated. - throw new CoreTaskFailedException(e.Message, e.InnerException) + // Return the output (if any) as a serialized string. + return this.DataConverter.Serialize(output); + } + catch (TaskFailedException e) { - FailureDetails = new FailureDetails(e, - e.FailureDetails.ToCoreFailureDetails(), - properties: e.FailureDetails.Properties), - }; - } - finally - { - // if user code crashed inside a critical section, or did not exit it, do that now - this.wrapperContext.ExitCriticalSectionIfNeeded(); + if (!innerContext.IsReplaying) + { + this.logger.OrchestrationFailed(e, instanceId, this.invocationContext.Name); + } + + // Convert back to something the Durable Task Framework natively understands so that + // failure details are correctly propagated. + throw new CoreTaskFailedException(e.Message, e.InnerException) + { + FailureDetails = new FailureDetails(e, + e.FailureDetails.ToCoreFailureDetails(), + properties: e.FailureDetails.Properties), + }; + } + finally + { + // if user code crashed inside a critical section, or did not exit it, do that now + this.wrapperContext.ExitCriticalSectionIfNeeded(); + } } - } /// public override string? GetStatus() diff --git a/test/Worker/Core.Tests/Shims/TaskShimLoggingScopeTests.cs b/test/Worker/Core.Tests/Shims/TaskShimLoggingScopeTests.cs new file mode 100644 index 000000000..ae3f4e789 --- /dev/null +++ b/test/Worker/Core.Tests/Shims/TaskShimLoggingScopeTests.cs @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using DurableTask.Core; +using Microsoft.DurableTask; +using Microsoft.DurableTask.Converters; +using Microsoft.DurableTask.Worker; +using Microsoft.DurableTask.Worker.Shims; +using Microsoft.Extensions.Logging; + +namespace Microsoft.DurableTask.Worker.Shims; + +public class TaskShimLoggingScopeTests +{ + [Fact] + public async Task TaskActivityShim_RunAsync_UsesInstanceIdScope() + { + // Arrange + string instanceId = Guid.NewGuid().ToString("N"); + IDictionary? scopeState = null; + Mock loggerMock = new(); + loggerMock.Setup(l => l.BeginScope(It.IsAny>())) + .Callback((IDictionary state) => scopeState = state) + .Returns(Mock.Of()); + loggerMock.Setup(l => l.IsEnabled(It.IsAny())).Returns(true); + Mock loggerFactoryMock = new(); + loggerFactoryMock.Setup(f => f.CreateLogger(It.IsAny())).Returns(loggerMock.Object); + TaskActivityShim shim = new(loggerFactoryMock.Object, JsonDataConverter.Default, new TaskName("TestActivity"), new TestActivity()); + TaskContext coreContext = new(new OrchestrationInstance { InstanceId = instanceId }); + + // Act + await shim.RunAsync(coreContext, "\"input\""); + + // Assert + scopeState.Should().NotBeNull(); + scopeState!.Should().ContainKey("InstanceId").WhoseValue.Should().Be(instanceId); + } + + [Fact] + public async Task TaskOrchestrationShim_Execute_UsesInstanceIdScope() + { + // Arrange + string instanceId = Guid.NewGuid().ToString("N"); + IDictionary? scopeState = null; + Mock loggerMock = new(); + loggerMock.Setup(l => l.BeginScope(It.IsAny>())) + .Callback((IDictionary state) => scopeState = state) + .Returns(Mock.Of()); + loggerMock.Setup(l => l.IsEnabled(It.IsAny())).Returns(true); + Mock loggerFactoryMock = new(); + loggerFactoryMock.Setup(f => f.CreateLogger(It.IsAny())).Returns(loggerMock.Object); + OrchestrationInvocationContext invocationContext = new(new TaskName("TestOrchestrator"), new DurableTaskWorkerOptions(), loggerFactoryMock.Object); + TaskOrchestrationShim shim = new(invocationContext, new TestOrchestrator()); + TestOrchestrationContext innerContext = new(instanceId); + + // Act + await shim.Execute(innerContext, "\"input\""); + + // Assert + scopeState.Should().NotBeNull(); + scopeState!.Should().ContainKey("InstanceId").WhoseValue.Should().Be(instanceId); + } + + class TestActivity : TaskActivity + { + public override Task RunAsync(TaskActivityContext context, string input) + { + return Task.FromResult("ok"); + } + } + + class TestOrchestrator : TaskOrchestrator + { + public override Task RunAsync(TaskOrchestrationContext context, string input) + { + return Task.FromResult("ok"); + } + } + + class TestOrchestrationContext : OrchestrationContext + { + public TestOrchestrationContext(string instanceId) + { + this.OrchestrationInstance = new OrchestrationInstance + { + InstanceId = instanceId, + ExecutionId = Guid.NewGuid().ToString("N"), + }; + } + + public override Task ScheduleTask(string name, string version, params object[] parameters) + { + throw new NotImplementedException(); + } + + public override Task CreateTimer(DateTime fireAt, T state) + { + throw new NotImplementedException(); + } + + public override Task CreateTimer(DateTime fireAt, T state, CancellationToken cancelToken) + { + throw new NotImplementedException(); + } + + public override Task CreateSubOrchestrationInstance(string name, string version, object input) + { + throw new NotImplementedException(); + } + + public override Task CreateSubOrchestrationInstance(string name, string version, string instanceId, object input) + { + throw new NotImplementedException(); + } + + public override Task CreateSubOrchestrationInstance(string name, string version, string instanceId, object input, IDictionary tags) + { + throw new NotImplementedException(); + } + + public override void SendEvent(OrchestrationInstance orchestrationInstance, string eventName, object eventData) + { + throw new NotImplementedException(); + } + + public override void ContinueAsNew(object input) + { + throw new NotImplementedException(); + } + + public override void ContinueAsNew(string newVersion, object input) + { + throw new NotImplementedException(); + } + } +}