diff --git a/src/components/Editor/IOEditor/InputValueEditor/InputValueEditor.tsx b/src/components/Editor/IOEditor/InputValueEditor/InputValueEditor.tsx index 066e3ed74..1f89ab7b3 100644 --- a/src/components/Editor/IOEditor/InputValueEditor/InputValueEditor.tsx +++ b/src/components/Editor/IOEditor/InputValueEditor/InputValueEditor.tsx @@ -9,13 +9,13 @@ import { Icon } from "@/components/ui/icon"; import { BlockStack } from "@/components/ui/layout"; import { Heading, Paragraph } from "@/components/ui/typography"; import useConfirmationDialog from "@/hooks/useConfirmationDialog"; +import { useNodeManager } from "@/hooks/useNodeManager"; import { useNodeSelectionTransfer } from "@/hooks/useNodeSelectionTransfer"; import useToastNotification from "@/hooks/useToastNotification"; import { useComponentSpec } from "@/providers/ComponentSpecProvider"; import { useContextPanel } from "@/providers/ContextPanelProvider"; import { type InputSpec } from "@/utils/componentSpec"; import { checkInputConnectionToRequiredFields } from "@/utils/inputConnectionUtils"; -import { inputNameToNodeId } from "@/utils/nodes/nodeIdUtils"; import { updateSubgraphSpec } from "@/utils/subgraphUtils"; import { NameField, TextField, TypeField } from "./FormFields/FormFields"; @@ -31,8 +31,10 @@ export const InputValueEditor = ({ input, disabled = false, }: InputValueEditorProps) => { + const { getInputNodeId } = useNodeManager(); + const notify = useToastNotification(); - const { transferSelection } = useNodeSelectionTransfer(inputNameToNodeId); + const { transferSelection } = useNodeSelectionTransfer(getInputNodeId); const { componentSpec, setComponentSpec, diff --git a/src/components/Editor/IOEditor/OutputNameEditor/OutputNameEditor.tsx b/src/components/Editor/IOEditor/OutputNameEditor/OutputNameEditor.tsx index da4258339..f6e0f2495 100644 --- a/src/components/Editor/IOEditor/OutputNameEditor/OutputNameEditor.tsx +++ b/src/components/Editor/IOEditor/OutputNameEditor/OutputNameEditor.tsx @@ -7,11 +7,11 @@ import { Icon } from "@/components/ui/icon"; import { BlockStack, InlineStack } from "@/components/ui/layout"; import { Heading, Paragraph } from "@/components/ui/typography"; import useConfirmationDialog from "@/hooks/useConfirmationDialog"; +import { useNodeManager } from "@/hooks/useNodeManager"; import { useNodeSelectionTransfer } from "@/hooks/useNodeSelectionTransfer"; import { useComponentSpec } from "@/providers/ComponentSpecProvider"; import { useContextPanel } from "@/providers/ContextPanelProvider"; import { type OutputSpec } from "@/utils/componentSpec"; -import { outputNameToNodeId } from "@/utils/nodes/nodeIdUtils"; import { updateSubgraphSpec } from "@/utils/subgraphUtils"; import { type OutputConnectedDetails } from "../../utils/getOutputConnectedDetails"; @@ -30,7 +30,8 @@ export const OutputNameEditor = ({ disabled, connectedDetails, }: OutputNameEditorProps) => { - const { transferSelection } = useNodeSelectionTransfer(outputNameToNodeId); + const { getOutputNodeId } = useNodeManager(); + const { transferSelection } = useNodeSelectionTransfer(getOutputNodeId); const { setComponentSpec, componentSpec, diff --git a/src/components/shared/ReactFlow/FlowCanvas/FlowCanvas.tsx b/src/components/shared/ReactFlow/FlowCanvas/FlowCanvas.tsx index 59567b72f..ab4c75e54 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/FlowCanvas.tsx +++ b/src/components/shared/ReactFlow/FlowCanvas/FlowCanvas.tsx @@ -119,6 +119,7 @@ const FlowCanvas = ({ currentSubgraphSpec, updateGraphSpec, currentSubgraphPath, + nodeManager, } = useComponentSpec(); const { preserveIOSelectionOnSpecChange, resetPrevSpec } = useIOSelectionPersistence(); @@ -285,10 +286,18 @@ const FlowCanvas = ({ let updatedSubgraphSpec = { ...currentSubgraphSpec }; for (const edge of params.edges) { - updatedSubgraphSpec = removeEdge(edge, updatedSubgraphSpec); + updatedSubgraphSpec = removeEdge( + edge, + updatedSubgraphSpec, + nodeManager, + ); } for (const node of params.nodes) { - updatedSubgraphSpec = removeNode(node, updatedSubgraphSpec); + updatedSubgraphSpec = removeNode( + node, + updatedSubgraphSpec, + nodeManager, + ); } const updatedRootSpec = updateSubgraphSpec( @@ -299,7 +308,13 @@ const FlowCanvas = ({ setComponentSpec(updatedRootSpec); }, - [componentSpec, currentSubgraphSpec, currentSubgraphPath, setComponentSpec], + [ + componentSpec, + currentSubgraphSpec, + currentSubgraphPath, + nodeManager, + setComponentSpec, + ], ); const nodeCallbacks = useNodeCallbacks({ @@ -313,18 +328,23 @@ const FlowCanvas = ({ connectable: !readOnly && !!nodesConnectable, readOnly, callbacks: nodeCallbacks, + nodeManager, }), - [readOnly, nodesConnectable, nodeCallbacks], + [readOnly, nodesConnectable, nodeCallbacks, nodeManager], ); const onConnect = useCallback( (connection: Connection) => { if (connection.source === connection.target) return; - const updatedGraphSpec = handleConnection(currentGraphSpec, connection); + const updatedGraphSpec = handleConnection( + currentGraphSpec, + connection, + nodeManager, + ); updateGraphSpec(updatedGraphSpec); }, - [currentGraphSpec, handleConnection, updateGraphSpec], + [currentGraphSpec, nodeManager, handleConnection, updateGraphSpec], ); const onConnectEnd = useCallback( @@ -361,7 +381,11 @@ const FlowCanvas = ({ ); if (existingInputEdge) { - newComponentSpec = removeEdge(existingInputEdge, newComponentSpec); + newComponentSpec = removeEdge( + existingInputEdge, + newComponentSpec, + nodeManager, + ); } const updatedComponentSpec = addAndConnectNode({ @@ -369,11 +393,18 @@ const FlowCanvas = ({ fromHandle, position, componentSpec: newComponentSpec, + nodeManager, }); setComponentSpec(updatedComponentSpec); }, - [reactFlowInstance, componentSpec, setComponentSpec, updateOrAddNodes], + [ + reactFlowInstance, + componentSpec, + nodeManager, + setComponentSpec, + updateOrAddNodes, + ], ); useEffect(() => { @@ -627,6 +658,7 @@ const FlowCanvas = ({ const updatedSubgraphSpec = updateNodePositions( updatedNodes, currentSubgraphSpec, + nodeManager, ); const updatedRootSpec = updateSubgraphSpec( @@ -646,6 +678,7 @@ const FlowCanvas = ({ componentSpec, currentSubgraphSpec, currentSubgraphPath, + nodeManager, setComponentSpec, onNodesChange, ], @@ -677,7 +710,9 @@ const FlowCanvas = ({ updatedComponentSpec: updatedSubgraphSpec, newNodes, updatedNodes, - } = duplicateNodes(currentSubgraphSpec, selectedNodes, { selected: true }); + } = duplicateNodes(currentSubgraphSpec, selectedNodes, nodeManager, { + selected: true, + }); const updatedRootSpec = updateSubgraphSpec( componentSpec, @@ -696,6 +731,7 @@ const FlowCanvas = ({ currentSubgraphSpec, currentSubgraphPath, selectedNodes, + nodeManager, setComponentSpec, setNodes, ]); @@ -873,6 +909,7 @@ const FlowCanvas = ({ const { newNodes, updatedComponentSpec } = duplicateNodes( componentSpec, nodesToPaste, + nodeManager, { position: reactFlowCenter, connection: "internal" }, ); @@ -898,6 +935,7 @@ const FlowCanvas = ({ nodes, reactFlowInstance, store, + nodeManager, updateOrAddNodes, setComponentSpec, readOnly, diff --git a/src/components/shared/ReactFlow/FlowCanvas/IONode/IONode.tsx b/src/components/shared/ReactFlow/FlowCanvas/IONode/IONode.tsx index db14e37ae..b62092fb2 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/IONode/IONode.tsx +++ b/src/components/shared/ReactFlow/FlowCanvas/IONode/IONode.tsx @@ -7,6 +7,7 @@ import { getOutputConnectedDetails } from "@/components/Editor/utils/getOutputCo import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; import { BlockStack, InlineStack } from "@/components/ui/layout"; import { Paragraph } from "@/components/ui/typography"; +import { useNodeManager } from "@/hooks/useNodeManager"; import { cn } from "@/lib/utils"; import { useComponentSpec } from "@/providers/ComponentSpecProvider"; import { useContextPanel } from "@/providers/ContextPanelProvider"; @@ -23,6 +24,7 @@ interface IONodeProps { const IONode = ({ type, data, selected = false }: IONodeProps) => { const { currentGraphSpec, currentSubgraphSpec } = useComponentSpec(); const { setContent, clearContent } = useContextPanel(); + const { getHandleNodeId } = useNodeManager(); const { spec, readOnly } = data; @@ -49,6 +51,9 @@ const IONode = ({ type, data, selected = false }: IONodeProps) => { [currentSubgraphSpec.outputs, spec.name], ); + const handleNodeType = isInput ? "handle-out" : "handle-in"; + const nodeHandleId = getHandleNodeId(spec.name, spec.name, handleNodeType); + useEffect(() => { if (selected) { if (input && isInput) { @@ -144,6 +149,7 @@ const IONode = ({ type, data, selected = false }: IONodeProps) => { { - const { nodeId, state } = useTaskNode(); + const { getInputHandleNodeId } = useNodeManager(); + const { taskId, nodeId, state } = useTaskNode(); const fromHandle = useConnection((connection) => connection.fromHandle?.id); const toHandle = useConnection((connection) => connection.toHandle?.id); @@ -44,7 +46,7 @@ export const InputHandle = ({ const [selected, setSelected] = useState(false); const [active, setActive] = useState(false); - const handleId = getInputHandleId(input.name); + const handleId = getInputHandleNodeId(taskId, input.name); const missing = invalid ? "bg-red-700!" : "bg-gray-500!"; const hasValue = value !== undefined && value !== null; @@ -218,7 +220,8 @@ export const OutputHandle = ({ onLabelClick, onHandleSelectionChange, }: OutputHandleProps) => { - const { nodeId, state } = useTaskNode(); + const { getOutputHandleNodeId } = useNodeManager(); + const { taskId, nodeId, state } = useTaskNode(); const fromHandle = useConnection((connection) => connection.fromHandle?.id); const toHandle = useConnection((connection) => connection.toHandle?.id); @@ -230,7 +233,7 @@ export const OutputHandle = ({ const [selected, setSelected] = useState(false); const [active, setActive] = useState(false); - const handleId = getOutputHandleId(output.name); + const handleId = getOutputHandleNodeId(taskId, output.name); const hasValue = value !== undefined && value !== "" && value !== null; const handleHandleClick = useCallback( @@ -355,14 +358,6 @@ export const OutputHandle = ({ ); }; -const getOutputHandleId = (outputName: string) => { - return `output_${outputName}`; -}; - -const getInputHandleId = (inputName: string) => { - return `input_${inputName}`; -}; - const skipHandleDeselect = (e: MouseEvent) => { let el = e.target as HTMLElement | null; while (el) { diff --git a/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeInputs.test.tsx b/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeInputs.test.tsx index 804e369ea..3114a757e 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeInputs.test.tsx +++ b/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeInputs.test.tsx @@ -17,6 +17,15 @@ vi.mock("@/providers/ComponentLibraryProvider", () => ({ }), })); +vi.mock("@/providers/ComponentLibraryProvider/ForcedSearchProvider", () => ({ + useForcedSearchContext: () => ({ + highlightSearchFilter: vi.fn(), + resetSearchFilter: vi.fn(), + currentSearchFilter: { searchTerm: "", filters: [] }, + highlightSearchResults: false, + }), +})); + vi.mock("@/providers/ComponentSpecProvider", () => ({ useComponentSpec: () => ({ graphSpec: { @@ -25,6 +34,18 @@ vi.mock("@/providers/ComponentSpecProvider", () => ({ }), })); +vi.mock("@/hooks/useNodeManager", () => ({ + useNodeManager: () => ({ + getInputHandleNodeId: vi.fn( + (_refId: string, inputName: string) => `input-handle-${inputName}`, + ), + getOutputHandleNodeId: vi.fn(), + getNodeId: vi.fn(), + getHandleNodeId: vi.fn(), + nodeManager: {}, + }), +})); + vi.mock("@/providers/TaskNodeProvider"); const TestWrapper = ReactFlowProvider; @@ -39,6 +60,7 @@ describe("", () => { state: { readOnly: false }, select: vi.fn(), nodeId: "test-node", + taskId: "test-task", } as any); }; diff --git a/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeInputs.tsx b/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeInputs.tsx index 05f5a0e51..95b6e4d0e 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeInputs.tsx +++ b/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeInputs.tsx @@ -2,6 +2,7 @@ import { useConnection } from "@xyflow/react"; import { AlertCircle } from "lucide-react"; import { type MouseEvent, useCallback, useEffect, useState } from "react"; +import { useNodeManager } from "@/hooks/useNodeManager"; import { cn } from "@/lib/utils"; import { useForcedSearchContext } from "@/providers/ComponentLibraryProvider/ForcedSearchProvider"; import { isValidFilterRequest } from "@/providers/ComponentLibraryProvider/types"; @@ -10,7 +11,6 @@ import { useTaskNode } from "@/providers/TaskNodeProvider"; import { inputsWithInvalidArguments } from "@/services/componentService"; import type { InputSpec } from "@/utils/componentSpec"; import { ComponentSearchFilter } from "@/utils/constants"; -import { inputNameToNodeId } from "@/utils/nodes/nodeIdUtils"; import { checkArtifactMatchesSearchFilters } from "@/utils/searchUtils"; import { InputHandle } from "./Handles"; @@ -27,7 +27,8 @@ export function TaskNodeInputs({ expanded, onBackgroundClick, }: TaskNodeInputsProps) { - const { inputs, taskSpec, state, select } = useTaskNode(); + const { getInputHandleNodeId } = useNodeManager(); + const { taskId, inputs, taskSpec, state, select } = useTaskNode(); const { graphSpec } = useComponentSpec(); const { highlightSearchFilter, @@ -147,7 +148,7 @@ export function TaskNodeInputs({ } const input = inputs.find( - (i) => inputNameToNodeId(i.name) === fromHandle?.id, + (i) => getInputHandleNodeId(taskId, i.name) === fromHandle?.id, ); if (!input) return; diff --git a/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeOutputs.tsx b/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeOutputs.tsx index f20907c0e..0d5f39f8b 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeOutputs.tsx +++ b/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeOutputs.tsx @@ -1,13 +1,13 @@ import { useConnection, useEdges } from "@xyflow/react"; import { type MouseEvent, useCallback, useEffect, useState } from "react"; +import { useNodeManager } from "@/hooks/useNodeManager"; import { cn } from "@/lib/utils"; import { useForcedSearchContext } from "@/providers/ComponentLibraryProvider/ForcedSearchProvider"; import { isValidFilterRequest } from "@/providers/ComponentLibraryProvider/types"; import { useTaskNode } from "@/providers/TaskNodeProvider"; import type { OutputSpec } from "@/utils/componentSpec"; import { ComponentSearchFilter } from "@/utils/constants"; -import { outputNameToNodeId } from "@/utils/nodes/nodeIdUtils"; import { checkArtifactMatchesSearchFilters } from "@/utils/searchUtils"; import { OutputHandle } from "./Handles"; @@ -23,7 +23,8 @@ export function TaskNodeOutputs({ expanded, onBackgroundClick, }: TaskNodeOutputsProps) { - const { nodeId, outputs, state, select } = useTaskNode(); + const { getOutputHandleNodeId } = useNodeManager(); + const { taskId, nodeId, outputs, state, select } = useTaskNode(); const { highlightSearchFilter, resetSearchFilter, @@ -40,7 +41,7 @@ export function TaskNodeOutputs({ edges.some( (edge) => edge.source === nodeId && - edge.sourceHandle === outputNameToNodeId(output.name), + edge.sourceHandle === getOutputHandleNodeId(taskId, output.name), ), ); @@ -138,7 +139,7 @@ export function TaskNodeOutputs({ } const output = outputs.find( - (o) => outputNameToNodeId(o.name) === fromHandle?.id, + (o) => getOutputHandleNodeId(taskId, o.name) === fromHandle?.id, ); if (!output) return; diff --git a/src/components/shared/ReactFlow/FlowCanvas/utils/addAndConnectNode.ts b/src/components/shared/ReactFlow/FlowCanvas/utils/addAndConnectNode.ts index 7d76f002b..b73772b50 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/utils/addAndConnectNode.ts +++ b/src/components/shared/ReactFlow/FlowCanvas/utils/addAndConnectNode.ts @@ -1,18 +1,14 @@ import type { Connection, Handle } from "@xyflow/react"; -import type { - ComponentReference, - ComponentSpec, - TaskSpec, - TypeSpecType, +import type { NodeManager, NodeType } from "@/nodeManager"; +import { + type ComponentReference, + type ComponentSpec, + isGraphImplementation, + type TaskSpec, + type TypeSpecType, } from "@/utils/componentSpec"; import { DEFAULT_NODE_DIMENSIONS } from "@/utils/constants"; -import { - inputNameToNodeId, - nodeIdToTaskId, - outputNameToNodeId, - taskIdToNodeId, -} from "@/utils/nodes/nodeIdUtils"; import addTask from "./addTask"; import { handleConnection } from "./handleConnection"; @@ -22,6 +18,7 @@ type AddAndConnectNodeParams = { fromHandle: Handle | null; position: { x: number; y: number }; componentSpec: ComponentSpec; + nodeManager: NodeManager; }; export function addAndConnectNode({ @@ -29,6 +26,7 @@ export function addAndConnectNode({ fromHandle, position, componentSpec, + nodeManager, }: AddAndConnectNodeParams): ComponentSpec { // 1. Add the new node const taskSpec: TaskSpec = { @@ -36,17 +34,54 @@ export function addAndConnectNode({ componentRef: { ...componentRef }, }; - if (!("graph" in componentSpec.implementation)) { + if (!isGraphImplementation(componentSpec.implementation)) { return componentSpec; } const oldGraphSpec = componentSpec.implementation.graph; - const fromHandleId = fromHandle?.id; - const fromHandleType = fromHandleId?.startsWith("input") ? "input" : "output"; + if (!fromHandle?.id) { + return componentSpec; + } + + const fromNodeId = fromHandle.nodeId; + const fromNodeType = nodeManager.getNodeType(fromNodeId); + const fromTaskId = nodeManager.getRefId(fromNodeId); + + if (!fromTaskId) { + return componentSpec; + } + + let fromHandleType: NodeType | undefined; + let fromHandleName: string | undefined; + + if (fromNodeType === "task") { + const fromHandleInfo = nodeManager.getHandleInfo(fromHandle.id); + fromHandleName = fromHandleInfo?.handleName; + fromHandleType = nodeManager.getNodeType(fromHandle.id); + } else if (fromNodeType === "input") { + fromHandleType = "handle-out"; + fromHandleName = fromTaskId; + } else if (fromNodeType === "output") { + fromHandleType = "handle-in"; + fromHandleName = fromTaskId; + } else { + return componentSpec; + } + + if (!fromHandleName) { + return componentSpec; + } + + if ( + !fromHandleType || + (fromHandleType !== "handle-in" && fromHandleType !== "handle-out") + ) { + return componentSpec; + } const adjustedPosition = - fromHandleType === "input" + fromHandleType === "handle-in" ? { ...position, x: position.x - DEFAULT_NODE_DIMENSIONS.w } : position; @@ -58,7 +93,7 @@ export function addAndConnectNode({ ); // 2. Find the new node - if (!("graph" in newComponentSpec.implementation)) { + if (!isGraphImplementation(newComponentSpec.implementation)) { return newComponentSpec; } @@ -72,62 +107,64 @@ export function addAndConnectNode({ return newComponentSpec; } - const newNodeId = taskIdToNodeId(newTaskId); + const newNodeId = nodeManager.getNodeId(newTaskId, "task"); // 3. Determine the connection data type and find the first matching handle on the new node - if (!fromHandle) { - return newComponentSpec; + let fromComponentSpec: ComponentSpec | undefined; + + if (fromNodeType === "task") { + // Get spec from task + const fromTaskSpec = graphSpec.tasks[fromTaskId]; + fromComponentSpec = fromTaskSpec?.componentRef.spec; + } else { + // For IO nodes, get spec from component spec + fromComponentSpec = componentSpec; } - const fromTaskId = nodeIdToTaskId(fromHandle.nodeId); - - const fromTaskSpec = graphSpec.tasks[fromTaskId]; - const fromComponentSpec = fromTaskSpec?.componentRef.spec; - - const fromNodeId = fromHandle.nodeId; - - const fromHandleName = fromHandleId?.replace(`${fromHandleType}_`, ""); - let connectionType: TypeSpecType | undefined; - if (fromHandleType === "input") { + if (fromHandleType === "handle-in") { connectionType = fromComponentSpec?.inputs?.find( (io) => io.name === fromHandleName, )?.type; - } else if (fromHandleType === "output") { + } else if (fromHandleType === "handle-out") { connectionType = fromComponentSpec?.outputs?.find( (io) => io.name === fromHandleName, )?.type; } // Find the first matching handle on the new node - const toHandleType = fromHandleType === "input" ? "output" : "input"; - - let targetHandleId: string | undefined; - - if (toHandleType === "input") { - const handleName = componentRef.spec?.inputs?.find( - (io) => io.type === connectionType, - )?.name; - if (!handleName) { - return newComponentSpec; - } - - targetHandleId = inputNameToNodeId(handleName); - } else if (toHandleType === "output") { - const handleName = componentRef.spec?.outputs?.find( - (io) => io.type === connectionType, - )?.name; - if (!handleName) { - return newComponentSpec; - } - - targetHandleId = outputNameToNodeId(handleName); + const toHandleType = + fromHandleType === "handle-in" ? "handle-out" : "handle-in"; + + const inputHandleName = componentRef.spec?.inputs?.find( + (io) => io.type === connectionType, + )?.name; + + const outputHandleName = componentRef.spec?.outputs?.find( + (io) => io.type === connectionType, + )?.name; + + const toHandleName = + toHandleType === "handle-in" ? inputHandleName : outputHandleName; + + if (!toHandleName) { + return newComponentSpec; } + const targetHandleId = nodeManager.getHandleNodeId( + newTaskId, + toHandleName, + toHandleType, + ); + // 4. Build a Connection object and use handleConnection to add the edge - if (fromNodeId && fromHandleId && targetHandleId) { + if (targetHandleId) { + const fromNodeId = fromHandle.nodeId; + const fromHandleId = fromHandle.id; + const isReversedConnection = - fromHandleType === "input" && toHandleType === "output"; + fromHandleType === "handle-in" && toHandleType === "handle-out"; + const connection: Connection = isReversedConnection ? // Drawing from an input handle to a new output handle { @@ -144,7 +181,11 @@ export function addAndConnectNode({ targetHandle: targetHandleId, }; - const updatedGraphSpec = handleConnection(graphSpec, connection); + const updatedGraphSpec = handleConnection( + graphSpec, + connection, + nodeManager, + ); return { ...newComponentSpec, diff --git a/src/components/shared/ReactFlow/FlowCanvas/utils/duplicateNodes.test.ts b/src/components/shared/ReactFlow/FlowCanvas/utils/duplicateNodes.test.ts index ee7a29026..a38b7fcdc 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/utils/duplicateNodes.test.ts +++ b/src/components/shared/ReactFlow/FlowCanvas/utils/duplicateNodes.test.ts @@ -1,6 +1,7 @@ import type { Node } from "@xyflow/react"; import { describe, expect, it, vi } from "vitest"; +import { NodeManager } from "@/nodeManager"; import type { TaskNodeData } from "@/types/nodes"; import type { ComponentSpec, @@ -9,14 +10,11 @@ import type { TaskOutputArgument, TaskSpec, } from "@/utils/componentSpec"; -import { - inputNameToNodeId, - outputNameToNodeId, - taskIdToNodeId, -} from "@/utils/nodes/nodeIdUtils"; import { duplicateNodes } from "./duplicateNodes"; +const createMockNodeManager = () => new NodeManager(); + // Mock utility functions const mockTaskSpec: TaskSpec = { componentRef: { name: "test-component" }, @@ -89,57 +87,71 @@ const createMockTaskNodeCallbacks = () => ({ const createMockTaskNode = ( taskId: string, taskSpec: TaskSpec, + nodeManager: NodeManager, position = { x: 100, y: 100 }, -): Node => ({ - id: taskIdToNodeId(taskId), - type: "task", - position, - data: { - taskSpec, - taskId, - label: "Test Task", - highlighted: false, - readOnly: false, - isGhost: false, - connectable: true, - callbacks: createMockTaskNodeCallbacks(), - }, - selected: false, - dragging: false, - measured: { width: 200, height: 100 }, -}); +): Node => { + const nodeId = nodeManager.getNodeId(taskId, "task"); + return { + id: nodeId, + type: "task", + position, + data: { + taskSpec, + taskId, + label: "Test Task", + highlighted: false, + readOnly: false, + isGhost: false, + connectable: true, + callbacks: createMockTaskNodeCallbacks(), + }, + selected: false, + dragging: false, + measured: { width: 200, height: 100 }, + }; +}; const createMockInputNode = ( inputName: string, + nodeManager: NodeManager, position = { x: 50, y: 50 }, -): Node => ({ - id: inputNameToNodeId(inputName), - type: "input", - position, - data: { - label: inputName, - inputSpec: { ...mockInputSpec, name: inputName }, - }, - selected: false, - dragging: false, - measured: { width: 150, height: 80 }, -}); +): Node => { + const nodeId = nodeManager.getNodeId(inputName, "input"); + + return { + id: nodeId, + type: "input", + position, + data: { + label: inputName, + inputSpec: { ...mockInputSpec, name: inputName }, + }, + selected: false, + dragging: false, + measured: { width: 150, height: 80 }, + }; +}; const createMockOutputNode = ( outputName: string, + nodeManager: NodeManager, position = { x: 300, y: 300 }, -): Node => ({ - id: outputNameToNodeId(outputName), - type: "output", - position, - data: { - label: outputName, - outputSpec: { ...mockOutputSpec, name: outputName }, - }, - selected: false, - dragging: false, - measured: { width: 150, height: 80 }, -}); +): Node => { + const nodeId = nodeManager.getNodeId(outputName, "output"); + + return { + id: nodeId, + type: "output", + position, + data: { + label: outputName, + outputSpec: { ...mockOutputSpec, name: outputName }, + }, + selected: false, + dragging: false, + measured: { width: 150, height: 80 }, + }; +}; describe("duplicateNodes", () => { describe("error handling", () => { @@ -154,8 +166,9 @@ describe("duplicateNodes", () => { }; const nodes: Node[] = []; + const nodeManager = createMockNodeManager(); - expect(() => duplicateNodes(componentSpec, nodes)).toThrow( + expect(() => duplicateNodes(componentSpec, nodes, nodeManager)).toThrow( "ComponentSpec does not contain a graph implementation.", ); }); @@ -174,16 +187,22 @@ describe("duplicateNodes", () => { "original-task": originalTaskSpec, }); - const taskNode = createMockTaskNode("original-task", originalTaskSpec, { - x: 100, - y: 100, - }); + const nodeManager = createMockNodeManager(); + + const taskNode = createMockTaskNode( + "original-task", + originalTaskSpec, + nodeManager, + { + x: 100, + y: 100, + }, + ); - const result = duplicateNodes(componentSpec, [taskNode]); + const result = duplicateNodes(componentSpec, [taskNode], nodeManager); expect(result.newNodes).toHaveLength(1); expect(result.newNodes[0].type).toBe("task"); - expect(result.newNodes[0].id).toBe(taskIdToNodeId("original-task 2")); expect(result.newNodes[0].position).toEqual({ x: 110, y: 110 }); expect(result.newNodes[0].selected).toBe(true); @@ -207,13 +226,19 @@ describe("duplicateNodes", () => { const componentSpec = createMockComponentSpec({}, [inputSpec]); - const inputNode = createMockInputNode("original-input", { x: 50, y: 50 }); + const nodeManager = createMockNodeManager(); + const inputNode = createMockInputNode("original-input", nodeManager, { + x: 50, + y: 50, + }); - const result = duplicateNodes(componentSpec, [inputNode]); + const result = duplicateNodes(componentSpec, [inputNode], nodeManager); expect(result.newNodes).toHaveLength(1); expect(result.newNodes[0].type).toBe("input"); - expect(result.newNodes[0].id).toBe(inputNameToNodeId("original-input 2")); + expect(result.newNodes[0].id).toBe( + nodeManager.getNodeId("original-input 2", "input"), + ); expect(result.newNodes[0].position).toEqual({ x: 60, y: 60 }); expect(result.updatedComponentSpec.inputs).toHaveLength(2); @@ -239,17 +264,18 @@ describe("duplicateNodes", () => { [outputSpec], ); - const outputNode = createMockOutputNode("original-output", { + const nodeManager = createMockNodeManager(); + const outputNode = createMockOutputNode("original-output", nodeManager, { x: 300, y: 300, }); - const result = duplicateNodes(componentSpec, [outputNode]); + const result = duplicateNodes(componentSpec, [outputNode], nodeManager); expect(result.newNodes).toHaveLength(1); expect(result.newNodes[0].type).toBe("output"); expect(result.newNodes[0].id).toBe( - outputNameToNodeId("original-output 2"), + nodeManager.getNodeId("original-output 2", "output"), ); expect(result.newNodes[0].position).toEqual({ x: 310, y: 310 }); @@ -270,18 +296,23 @@ describe("duplicateNodes", () => { task2: taskSpec2, }); + const nodeManager = createMockNodeManager(); const nodes = [ - createMockTaskNode("task1", taskSpec1, { x: 100, y: 100 }), - createMockTaskNode("task2", taskSpec2, { x: 200, y: 200 }), + createMockTaskNode("task1", taskSpec1, nodeManager, { x: 100, y: 100 }), + createMockTaskNode("task2", taskSpec2, nodeManager, { x: 200, y: 200 }), ]; - const result = duplicateNodes(componentSpec, nodes); + const result = duplicateNodes(componentSpec, nodes, nodeManager); expect(result.newNodes).toHaveLength(2); - expect(result.newNodes.map((n) => n.id)).toEqual([ - taskIdToNodeId("task1 2"), - taskIdToNodeId("task2 2"), - ]); + if ("graph" in result.updatedComponentSpec.implementation!) { + expect( + result.updatedComponentSpec.implementation.graph.tasks, + ).toHaveProperty("task1 2"); + expect( + result.updatedComponentSpec.implementation.graph.tasks, + ).toHaveProperty("task2 2"); + } }); }); @@ -291,10 +322,15 @@ describe("duplicateNodes", () => { "original-task": mockTaskSpec, }); - const taskNode = createMockTaskNode("original-task", mockTaskSpec); + const nodeManager = createMockNodeManager(); + const taskNode = createMockTaskNode( + "original-task", + mockTaskSpec, + nodeManager, + ); taskNode.selected = true; - const result = duplicateNodes(componentSpec, [taskNode], { + const result = duplicateNodes(componentSpec, [taskNode], nodeManager, { selected: false, }); @@ -308,12 +344,19 @@ describe("duplicateNodes", () => { task2: mockTaskSpec, }); + const nodeManager = createMockNodeManager(); const nodes = [ - createMockTaskNode("task1", mockTaskSpec, { x: 100, y: 100 }), - createMockTaskNode("task2", mockTaskSpec, { x: 200, y: 200 }), + createMockTaskNode("task1", mockTaskSpec, nodeManager, { + x: 100, + y: 100, + }), + createMockTaskNode("task2", mockTaskSpec, nodeManager, { + x: 200, + y: 200, + }), ]; - const result = duplicateNodes(componentSpec, nodes, { + const result = duplicateNodes(componentSpec, nodes, nodeManager, { position: { x: 500, y: 500 }, }); @@ -364,12 +407,13 @@ describe("duplicateNodes", () => { task2, }); + const nodeManager = createMockNodeManager(); const nodes = [ - createMockTaskNode("task1", task1), - createMockTaskNode("task2", task2), + createMockTaskNode("task1", task1, nodeManager), + createMockTaskNode("task2", task2, nodeManager), ]; - const result = duplicateNodes(componentSpec, nodes, { + const result = duplicateNodes(componentSpec, nodes, nodeManager, { connection: "none", }); @@ -387,12 +431,13 @@ describe("duplicateNodes", () => { task2, }); + const nodeManager = createMockNodeManager(); const nodes = [ - createMockTaskNode("task1", task1), - createMockTaskNode("task2", task2), + createMockTaskNode("task1", task1, nodeManager), + createMockTaskNode("task2", task2, nodeManager), ]; - const result = duplicateNodes(componentSpec, nodes, { + const result = duplicateNodes(componentSpec, nodes, nodeManager, { connection: "internal", }); @@ -438,13 +483,14 @@ describe("duplicateNodes", () => { task3, }); + const nodeManager = createMockNodeManager(); // Duplicate task1 and task2, but NOT task3 const nodes = [ - createMockTaskNode("task1", task1), - createMockTaskNode("task2", task2WithConnections), + createMockTaskNode("task1", task1, nodeManager), + createMockTaskNode("task2", task2WithConnections, nodeManager), ]; - const result = duplicateNodes(componentSpec, nodes, { + const result = duplicateNodes(componentSpec, nodes, nodeManager, { connection: "external", }); @@ -472,12 +518,13 @@ describe("duplicateNodes", () => { task2, }); + const nodeManager = createMockNodeManager(); const nodes = [ - createMockTaskNode("task1", task1), - createMockTaskNode("task2", task2), + createMockTaskNode("task1", task1, nodeManager), + createMockTaskNode("task2", task2, nodeManager), ]; - const result = duplicateNodes(componentSpec, nodes, { + const result = duplicateNodes(componentSpec, nodes, nodeManager, { connection: "all", }); @@ -514,12 +561,13 @@ describe("duplicateNodes", () => { inputSpec, ]); + const nodeManager = createMockNodeManager(); const nodes = [ - createMockInputNode("graph-input"), - createMockTaskNode("task1", taskSpec), + createMockInputNode("graph-input", nodeManager), + createMockTaskNode("task1", taskSpec, nodeManager), ]; - const result = duplicateNodes(componentSpec, nodes, { + const result = duplicateNodes(componentSpec, nodes, nodeManager, { connection: "all", }); @@ -551,12 +599,13 @@ describe("duplicateNodes", () => { [outputSpec], ); + const nodeManager = createMockNodeManager(); const nodes = [ - createMockTaskNode("task1", taskSpec), - createMockOutputNode("graph-output"), + createMockTaskNode("task1", taskSpec, nodeManager), + createMockOutputNode("graph-output", nodeManager), ]; - const result = duplicateNodes(componentSpec, nodes, { + const result = duplicateNodes(componentSpec, nodes, nodeManager, { connection: "all", }); @@ -579,7 +628,8 @@ describe("duplicateNodes", () => { describe("edge cases", () => { it("should handle empty node array", () => { const componentSpec = createMockComponentSpec(); - const result = duplicateNodes(componentSpec, []); + const nodeManager = createMockNodeManager(); + const result = duplicateNodes(componentSpec, [], nodeManager); expect(result.newNodes).toHaveLength(0); expect(result.nodeIdMap).toEqual({}); @@ -590,10 +640,15 @@ describe("duplicateNodes", () => { "original-task": mockTaskSpec, }); - const taskNode = createMockTaskNode("original-task", mockTaskSpec); + const nodeManager = createMockNodeManager(); + const taskNode = createMockTaskNode( + "original-task", + mockTaskSpec, + nodeManager, + ); taskNode.measured = { width: 300, height: 200 }; - const result = duplicateNodes(componentSpec, [taskNode]); + const result = duplicateNodes(componentSpec, [taskNode], nodeManager); expect(result.newNodes[0].measured).toEqual({ width: 300, height: 200 }); }); @@ -608,12 +663,14 @@ describe("duplicateNodes", () => { "original-task": taskSpecWithoutPosition, }); + const nodeManager = createMockNodeManager(); const taskNode = createMockTaskNode( "original-task", taskSpecWithoutPosition, + nodeManager, ); - const result = duplicateNodes(componentSpec, [taskNode]); + const result = duplicateNodes(componentSpec, [taskNode], nodeManager); expect(result.newNodes).toHaveLength(1); expect(result.newNodes[0].position).toEqual({ x: 110, y: 110 }); @@ -626,18 +683,22 @@ describe("duplicateNodes", () => { "original-task": mockTaskSpec, }); - const taskNode = createMockTaskNode("original-task", mockTaskSpec); + const nodeManager = createMockNodeManager(); + const taskNode = createMockTaskNode( + "original-task", + mockTaskSpec, + nodeManager, + ); + const originalNodeId = taskNode.id; - const result = duplicateNodes(componentSpec, [taskNode]); + const result = duplicateNodes(componentSpec, [taskNode], nodeManager); expect(result).toHaveProperty("updatedComponentSpec"); expect(result).toHaveProperty("nodeIdMap"); expect(result).toHaveProperty("newNodes"); expect(result).toHaveProperty("updatedNodes"); - expect(result.nodeIdMap).toEqual({ - [taskIdToNodeId("original-task")]: taskIdToNodeId("original-task 2"), - }); + expect(result.nodeIdMap).toHaveProperty(originalNodeId); expect(result.updatedNodes).toHaveLength(1); expect(result.updatedNodes[0]).toBe(taskNode); diff --git a/src/components/shared/ReactFlow/FlowCanvas/utils/duplicateNodes.ts b/src/components/shared/ReactFlow/FlowCanvas/utils/duplicateNodes.ts index 948bdacad..914b9bdb6 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/utils/duplicateNodes.ts +++ b/src/components/shared/ReactFlow/FlowCanvas/utils/duplicateNodes.ts @@ -1,5 +1,6 @@ import { type Node, type XYPosition } from "@xyflow/react"; +import type { NodeManager } from "@/nodeManager"; import { isInputNode, isOutputNode, @@ -19,14 +20,6 @@ import { createInputNode } from "@/utils/nodes/createInputNode"; import { createOutputNode } from "@/utils/nodes/createOutputNode"; import { createTaskNode } from "@/utils/nodes/createTaskNode"; import { getNodesBounds } from "@/utils/nodes/getNodesBounds"; -import { - inputNameToNodeId, - nodeIdToInputName, - nodeIdToOutputName, - nodeIdToTaskId, - outputNameToNodeId, - taskIdToNodeId, -} from "@/utils/nodes/nodeIdUtils"; import { setPositionInAnnotations } from "@/utils/nodes/setPositionInAnnotations"; import { convertTaskCallbacksToNodeCallbacks } from "@/utils/nodes/taskCallbackUtils"; import { @@ -49,6 +42,7 @@ type ConnectionMode = "none" | "internal" | "external" | "all"; export const duplicateNodes = ( componentSpec: ComponentSpec, nodesToDuplicate: Node[], + nodeManager: NodeManager, config?: { selected?: boolean; position?: XYPosition; @@ -63,7 +57,6 @@ export const duplicateNodes = ( const graphSpec = componentSpec.implementation.graph; const nodeIdMap: Record = {}; - const newTasks: Record = {}; const newInputs: Record = {}; const newOutputs: Record = {}; @@ -77,9 +70,14 @@ export const duplicateNodes = ( const oldNodeId = node.id; if (isTaskNode(node)) { - const oldTaskId = nodeIdToTaskId(oldNodeId); + const oldTaskId = nodeManager.getRefId(oldNodeId); + if (!oldTaskId) { + console.warn("Could not find taskId for node:", node); + return; + } + const newTaskId = getUniqueTaskId(graphSpec, oldTaskId); - const newNodeId = taskIdToNodeId(newTaskId); + const newNodeId = nodeManager.getNodeId(newTaskId, "task"); nodeIdMap[oldNodeId] = newNodeId; @@ -102,7 +100,7 @@ export const duplicateNodes = ( ); const newInputName = getUniqueInputName(componentSpec, inputSpec?.name); - const newNodeId = inputNameToNodeId(newInputName); + const newNodeId = nodeManager.getNodeId(newInputName, "input"); nodeIdMap[oldNodeId] = newNodeId; @@ -129,7 +127,7 @@ export const duplicateNodes = ( componentSpec, outputSpec?.name, ); - const newNodeId = outputNameToNodeId(newOutputName); + const newNodeId = nodeManager.getNodeId(newOutputName, "output"); nodeIdMap[oldNodeId] = newNodeId; @@ -172,6 +170,7 @@ export const duplicateNodes = ( nodesToDuplicate, componentSpec, connection, + nodeManager, ); } else { // If the Argument is not a TaskOutput or GraphInput, copy it over @@ -193,7 +192,7 @@ export const duplicateNodes = ( /* Reconfigure Outputs */ Object.entries(newOutputs).forEach((output) => { const [outputName] = output; - const newNodeId = outputNameToNodeId(outputName); + const newNodeId = nodeManager.getNodeId(outputName, "output"); const oldNodeId = Object.keys(nodeIdMap).find( (key) => nodeIdMap[key] === newNodeId, ); @@ -202,9 +201,9 @@ export const duplicateNodes = ( return; } - const oldOutputName = nodeIdToOutputName(oldNodeId); + const oldOutputName = nodeManager.getRefId(oldNodeId); - if (!graphSpec.outputValues) { + if (!graphSpec.outputValues || !oldOutputName) { return; } @@ -225,9 +224,13 @@ export const duplicateNodes = ( ) { if ("taskOutput" in updatedOutputValue) { const oldTaskId = updatedOutputValue.taskOutput.taskId; - const oldTaskNodeId = taskIdToNodeId(oldTaskId); + const oldTaskNodeId = nodeManager.getNodeId(oldTaskId, "task"); if (oldTaskNodeId in nodeIdMap) { - const newTaskId = nodeIdToTaskId(nodeIdMap[oldTaskNodeId]); + const newTaskId = nodeManager.getRefId(nodeIdMap[oldTaskNodeId]); + if (!newTaskId) { + return; + } + updatedOutputValue.taskOutput = { ...updatedOutputValue.taskOutput, taskId: newTaskId, @@ -278,10 +281,14 @@ export const duplicateNodes = ( return null; } - if (isTaskNode(originalNode)) { - const newTaskId = nodeIdToTaskId(newNodeId); + const newId = nodeManager.getRefId(newNodeId); - const newTaskSpec = updatedGraphSpec.tasks[newTaskId]; + if (!newId) { + return null; + } + + if (isTaskNode(originalNode)) { + const newTaskSpec = updatedGraphSpec.tasks[newId]; const nodeData: NodeData = { readOnly: originalNode.data.readOnly, @@ -289,9 +296,10 @@ export const duplicateNodes = ( callbacks: convertTaskCallbacksToNodeCallbacks( originalNode.data.callbacks, ), + nodeManager, }; - const newNode = createTaskNode([newTaskId, newTaskSpec], nodeData); + const newNode = createTaskNode([newId, newTaskSpec], nodeData); newNode.id = newNodeId; newNode.selected = false; @@ -308,9 +316,8 @@ export const duplicateNodes = ( return newNode; } else if (isInputNode(originalNode)) { - const newInputName = nodeIdToInputName(newNodeId); const newInputSpec = updatedInputs.find( - (input) => input.name === newInputName, + (input) => input.name === newId, ); if (!newInputSpec) { @@ -319,6 +326,7 @@ export const duplicateNodes = ( const nodeData: NodeData = { readOnly: originalNode.data.readOnly, + nodeManager, }; const newNode = createInputNode(newInputSpec, nodeData); @@ -338,9 +346,8 @@ export const duplicateNodes = ( return newNode; } else if (isOutputNode(originalNode)) { - const newOutputName = nodeIdToOutputName(newNodeId); const newOutputSpec = updatedOutputs.find( - (output) => output.name === newOutputName, + (output) => output.name === newId, ); if (!newOutputSpec) { @@ -349,6 +356,7 @@ export const duplicateNodes = ( const nodeData: NodeData = { readOnly: originalNode.data.readOnly, + nodeManager, }; const newNode = createOutputNode(newOutputSpec, nodeData); @@ -389,10 +397,14 @@ export const duplicateNodes = ( y: node.position.y + offset.y, }; - if (isTaskNode(node)) { - const taskId = nodeIdToTaskId(node.id); + const newId = nodeManager.getRefId(node.id); + + if (!newId) { + return null; + } - const taskSpec = node.data.taskSpec as TaskSpec; + if (isTaskNode(node)) { + const taskSpec = node.data.taskSpec; const annotations = taskSpec.annotations || {}; const updatedAnnotations = setPositionInAnnotations( @@ -405,13 +417,9 @@ export const duplicateNodes = ( annotations: updatedAnnotations, }; - updatedGraphSpec.tasks[taskId] = newTaskSpec; + updatedGraphSpec.tasks[newId] = newTaskSpec; } else if (isInputNode(node)) { - const newInputName = nodeIdToInputName(node.id); - - const inputSpec = updatedInputs.find( - (input) => input.name === newInputName, - ); + const inputSpec = updatedInputs.find((input) => input.name === newId); if (!inputSpec) { return; @@ -430,17 +438,15 @@ export const duplicateNodes = ( }; const updatedInputIndex = updatedInputs.findIndex( - (input) => input.name === newInputName, + (input) => input.name === newId, ); if (updatedInputIndex !== -1) { updatedInputs[updatedInputIndex] = newInputSpec; } } else if (isOutputNode(node)) { - const newOutputName = nodeIdToOutputName(node.id); - const outputSpec = updatedOutputs.find( - (output) => output.name === newOutputName, + (output) => output.name === newId, ); if (!outputSpec) { @@ -460,7 +466,7 @@ export const duplicateNodes = ( }; const updatedOutputIndex = updatedOutputs.findIndex( - (output) => output.name === newOutputName, + (output) => output.name === newId, ); if (updatedOutputIndex !== -1) { @@ -493,14 +499,15 @@ function reconfigureConnections( nodes: Node[], componentSpec: ComponentSpec, mode: ConnectionMode, + nodeManager: NodeManager, ) { - let oldNodeId = undefined; - let newArgId = undefined; + let oldNodeId: string | undefined = undefined; + let newArgId: string | undefined = undefined; let isExternal = false; if ("taskOutput" in argument) { const oldTaskId = argument.taskOutput.taskId; - oldNodeId = taskIdToNodeId(oldTaskId); + oldNodeId = nodeManager.getNodeId(oldTaskId, "task"); if (!isGraphImplementation(componentSpec.implementation)) { throw new Error("ComponentSpec does not contain a graph implementation."); @@ -515,12 +522,12 @@ function reconfigureConnections( return reconfigureExternalConnection(taskSpec, argKey, mode); } - const newTaskId = nodeIdToTaskId(newNodeId); + const newTaskId = nodeManager.getRefId(newNodeId); newArgId = newTaskId; } else if ("graphInput" in argument) { const oldInputName = argument.graphInput.inputName; - oldNodeId = inputNameToNodeId(oldInputName); + oldNodeId = nodeManager.getNodeId(oldInputName, "input"); if (!("inputs" in componentSpec)) { throw new Error("ComponentSpec does not contain inputs."); @@ -535,7 +542,7 @@ function reconfigureConnections( return reconfigureExternalConnection(taskSpec, argKey, mode); } - const newInputName = nodeIdToInputName(newNodeId); + const newInputName = nodeManager.getRefId(newNodeId); newArgId = newInputName; } diff --git a/src/components/shared/ReactFlow/FlowCanvas/utils/handleConnection.ts b/src/components/shared/ReactFlow/FlowCanvas/utils/handleConnection.ts index fad0e0c97..e643ea778 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/utils/handleConnection.ts +++ b/src/components/shared/ReactFlow/FlowCanvas/utils/handleConnection.ts @@ -1,15 +1,11 @@ import type { Connection } from "@xyflow/react"; +import type { NodeManager } from "@/nodeManager"; import type { GraphInputArgument, GraphSpec, TaskOutputArgument, } from "@/utils/componentSpec"; -import { - nodeIdToInputName, - nodeIdToOutputName, - nodeIdToTaskId, -} from "@/utils/nodes/nodeIdUtils"; import { setGraphOutputValue } from "./setGraphOutputValue"; import { setTaskArgument } from "./setTaskArgument"; @@ -17,44 +13,55 @@ import { setTaskArgument } from "./setTaskArgument"; export const handleConnection = ( graphSpec: GraphSpec, connection: Connection, + nodeManager: NodeManager, ) => { - const targetTaskInputName = connection.targetHandle?.replace(/^input_/, ""); - const sourceTaskOutputName = connection.sourceHandle?.replace(/^output_/, ""); + const sourceId = nodeManager.getRefId(connection.source); + const targetId = nodeManager.getRefId(connection.target); - if (sourceTaskOutputName !== undefined) { + const sourceHandleName = connection.sourceHandle + ? nodeManager.getHandleInfo(connection.sourceHandle)?.handleName + : undefined; + const targetHandleName = connection.targetHandle + ? nodeManager.getHandleInfo(connection.targetHandle)?.handleName + : undefined; + + // Previously sourceHandle & targetHandle were `undefined` for IO Nodes, but in the new NodeManager system the handles now have an id & name. + // Thus, if the handle name is the same as the input/output name, treat it as undefined. + const sourceTaskOutputName = + sourceId && sourceHandleName === sourceId ? undefined : sourceHandleName; + const targetTaskInputName = + targetId && targetHandleName === targetId ? undefined : targetHandleName; + + if (sourceTaskOutputName !== undefined && sourceId) { const taskOutputArgument: TaskOutputArgument = { taskOutput: { - taskId: nodeIdToTaskId(connection.source), + taskId: sourceId, outputName: sourceTaskOutputName, }, }; - if (targetTaskInputName !== undefined) { + if (targetTaskInputName !== undefined && targetId) { return setTaskArgument( graphSpec, - nodeIdToTaskId(connection.target), + targetId, targetTaskInputName, taskOutputArgument, ); - } else { - return setGraphOutputValue( - graphSpec, - nodeIdToOutputName(connection.target), - taskOutputArgument, - ); + } else if (targetId) { + return setGraphOutputValue(graphSpec, targetId, taskOutputArgument); // TODO: Perhaps propagate type information } - } else { - const graphInputName = nodeIdToInputName(connection.source); + } else if (sourceId) { + const graphInputName = sourceId; const graphInputArgument: GraphInputArgument = { graphInput: { inputName: graphInputName, }, }; - if (targetTaskInputName !== undefined) { + if (targetTaskInputName !== undefined && targetId) { return setTaskArgument( graphSpec, - nodeIdToTaskId(connection.target), + targetId, targetTaskInputName, graphInputArgument, ); diff --git a/src/components/shared/ReactFlow/FlowCanvas/utils/removeEdge.ts b/src/components/shared/ReactFlow/FlowCanvas/utils/removeEdge.ts index 2aacc0289..d572f3163 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/utils/removeEdge.ts +++ b/src/components/shared/ReactFlow/FlowCanvas/utils/removeEdge.ts @@ -1,46 +1,79 @@ import type { Edge } from "@xyflow/react"; -import type { ComponentSpec, GraphImplementation } from "@/utils/componentSpec"; +import type { NodeManager } from "@/nodeManager"; import { - nodeIdToInputName, - nodeIdToOutputName, - nodeIdToTaskId, -} from "@/utils/nodes/nodeIdUtils"; + type ComponentSpec, + isGraphImplementation, +} from "@/utils/componentSpec"; import { setGraphOutputValue } from "./setGraphOutputValue"; import { setTaskArgument } from "./setTaskArgument"; -export const removeEdge = (edge: Edge, componentSpec: ComponentSpec) => { - const graphSpec = (componentSpec.implementation as GraphImplementation) - ?.graph; +export const removeEdge = ( + edge: Edge, + componentSpec: ComponentSpec, + nodeManager: NodeManager, +) => { + if (!isGraphImplementation(componentSpec.implementation)) { + return componentSpec; + } + + const graphSpec = componentSpec.implementation.graph; + const updatedComponentSpec = { ...componentSpec }; - if (!edge.targetHandle) { + const targetNodeId = edge.target; + const targetTaskId = nodeManager.getRefId(targetNodeId); + const targetNodeType = nodeManager.getNodeType(targetNodeId); + + if (!targetTaskId || !targetNodeType) { + console.error("Could not resolve target node information:", { + targetNodeId, + targetTaskId, + targetNodeType, + }); return componentSpec; } - const inputName = nodeIdToInputName(edge.targetHandle); - - const updatedComponentSpec = { - ...componentSpec, - }; - - if (inputName !== undefined && graphSpec) { - const taskId = nodeIdToTaskId(edge.target); - const newGraphSpec = setTaskArgument(graphSpec, taskId, inputName); - updatedComponentSpec.implementation = { - ...updatedComponentSpec.implementation, - graph: newGraphSpec, - }; - - return updatedComponentSpec; - } else { - const outputName = nodeIdToOutputName(edge.target); - const newGraphSpec = setGraphOutputValue(graphSpec, outputName); - updatedComponentSpec.implementation = { - ...updatedComponentSpec.implementation, - graph: newGraphSpec, - }; - - return updatedComponentSpec; + switch (targetNodeType) { + case "task": { + if (!edge.targetHandle) { + console.error("No target handle found for task connection"); + return componentSpec; + } + + const targetHandleInfo = nodeManager.getHandleInfo(edge.targetHandle); + if (!targetHandleInfo) { + console.error("Could not resolve target handle info"); + return componentSpec; + } + + const inputName = targetHandleInfo.handleName; + const newGraphSpec = setTaskArgument(graphSpec, targetTaskId, inputName); + + updatedComponentSpec.implementation = { + ...updatedComponentSpec.implementation, + graph: newGraphSpec, + }; + break; + } + + case "output": { + const newGraphSpec = setGraphOutputValue(graphSpec, targetTaskId); + + updatedComponentSpec.implementation = { + ...updatedComponentSpec.implementation, + graph: newGraphSpec, + }; + break; + } + + default: + console.error( + "Unsupported target node type for edge removal:", + targetNodeType, + ); + return componentSpec; } + + return updatedComponentSpec; }; diff --git a/src/components/shared/ReactFlow/FlowCanvas/utils/removeNode.ts b/src/components/shared/ReactFlow/FlowCanvas/utils/removeNode.ts index 923ea8744..481819a81 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/utils/removeNode.ts +++ b/src/components/shared/ReactFlow/FlowCanvas/utils/removeNode.ts @@ -1,32 +1,33 @@ import { type Node } from "@xyflow/react"; +import type { NodeManager } from "@/nodeManager"; import { type ComponentSpec, isGraphImplementation, } from "@/utils/componentSpec"; -import { - nodeIdToInputName, - nodeIdToOutputName, - nodeIdToTaskId, -} from "@/utils/nodes/nodeIdUtils"; import { setGraphOutputValue } from "./setGraphOutputValue"; import { setTaskArgument } from "./setTaskArgument"; -export const removeNode = (node: Node, componentSpec: ComponentSpec) => { +export const removeNode = ( + node: Node, + componentSpec: ComponentSpec, + nodeManager: NodeManager, +) => { + const id = nodeManager.getRefId(node.id); + + if (!id) return componentSpec; + if (node.type === "task") { - const taskId = nodeIdToTaskId(node.id); - return removeTask(taskId, componentSpec); + return removeTask(id, componentSpec); } if (node.type === "input") { - const inputName = nodeIdToInputName(node.id); - return removeGraphInput(inputName, componentSpec); + return removeGraphInput(id, componentSpec); } if (node.type === "output") { - const outputName = nodeIdToOutputName(node.id); - return removeGraphOutput(outputName, componentSpec); + return removeGraphOutput(id, componentSpec); } return componentSpec; diff --git a/src/components/shared/ReactFlow/FlowCanvas/utils/updateNodePosition.test.ts b/src/components/shared/ReactFlow/FlowCanvas/utils/updateNodePosition.test.ts index 8b193e848..cfaea11d7 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/utils/updateNodePosition.test.ts +++ b/src/components/shared/ReactFlow/FlowCanvas/utils/updateNodePosition.test.ts @@ -1,11 +1,30 @@ import type { Node } from "@xyflow/react"; -import { describe, expect, test } from "vitest"; +import { beforeEach, describe, expect, test, vi } from "vitest"; -import type { ComponentSpec } from "@/utils/componentSpec"; +import type { NodeManager } from "@/nodeManager"; +import { + type ComponentSpec, + isGraphImplementation, +} from "@/utils/componentSpec"; import { updateNodePositions } from "./updateNodePosition"; describe("updateNodePositions", () => { + const mockNodeManager = { + getRefId: vi.fn(), + } as unknown as NodeManager; + + beforeEach(() => { + vi.clearAllMocks(); + + vi.mocked(mockNodeManager.getRefId).mockImplementation((nodeId: string) => { + if (nodeId === "task_123") return "123"; + if (nodeId === "input_test_input") return "test_input"; + if (nodeId === "output_test_output") return "test_output"; + return undefined; + }); + }); + test("should throw an error if implementation is not graph", () => { const nodes: Node[] = []; const componentSpec: ComponentSpec = { @@ -19,9 +38,9 @@ describe("updateNodePositions", () => { outputs: [], }; - expect(() => updateNodePositions(nodes, componentSpec)).toThrow( - "Component spec is not a graph", - ); + expect(() => + updateNodePositions(nodes, componentSpec, mockNodeManager), + ).toThrow("Component spec is not a graph"); }); test("should update task node positions", () => { @@ -48,21 +67,15 @@ describe("updateNodePositions", () => { }, }; - const result = updateNodePositions(nodes, componentSpec); + const result = updateNodePositions(nodes, componentSpec, mockNodeManager); + const graph = isGraphImplementation(result.implementation) + ? result.implementation.graph + : null; + expect(graph).not.toBeNull(); + if (!graph) return; - expect(result).toMatchObject({ - implementation: { - graph: { - tasks: { - "123": { - componentRef: {}, - annotations: { - "editor.position": JSON.stringify({ x: 100, y: 200 }), - }, - }, - }, - }, - }, + expect(graph.tasks["123"].annotations).toEqual({ + "editor.position": JSON.stringify({ x: 100, y: 200 }), }); }); @@ -88,18 +101,10 @@ describe("updateNodePositions", () => { ], }; - const result = updateNodePositions(nodes, componentSpec); + const result = updateNodePositions(nodes, componentSpec, mockNodeManager); - expect(result).toMatchObject({ - inputs: [ - { - name: "test_input", - type: "string", - annotations: { - "editor.position": JSON.stringify({ x: 50, y: 100 }), - }, - }, - ], + expect(result.inputs![0].annotations).toEqual({ + "editor.position": JSON.stringify({ x: 50, y: 100 }), }); }); @@ -125,18 +130,10 @@ describe("updateNodePositions", () => { ], }; - const result = updateNodePositions(nodes, componentSpec); + const result = updateNodePositions(nodes, componentSpec, mockNodeManager); - expect(result).toMatchObject({ - outputs: [ - { - name: "test_output", - type: "string", - annotations: { - "editor.position": JSON.stringify({ x: 300, y: 150 }), - }, - }, - ], + expect(result.outputs![0].annotations).toEqual({ + "editor.position": JSON.stringify({ x: 300, y: 150 }), }); }); @@ -190,35 +187,23 @@ describe("updateNodePositions", () => { ], }; - const result = updateNodePositions(nodes, componentSpec); + const result = updateNodePositions(nodes, componentSpec, mockNodeManager); + const graph = isGraphImplementation(result.implementation) + ? result.implementation.graph + : null; + expect(graph).not.toBeNull(); + if (!graph) return; - expect(result).toMatchObject({ - implementation: { - graph: { - tasks: { - "123": expect.objectContaining({ - componentRef: {}, - annotations: { - "editor.position": JSON.stringify({ x: 100, y: 200 }), - }, - }), - }, - }, - }, - inputs: [ - expect.objectContaining({ - annotations: { - "editor.position": JSON.stringify({ x: 50, y: 100 }), - }, - }), - ], - outputs: [ - expect.objectContaining({ - annotations: { - "editor.position": JSON.stringify({ x: 300, y: 150 }), - }, - }), - ], + expect(graph.tasks["123"].annotations).toEqual({ + "editor.position": JSON.stringify({ x: 100, y: 200 }), + }); + + expect(result.inputs![0].annotations).toEqual({ + "editor.position": JSON.stringify({ x: 50, y: 100 }), + }); + + expect(result.outputs![0].annotations).toEqual({ + "editor.position": JSON.stringify({ x: 300, y: 150 }), }); }); @@ -248,22 +233,43 @@ describe("updateNodePositions", () => { }, }; - const result = updateNodePositions(nodes, componentSpec); + const result = updateNodePositions(nodes, componentSpec, mockNodeManager); + + const graph = isGraphImplementation(result.implementation) + ? result.implementation.graph + : null; + expect(graph).not.toBeNull(); + if (!graph) return; + + expect(graph.tasks["123"].annotations).toEqual({ + "existing.annotation": "value", + "editor.position": JSON.stringify({ x: 100, y: 200 }), + }); + }); + + test("should skip nodes with no ref ID", () => { + vi.mocked(mockNodeManager.getRefId).mockReturnValue(undefined); + + const nodes: Node[] = [ + { + id: "unknown_node", + type: "task", + position: { x: 100, y: 200 }, + data: {}, + }, + ]; - expect(result).toMatchObject({ + const componentSpec: ComponentSpec = { + name: "test", implementation: { graph: { - tasks: { - "123": { - componentRef: {}, - annotations: { - "existing.annotation": "value", - "editor.position": JSON.stringify({ x: 100, y: 200 }), - }, - }, - }, + tasks: {}, }, }, - }); + }; + + const result = updateNodePositions(nodes, componentSpec, mockNodeManager); + + expect(result).toEqual(componentSpec); }); }); diff --git a/src/components/shared/ReactFlow/FlowCanvas/utils/updateNodePosition.ts b/src/components/shared/ReactFlow/FlowCanvas/utils/updateNodePosition.ts index 9d1829176..dfcce4606 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/utils/updateNodePosition.ts +++ b/src/components/shared/ReactFlow/FlowCanvas/utils/updateNodePosition.ts @@ -1,19 +1,16 @@ import type { Node } from "@xyflow/react"; +import type { NodeManager } from "@/nodeManager"; import { type ComponentSpec, isGraphImplementation, } from "@/utils/componentSpec"; -import { - nodeIdToInputName, - nodeIdToOutputName, - nodeIdToTaskId, -} from "@/utils/nodes/nodeIdUtils"; import { setPositionInAnnotations } from "@/utils/nodes/setPositionInAnnotations"; export const updateNodePositions = ( updatedNodes: Node[], componentSpec: ComponentSpec, + nodeManager: NodeManager, ) => { const newComponentSpec = { ...componentSpec }; @@ -31,10 +28,13 @@ export const updateNodePositions = ( y: node.position.y, }; + const refId = nodeManager.getRefId(node.id); + + if (!refId) continue; + if (node.type === "task") { - const taskId = nodeIdToTaskId(node.id); - if (updatedGraphSpec.tasks[taskId]) { - const taskSpec = { ...updatedGraphSpec.tasks[taskId] }; + if (updatedGraphSpec.tasks[refId]) { + const taskSpec = { ...updatedGraphSpec.tasks[refId] }; const annotations = taskSpec.annotations || {}; @@ -48,14 +48,13 @@ export const updateNodePositions = ( annotations: updatedAnnotations, }; - updatedGraphSpec.tasks[taskId] = newTaskSpec; + updatedGraphSpec.tasks[refId] = newTaskSpec; newComponentSpec.implementation.graph = updatedGraphSpec; } } else if (node.type === "input") { - const inputName = nodeIdToInputName(node.id); const inputs = [...(newComponentSpec.inputs || [])]; - const inputIndex = inputs.findIndex((input) => input.name === inputName); + const inputIndex = inputs.findIndex((input) => input.name === refId); if (inputIndex >= 0) { const annotations = inputs[inputIndex].annotations || {}; @@ -73,11 +72,8 @@ export const updateNodePositions = ( newComponentSpec.inputs = inputs; } } else if (node.type === "output") { - const outputName = nodeIdToOutputName(node.id); const outputs = [...(newComponentSpec.outputs || [])]; - const outputIndex = outputs.findIndex( - (output) => output.name === outputName, - ); + const outputIndex = outputs.findIndex((output) => output.name === refId); if (outputIndex >= 0) { const annotations = outputs[outputIndex].annotations || {}; diff --git a/src/hooks/useComponentSpecToEdges.test.ts b/src/hooks/useComponentSpecToEdges.test.ts index 14be3ac3a..9a6d18bc1 100644 --- a/src/hooks/useComponentSpecToEdges.test.ts +++ b/src/hooks/useComponentSpecToEdges.test.ts @@ -1,11 +1,19 @@ -import { renderHook } from "@testing-library/react"; import { MarkerType } from "@xyflow/react"; -import { describe, expect, it } from "vitest"; +import { describe, expect, it, vi } from "vitest"; + +import type { NodeManager } from "@/nodeManager"; import type { ComponentSpec } from "../utils/componentSpec"; -import useComponentSpecToEdges from "./useComponentSpecToEdges"; +import { getEdges } from "./useComponentSpecToEdges"; + +describe("getEdges", () => { + const mockNodeManager: NodeManager = { + getNodeId: vi.fn((refId: string, type: string) => `${type}_${refId}`), + getHandleNodeId: vi.fn( + (_refId: string, handleName: string) => `${handleName}`, + ), + } as any; -describe("useComponentSpecToEdges", () => { const createBasicComponentSpec = (implementation: any): ComponentSpec => ({ name: "Test Component", implementation, @@ -18,9 +26,9 @@ describe("useComponentSpecToEdges", () => { container: { image: "test" }, }); - const { result } = renderHook(() => useComponentSpecToEdges(componentSpec)); + const result = getEdges(componentSpec, mockNodeManager); - expect(result.current.edges).toEqual([]); + expect(result).toEqual([]); }); it("creates task edges correctly", () => { @@ -40,14 +48,14 @@ describe("useComponentSpecToEdges", () => { }, }); - const { result } = renderHook(() => useComponentSpecToEdges(componentSpec)); + const result = getEdges(componentSpec, mockNodeManager); - expect(result.current.edges).toContainEqual({ + expect(result).toContainEqual({ id: "task2_output1-task1_input1", source: "task_task2", - sourceHandle: "output_output1", + sourceHandle: "output1", target: "task_task1", - targetHandle: "input_input1", + targetHandle: "input1", markerEnd: { type: MarkerType.Arrow }, type: "customEdge", }); @@ -68,14 +76,14 @@ describe("useComponentSpecToEdges", () => { }, }); - const { result } = renderHook(() => useComponentSpecToEdges(componentSpec)); + const result = getEdges(componentSpec, mockNodeManager); - expect(result.current.edges).toContainEqual({ + expect(result).toContainEqual({ id: "Input_graphInput1-task1_input1", source: "input_graphInput1", sourceHandle: null, target: "task_task1", - targetHandle: "input_input1", + targetHandle: "input1", markerEnd: { type: MarkerType.Arrow }, type: "customEdge", }); @@ -98,104 +106,12 @@ describe("useComponentSpecToEdges", () => { outputs: [], }; - const { result } = renderHook(() => useComponentSpecToEdges(componentSpec)); + const result = getEdges(componentSpec, mockNodeManager); - expect(result.current.edges).toContainEqual({ + expect(result).toContainEqual({ id: "task1_output1-Output_graphOutput1", source: "task_task1", - sourceHandle: "output_output1", - target: "output_graphOutput1", - targetHandle: null, - markerEnd: { type: MarkerType.Arrow }, - type: "customEdge", - }); - }); - - it("handles string arguments by returning no edges", () => { - const componentSpec: ComponentSpec = { - name: "Test Component", - implementation: { - graph: { - tasks: { - task1: { - componentRef: {}, - arguments: { - input1: "string value", - }, - }, - }, - outputValues: {}, - }, - }, - inputs: [], - outputs: [], - }; - - const { result } = renderHook(() => useComponentSpecToEdges(componentSpec)); - expect(result.current.edges).toEqual([]); - }); - - it("handles complex component specs with multiple edge types", () => { - const componentSpec: ComponentSpec = { - name: "Test Component", - implementation: { - graph: { - tasks: { - task1: { - componentRef: {}, - arguments: { - input1: { graphInput: { inputName: "graphInput1" } }, - input2: "static value", - }, - }, - task2: { - componentRef: {}, - arguments: { - input1: { - taskOutput: { taskId: "task1", outputName: "output1" }, - }, - }, - }, - }, - outputValues: { - graphOutput1: { - taskOutput: { taskId: "task2", outputName: "output1" }, - }, - }, - }, - }, - inputs: [], - outputs: [], - }; - - const { result } = renderHook(() => useComponentSpecToEdges(componentSpec)); - - expect(result.current.edges).toHaveLength(3); - - expect(result.current.edges).toContainEqual({ - id: "Input_graphInput1-task1_input1", - source: "input_graphInput1", - sourceHandle: null, - target: "task_task1", - targetHandle: "input_input1", - markerEnd: { type: MarkerType.Arrow }, - type: "customEdge", - }); - - expect(result.current.edges).toContainEqual({ - id: "task1_output1-task2_input1", - source: "task_task1", - sourceHandle: "output_output1", - target: "task_task2", - targetHandle: "input_input1", - markerEnd: { type: MarkerType.Arrow }, - type: "customEdge", - }); - - expect(result.current.edges).toContainEqual({ - id: "task2_output1-Output_graphOutput1", - source: "task_task2", - sourceHandle: "output_output1", + sourceHandle: "output1", target: "output_graphOutput1", targetHandle: null, markerEnd: { type: MarkerType.Arrow }, diff --git a/src/hooks/useComponentSpecToEdges.ts b/src/hooks/useComponentSpecToEdges.ts index e3bad6e17..c385d3091 100644 --- a/src/hooks/useComponentSpecToEdges.ts +++ b/src/hooks/useComponentSpecToEdges.ts @@ -6,6 +6,7 @@ import { } from "@xyflow/react"; import { useEffect } from "react"; +import { NodeManager } from "@/nodeManager"; import { type ArgumentType, type ComponentSpec, @@ -15,11 +16,8 @@ import { type TaskOutputArgument, type TaskSpec, } from "@/utils/componentSpec"; -import { - inputNameToNodeId, - outputNameToNodeId, - taskIdToNodeId, -} from "@/utils/nodes/nodeIdUtils"; + +import { useNodeManager } from "./useNodeManager"; const useComponentSpecToEdges = ( componentSpec: ComponentSpec, @@ -27,14 +25,15 @@ const useComponentSpecToEdges = ( edges: Edge[]; onEdgesChange: (changes: EdgeChange[]) => void; } => { + const { nodeManager } = useNodeManager(); const [flowEdges, setFlowEdges, onFlowEdgesChange] = useEdgesState( - getEdges(componentSpec), + getEdges(componentSpec, nodeManager), ); useEffect(() => { - const newEdges = getEdges(componentSpec); + const newEdges = getEdges(componentSpec, nodeManager); setFlowEdges(newEdges); - }, [componentSpec]); + }, [componentSpec, nodeManager]); return { edges: flowEdges, @@ -42,28 +41,38 @@ const useComponentSpecToEdges = ( }; }; -const getEdges = (componentSpec: ComponentSpec) => { +export const getEdges = ( + componentSpec: ComponentSpec, + nodeManager: NodeManager, +) => { if (!isGraphImplementation(componentSpec.implementation)) { return []; } const graphSpec = componentSpec.implementation.graph; - const taskEdges = createEdgesFromTaskSpec(graphSpec); - const outputEdges = createOutputEdgesFromGraphSpec(graphSpec); + const taskEdges = createEdgesFromTaskSpec(graphSpec, nodeManager); + const outputEdges = createOutputEdgesFromGraphSpec(graphSpec, nodeManager); return [...taskEdges, ...outputEdges]; }; -const createEdgesFromTaskSpec = (graphSpec: GraphSpec) => { +const createEdgesFromTaskSpec = ( + graphSpec: GraphSpec, + nodeManager: NodeManager, +) => { const edges: Edge[] = Object.entries(graphSpec.tasks).flatMap( - ([taskId, taskSpec]) => createEdgesForTask(taskId, taskSpec), + ([taskId, taskSpec]) => createEdgesForTask(taskId, taskSpec, nodeManager), ); return edges; }; -const createEdgesForTask = (taskId: string, taskSpec: TaskSpec): Edge[] => { +const createEdgesForTask = ( + taskId: string, + taskSpec: TaskSpec, + nodeManager: NodeManager, +): Edge[] => { return Object.entries(taskSpec.arguments ?? {}).flatMap( ([inputName, argument]) => - createEdgeForArgument(taskId, inputName, argument), + createEdgeForArgument(taskId, inputName, argument, nodeManager), ); }; @@ -71,17 +80,22 @@ const createEdgeForArgument = ( taskId: string, inputName: string, argument: ArgumentType, + nodeManager: NodeManager, ): Edge[] => { if (typeof argument === "string") { return []; } if ("taskOutput" in argument) { - return [createTaskOutputEdge(taskId, inputName, argument.taskOutput)]; + return [ + createTaskOutputEdge(taskId, inputName, argument.taskOutput, nodeManager), + ]; } if ("graphInput" in argument) { - return [createGraphInputEdge(taskId, inputName, argument.graphInput)]; + return [ + createGraphInputEdge(taskId, inputName, argument.graphInput, nodeManager), + ]; } console.error("Impossible task input argument kind: ", argument); @@ -92,13 +106,29 @@ const createTaskOutputEdge = ( taskId: string, inputName: string, taskOutput: TaskOutputArgument["taskOutput"], + nodeManager: NodeManager, ): Edge => { + const sourceNodeId = nodeManager.getNodeId(taskOutput.taskId, "task"); + const targetNodeId = nodeManager.getNodeId(taskId, "task"); + + const sourceHandleNodeId = nodeManager.getHandleNodeId( + taskOutput.taskId, + taskOutput.outputName, + "handle-out", + ); + + const targetHandleNodeId = nodeManager.getHandleNodeId( + taskId, + inputName, + "handle-in", + ); + return { id: `${taskOutput.taskId}_${taskOutput.outputName}-${taskId}_${inputName}`, - source: taskIdToNodeId(taskOutput.taskId), - sourceHandle: outputNameToNodeId(taskOutput.outputName), - target: taskIdToNodeId(taskId), - targetHandle: inputNameToNodeId(inputName), + source: sourceNodeId, + sourceHandle: sourceHandleNodeId, + target: targetNodeId, + targetHandle: targetHandleNodeId, markerEnd: { type: MarkerType.Arrow }, type: "customEdge", }; @@ -108,27 +138,50 @@ const createGraphInputEdge = ( taskId: string, inputName: string, graphInput: GraphInputArgument["graphInput"], + nodeManager: NodeManager, ): Edge => { + const sourceNodeId = nodeManager.getNodeId(graphInput.inputName, "input"); + const targetNodeId = nodeManager.getNodeId(taskId, "task"); + + const targetHandleNodeId = nodeManager.getHandleNodeId( + taskId, + inputName, + "handle-in", + ); + return { id: `Input_${graphInput.inputName}-${taskId}_${inputName}`, - source: inputNameToNodeId(graphInput.inputName), + source: sourceNodeId, sourceHandle: null, - target: taskIdToNodeId(taskId), - targetHandle: inputNameToNodeId(inputName), + target: targetNodeId, + targetHandle: targetHandleNodeId, type: "customEdge", markerEnd: { type: MarkerType.Arrow }, }; }; -const createOutputEdgesFromGraphSpec = (graphSpec: GraphSpec) => { +const createOutputEdgesFromGraphSpec = ( + graphSpec: GraphSpec, + nodeManager: NodeManager, +) => { const outputEdges: Edge[] = Object.entries(graphSpec.outputValues ?? {}).map( ([outputName, argument]) => { const taskOutput = argument.taskOutput; + + const sourceNodeId = nodeManager.getNodeId(taskOutput.taskId, "task"); + const targetNodeId = nodeManager.getNodeId(outputName, "output"); + + const sourceHandleNodeId = nodeManager.getHandleNodeId( + taskOutput.taskId, + taskOutput.outputName, + "handle-out", + ); + const edge: Edge = { id: `${taskOutput.taskId}_${taskOutput.outputName}-Output_${outputName}`, - source: taskIdToNodeId(taskOutput.taskId), - sourceHandle: outputNameToNodeId(taskOutput.outputName), - target: outputNameToNodeId(outputName), + source: sourceNodeId, + sourceHandle: sourceHandleNodeId, + target: targetNodeId, targetHandle: null, type: "customEdge", markerEnd: { type: MarkerType.Arrow }, diff --git a/src/hooks/useNodeCallbacks.ts b/src/hooks/useNodeCallbacks.ts index 9fc473795..bf8ca899c 100644 --- a/src/hooks/useNodeCallbacks.ts +++ b/src/hooks/useNodeCallbacks.ts @@ -42,6 +42,7 @@ export const useNodeCallbacks = ({ updateGraphSpec, componentSpec, setComponentSpec, + nodeManager, } = useComponentSpec(); // Workaround for nodes state being stale in task node callbacks @@ -182,7 +183,9 @@ export const useNodeCallbacks = ({ updatedComponentSpec: updatedSubgraphSpec, newNodes, updatedNodes, - } = duplicateNodes(currentSubgraphSpec, [node], { selected }); + } = duplicateNodes(currentSubgraphSpec, [node], nodeManager, { + selected, + }); const updatedRootSpec = updateSubgraphSpec( componentSpec, @@ -201,6 +204,7 @@ export const useNodeCallbacks = ({ componentSpec, currentSubgraphSpec, currentSubgraphPath, + nodeManager, getNodeById, setComponentSpec, updateOrAddNodes, diff --git a/src/providers/TaskNodeProvider.tsx b/src/providers/TaskNodeProvider.tsx index 114720082..45f8c0de2 100644 --- a/src/providers/TaskNodeProvider.tsx +++ b/src/providers/TaskNodeProvider.tsx @@ -3,6 +3,7 @@ import { type ReactNode, useCallback, useMemo } from "react"; import type { ContainerExecutionStatus } from "@/api/types.gen"; import useComponentFromUrl from "@/hooks/useComponentFromUrl"; +import { useNodeManager } from "@/hooks/useNodeManager"; import { useTaskNodeDimensions } from "@/hooks/useTaskNodeDimensions"; import useToastNotification from "@/hooks/useToastNotification"; import type { Annotations } from "@/types/annotations"; @@ -15,7 +16,6 @@ import type { TaskSpec, } from "@/utils/componentSpec"; import { getComponentName } from "@/utils/getComponentName"; -import { taskIdToNodeId } from "@/utils/nodes/nodeIdUtils"; import { createRequiredContext, @@ -72,10 +72,11 @@ export const TaskNodeProvider = ({ }: TaskNodeProviderProps) => { const notify = useToastNotification(); const reactFlowInstance = useReactFlow(); + const { getTaskNodeId } = useNodeManager(); const taskSpec = data.taskSpec; const taskId = data.taskId; - const nodeId = taskId ? taskIdToNodeId(taskId) : ""; + const nodeId = getTaskNodeId(taskId); const componentRef = taskSpec?.componentRef || {}; const inputs = componentRef.spec?.inputs || []; diff --git a/src/types/nodes.ts b/src/types/nodes.ts index f0c23858f..844543dc2 100644 --- a/src/types/nodes.ts +++ b/src/types/nodes.ts @@ -1,5 +1,6 @@ import type { Node } from "@xyflow/react"; +import type { NodeManager } from "@/nodeManager"; import type { ArgumentType, ComponentReference, @@ -28,6 +29,7 @@ export interface NodeData extends Record { readOnly: boolean; connectable?: boolean; callbacks?: NodeCallbacks; + nodeManager: NodeManager; } export interface TaskNodeData extends Record { diff --git a/src/utils/nodes/createInputNode.ts b/src/utils/nodes/createInputNode.ts index 6193a664c..433c70ae7 100644 --- a/src/utils/nodes/createInputNode.ts +++ b/src/utils/nodes/createInputNode.ts @@ -4,14 +4,14 @@ import type { IONodeData, NodeData } from "@/types/nodes"; import type { InputSpec } from "../componentSpec"; import { extractPositionFromAnnotations } from "./extractPositionFromAnnotations"; -import { inputNameToNodeId } from "./nodeIdUtils"; export const createInputNode = (input: InputSpec, nodeData: NodeData) => { const { name, annotations } = input; - const { readOnly } = nodeData; + const { nodeManager, readOnly } = nodeData; + + const nodeId = nodeManager.getNodeId(name, "input"); const position = extractPositionFromAnnotations(annotations); - const nodeId = inputNameToNodeId(name); const inputNodeData: IONodeData = { spec: input, diff --git a/src/utils/nodes/createNodesFromComponentSpec.test.ts b/src/utils/nodes/createNodesFromComponentSpec.test.ts index c1f75f2e3..000887b23 100644 --- a/src/utils/nodes/createNodesFromComponentSpec.test.ts +++ b/src/utils/nodes/createNodesFromComponentSpec.test.ts @@ -1,5 +1,7 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { NodeManager } from "@/nodeManager"; + import type { ComponentSpec } from "../componentSpec"; import { isGraphImplementation } from "../componentSpec"; import createNodesFromComponentSpec from "./createNodesFromComponentSpec"; @@ -21,15 +23,26 @@ describe("createNodesFromComponentSpec", () => { onUpgrade: vi.fn(), }; + const mockNodeManager = { + getNodeId: vi.fn((refId: string, type: string) => `${type}_${refId}`), + getHandleNodeId: vi.fn(), + getHandleInfo: vi.fn(), + getNodeType: vi.fn(), + getRefId: vi.fn(), + updateRefId: vi.fn(), + } as unknown as NodeManager; + const readOnly = false; const mockNodeData = { readOnly, - nodeCallbacks: mockNodeCallbacks, + connectable: true, + callbacks: mockNodeCallbacks, + nodeManager: mockNodeManager, }; beforeEach(() => { - mockNodeCallbacks.setArguments.mockClear(); + vi.clearAllMocks(); }); it("returns empty array for non-graph implementations", () => { diff --git a/src/utils/nodes/createOutputNode.ts b/src/utils/nodes/createOutputNode.ts index 6fee2ee03..aff2e13d1 100644 --- a/src/utils/nodes/createOutputNode.ts +++ b/src/utils/nodes/createOutputNode.ts @@ -4,14 +4,14 @@ import type { IONodeData, NodeData } from "@/types/nodes"; import type { OutputSpec } from "../componentSpec"; import { extractPositionFromAnnotations } from "./extractPositionFromAnnotations"; -import { outputNameToNodeId } from "./nodeIdUtils"; export const createOutputNode = (output: OutputSpec, nodeData: NodeData) => { const { name, annotations } = output; - const { readOnly } = nodeData; + const { nodeManager, readOnly } = nodeData; + + const nodeId = nodeManager.getNodeId(name, "output"); const position = extractPositionFromAnnotations(annotations); - const nodeId = outputNameToNodeId(name); const outputNodeData: IONodeData = { spec: output, diff --git a/src/utils/nodes/createTaskNode.ts b/src/utils/nodes/createTaskNode.ts index 69246709f..27f410f88 100644 --- a/src/utils/nodes/createTaskNode.ts +++ b/src/utils/nodes/createTaskNode.ts @@ -4,7 +4,6 @@ import type { NodeData, TaskNodeData } from "@/types/nodes"; import type { TaskSpec } from "../componentSpec"; import { extractPositionFromAnnotations } from "./extractPositionFromAnnotations"; -import { taskIdToNodeId } from "./nodeIdUtils"; import { convertNodeCallbacksToTaskCallbacks } from "./taskCallbackUtils"; export const createTaskNode = ( @@ -12,10 +11,11 @@ export const createTaskNode = ( nodeData: NodeData, ) => { const [taskId, taskSpec] = task; - const { callbacks, connectable = true, ...data } = nodeData; + const { nodeManager, callbacks, connectable = true, ...data } = nodeData; + + const nodeId = nodeManager.getNodeId(taskId, "task"); const position = extractPositionFromAnnotations(taskSpec.annotations); - const nodeId = taskIdToNodeId(taskId); // Inject the taskId and nodeId into the callbacks const taskCallbacks = convertNodeCallbacksToTaskCallbacks( diff --git a/src/utils/nodes/nodeIdUtils.test.ts b/src/utils/nodes/nodeIdUtils.test.ts deleted file mode 100644 index dda2e8e06..000000000 --- a/src/utils/nodes/nodeIdUtils.test.ts +++ /dev/null @@ -1,72 +0,0 @@ -import { describe, expect, it } from "vitest"; - -import { - inputNameToNodeId, - nodeIdToInputName, - nodeIdToOutputName, - nodeIdToTaskId, - outputNameToNodeId, - taskIdToNodeId, -} from "./nodeIdUtils"; - -describe("nodeIdUtils", () => { - describe("nodeIdToTaskId", () => { - it('should extract task ID by removing the "task_" prefix', () => { - expect(nodeIdToTaskId("task_123")).toBe("123"); - expect(nodeIdToTaskId("task_abc")).toBe("abc"); - expect(nodeIdToTaskId("task_")).toBe(""); - }); - - it("should return the original string if no prefix exists", () => { - expect(nodeIdToTaskId("123")).toBe("123"); - }); - }); - - describe("nodeIdToInputName", () => { - it('should extract input name by removing the "input_" prefix', () => { - expect(nodeIdToInputName("input_name")).toBe("name"); - expect(nodeIdToInputName("input_data")).toBe("data"); - expect(nodeIdToInputName("input_")).toBe(""); - }); - - it("should return the original string if no prefix exists", () => { - expect(nodeIdToInputName("name")).toBe("name"); - }); - }); - - describe("nodeIdToOutputName", () => { - it('should extract output name by removing the "output_" prefix', () => { - expect(nodeIdToOutputName("output_result")).toBe("result"); - expect(nodeIdToOutputName("output_data")).toBe("data"); - expect(nodeIdToOutputName("output_")).toBe(""); - }); - - it("should return the original string if no prefix exists", () => { - expect(nodeIdToOutputName("result")).toBe("result"); - }); - }); - - describe("taskIdToNodeId", () => { - it('should create a task node ID by adding the "task_" prefix', () => { - expect(taskIdToNodeId("123")).toBe("task_123"); - expect(taskIdToNodeId("abc")).toBe("task_abc"); - expect(taskIdToNodeId("")).toBe("task_"); - }); - }); - - describe("inputNameToNodeId", () => { - it('should create an input node ID by adding the "input_" prefix', () => { - expect(inputNameToNodeId("name")).toBe("input_name"); - expect(inputNameToNodeId("data")).toBe("input_data"); - expect(inputNameToNodeId("")).toBe("input_"); - }); - }); - - describe("outputNameToNodeId", () => { - it('should create an output node ID by adding the "output_" prefix', () => { - expect(outputNameToNodeId("result")).toBe("output_result"); - expect(outputNameToNodeId("data")).toBe("output_data"); - expect(outputNameToNodeId("")).toBe("output_"); - }); - }); -}); diff --git a/src/utils/nodes/nodeIdUtils.ts b/src/utils/nodes/nodeIdUtils.ts deleted file mode 100644 index 2b2b025cb..000000000 --- a/src/utils/nodes/nodeIdUtils.ts +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Utility functions for converting between node IDs and their corresponding names/identifiers - */ - -/** - * Extracts the task ID from a task node ID by removing the "task_" prefix - */ -export const nodeIdToTaskId = (id: string) => id.replace(/^task_/, ""); - -/** - * Extracts the input name from an input node ID by removing the "input_" prefix - */ -export const nodeIdToInputName = (id: string) => id.replace(/^input_/, ""); - -/** - * Extracts the output name from an output node ID by removing the "output_" prefix - */ -export const nodeIdToOutputName = (id: string) => id.replace(/^output_/, ""); - -/** - * Creates a task node ID by adding the "task_" prefix to a task ID - */ -export const taskIdToNodeId = (taskId: string) => `task_${taskId}`; - -/** - * Creates an input node ID by adding the "input_" prefix to an input name - */ -export const inputNameToNodeId = (inputName: string) => `input_${inputName}`; - -/** - * Creates an output node ID by adding the "output_" prefix to an output name - */ -export const outputNameToNodeId = (outputName: string) => - `output_${outputName}`;