Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/Worker/Core/Shims/TaskActivityShim.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ public TaskActivityShim(
TaskActivityContextWrapper contextWrapper = new(coreContext, this.name);

string instanceId = coreContext.OrchestrationInstance.InstanceId;
using IDisposable? scope = this.logger.BeginScope(new Dictionary<string, object?>
{
["InstanceId"] = instanceId,
});
this.logger.ActivityStarted(instanceId, this.name);

try
Expand Down
80 changes: 42 additions & 38 deletions src/Worker/Core/Shims/TaskOrchestrationShim.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,55 +57,59 @@ public TaskOrchestrationShim(

/// <inheritdoc/>
public override async Task<string?> Execute(OrchestrationContext innerContext, string rawInput)
{
Check.NotNull(innerContext);
JsonDataConverterShim converterShim = new(this.invocationContext.Options.DataConverter);
innerContext.MessageDataConverter = converterShim;
innerContext.ErrorDataConverter = converterShim;

object? input = this.DataConverter.Deserialize(rawInput, this.implementation.InputType);
this.wrapperContext = new(innerContext, this.invocationContext, input, this.properties);

string instanceId = innerContext.OrchestrationInstance.InstanceId;
if (!innerContext.IsReplaying)
{
this.logger.OrchestrationStarted(instanceId, this.invocationContext.Name);
}
Check.NotNull(innerContext);
JsonDataConverterShim converterShim = new(this.invocationContext.Options.DataConverter);
innerContext.MessageDataConverter = converterShim;
innerContext.ErrorDataConverter = converterShim;

try
{
object? output = await this.implementation.RunAsync(this.wrapperContext, input);
object? input = this.DataConverter.Deserialize(rawInput, this.implementation.InputType);
this.wrapperContext = new(innerContext, this.invocationContext, input, this.properties);

string instanceId = innerContext.OrchestrationInstance.InstanceId;
using IDisposable? scope = this.logger.BeginScope(new Dictionary<string, object?>
{
["InstanceId"] = instanceId,
});
if (!innerContext.IsReplaying)
{
this.logger.OrchestrationCompleted(instanceId, this.invocationContext.Name);
this.logger.OrchestrationStarted(instanceId, this.invocationContext.Name);
}

// Return the output (if any) as a serialized string.
return this.DataConverter.Serialize(output);
}
catch (TaskFailedException e)
{
if (!innerContext.IsReplaying)
try
{
this.logger.OrchestrationFailed(e, instanceId, this.invocationContext.Name);
}
object? output = await this.implementation.RunAsync(this.wrapperContext, input);

if (!innerContext.IsReplaying)
{
this.logger.OrchestrationCompleted(instanceId, this.invocationContext.Name);
}

// Convert back to something the Durable Task Framework natively understands so that
// failure details are correctly propagated.
throw new CoreTaskFailedException(e.Message, e.InnerException)
// Return the output (if any) as a serialized string.
return this.DataConverter.Serialize(output);
}
catch (TaskFailedException e)
{
FailureDetails = new FailureDetails(e,
e.FailureDetails.ToCoreFailureDetails(),
properties: e.FailureDetails.Properties),
};
}
finally
{
// if user code crashed inside a critical section, or did not exit it, do that now
this.wrapperContext.ExitCriticalSectionIfNeeded();
if (!innerContext.IsReplaying)
{
this.logger.OrchestrationFailed(e, instanceId, this.invocationContext.Name);
}

// Convert back to something the Durable Task Framework natively understands so that
// failure details are correctly propagated.
throw new CoreTaskFailedException(e.Message, e.InnerException)
{
FailureDetails = new FailureDetails(e,
e.FailureDetails.ToCoreFailureDetails(),
properties: e.FailureDetails.Properties),
};
}
finally
{
// if user code crashed inside a critical section, or did not exit it, do that now
this.wrapperContext.ExitCriticalSectionIfNeeded();
}
}
}

/// <inheritdoc/>
public override string? GetStatus()
Expand Down
136 changes: 136 additions & 0 deletions test/Worker/Core.Tests/Shims/TaskShimLoggingScopeTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using DurableTask.Core;
using Microsoft.DurableTask;
using Microsoft.DurableTask.Converters;
using Microsoft.DurableTask.Worker;
using Microsoft.DurableTask.Worker.Shims;
using Microsoft.Extensions.Logging;

namespace Microsoft.DurableTask.Worker.Shims;

