Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/task-graph/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ export * from "./task/TaskRegistry";
export * from "./task/JobQueueTask";
export * from "./task/TaskQueueRegistry";
export * from "./task/ArrayTask";
export * from "./task/StreamingUtils";
export * from "./task/StreamingTypes";

export * from "./task-graph/DataflowEvents";
export * from "./task-graph/Dataflow";
Expand Down
64 changes: 64 additions & 0 deletions packages/task-graph/src/task-graph/Dataflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,18 @@ export class Dataflow {
public provenance: Provenance = {};
public status: TaskStatus = TaskStatus.PENDING;
public error: TaskError | undefined;
/** Streaming value for incremental updates */
public streamingValue: AsyncIterable<any> | null = null;
/** Whether this dataflow is currently streaming */
public isStreaming: boolean = false;

public reset() {
this.status = TaskStatus.PENDING;
this.error = undefined;
this.value = undefined;
this.provenance = {};
this.streamingValue = null;
this.isStreaming = false;
this.emit("reset");
this.emit("status", this.status);
}
Expand Down Expand Up @@ -98,6 +104,64 @@ export class Dataflow {
this.value = entireDataBlock[this.sourceTaskPortId];
}
if (nodeProvenance) this.provenance = nodeProvenance;
// Clear streaming state when complete data is set
this.isStreaming = false;
this.streamingValue = null;
}

/**
* Sets streaming port data for incremental updates
* @param chunk Partial data chunk from streaming task
* @param nodeProvenance Optional provenance information
*/
setStreamingPortData(chunk: any, nodeProvenance?: Provenance) {
if (!this.isStreaming) {
this.isStreaming = true;
this.status = TaskStatus.PROCESSING;
this.emit("status", this.status);
}

if (this.sourceTaskPortId === DATAFLOW_ALL_PORTS) {
// Merge chunk into existing value or set as new
if (this.value === undefined) {
this.value = chunk;
} else if (typeof chunk === "object" && chunk !== null && !Array.isArray(chunk)) {
this.value = { ...this.value, ...chunk };
} else {
// For non-object chunks, replace or append based on type
this.value = chunk;
}
} else if (this.sourceTaskPortId === DATAFLOW_ERROR_PORT) {
this.error = chunk;
} else {
// Merge chunk property into existing value
const chunkValue = chunk[this.sourceTaskPortId];
if (chunkValue !== undefined) {
if (this.value === undefined) {
this.value = chunkValue;
} else if (typeof chunkValue === "object" && chunkValue !== null && !Array.isArray(chunkValue)) {
this.value = { ...this.value, ...chunkValue };
} else {
this.value = chunkValue;
}
}
}

if (nodeProvenance) {
this.provenance = { ...this.provenance, ...nodeProvenance };
}

this.emit("stream_chunk", chunk);
}

/**
* Gets streaming port data as an async iterable
* @returns AsyncIterable of streaming chunks
*/
async *getStreamingPortData(): AsyncIterable<any> {
if (this.streamingValue) {
yield* this.streamingValue;
}
}

