diff --git a/src/components/shared/ReactFlow/FlowCanvas/FlowCanvas.tsx b/src/components/shared/ReactFlow/FlowCanvas/FlowCanvas.tsx index 701730b9f..ceedf1b97 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/FlowCanvas.tsx +++ b/src/components/shared/ReactFlow/FlowCanvas/FlowCanvas.tsx @@ -373,13 +373,7 @@ const FlowCanvas = ({ setComponentSpec(updatedComponentSpec); }, - [ - reactFlowInstance, - componentSpec, - nodeData, - setComponentSpec, - updateOrAddNodes, - ], + [reactFlowInstance, componentSpec, setComponentSpec, updateOrAddNodes], ); useEffect(() => { diff --git a/src/types/taskNode.ts b/src/types/taskNode.ts index 469e449a8..eea6a2978 100644 --- a/src/types/taskNode.ts +++ b/src/types/taskNode.ts @@ -6,6 +6,14 @@ import type { import type { Annotations } from "./annotations"; +export type TaskType = "task" | "input" | "output"; + +export interface NodeData extends Record { + readOnly?: boolean; + connectable?: boolean; + nodeCallbacks?: NodeCallbacks; +} + export interface TaskNodeData extends Record { taskSpec?: TaskSpec; taskId?: string; @@ -13,8 +21,7 @@ export interface TaskNodeData extends Record { isGhost?: boolean; connectable?: boolean; highlighted?: boolean; - callbacks?: TaskNodeCallbacks; - nodeCallbacks?: NodeCallbacks; + callbacks?: TaskCallbacks; } export type NodeAndTaskId = { @@ -22,10 +29,8 @@ export type NodeAndTaskId = { nodeId: string; }; -export type TaskType = "task" | "input" | "output"; - /* Note: Optional callbacks will cause TypeScript to break when applying the callbacks to the Nodes. */ -interface TaskNodeCallbacks { +export interface TaskCallbacks { setArguments: (args: Record) => void; setAnnotations: (annotations: Annotations) => void; setCacheStaleness: (cacheStaleness: string | undefined) => void; @@ -35,13 +40,14 @@ interface TaskNodeCallbacks { } // Dynamic Node Callback types - every callback has a version with the node & task id added to it as an input parameter -export type CallbackWithIds = - TaskNodeCallbacks[K] extends (...args: infer A) => infer R - ? (ids: NodeAndTaskId, ...args: A) => R - : never; +type CallbackWithIds = TaskCallbacks[K] extends ( + ...args: infer A +) => infer R + ? (ids: NodeAndTaskId, ...args: A) => R + : never; export type NodeCallbacks = { - [K in keyof TaskNodeCallbacks]: CallbackWithIds; + [K in keyof TaskCallbacks]: CallbackWithIds; }; export type TaskNodeDimensions = { w: number; h: number | undefined }; diff --git a/src/utils/nodes/createInputNode.ts b/src/utils/nodes/createInputNode.ts index fc582beea..3e859c1dd 100644 --- a/src/utils/nodes/createInputNode.ts +++ b/src/utils/nodes/createInputNode.ts @@ -1,12 +1,12 @@ import { type Node } from "@xyflow/react"; -import type { TaskNodeData } from "@/types/taskNode"; +import type { NodeData } from "@/types/taskNode"; import type { InputSpec } from "../componentSpec"; import { extractPositionFromAnnotations } from "./extractPositionFromAnnotations"; import { inputNameToNodeId } from "./nodeIdUtils"; -export const createInputNode = (input: InputSpec, nodeData: TaskNodeData) => { +export const createInputNode = (input: InputSpec, nodeData: NodeData) => { const { name, annotations, ...rest } = input; const position = extractPositionFromAnnotations(annotations); diff --git a/src/utils/nodes/createNodesFromComponentSpec.ts b/src/utils/nodes/createNodesFromComponentSpec.ts index e597cd9f4..ee39895a1 100644 --- a/src/utils/nodes/createNodesFromComponentSpec.ts +++ b/src/utils/nodes/createNodesFromComponentSpec.ts @@ -1,6 +1,6 @@ import { type Node } from "@xyflow/react"; -import type { TaskNodeData } from "@/types/taskNode"; +import type { NodeData } from "@/types/taskNode"; import { type ComponentSpec, type GraphSpec, @@ -13,7 +13,7 @@ import { createTaskNode } from "./createTaskNode"; const createNodesFromComponentSpec = ( componentSpec: ComponentSpec, - nodeData: TaskNodeData, + nodeData: NodeData, ): Node[] => { if (!isGraphImplementation(componentSpec.implementation)) { return []; @@ -27,16 +27,13 @@ const createNodesFromComponentSpec = ( return [...taskNodes, ...inputNodes, ...outputNodes]; }; -const createTaskNodes = (graphSpec: GraphSpec, nodeData: TaskNodeData) => { +const createTaskNodes = (graphSpec: GraphSpec, nodeData: NodeData) => { return Object.entries(graphSpec.tasks).map((task) => createTaskNode(task, nodeData), ); }; -const createInputNodes = ( - componentSpec: ComponentSpec, - nodeData: TaskNodeData, -) => { +const createInputNodes = (componentSpec: ComponentSpec, nodeData: NodeData) => { return (componentSpec.inputs ?? []).map((inputSpec) => createInputNode(inputSpec, nodeData), ); @@ -44,7 +41,7 @@ const createInputNodes = ( const createOutputNodes = ( componentSpec: ComponentSpec, - nodeData: TaskNodeData, + nodeData: NodeData, ) => { return (componentSpec.outputs ?? []).map((outputSpec) => createOutputNode(outputSpec, nodeData), diff --git a/src/utils/nodes/createOutputNode.ts b/src/utils/nodes/createOutputNode.ts index 8d2a5b9b4..8c8ba67b6 100644 --- a/src/utils/nodes/createOutputNode.ts +++ b/src/utils/nodes/createOutputNode.ts @@ -1,15 +1,12 @@ import { type Node } from "@xyflow/react"; -import type { TaskNodeData } from "@/types/taskNode"; +import type { NodeData } from "@/types/taskNode"; import type { OutputSpec } from "../componentSpec"; import { extractPositionFromAnnotations } from "./extractPositionFromAnnotations"; import { outputNameToNodeId } from "./nodeIdUtils"; -export const createOutputNode = ( - output: OutputSpec, - nodeData: TaskNodeData, -) => { +export const createOutputNode = (output: OutputSpec, nodeData: NodeData) => { const { name, annotations, ...rest } = output; const position = extractPositionFromAnnotations(annotations); diff --git a/src/utils/nodes/createTaskNode.ts b/src/utils/nodes/createTaskNode.ts index 24a9e936e..1a1b0166b 100644 --- a/src/utils/nodes/createTaskNode.ts +++ b/src/utils/nodes/createTaskNode.ts @@ -1,34 +1,39 @@ import { type Node } from "@xyflow/react"; -import type { TaskNodeData } from "@/types/taskNode"; +import type { NodeData, TaskNodeData } from "@/types/taskNode"; import type { TaskSpec } from "../componentSpec"; import { extractPositionFromAnnotations } from "./extractPositionFromAnnotations"; -import { generateDynamicNodeCallbacks } from "./generateDynamicNodeCallbacks"; import { taskIdToNodeId } from "./nodeIdUtils"; +import { convertNodeCallbacksToTaskCallbacks } from "./taskCallbackUtils"; export const createTaskNode = ( task: [`${string}`, TaskSpec], - nodeData: TaskNodeData, + nodeData: NodeData, ) => { const [taskId, taskSpec] = task; + const { nodeCallbacks, ...data } = nodeData; const position = extractPositionFromAnnotations(taskSpec.annotations); const nodeId = taskIdToNodeId(taskId); // Inject the taskId and nodeId into the callbacks - const nodeCallbacks = nodeData.nodeCallbacks; - const dynamicCallbacks = generateDynamicNodeCallbacks(nodeId, nodeCallbacks); + const taskCallbacks = convertNodeCallbacksToTaskCallbacks( + { taskId, nodeId }, + nodeCallbacks, + ); + + const taskNodeData: TaskNodeData = { + ...data, + taskSpec, + taskId, + highlighted: false, + callbacks: taskCallbacks, + }; return { id: nodeId, - data: { - ...nodeData, - taskSpec, - taskId, - highlighted: false, - callbacks: dynamicCallbacks, // Use these callbacks internally within the node - }, + data: taskNodeData, position: position, type: "task", } as Node; diff --git a/src/utils/nodes/generateDynamicNodeCallbacks.ts b/src/utils/nodes/generateDynamicNodeCallbacks.ts deleted file mode 100644 index 3d1446956..000000000 --- a/src/utils/nodes/generateDynamicNodeCallbacks.ts +++ /dev/null @@ -1,42 +0,0 @@ -import type { - CallbackWithIds, - NodeAndTaskId, - NodeCallbacks, -} from "@/types/taskNode"; - -import { nodeIdToTaskId } from "./nodeIdUtils"; - -type ExcludeNodeAndTaskId = T extends [NodeAndTaskId, ...infer Rest] - ? Rest - : never; - -// Utility function that adds the taskId and nodeId to the callbacks as the first argument -export const generateDynamicNodeCallbacks = ( - nodeId: string, - nodeCallbacks?: NodeCallbacks, -): NodeCallbacks => { - if (!nodeCallbacks) { - return {} as NodeCallbacks; - } - - const taskId = nodeIdToTaskId(nodeId); - return Object.fromEntries( - (Object.keys(nodeCallbacks) as (keyof NodeCallbacks)[]).map( - (callbackName) => { - const callbackFn = nodeCallbacks[callbackName] as CallbackWithIds< - typeof callbackName - >; - return [ - callbackName, - ((...args: any[]) => - callbackFn( - { taskId, nodeId }, - ...((args ?? []) as ExcludeNodeAndTaskId< - Parameters - >), - )) as NodeCallbacks[typeof callbackName], - ]; - }, - ), - ) as NodeCallbacks; -}; diff --git a/src/utils/nodes/taskCallbackUtils.ts b/src/utils/nodes/taskCallbackUtils.ts new file mode 100644 index 000000000..4b564491d --- /dev/null +++ b/src/utils/nodes/taskCallbackUtils.ts @@ -0,0 +1,35 @@ +import type { + NodeAndTaskId, + NodeCallbacks, + TaskCallbacks, +} from "@/types/taskNode"; + +// Sync TaskCallbacks with NodeCallbacks by injecting nodeId and taskId +export const convertNodeCallbacksToTaskCallbacks = ( + ids: NodeAndTaskId, + nodeCallbacks?: NodeCallbacks, +): TaskCallbacks => { + if (!nodeCallbacks) { + return createEmptyTaskCallbacks(); + } + return { + setArguments: (args) => nodeCallbacks.setArguments?.(ids, args), + setAnnotations: (annotations) => + nodeCallbacks.setAnnotations?.(ids, annotations), + setCacheStaleness: (cacheStaleness) => + nodeCallbacks.setCacheStaleness?.(ids, cacheStaleness), + onDelete: () => nodeCallbacks.onDelete?.(ids), + onDuplicate: (selected) => nodeCallbacks.onDuplicate?.(ids, selected), + onUpgrade: (newComponentRef) => + nodeCallbacks.onUpgrade?.(ids, newComponentRef), + }; +}; + +const createEmptyTaskCallbacks = (): TaskCallbacks => ({ + setArguments: () => {}, + setAnnotations: () => {}, + setCacheStaleness: () => {}, + onDelete: () => {}, + onDuplicate: () => {}, + onUpgrade: () => {}, +});