public class TaskShimLoggingScopeTests
{
[Fact]
public async Task TaskActivityShim_RunAsync_UsesInstanceIdScope()
{
// Arrange
string instanceId = Guid.NewGuid().ToString("N");
IDictionary<string, object?>? scopeState = null;
Mock<ILogger> loggerMock = new();
loggerMock.Setup(l => l.BeginScope(It.IsAny<IDictionary<string, object?>>()))
.Callback((IDictionary<string, object?> state) => scopeState = state)
.Returns(Mock.Of<IDisposable>());
loggerMock.Setup(l => l.IsEnabled(It.IsAny<LogLevel>())).Returns(true);
Mock<ILoggerFactory> loggerFactoryMock = new();
loggerFactoryMock.Setup(f => f.CreateLogger(It.IsAny<string>())).Returns(loggerMock.Object);
TaskActivityShim shim = new(loggerFactoryMock.Object, JsonDataConverter.Default, new TaskName("TestActivity"), new TestActivity());
TaskContext coreContext = new(new OrchestrationInstance { InstanceId = instanceId });

// Act
await shim.RunAsync(coreContext, "\"input\"");

// Assert
scopeState.Should().NotBeNull();
scopeState!.Should().ContainKey("InstanceId").WhoseValue.Should().Be(instanceId);
}

[Fact]
public async Task TaskOrchestrationShim_Execute_UsesInstanceIdScope()
{
// Arrange
string instanceId = Guid.NewGuid().ToString("N");
IDictionary<string, object?>? scopeState = null;
Mock<ILogger> loggerMock = new();
loggerMock.Setup(l => l.BeginScope(It.IsAny<IDictionary<string, object?>>()))
.Callback((IDictionary<string, object?> state) => scopeState = state)
.Returns(Mock.Of<IDisposable>());
loggerMock.Setup(l => l.IsEnabled(It.IsAny<LogLevel>())).Returns(true);
Mock<ILoggerFactory> loggerFactoryMock = new();
loggerFactoryMock.Setup(f => f.CreateLogger(It.IsAny<string>())).Returns(loggerMock.Object);
OrchestrationInvocationContext invocationContext = new(new TaskName("TestOrchestrator"), new DurableTaskWorkerOptions(), loggerFactoryMock.Object);
TaskOrchestrationShim shim = new(invocationContext, new TestOrchestrator());
TestOrchestrationContext innerContext = new(instanceId);

// Act
await shim.Execute(innerContext, "\"input\"");

// Assert
scopeState.Should().NotBeNull();
scopeState!.Should().ContainKey("InstanceId").WhoseValue.Should().Be(instanceId);
}

class TestActivity : TaskActivity<string, string>
{
public override Task<string> RunAsync(TaskActivityContext context, string input)
{
return Task.FromResult("ok");
}
}

class TestOrchestrator : TaskOrchestrator<string, string>
{
public override Task<string> RunAsync(TaskOrchestrationContext context, string input)
{
return Task.FromResult("ok");
}
}

class TestOrchestrationContext : OrchestrationContext
{
public TestOrchestrationContext(string instanceId)
{
this.OrchestrationInstance = new OrchestrationInstance
{
InstanceId = instanceId,
ExecutionId = Guid.NewGuid().ToString("N"),
};
}

public override Task<TResult> ScheduleTask<TResult>(string name, string version, params object[] parameters)
{
throw new NotImplementedException();
}

public override Task<T> CreateTimer<T>(DateTime fireAt, T state)
{
throw new NotImplementedException();
}

public override Task<T> CreateTimer<T>(DateTime fireAt, T state, CancellationToken cancelToken)
{
throw new NotImplementedException();
}

public override Task<T> CreateSubOrchestrationInstance<T>(string name, string version, object input)
{
throw new NotImplementedException();
}

public override Task<T> CreateSubOrchestrationInstance<T>(string name, string version, string instanceId, object input)
{
throw new NotImplementedException();
}

public override Task<T> CreateSubOrchestrationInstance<T>(string name, string version, string instanceId, object input, IDictionary<string, string> tags)
{
throw new NotImplementedException();
}

public override void SendEvent(OrchestrationInstance orchestrationInstance, string eventName, object eventData)
{
throw new NotImplementedException();
}

public override void ContinueAsNew(object input)
{
throw new NotImplementedException();
}

public override void ContinueAsNew(string newVersion, object input)
{
throw new NotImplementedException();
}
}
}
Loading