diff --git a/src/components/shared/ReactFlow/FlowCanvas/IONode/IONode.tsx b/src/components/shared/ReactFlow/FlowCanvas/IONode/IONode.tsx index f12843b50..d904274be 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/IONode/IONode.tsx +++ b/src/components/shared/ReactFlow/FlowCanvas/IONode/IONode.tsx @@ -175,6 +175,7 @@ const IONode = ({ type, data, selected = false }: IONodeProps) => { { - const { nodeId, state, name } = useTaskNode(); + const { getTaskInputNodeId } = useNodeManager(); + const { taskId, nodeId, state, name } = useTaskNode(); const fromHandle = useConnection((connection) => connection.fromHandle?.id); const toHandle = useConnection((connection) => connection.toHandle?.id); @@ -45,7 +47,7 @@ export const InputHandle = ({ const [selected, setSelected] = useState(false); const [active, setActive] = useState(false); - const handleId = getInputHandleId(input.name); + const handleId = getTaskInputNodeId(taskId, input.name); const missing = invalid ? "bg-red-700!" : "bg-gray-500!"; const hasValue = value !== undefined && value !== null; @@ -228,7 +230,8 @@ export const OutputHandle = ({ onLabelClick, onHandleSelectionChange, }: OutputHandleProps) => { - const { nodeId, state, name } = useTaskNode(); + const { getTaskOutputNodeId } = useNodeManager(); + const { taskId, nodeId, state, name } = useTaskNode(); const fromHandle = useConnection((connection) => connection.fromHandle?.id); const toHandle = useConnection((connection) => connection.toHandle?.id); @@ -240,7 +243,7 @@ export const OutputHandle = ({ const [selected, setSelected] = useState(false); const [active, setActive] = useState(false); - const handleId = getOutputHandleId(output.name); + const handleId = getTaskOutputNodeId(taskId, output.name); const hasValue = value !== undefined && value !== "" && value !== null; const handleHandleClick = useCallback( @@ -374,14 +377,6 @@ export const OutputHandle = ({ ); }; -const getOutputHandleId = (outputName: string) => { - return `output_${outputName}`; -}; - -const getInputHandleId = (inputName: string) => { - return `input_${inputName}`; -}; - const skipHandleDeselect = (e: MouseEvent) => { let el = e.target as HTMLElement | null; while (el) { diff --git a/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeInputs.tsx b/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeInputs.tsx index ce8415c99..8677588ab 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeInputs.tsx +++ b/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeInputs.tsx @@ -11,7 +11,6 @@ import { useTaskNode } from "@/providers/TaskNodeProvider"; import { inputsWithInvalidArguments } from "@/services/componentService"; import type { InputSpec } from "@/utils/componentSpec"; import { ComponentSearchFilter } from "@/utils/constants"; -import { inputNameToInputId } from "@/utils/nodes/conversions"; import { checkArtifactMatchesSearchFilters } from "@/utils/searchUtils"; import { InputHandle } from "./Handles"; @@ -28,8 +27,8 @@ export function TaskNodeInputs({ expanded, onBackgroundClick, }: TaskNodeInputsProps) { - const { getInputNodeId } = useNodeManager(); - const { inputs, taskSpec, state, select } = useTaskNode(); + const { getTaskInputNodeId } = useNodeManager(); + const { taskId, inputs, taskSpec, state, select } = useTaskNode(); const { graphSpec } = useComponentSpec(); const { highlightSearchFilter, @@ -147,7 +146,7 @@ export function TaskNodeInputs({ } const input = inputs.find( - (i) => getInputNodeId(inputNameToInputId(i.name)) === fromHandle?.id, + (i) => getTaskInputNodeId(taskId, i.name) === fromHandle?.id, ); if (!input) return; diff --git a/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeOutputs.tsx b/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeOutputs.tsx index 45e1abd33..63b7c2580 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeOutputs.tsx +++ b/src/components/shared/ReactFlow/FlowCanvas/TaskNode/TaskNodeCard/TaskNodeOutputs.tsx @@ -8,7 +8,6 @@ import { isValidFilterRequest } from "@/providers/ComponentLibraryProvider/types import { useTaskNode } from "@/providers/TaskNodeProvider"; import type { OutputSpec } from "@/utils/componentSpec"; import { ComponentSearchFilter } from "@/utils/constants"; -import { outputNameToOutputId } from "@/utils/nodes/conversions"; import { checkArtifactMatchesSearchFilters } from "@/utils/searchUtils"; import { OutputHandle } from "./Handles"; @@ -24,8 +23,8 @@ export function TaskNodeOutputs({ expanded, onBackgroundClick, }: TaskNodeOutputsProps) { - const { getOutputNodeId } = useNodeManager(); - const { nodeId, outputs, state, select } = useTaskNode(); + const { getTaskOutputNodeId } = useNodeManager(); + const { taskId, nodeId, outputs, state, select } = useTaskNode(); const { highlightSearchFilter, resetSearchFilter, @@ -42,8 +41,7 @@ export function TaskNodeOutputs({ edges.some( (edge) => edge.source === nodeId && - edge.sourceHandle === - getOutputNodeId(outputNameToOutputId(output.name)), + edge.sourceHandle === getTaskOutputNodeId(taskId, output.name), ), ); @@ -141,7 +139,7 @@ export function TaskNodeOutputs({ } const output = outputs.find( - (o) => getOutputNodeId(outputNameToOutputId(o.name)) === fromHandle?.id, + (o) => getTaskOutputNodeId(taskId, o.name) === fromHandle?.id, ); if (!output) return; diff --git a/src/components/shared/ReactFlow/FlowCanvas/utils/addAndConnectNode.ts b/src/components/shared/ReactFlow/FlowCanvas/utils/addAndConnectNode.ts index 8e2655a98..16ed78a85 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/utils/addAndConnectNode.ts +++ b/src/components/shared/ReactFlow/FlowCanvas/utils/addAndConnectNode.ts @@ -10,8 +10,8 @@ import { } from "@/utils/componentSpec"; import { DEFAULT_NODE_DIMENSIONS } from "@/utils/constants"; import { - inputNameToInputId, - outputNameToOutputId, + inputIdToInputName, + outputIdToOutputName, } from "@/utils/nodes/conversions"; import addTask from "./addTask"; @@ -44,8 +44,58 @@ export function addAndConnectNode({ const oldGraphSpec = componentSpec.implementation.graph; - const fromHandleId = fromHandle?.id; - const fromHandleType = fromHandleId?.startsWith("input") ? "input" : "output"; + if (!fromHandle?.id) { + return componentSpec; + } + + const isTaskHandle = nodeManager.isManaged(fromHandle.id); + let fromHandleType: "input" | "output"; + let fromHandleName: string | undefined; + let fromTaskId: string | undefined; + + if (isTaskHandle) { + // Handle is managed by NodeManager (task handle) + const fromHandleInfo = nodeManager.getHandleInfo(fromHandle.id); + const fromNodeType = nodeManager.getNodeType(fromHandle.id); + + if (!fromHandleInfo || !fromNodeType) { + return componentSpec; + } + + fromHandleType = fromNodeType === "taskInput" ? "input" : "output"; + fromHandleName = fromHandleInfo.handleName; + fromTaskId = fromHandleInfo.taskId; + } else { + // Simple IO node handle - get info from the source node, not the handle + const fromNodeId = fromHandle.nodeId; + const fromNodeType = nodeManager.getNodeType(fromNodeId); + + if (!fromNodeType) { + return componentSpec; + } + + if (fromNodeType === "input") { + fromHandleType = "output"; + const inputId = nodeManager.getTaskId(fromNodeId); + if (inputId) { + fromHandleName = inputIdToInputName(inputId); + fromTaskId = inputId; + } + } else if (fromNodeType === "output") { + fromHandleType = "input"; + const outputId = nodeManager.getTaskId(fromNodeId); + if (outputId) { + fromHandleName = outputIdToOutputName(outputId); + fromTaskId = outputId; + } + } else { + return componentSpec; + } + } + + if (!fromTaskId || !fromHandleName) { + return componentSpec; + } const adjustedPosition = fromHandleType === "input" @@ -77,22 +127,17 @@ export function addAndConnectNode({ const newNodeId = nodeManager.getNodeId(newTaskId, "task"); // 3. Determine the connection data type and find the first matching handle on the new node - if (!fromHandle) { - return newComponentSpec; + let fromComponentSpec: ComponentSpec | undefined; + + if (isTaskHandle) { + // Get spec from task + const fromTaskSpec = graphSpec.tasks[fromTaskId]; + fromComponentSpec = fromTaskSpec?.componentRef.spec; + } else { + // For IO nodes, get spec from component spec + fromComponentSpec = componentSpec; } - const fromTaskId = nodeManager.getTaskId(fromHandle.nodeId); - if (!fromTaskId) { - return newComponentSpec; - } - - const fromTaskSpec = graphSpec.tasks[fromTaskId]; - const fromComponentSpec = fromTaskSpec?.componentRef.spec; - - const fromNodeId = fromHandle.nodeId; - - const fromHandleName = fromHandleId?.replace(`${fromHandleType}_`, ""); - let connectionType: TypeSpecType | undefined; if (fromHandleType === "input") { connectionType = fromComponentSpec?.inputs?.find( @@ -106,7 +151,6 @@ export function addAndConnectNode({ // Find the first matching handle on the new node const toHandleType = fromHandleType === "input" ? "output" : "input"; - let targetHandleId: string | undefined; if (toHandleType === "input") { @@ -117,8 +161,11 @@ export function addAndConnectNode({ return newComponentSpec; } - const inputId = inputNameToInputId(handleName); - targetHandleId = nodeManager.getNodeId(inputId, "input"); + targetHandleId = nodeManager.getTaskHandleNodeId( + newTaskId, + handleName, + "taskInput", + ); } else if (toHandleType === "output") { const handleName = componentRef.spec?.outputs?.find( (io) => io.type === connectionType, @@ -127,14 +174,21 @@ export function addAndConnectNode({ return newComponentSpec; } - const outputId = outputNameToOutputId(handleName); - targetHandleId = nodeManager.getNodeId(outputId, "output"); + targetHandleId = nodeManager.getTaskHandleNodeId( + newTaskId, + handleName, + "taskOutput", + ); } // 4. Build a Connection object and use handleConnection to add the edge - if (fromNodeId && fromHandleId && targetHandleId) { + if (targetHandleId) { + const fromNodeId = fromHandle.nodeId; + const fromHandleId = fromHandle.id; + const isReversedConnection = fromHandleType === "input" && toHandleType === "output"; + const connection: Connection = isReversedConnection ? // Drawing from an input handle to a new output handle { diff --git a/src/components/shared/ReactFlow/FlowCanvas/utils/handleConnection.ts b/src/components/shared/ReactFlow/FlowCanvas/utils/handleConnection.ts index 773587f01..e415d51a2 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/utils/handleConnection.ts +++ b/src/components/shared/ReactFlow/FlowCanvas/utils/handleConnection.ts @@ -1,107 +1,178 @@ import type { Connection } from "@xyflow/react"; -import type { NodeManager } from "@/nodeManager"; +import type { NodeManager, NodeType } from "@/nodeManager"; import type { GraphInputArgument, GraphSpec, TaskOutputArgument, } from "@/utils/componentSpec"; -import { inputIdToInputName } from "@/utils/nodes/conversions"; +import { + inputIdToInputName, + outputIdToOutputName, +} from "@/utils/nodes/conversions"; import { setGraphOutputValue } from "./setGraphOutputValue"; import { setTaskArgument } from "./setTaskArgument"; +type NodeInfo = { + id: string; + type?: NodeType; + handle?: { + taskId: string; + handleName: string; + }; +}; + export const handleConnection = ( graphSpec: GraphSpec, connection: Connection, nodeManager: NodeManager, ) => { - const targetTaskInputName = connection.targetHandle?.replace(/^input_/, ""); - const sourceTaskOutputName = connection.sourceHandle?.replace(/^output_/, ""); - - const targetId = nodeManager.getTaskId(connection.target); - const sourceId = nodeManager.getTaskId(connection.source); - - if (sourceTaskOutputName !== undefined) { - if (!sourceId) { - console.error( - "addConnection: Could not find task ID for source node: ", - connection.source, - ); - return graphSpec; - } - - const taskOutputArgument: TaskOutputArgument = { - taskOutput: { - taskId: sourceId, - outputName: sourceTaskOutputName, - }, - }; - - if (targetTaskInputName !== undefined) { - if (!targetId) { - console.error( - "addConnection: Could not find Input ID for target node: ", - connection.target, - ); - return graphSpec; - } - - return setTaskArgument( - graphSpec, - targetId, - targetTaskInputName, - taskOutputArgument, - ); - } else { - if (!targetId) { - console.error( - "addConnection: Could not find Output ID for target node: ", - connection.target, - ); - return graphSpec; - } - - return setGraphOutputValue(graphSpec, targetId, taskOutputArgument); - // TODO: Perhaps propagate type information - } - } else { - if (!sourceId) { - console.error( - "addConnection: Could not find task ID for source node: ", - connection.source, - ); + const sourceId = nodeManager.getTaskId(connection.source!); + const sourceType = nodeManager.getNodeType(connection.source!); + + const targetId = nodeManager.getTaskId(connection.target!); + const targetType = nodeManager.getNodeType(connection.target!); + + if (!sourceId || !targetId || !sourceType || !targetType) { + console.error("Could not resolve node information:", { + sourceId, + sourceType, + targetId, + targetType, + }); + return graphSpec; + } + + if (sourceId === targetId) { + console.warn("Cannot connect node to itself"); + return graphSpec; + } + + let sourceHandleInfo: { taskId: string; handleName: string } | undefined; + let targetHandleInfo: { taskId: string; handleName: string } | undefined; + + if (connection.sourceHandle) { + sourceHandleInfo = nodeManager.getHandleInfo(connection.sourceHandle); + } + + if (connection.targetHandle) { + targetHandleInfo = nodeManager.getHandleInfo(connection.targetHandle); + } + + const source: NodeInfo = { + id: sourceId, + type: sourceType, + handle: sourceHandleInfo, + }; + + const target: NodeInfo = { + id: targetId, + type: targetType, + handle: targetHandleInfo, + }; + + const connectionType = `${source.type}_to_${target.type}` as const; + + switch (connectionType) { + case "input_to_task": + return handleGraphInputToTask(graphSpec, source, target); + + case "task_to_task": + return handleTaskToTask(graphSpec, source, target); + + case "task_to_output": + return handleTaskToGraphOutput(graphSpec, source, target); + + default: + console.error("Unsupported connection pattern:", connectionType); return graphSpec; - } - const inputName = inputIdToInputName(sourceId); - const graphInputArgument: GraphInputArgument = { - graphInput: { - inputName: inputName, - }, - }; - if (targetTaskInputName !== undefined) { - if (!targetId) { - console.error( - "addConnection: Could not find Output ID for target node: ", - connection.target, - ); - return graphSpec; - } - - return setTaskArgument( - graphSpec, - targetId, - targetTaskInputName, - graphInputArgument, - ); - } else { - console.error( - "addConnection: Cannot directly connect graph input to graph output: ", - connection, - ); - } } +}; + +const handleGraphInputToTask = ( + graphSpec: GraphSpec, + source: NodeInfo, + target: NodeInfo, +): GraphSpec => { + if (!target.handle?.handleName) { + console.error( + "Target handle name missing for graph input to task connection", + ); + return graphSpec; + } + + const inputId = source.id; + const inputName = inputIdToInputName(inputId); + const targetInputName = target.handle.handleName; + + const graphInputArgument: GraphInputArgument = { + graphInput: { inputName }, + }; + + return setTaskArgument( + graphSpec, + target.id, + targetInputName, + graphInputArgument, + ); +}; + +const handleTaskToTask = ( + graphSpec: GraphSpec, + source: NodeInfo, + target: NodeInfo, +): GraphSpec => { + if (!source.handle?.handleName) { + console.error("Source handle name missing for task to task connection"); + return graphSpec; + } + + if (!target.handle?.handleName) { + console.error("Target handle name missing for task to task connection"); + return graphSpec; + } + + const sourceOutputName = source.handle.handleName; + const targetInputName = target.handle.handleName; + + const taskOutputArgument: TaskOutputArgument = { + taskOutput: { + taskId: source.id, + outputName: sourceOutputName, + }, + }; + + return setTaskArgument( + graphSpec, + target.id, + targetInputName, + taskOutputArgument, + ); +}; + +const handleTaskToGraphOutput = ( + graphSpec: GraphSpec, + source: NodeInfo, + target: NodeInfo, +): GraphSpec => { + if (!source.handle?.handleName) { + console.error( + "Source handle name missing for task to graph output connection", + ); + return graphSpec; + } + + const sourceOutputName = source.handle.handleName; + const outputId = target.id; + const outputName = outputIdToOutputName(outputId); + + const taskOutputArgument: TaskOutputArgument = { + taskOutput: { + taskId: source.id, + outputName: sourceOutputName, + }, + }; - // GraphSpec was not updated (due to an error or other reason) - return graphSpec; + return setGraphOutputValue(graphSpec, outputName, taskOutputArgument); }; diff --git a/src/components/shared/ReactFlow/FlowCanvas/utils/removeEdge.ts b/src/components/shared/ReactFlow/FlowCanvas/utils/removeEdge.ts index 3e02034a7..4a6f3fecd 100644 --- a/src/components/shared/ReactFlow/FlowCanvas/utils/removeEdge.ts +++ b/src/components/shared/ReactFlow/FlowCanvas/utils/removeEdge.ts @@ -1,7 +1,10 @@ import type { Edge } from "@xyflow/react"; import type { NodeManager } from "@/nodeManager"; -import type { ComponentSpec, GraphImplementation } from "@/utils/componentSpec"; +import { + type ComponentSpec, + isGraphImplementation, +} from "@/utils/componentSpec"; import { outputIdToOutputName } from "@/utils/nodes/conversions"; import { setGraphOutputValue } from "./setGraphOutputValue"; @@ -12,34 +15,67 @@ export const removeEdge = ( componentSpec: ComponentSpec, nodeManager: NodeManager, ) => { - const graphSpec = (componentSpec.implementation as GraphImplementation) - ?.graph; - - const inputName = edge.targetHandle?.replace(/^input_/, ""); - - const updatedComponentSpec = { - ...componentSpec, - }; - - const taskId = nodeManager.getTaskId(edge.target); - if (!taskId) return componentSpec; - - if (inputName !== undefined && graphSpec) { - const newGraphSpec = setTaskArgument(graphSpec, taskId, inputName); - updatedComponentSpec.implementation = { - ...updatedComponentSpec.implementation, - graph: newGraphSpec, - }; - - return updatedComponentSpec; - } else { - const outputName = outputIdToOutputName(taskId); - const newGraphSpec = setGraphOutputValue(graphSpec, outputName); - updatedComponentSpec.implementation = { - ...updatedComponentSpec.implementation, - graph: newGraphSpec, - }; - - return updatedComponentSpec; + if (!isGraphImplementation(componentSpec.implementation)) { + return componentSpec; } + + const graphSpec = componentSpec.implementation.graph; + const updatedComponentSpec = { ...componentSpec }; + + const targetNodeId = edge.target; + const targetTaskId = nodeManager.getTaskId(targetNodeId); + const targetNodeType = nodeManager.getNodeType(targetNodeId); + + if (!targetTaskId || !targetNodeType) { + console.error("Could not resolve target node information:", { + targetNodeId, + targetTaskId, + targetNodeType, + }); + return componentSpec; + } + + switch (targetNodeType) { + case "task": { + if (!edge.targetHandle) { + console.error("No target handle found for task connection"); + return componentSpec; + } + + const targetHandleInfo = nodeManager.getHandleInfo(edge.targetHandle); + if (!targetHandleInfo) { + console.error("Could not resolve target handle info"); + return componentSpec; + } + + const inputName = targetHandleInfo.handleName; + const newGraphSpec = setTaskArgument(graphSpec, targetTaskId, inputName); + + updatedComponentSpec.implementation = { + ...updatedComponentSpec.implementation, + graph: newGraphSpec, + }; + break; + } + + case "output": { + const outputName = outputIdToOutputName(targetTaskId); + const newGraphSpec = setGraphOutputValue(graphSpec, outputName); + + updatedComponentSpec.implementation = { + ...updatedComponentSpec.implementation, + graph: newGraphSpec, + }; + break; + } + + default: + console.error( + "Unsupported target node type for edge removal:", + targetNodeType, + ); + return componentSpec; + } + + return updatedComponentSpec; }; diff --git a/src/hooks/useComponentSpecToEdges.ts b/src/hooks/useComponentSpecToEdges.ts index 49e122f12..c2d163279 100644 --- a/src/hooks/useComponentSpecToEdges.ts +++ b/src/hooks/useComponentSpecToEdges.ts @@ -110,14 +110,22 @@ const createTaskOutputEdge = ( nodeManager: NodeManager, ): Edge => { const sourceNodeId = nodeManager.getNodeId(taskOutput.taskId, "task"); - const sourceOutputId = outputNameToOutputId(taskOutput.outputName); - const sourceHandleNodeId = nodeManager.getNodeId(sourceOutputId, "output"); const targetNodeId = nodeManager.getNodeId(taskId, "task"); - const targetInputId = inputNameToInputId(inputName); - const targetHandleNodeId = nodeManager.getNodeId(targetInputId, "input"); + + const sourceHandleNodeId = nodeManager.getTaskHandleNodeId( + taskOutput.taskId, + taskOutput.outputName, + "taskOutput", + ); + + const targetHandleNodeId = nodeManager.getTaskHandleNodeId( + taskId, + inputName, + "taskInput", + ); return { - id: `${taskOutput.taskId}_${sourceOutputId}-${taskId}_${targetInputId}`, + id: `${taskOutput.taskId}_${taskOutput.outputName}-${taskId}_${inputName}`, source: sourceNodeId, sourceHandle: sourceHandleNodeId, target: targetNodeId, @@ -136,11 +144,15 @@ const createGraphInputEdge = ( const inputId = inputNameToInputId(graphInput.inputName); const sourceNodeId = nodeManager.getNodeId(inputId, "input"); const targetNodeId = nodeManager.getNodeId(taskId, "task"); - const targetInputId = inputNameToInputId(inputName); - const targetHandleNodeId = nodeManager.getNodeId(targetInputId, "input"); + + const targetHandleNodeId = nodeManager.getTaskHandleNodeId( + taskId, + inputName, + "taskInput", + ); return { - id: `Input_${inputId}-${taskId}_${targetInputId}`, + id: `Input_${inputId}-${taskId}_${inputName}`, source: sourceNodeId, sourceHandle: null, target: targetNodeId, @@ -159,16 +171,17 @@ const createOutputEdgesFromGraphSpec = ( const taskOutput = argument.taskOutput; const sourceNodeId = nodeManager.getNodeId(taskOutput.taskId, "task"); - const sourceOutputId = outputNameToOutputId(taskOutput.outputName); - const sourceHandleNodeId = nodeManager.getNodeId( - sourceOutputId, - "output", - ); const targetOutputId = outputNameToOutputId(outputName); const targetNodeId = nodeManager.getNodeId(targetOutputId, "output"); + const sourceHandleNodeId = nodeManager.getTaskHandleNodeId( + taskOutput.taskId, + taskOutput.outputName, + "taskOutput", + ); + const edge: Edge = { - id: `${taskOutput.taskId}_${sourceOutputId}-Output_${targetOutputId}`, + id: `${taskOutput.taskId}_${taskOutput.outputName}-Output_${outputName}`, source: sourceNodeId, sourceHandle: sourceHandleNodeId, target: targetNodeId,