getPortData(): TaskOutput {
Expand Down
3 changes: 3 additions & 0 deletions packages/task-graph/src/task-graph/DataflowEvents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ export type DataflowEventListeners = {

/** Fired when a dataflow status changes */
status: (status: TaskStatus) => void;

/** Fired when a streaming chunk arrives */
stream_chunk: (chunk: unknown) => void;
};
/** Union type of all possible dataflow event names */

Expand Down
170 changes: 159 additions & 11 deletions packages/task-graph/src/task-graph/TaskGraphRunner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { Provenance, TaskInput, TaskOutput, TaskStatus } from "../task/TaskTypes
import { DATAFLOW_ALL_PORTS } from "./Dataflow";
import { TaskGraph, TaskGraphRunConfig } from "./TaskGraph";
import { DependencyBasedScheduler, TopologicalScheduler } from "./TaskGraphScheduler";
import { ensureTask } from "./Conversions";

export type GraphSingleTaskResult<T> = {
id: unknown;
Expand Down Expand Up @@ -85,6 +86,8 @@ export class TaskGraphRunner {
protected inProgressTasks: Map<unknown, Promise<TaskOutput>> = new Map();
protected inProgressFunctions: Map<unknown, Promise<any>> = new Map();
protected failedTaskErrors: Map<unknown, TaskError> = new Map();
/** Map of task IDs to their active streaming iterators */
protected streamingTasks: Map<unknown, AsyncIterableIterator<any>> = new Map();

/**
* Constructor for TaskGraphRunner
Expand Down Expand Up @@ -137,24 +140,36 @@ export class TaskGraphRunner {
// Only filter input for non-root tasks; root tasks get the full input
const taskInput = isRootTask ? input : this.filterInputForTask(task, input);

const taskPromise = this.runTaskWithProvenance(
task,
taskInput,
config?.parentProvenance || {}
);
this.inProgressTasks!.set(task.config.id, taskPromise);
const taskResult = await taskPromise;

if (this.graph.getTargetDataflows(task.config.id).length === 0) {
// we save the results of all the leaves
results.push(taskResult as GraphSingleTaskResult<ExecuteOutput>);
// Check if task is streamable
if (task.isStreamable() && task.executeStream) {
await this.handleStreamingTask(
task,
taskInput,
config?.parentProvenance || {},
results as GraphResultArray<ExecuteOutput>
);
} else {
const taskPromise = this.runTaskWithProvenance(
task,
taskInput,
config?.parentProvenance || {}
);
this.inProgressTasks!.set(task.config.id, taskPromise);
const taskResult = await taskPromise;

if (this.graph.getTargetDataflows(task.config.id).length === 0) {
// we save the results of all the leaves
results.push(taskResult as GraphSingleTaskResult<ExecuteOutput>);
}
}
} catch (error) {
this.failedTaskErrors.set(task.config.id, error as TaskError);
} finally {
this.processScheduler.onTaskCompleted(task.config.id);
this.pushStatusFromNodeToEdges(this.graph, task);
this.pushErrorFromNodeToEdges(this.graph, task);
// Clean up streaming iterator
this.streamingTasks.delete(task.config.id);
}
};

Expand All @@ -171,6 +186,18 @@ export class TaskGraphRunner {
await Promise.allSettled(Array.from(this.inProgressTasks.values()));
// Clean up stragglers to avoid unhandled promise rejections
await Promise.allSettled(Array.from(this.inProgressFunctions.values()));
// Clean up streaming tasks
for (const [taskId, iterator] of this.streamingTasks.entries()) {
// Try to clean up the iterator if it has a return method
if (iterator.return) {
try {
await iterator.return();
} catch {
// Ignore errors during cleanup
}
}
}
this.streamingTasks.clear();

if (this.failedTaskErrors.size > 0) {
const latestError = this.failedTaskErrors.values().next().value!;
Expand Down Expand Up @@ -405,6 +432,126 @@ export class TaskGraphRunner {
}
}

/**
* Pushes streaming output chunks from a task to its target dataflows
* @param node The task that produced the streaming chunk
* @param chunk The partial output chunk
* @param nodeProvenance The provenance input for the task
*/
protected async pushStreamingOutputFromNodeToEdges(
node: ITask,
chunk: Partial<TaskOutput>,
nodeProvenance?: Provenance
) {
const dataflows = this.graph.getTargetDataflows(node.config.id);
for (const dataflow of dataflows) {
const compatibility = dataflow.semanticallyCompatible(this.graph, dataflow);
if (compatibility === "static" || compatibility === "runtime") {
dataflow.setStreamingPortData(chunk, nodeProvenance);
}
}
}

/**
* Handles execution of a streaming task
* @param task The task to execute
* @param input The input to the task
* @param parentProvenance The provenance input for the task
* @param results Array to collect results
*/
protected async handleStreamingTask<T>(
task: ITask,
input: TaskInput,
parentProvenance: Provenance,
results: GraphResultArray<T>
): Promise<void> {
// Update provenance for the current task
const nodeProvenance = {
...parentProvenance,
...this.getInputProvenance(task),
...task.getProvenance(),
};
this.provenanceInput.set(task.config.id, nodeProvenance);
this.copyInputFromEdgesToNode(task);

// Notify scheduler that streaming has started
if (this.processScheduler instanceof DependencyBasedScheduler) {
this.processScheduler.onTaskStreamingStart(task.config.id, task);
}

// Create execution context
const context = {
signal: this.abortController!.signal,
nodeProvenance,
updateProgress: async (progress: number, message?: string, ...args: any[]) =>
await this.handleProgress(task, progress, message, ...args),
own: (i: any) => {
const task = ensureTask(i, { isOwned: true });
this.graph.addTask(task);
return i;
},
onStreamChunk: async (chunk: Partial<TaskOutput>) => {
await this.pushStreamingOutputFromNodeToEdges(task, chunk, nodeProvenance);
// Notify scheduler of chunk
if (this.processScheduler instanceof DependencyBasedScheduler) {
this.processScheduler.onTaskStreamingChunk(task.config.id);
}
},
};

// Execute streaming task
let finalOutput: TaskOutput | undefined;
const streamIterator = task.executeStream!(input, context);
this.streamingTasks.set(task.config.id, streamIterator);

try {
// Iterate over stream chunks
for await (const chunk of streamIterator) {
if (this.abortController?.signal.aborted) {
break;
}
// Push chunk to dataflows
await this.pushStreamingOutputFromNodeToEdges(task, chunk, nodeProvenance);
// Notify scheduler of chunk
if (this.processScheduler instanceof DependencyBasedScheduler) {
this.processScheduler.onTaskStreamingChunk(task.config.id);
}
// Accumulate final output
if (finalOutput === undefined) {
finalOutput = chunk as TaskOutput;
} else {
// Merge chunks
finalOutput = { ...finalOutput, ...chunk };
}
}

// Set final output
if (finalOutput !== undefined) {
task.runOutputData = finalOutput;
await this.pushOutputFromNodeToEdges(task, finalOutput, nodeProvenance);
}

// Add to results if leaf node
if (this.graph.getTargetDataflows(task.config.id).length === 0 && finalOutput !== undefined) {
results.push({
id: task.config.id,
type: (task.constructor as any).runtype || (task.constructor as any).type,
data: finalOutput as T,
});
}
} catch (error) {
// Clean up iterator on error
if (streamIterator.return) {
try {
await streamIterator.return();
} catch {
// Ignore cleanup errors
}
}
throw error;
}
}

/**
* Pushes the status of a task to its target edges
* @param node The task that produced the status
Expand Down Expand Up @@ -547,6 +694,7 @@ export class TaskGraphRunner {
this.inProgressTasks.clear();
this.inProgressFunctions.clear();
this.failedTaskErrors.clear();
this.streamingTasks.clear();
this.graph.emit("start");
}

Expand Down
Loading
Loading