From caf1539396c088efaf7041eb88fece93c3c9f20d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:23:45 +0000 Subject: [PATCH 01/16] Initial plan From 622ae7d263e7dd789e1c38202d5a0d50a58b3486 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:32:07 +0000 Subject: [PATCH 02/16] Add IdentityFixPass to fix invalid graphs with direct input-output connections Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- src/onnx_ir/passes/common/__init__.py | 4 + src/onnx_ir/passes/common/identity_fix.py | 101 ++++ .../passes/common/identity_fix_test.py | 540 ++++++++++++++++++ 3 files changed, 645 insertions(+) create mode 100644 src/onnx_ir/passes/common/identity_fix.py create mode 100644 src/onnx_ir/passes/common/identity_fix_test.py diff --git a/src/onnx_ir/passes/common/__init__.py b/src/onnx_ir/passes/common/__init__.py index 5fe199a1..cd436dec 100644 --- a/src/onnx_ir/passes/common/__init__.py +++ b/src/onnx_ir/passes/common/__init__.py @@ -9,6 +9,7 @@ "DeduplicateHashedInitializersPass", "DeduplicateInitializersPass", "IdentityEliminationPass", + "IdentityFixPass", "InlinePass", "LiftConstantsToInitializersPass", "LiftSubgraphInitializersToMainGraphPass", @@ -36,6 +37,9 @@ from onnx_ir.passes.common.identity_elimination import ( IdentityEliminationPass, ) +from onnx_ir.passes.common.identity_fix import ( + IdentityFixPass, +) from onnx_ir.passes.common.initializer_deduplication import ( DeduplicateHashedInitializersPass, DeduplicateInitializersPass, diff --git a/src/onnx_ir/passes/common/identity_fix.py b/src/onnx_ir/passes/common/identity_fix.py new file mode 100644 index 00000000..662b4af8 --- /dev/null +++ b/src/onnx_ir/passes/common/identity_fix.py @@ -0,0 +1,101 @@ +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 +"""Identity fix pass for adding Identity nodes when graph inputs are directly used as outputs.""" + +from __future__ import annotations + +__all__ = [ + "IdentityFixPass", +] + +import logging + +import onnx_ir as ir + +logger = logging.getLogger(__name__) + + +class IdentityFixPass(ir.passes.InPlacePass): + """Pass for adding Identity nodes when graph inputs are directly used as outputs. + + This pass adds Identity nodes according to the following rule: + + If a graph input is directly used as a graph output (without any intermediate nodes), + insert an Identity node between them. This turns an invalid ONNX graph into a valid one. + + Example transformation: + Before: input -> (direct connection) -> output + After: input -> Identity -> output + + This is required because ONNX specification does not allow a graph input to be + directly used as a graph output without any processing nodes in between. + """ + + def call(self, model: ir.Model) -> ir.passes.PassResult: + """Main entry point for the identity fix pass.""" + modified = False + + # Process the main graph + if self._process_graph(model.graph): + modified = True + + # Process functions + for function in model.functions.values(): + if self._process_graph(function): + modified = True + + if modified: + logger.info("Identity fix pass modified the model") + + return ir.passes.PassResult(model, modified=modified) + + def _process_graph(self, graph_like: ir.Graph | ir.Function) -> bool: + """Process a single graph or function, returning True if modified.""" + modified = False + + # Check each output to see if it's directly a graph input + outputs_to_fix = [] + for output in graph_like.outputs: + if output.is_graph_input(): + outputs_to_fix.append(output) + + # Add Identity nodes for each output that needs fixing + for output in outputs_to_fix: + # Create an Identity node + identity_node = ir.Node("", "Identity", inputs=[output]) + identity_output = identity_node.outputs[0] + + # Copy metadata from the original output + identity_output.name = output.name + identity_output.shape = output.shape + identity_output.type = output.type + if output.metadata_props: + identity_output.metadata_props.update(output.metadata_props) + identity_output.doc_string = output.doc_string + + # Add the node to the graph + graph_like.append(identity_node) + + # Replace the output with the Identity node's output + # Find the index of the output in the graph outputs + output_index = graph_like.outputs.index(output) + graph_like.outputs[output_index] = identity_output + + logger.debug( + "Added Identity node for graph input '%s' used as output", output.name + ) + modified = True + + # Process subgraphs in nodes + for node in graph_like: + for attr in node.attributes.values(): + if isinstance(attr, ir.Attr): + if attr.type == ir.AttributeType.GRAPH and attr.value is not None: + if self._process_graph(attr.value): + modified = True + elif attr.type == ir.AttributeType.GRAPHS and attr.value is not None: + for subgraph in attr.value: + if subgraph is not None and self._process_graph(subgraph): + modified = True + + return modified diff --git a/src/onnx_ir/passes/common/identity_fix_test.py b/src/onnx_ir/passes/common/identity_fix_test.py new file mode 100644 index 00000000..5ea47575 --- /dev/null +++ b/src/onnx_ir/passes/common/identity_fix_test.py @@ -0,0 +1,540 @@ +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the identity fix pass.""" + +from __future__ import annotations + +import unittest + +import onnx_ir as ir +from onnx_ir.passes.common import identity_fix + + +class TestIdentityFixPass(unittest.TestCase): + """Test cases for IdentityFixPass.""" + + def test_add_identity_when_input_is_direct_output(self): + """Test: Add Identity node when graph input is directly used as output.""" + # Create a simple model: input -> (direct) -> output + input_value = ir.val( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + graph = ir.Graph( + inputs=[input_value], + outputs=[input_value], # Input is directly used as output + nodes=[], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_fix.IdentityFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify an Identity node was added + nodes = list(result.model.graph) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].op_type, "Identity") + + # Verify the Identity node uses the input + identity_node = nodes[0] + self.assertIs(identity_node.inputs[0], input_value) + + # Verify the output is now the Identity node's output + self.assertEqual(len(result.model.graph.outputs), 1) + self.assertIs(result.model.graph.outputs[0], identity_node.outputs[0]) + + # Verify the output name is preserved + self.assertEqual(result.model.graph.outputs[0].name, "input") + + def test_no_modification_when_identity_exists(self): + """Test: No modification when Identity node already exists between input and output.""" + # Create a model: input -> Identity -> output + input_value = ir.val( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + identity_node = ir.Node("", "Identity", inputs=[input_value]) + identity_node.outputs[0].name = "output" + identity_node.outputs[0].shape = input_value.shape + identity_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[identity_node.outputs[0]], # Output is Identity's output + nodes=[identity_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_fix.IdentityFixPass() + result = pass_instance(model) + + # Verify the pass did NOT modify the model + self.assertFalse(result.modified) + + # Verify structure is unchanged + nodes = list(result.model.graph) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].op_type, "Identity") + + def test_no_modification_when_node_exists_between_input_and_output(self): + """Test: No modification when a processing node exists between input and output.""" + # Create a model: input -> Add -> output + input_value = ir.val( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.outputs[0].name = "output" + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_fix.IdentityFixPass() + result = pass_instance(model) + + # Verify the pass did NOT modify the model + self.assertFalse(result.modified) + + # Verify structure is unchanged + nodes = list(result.model.graph) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].op_type, "Add") + + def test_multiple_inputs_one_direct_output(self): + """Test: Add Identity for one input that's directly used as output, leave others alone.""" + # Create inputs + input1 = ir.val( + "input1", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + input2 = ir.val( + "input2", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create a node that uses input2 + add_node = ir.Node("", "Add", inputs=[input2, input2]) + add_node.outputs[0].name = "output2" + add_node.outputs[0].shape = input2.shape + add_node.outputs[0].type = input2.type + + graph = ir.Graph( + inputs=[input1, input2], + outputs=[input1, add_node.outputs[0]], # input1 is directly used as output + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_fix.IdentityFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify one Identity node was added + nodes = list(result.model.graph) + self.assertEqual(len(nodes), 2) # Add + Identity + + identity_nodes = [n for n in nodes if n.op_type == "Identity"] + self.assertEqual(len(identity_nodes), 1) + + # Verify the Identity node uses input1 + identity_node = identity_nodes[0] + self.assertIs(identity_node.inputs[0], input1) + + # Verify outputs + self.assertEqual(len(result.model.graph.outputs), 2) + # First output should be the Identity node's output + self.assertIs(result.model.graph.outputs[0], identity_node.outputs[0]) + # Second output should still be the Add node's output + self.assertIs(result.model.graph.outputs[1], add_node.outputs[0]) + + def test_multiple_direct_outputs(self): + """Test: Add Identity nodes for multiple inputs used directly as outputs.""" + # Create inputs + input1 = ir.val( + "input1", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + input2 = ir.val( + "input2", shape=ir.Shape([3, 3]), type=ir.TensorType(ir.DataType.INT32) + ) + + graph = ir.Graph( + inputs=[input1, input2], + outputs=[input1, input2], # Both inputs directly used as outputs + nodes=[], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_fix.IdentityFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify two Identity nodes were added + nodes = list(result.model.graph) + self.assertEqual(len(nodes), 2) + self.assertTrue(all(n.op_type == "Identity" for n in nodes)) + + # Verify both inputs are used by Identity nodes + identity_inputs = [n.inputs[0] for n in nodes] + self.assertIn(input1, identity_inputs) + self.assertIn(input2, identity_inputs) + + # Verify outputs are now Identity nodes' outputs + self.assertEqual(len(result.model.graph.outputs), 2) + for output in result.model.graph.outputs: + self.assertIsNotNone(output.producer()) + self.assertEqual(output.producer().op_type, "Identity") + + def test_empty_graph(self): + """Test: Pass on an empty graph.""" + graph = ir.Graph(inputs=[], outputs=[], nodes=[], name="empty_graph") + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_fix.IdentityFixPass() + result = pass_instance(model) + + # Verify the pass did not modify the model + self.assertFalse(result.modified) + + # Verify structure is unchanged + self.assertEqual(len(list(result.model.graph)), 0) + + def test_graph_with_no_direct_input_output(self): + """Test: Graph with inputs and outputs but no direct connections.""" + input_value = ir.val( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create a Constant node (output doesn't come from input) + const_node = ir.Node("", "Constant", inputs=[]) + const_node.outputs[0].name = "output" + const_node.outputs[0].shape = ir.Shape([2, 2]) + const_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + graph = ir.Graph( + inputs=[input_value], + outputs=[const_node.outputs[0]], + nodes=[const_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_fix.IdentityFixPass() + result = pass_instance(model) + + # Verify the pass did not modify the model + self.assertFalse(result.modified) + + # Verify structure is unchanged + nodes = list(result.model.graph) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].op_type, "Constant") + + def test_preserve_output_metadata(self): + """Test: Output metadata (shape, type, name) is preserved.""" + input_value = ir.val( + "my_input", shape=ir.Shape([5, 10]), type=ir.TensorType(ir.DataType.INT64) + ) + input_value.doc_string = "Test doc string" + input_value.metadata_props["custom_key"] = "custom_value" + + graph = ir.Graph( + inputs=[input_value], + outputs=[input_value], + nodes=[], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_fix.IdentityFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify metadata is preserved + output = result.model.graph.outputs[0] + self.assertEqual(output.name, "my_input") + self.assertEqual(output.shape, ir.Shape([5, 10])) + self.assertEqual(output.type, ir.TensorType(ir.DataType.INT64)) + self.assertEqual(output.doc_string, "Test doc string") + self.assertEqual(output.metadata_props.get("custom_key"), "custom_value") + + def test_subgraph_with_direct_input_output(self): + """Test: Add Identity in subgraphs (e.g., in If node).""" + # Create main graph input + main_input = ir.val( + "main_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create a condition input for If + condition = ir.val("condition", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.BOOL)) + + # Create then_branch subgraph with direct input->output + then_input = ir.val( + "then_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + then_branch = ir.Graph( + inputs=[then_input], + outputs=[then_input], # Direct input->output + nodes=[], + name="then_branch", + ) + + # Create else_branch subgraph with a node + else_input = ir.val( + "else_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + else_add = ir.Node("", "Add", inputs=[else_input, else_input]) + else_add.outputs[0].name = "else_output" + else_add.outputs[0].shape = else_input.shape + else_add.outputs[0].type = else_input.type + else_branch = ir.Graph( + inputs=[else_input], + outputs=[else_add.outputs[0]], + nodes=[else_add], + name="else_branch", + ) + + # Create If node with subgraphs + if_node = ir.Node("", "If", inputs=[condition]) + if_node.attributes["then_branch"] = ir.AttrGraph("then_branch", then_branch) + if_node.attributes["else_branch"] = ir.AttrGraph("else_branch", else_branch) + if_node.outputs[0].name = "if_output" + if_node.outputs[0].shape = main_input.shape + if_node.outputs[0].type = main_input.type + + # Create main graph + main_graph = ir.Graph( + inputs=[main_input, condition], + outputs=[if_node.outputs[0]], + nodes=[if_node], + name="main_graph", + ) + + model = ir.Model(main_graph, ir_version=10) + + # Run the pass + pass_instance = identity_fix.IdentityFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify Identity was added in then_branch + if_node = list(result.model.graph)[0] + then_branch_after = if_node.attributes["then_branch"].value + then_nodes = list(then_branch_after) + self.assertEqual(len(then_nodes), 1) + self.assertEqual(then_nodes[0].op_type, "Identity") + + # Verify else_branch was not modified + else_branch_after = if_node.attributes["else_branch"].value + else_nodes = list(else_branch_after) + self.assertEqual(len(else_nodes), 1) + self.assertEqual(else_nodes[0].op_type, "Add") + + def test_function_with_direct_input_output(self): + """Test: Add Identity in functions.""" + # Create function with direct input->output + func_input = ir.val( + "func_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + func_graph = ir.Graph( + inputs=[func_input], + outputs=[func_input], # Direct input->output + nodes=[], + name="test_function_graph", + ) + + function = ir.Function( + domain="test_domain", + name="test_function", + graph=func_graph, + attributes=[], + ) + + # Create main graph that calls the function + main_input = ir.val( + "main_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + call_node = ir.Node("test_domain", "test_function", inputs=[main_input]) + call_node.outputs[0].name = "main_output" + call_node.outputs[0].shape = main_input.shape + call_node.outputs[0].type = main_input.type + + main_graph = ir.Graph( + inputs=[main_input], + outputs=[call_node.outputs[0]], + nodes=[call_node], + name="main_graph", + ) + + model = ir.Model(main_graph, ir_version=10, functions=[function]) + + # Run the pass + pass_instance = identity_fix.IdentityFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify Identity was added in the function + func_after = result.model.functions[function.identifier()] + func_nodes = list(func_after) + self.assertEqual(len(func_nodes), 1) + self.assertEqual(func_nodes[0].op_type, "Identity") + + # Verify the function output is now the Identity node's output + identity_node = func_nodes[0] + self.assertIs(func_after.outputs[0], identity_node.outputs[0]) + self.assertIs(identity_node.inputs[0], func_input) + + def test_same_input_used_multiple_times_as_output(self): + """Test: Same input used multiple times as output.""" + input_value = ir.val( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + graph = ir.Graph( + inputs=[input_value], + outputs=[input_value, input_value], # Same input used twice + nodes=[], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = identity_fix.IdentityFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify two Identity nodes were added + nodes = list(result.model.graph) + self.assertEqual(len(nodes), 2) + self.assertTrue(all(n.op_type == "Identity" for n in nodes)) + + # Verify both use the same input + for node in nodes: + self.assertIs(node.inputs[0], input_value) + + # Verify outputs are now different Identity nodes' outputs + self.assertEqual(len(result.model.graph.outputs), 2) + self.assertIsNot( + result.model.graph.outputs[0], result.model.graph.outputs[1] + ) + + def test_nested_subgraphs(self): + """Test: Handle nested subgraphs (subgraph within subgraph).""" + # Create innermost graph with direct input->output + inner_input = ir.val( + "inner_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + inner_graph = ir.Graph( + inputs=[inner_input], + outputs=[inner_input], # Direct input->output + nodes=[], + name="inner_graph", + ) + + # Create middle graph with an If node containing the inner graph + middle_input = ir.val( + "middle_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + middle_condition = ir.val( + "middle_condition", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.BOOL) + ) + + middle_if = ir.Node("", "If", inputs=[middle_condition]) + middle_if.attributes["then_branch"] = ir.AttrGraph("then_branch", inner_graph) + middle_if.attributes["else_branch"] = ir.AttrGraph("else_branch", inner_graph) + middle_if.outputs[0].name = "middle_output" + middle_if.outputs[0].shape = middle_input.shape + middle_if.outputs[0].type = middle_input.type + + middle_graph = ir.Graph( + inputs=[middle_input, middle_condition], + outputs=[middle_if.outputs[0]], + nodes=[middle_if], + name="middle_graph", + ) + + # Create outer graph with an If node containing the middle graph + outer_input = ir.val( + "outer_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + outer_condition = ir.val( + "outer_condition", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.BOOL) + ) + + outer_if = ir.Node("", "If", inputs=[outer_condition]) + outer_if.attributes["then_branch"] = ir.AttrGraph("then_branch", middle_graph) + outer_if.attributes["else_branch"] = ir.AttrGraph("else_branch", middle_graph) + outer_if.outputs[0].name = "outer_output" + outer_if.outputs[0].shape = outer_input.shape + outer_if.outputs[0].type = outer_input.type + + main_graph = ir.Graph( + inputs=[outer_input, outer_condition], + outputs=[outer_if.outputs[0]], + nodes=[outer_if], + name="main_graph", + ) + + model = ir.Model(main_graph, ir_version=10) + + # Run the pass + pass_instance = identity_fix.IdentityFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Navigate to innermost graphs and verify Identity nodes were added + outer_if = list(result.model.graph)[0] + outer_then = outer_if.attributes["then_branch"].value + middle_if_node = list(outer_then)[0] + inner_then = middle_if_node.attributes["then_branch"].value + + inner_nodes = list(inner_then) + self.assertEqual(len(inner_nodes), 1) + self.assertEqual(inner_nodes[0].op_type, "Identity") + + +if __name__ == "__main__": + unittest.main() From e848c91d8c8887d9e2091edb51301316f70e7b65 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:36:56 +0000 Subject: [PATCH 03/16] Fix linting issues in identity_fix pass Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- src/onnx_ir/passes/common/identity_fix.py | 8 ++++---- src/onnx_ir/passes/common/identity_fix_test.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/onnx_ir/passes/common/identity_fix.py b/src/onnx_ir/passes/common/identity_fix.py index 662b4af8..477d697e 100644 --- a/src/onnx_ir/passes/common/identity_fix.py +++ b/src/onnx_ir/passes/common/identity_fix.py @@ -64,7 +64,7 @@ def _process_graph(self, graph_like: ir.Graph | ir.Function) -> bool: # Create an Identity node identity_node = ir.Node("", "Identity", inputs=[output]) identity_output = identity_node.outputs[0] - + # Copy metadata from the original output identity_output.name = output.name identity_output.shape = output.shape @@ -72,15 +72,15 @@ def _process_graph(self, graph_like: ir.Graph | ir.Function) -> bool: if output.metadata_props: identity_output.metadata_props.update(output.metadata_props) identity_output.doc_string = output.doc_string - + # Add the node to the graph graph_like.append(identity_node) - + # Replace the output with the Identity node's output # Find the index of the output in the graph outputs output_index = graph_like.outputs.index(output) graph_like.outputs[output_index] = identity_output - + logger.debug( "Added Identity node for graph input '%s' used as output", output.name ) diff --git a/src/onnx_ir/passes/common/identity_fix_test.py b/src/onnx_ir/passes/common/identity_fix_test.py index 5ea47575..2779613d 100644 --- a/src/onnx_ir/passes/common/identity_fix_test.py +++ b/src/onnx_ir/passes/common/identity_fix_test.py @@ -298,7 +298,9 @@ def test_subgraph_with_direct_input_output(self): ) # Create a condition input for If - condition = ir.val("condition", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.BOOL)) + condition = ir.val( + "condition", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.BOOL) + ) # Create then_branch subgraph with direct input->output then_input = ir.val( @@ -352,7 +354,7 @@ def test_subgraph_with_direct_input_output(self): self.assertTrue(result.modified) # Verify Identity was added in then_branch - if_node = list(result.model.graph)[0] + if_node = next(iter(result.model.graph)) then_branch_after = if_node.attributes["then_branch"].value then_nodes = list(then_branch_after) self.assertEqual(len(then_nodes), 1) @@ -455,9 +457,7 @@ def test_same_input_used_multiple_times_as_output(self): # Verify outputs are now different Identity nodes' outputs self.assertEqual(len(result.model.graph.outputs), 2) - self.assertIsNot( - result.model.graph.outputs[0], result.model.graph.outputs[1] - ) + self.assertIsNot(result.model.graph.outputs[0], result.model.graph.outputs[1]) def test_nested_subgraphs(self): """Test: Handle nested subgraphs (subgraph within subgraph).""" @@ -526,11 +526,11 @@ def test_nested_subgraphs(self): self.assertTrue(result.modified) # Navigate to innermost graphs and verify Identity nodes were added - outer_if = list(result.model.graph)[0] + outer_if = next(iter(result.model.graph)) outer_then = outer_if.attributes["then_branch"].value - middle_if_node = list(outer_then)[0] + middle_if_node = next(iter(outer_then)) inner_then = middle_if_node.attributes["then_branch"].value - + inner_nodes = list(inner_then) self.assertEqual(len(inner_nodes), 1) self.assertEqual(inner_nodes[0].op_type, "Identity") From cec1a44f192121de50ce4020811047a94ca7dffe Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 11 Dec 2025 21:40:41 +0000 Subject: [PATCH 04/16] Address code review feedback - simplify metadata_props update Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- src/onnx_ir/passes/common/identity_fix.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/onnx_ir/passes/common/identity_fix.py b/src/onnx_ir/passes/common/identity_fix.py index 477d697e..6e66827c 100644 --- a/src/onnx_ir/passes/common/identity_fix.py +++ b/src/onnx_ir/passes/common/identity_fix.py @@ -69,8 +69,7 @@ def _process_graph(self, graph_like: ir.Graph | ir.Function) -> bool: identity_output.name = output.name identity_output.shape = output.shape identity_output.type = output.type - if output.metadata_props: - identity_output.metadata_props.update(output.metadata_props) + identity_output.metadata_props.update(output.metadata_props) identity_output.doc_string = output.doc_string # Add the node to the graph From 371c1dc33e333743596fae38d6467f8a3d0768e2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 15:45:27 -0800 Subject: [PATCH 05/16] Update Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/__init__.py | 8 +- src/onnx_ir/passes/common/identity_fix.py | 100 ------------------ src/onnx_ir/passes/common/output_fix.py | 87 +++++++++++++++ ...dentity_fix_test.py => output_fix_test.py} | 30 +++--- 4 files changed, 106 insertions(+), 119 deletions(-) delete mode 100644 src/onnx_ir/passes/common/identity_fix.py create mode 100644 src/onnx_ir/passes/common/output_fix.py rename src/onnx_ir/passes/common/{identity_fix_test.py => output_fix_test.py} (95%) diff --git a/src/onnx_ir/passes/common/__init__.py b/src/onnx_ir/passes/common/__init__.py index cd436dec..78f819be 100644 --- a/src/onnx_ir/passes/common/__init__.py +++ b/src/onnx_ir/passes/common/__init__.py @@ -9,7 +9,7 @@ "DeduplicateHashedInitializersPass", "DeduplicateInitializersPass", "IdentityEliminationPass", - "IdentityFixPass", + "OutputFixPass", "InlinePass", "LiftConstantsToInitializersPass", "LiftSubgraphInitializersToMainGraphPass", @@ -37,9 +37,6 @@ from onnx_ir.passes.common.identity_elimination import ( IdentityEliminationPass, ) -from onnx_ir.passes.common.identity_fix import ( - IdentityFixPass, -) from onnx_ir.passes.common.initializer_deduplication import ( DeduplicateHashedInitializersPass, DeduplicateInitializersPass, @@ -47,6 +44,9 @@ from onnx_ir.passes.common.inliner import InlinePass from onnx_ir.passes.common.naming import NameFixPass from onnx_ir.passes.common.onnx_checker import CheckerPass +from onnx_ir.passes.common.output_fix import ( + OutputFixPass, +) from onnx_ir.passes.common.shape_inference import ShapeInferencePass from onnx_ir.passes.common.topological_sort import TopologicalSortPass from onnx_ir.passes.common.unused_removal import ( diff --git a/src/onnx_ir/passes/common/identity_fix.py b/src/onnx_ir/passes/common/identity_fix.py deleted file mode 100644 index 6e66827c..00000000 --- a/src/onnx_ir/passes/common/identity_fix.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) ONNX Project Contributors -# SPDX-License-Identifier: Apache-2.0 -"""Identity fix pass for adding Identity nodes when graph inputs are directly used as outputs.""" - -from __future__ import annotations - -__all__ = [ - "IdentityFixPass", -] - -import logging - -import onnx_ir as ir - -logger = logging.getLogger(__name__) - - -class IdentityFixPass(ir.passes.InPlacePass): - """Pass for adding Identity nodes when graph inputs are directly used as outputs. - - This pass adds Identity nodes according to the following rule: - - If a graph input is directly used as a graph output (without any intermediate nodes), - insert an Identity node between them. This turns an invalid ONNX graph into a valid one. - - Example transformation: - Before: input -> (direct connection) -> output - After: input -> Identity -> output - - This is required because ONNX specification does not allow a graph input to be - directly used as a graph output without any processing nodes in between. - """ - - def call(self, model: ir.Model) -> ir.passes.PassResult: - """Main entry point for the identity fix pass.""" - modified = False - - # Process the main graph - if self._process_graph(model.graph): - modified = True - - # Process functions - for function in model.functions.values(): - if self._process_graph(function): - modified = True - - if modified: - logger.info("Identity fix pass modified the model") - - return ir.passes.PassResult(model, modified=modified) - - def _process_graph(self, graph_like: ir.Graph | ir.Function) -> bool: - """Process a single graph or function, returning True if modified.""" - modified = False - - # Check each output to see if it's directly a graph input - outputs_to_fix = [] - for output in graph_like.outputs: - if output.is_graph_input(): - outputs_to_fix.append(output) - - # Add Identity nodes for each output that needs fixing - for output in outputs_to_fix: - # Create an Identity node - identity_node = ir.Node("", "Identity", inputs=[output]) - identity_output = identity_node.outputs[0] - - # Copy metadata from the original output - identity_output.name = output.name - identity_output.shape = output.shape - identity_output.type = output.type - identity_output.metadata_props.update(output.metadata_props) - identity_output.doc_string = output.doc_string - - # Add the node to the graph - graph_like.append(identity_node) - - # Replace the output with the Identity node's output - # Find the index of the output in the graph outputs - output_index = graph_like.outputs.index(output) - graph_like.outputs[output_index] = identity_output - - logger.debug( - "Added Identity node for graph input '%s' used as output", output.name - ) - modified = True - - # Process subgraphs in nodes - for node in graph_like: - for attr in node.attributes.values(): - if isinstance(attr, ir.Attr): - if attr.type == ir.AttributeType.GRAPH and attr.value is not None: - if self._process_graph(attr.value): - modified = True - elif attr.type == ir.AttributeType.GRAPHS and attr.value is not None: - for subgraph in attr.value: - if subgraph is not None and self._process_graph(subgraph): - modified = True - - return modified diff --git a/src/onnx_ir/passes/common/output_fix.py b/src/onnx_ir/passes/common/output_fix.py new file mode 100644 index 00000000..76d04738 --- /dev/null +++ b/src/onnx_ir/passes/common/output_fix.py @@ -0,0 +1,87 @@ +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 +"""Identity fix pass for adding Identity nodes when graph inputs are directly used as outputs.""" + +from __future__ import annotations + +__all__ = [ + "OutputFixPass", +] + +import logging + +import onnx_ir as ir + +logger = logging.getLogger(__name__) + + +class OutputFixPass(ir.passes.InPlacePass): + """Pass for adding Identity nodes when graph inputs are directly used as outputs. + + This pass adds Identity nodes according to the following rule: + + If a graph input is directly used as a graph output (without any intermediate nodes), + insert an Identity node between them. This turns an invalid ONNX graph into a valid one. + + Example transformation: + Before: input -> (direct connection) -> output + After: input -> Identity -> output + + This is required because ONNX specification does not allow a graph input to be + directly used as a graph output without any processing nodes in between. + """ + + def call(self, model: ir.Model) -> ir.passes.PassResult: + """Main entry point for the identity fix pass.""" + modified = False + + # Process the main graph + if self._process_graph(model.graph): + modified = True + + # Process functions + for function in model.functions.values(): + if self._process_graph(function): + modified = True + + return ir.passes.PassResult(model, modified=modified) + + def _process_graph(self, graph_like: ir.Graph | ir.Function) -> bool: + """Process a single graph or function, returning True if modified.""" + modified = False + + for graph in (graph_like, *graph_like.subgraphs()): + + # Check each output to see if it's directly a graph input + outputs_to_fix: list[tuple[ir.Value, int]] = [] + for i, output in enumerate(graph.outputs): + if output.is_graph_input(): + outputs_to_fix.append((output, i)) + + # Add Identity nodes for each output that needs fixing + for output, index in outputs_to_fix: + # Create an Identity node + identity_node = ir.node("Identity", inputs=[output]) + identity_output = identity_node.outputs[0] + + # Copy metadata from the original output + identity_output.name = output.name + identity_output.shape = output.shape + identity_output.type = output.type + identity_output.metadata_props.update(output.metadata_props) + identity_output.doc_string = output.doc_string + + # Create a unique name for the old output to avoid name conflicts + # TODO: Use a better unique naming strategy if needed + output.name = f"{output.name}_orig" + + # Add the node to the graph + graph.append(identity_node) + graph.outputs[index] = identity_output + + logger.debug( + "Added Identity node for graph input '%s' used as output", output + ) + modified = True + + return modified diff --git a/src/onnx_ir/passes/common/identity_fix_test.py b/src/onnx_ir/passes/common/output_fix_test.py similarity index 95% rename from src/onnx_ir/passes/common/identity_fix_test.py rename to src/onnx_ir/passes/common/output_fix_test.py index 2779613d..959bbd43 100644 --- a/src/onnx_ir/passes/common/identity_fix_test.py +++ b/src/onnx_ir/passes/common/output_fix_test.py @@ -7,11 +7,11 @@ import unittest import onnx_ir as ir -from onnx_ir.passes.common import identity_fix +from onnx_ir.passes.common import output_fix -class TestIdentityFixPass(unittest.TestCase): - """Test cases for IdentityFixPass.""" +class TestOutputFixPass(unittest.TestCase): + """Test cases for OutputFixPass.""" def test_add_identity_when_input_is_direct_output(self): """Test: Add Identity node when graph input is directly used as output.""" @@ -30,7 +30,7 @@ def test_add_identity_when_input_is_direct_output(self): model = ir.Model(graph, ir_version=10) # Run the pass - pass_instance = identity_fix.IdentityFixPass() + pass_instance = output_fix.OutputFixPass() result = pass_instance(model) # Verify the pass was applied @@ -74,7 +74,7 @@ def test_no_modification_when_identity_exists(self): model = ir.Model(graph, ir_version=10) # Run the pass - pass_instance = identity_fix.IdentityFixPass() + pass_instance = output_fix.OutputFixPass() result = pass_instance(model) # Verify the pass did NOT modify the model @@ -107,7 +107,7 @@ def test_no_modification_when_node_exists_between_input_and_output(self): model = ir.Model(graph, ir_version=10) # Run the pass - pass_instance = identity_fix.IdentityFixPass() + pass_instance = output_fix.OutputFixPass() result = pass_instance(model) # Verify the pass did NOT modify the model @@ -144,7 +144,7 @@ def test_multiple_inputs_one_direct_output(self): model = ir.Model(graph, ir_version=10) # Run the pass - pass_instance = identity_fix.IdentityFixPass() + pass_instance = output_fix.OutputFixPass() result = pass_instance(model) # Verify the pass was applied @@ -188,7 +188,7 @@ def test_multiple_direct_outputs(self): model = ir.Model(graph, ir_version=10) # Run the pass - pass_instance = identity_fix.IdentityFixPass() + pass_instance = output_fix.OutputFixPass() result = pass_instance(model) # Verify the pass was applied @@ -216,7 +216,7 @@ def test_empty_graph(self): model = ir.Model(graph, ir_version=10) # Run the pass - pass_instance = identity_fix.IdentityFixPass() + pass_instance = output_fix.OutputFixPass() result = pass_instance(model) # Verify the pass did not modify the model @@ -247,7 +247,7 @@ def test_graph_with_no_direct_input_output(self): model = ir.Model(graph, ir_version=10) # Run the pass - pass_instance = identity_fix.IdentityFixPass() + pass_instance = output_fix.OutputFixPass() result = pass_instance(model) # Verify the pass did not modify the model @@ -276,7 +276,7 @@ def test_preserve_output_metadata(self): model = ir.Model(graph, ir_version=10) # Run the pass - pass_instance = identity_fix.IdentityFixPass() + pass_instance = output_fix.OutputFixPass() result = pass_instance(model) # Verify the pass was applied @@ -347,7 +347,7 @@ def test_subgraph_with_direct_input_output(self): model = ir.Model(main_graph, ir_version=10) # Run the pass - pass_instance = identity_fix.IdentityFixPass() + pass_instance = output_fix.OutputFixPass() result = pass_instance(model) # Verify the pass was applied @@ -407,7 +407,7 @@ def test_function_with_direct_input_output(self): model = ir.Model(main_graph, ir_version=10, functions=[function]) # Run the pass - pass_instance = identity_fix.IdentityFixPass() + pass_instance = output_fix.OutputFixPass() result = pass_instance(model) # Verify the pass was applied @@ -440,7 +440,7 @@ def test_same_input_used_multiple_times_as_output(self): model = ir.Model(graph, ir_version=10) # Run the pass - pass_instance = identity_fix.IdentityFixPass() + pass_instance = output_fix.OutputFixPass() result = pass_instance(model) # Verify the pass was applied @@ -519,7 +519,7 @@ def test_nested_subgraphs(self): model = ir.Model(main_graph, ir_version=10) # Run the pass - pass_instance = identity_fix.IdentityFixPass() + pass_instance = output_fix.OutputFixPass() result = pass_instance(model) # Verify the pass was applied From d2478856c68b1739d49c200b4ca01e8b7af1eefc Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 15:45:37 -0800 Subject: [PATCH 06/16] lint Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/output_fix.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/onnx_ir/passes/common/output_fix.py b/src/onnx_ir/passes/common/output_fix.py index 76d04738..1175f8f7 100644 --- a/src/onnx_ir/passes/common/output_fix.py +++ b/src/onnx_ir/passes/common/output_fix.py @@ -51,7 +51,6 @@ def _process_graph(self, graph_like: ir.Graph | ir.Function) -> bool: modified = False for graph in (graph_like, *graph_like.subgraphs()): - # Check each output to see if it's directly a graph input outputs_to_fix: list[tuple[ir.Value, int]] = [] for i, output in enumerate(graph.outputs): @@ -79,9 +78,7 @@ def _process_graph(self, graph_like: ir.Graph | ir.Function) -> bool: graph.append(identity_node) graph.outputs[index] = identity_output - logger.debug( - "Added Identity node for graph input '%s' used as output", output - ) + logger.debug("Added Identity node for graph input '%s' used as output", output) modified = True return modified From 4c90386ccb02ea82793e8fb43a00ba25f6c489d6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 15:52:46 -0800 Subject: [PATCH 07/16] Tests Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/output_fix_test.py | 262 +++++++++++++++++++ 1 file changed, 262 insertions(+) diff --git a/src/onnx_ir/passes/common/output_fix_test.py b/src/onnx_ir/passes/common/output_fix_test.py index 959bbd43..e1a6ac80 100644 --- a/src/onnx_ir/passes/common/output_fix_test.py +++ b/src/onnx_ir/passes/common/output_fix_test.py @@ -535,6 +535,268 @@ def test_nested_subgraphs(self): self.assertEqual(len(inner_nodes), 1) self.assertEqual(inner_nodes[0].op_type, "Identity") + def test_pass_is_idempotent(self): + """Test: Running the pass twice should not modify the model again.""" + input_value = ir.val( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + graph = ir.Graph( + inputs=[input_value], + outputs=[input_value], + nodes=[], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass first time + pass_instance = output_fix.OutputFixPass() + result1 = pass_instance(model) + self.assertTrue(result1.modified) + + # Run the pass second time on the result + result2 = pass_instance(result1.model) + self.assertFalse(result2.modified) + + # Verify structure remains the same + nodes = list(result2.model.graph) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0].op_type, "Identity") + + def test_output_order_preserved(self): + """Test: The order of outputs is preserved after transformation.""" + input1 = ir.val( + "input1", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + input2 = ir.val( + "input2", shape=ir.Shape([3, 3]), type=ir.TensorType(ir.DataType.INT32) + ) + input3 = ir.val( + "input3", shape=ir.Shape([4, 4]), type=ir.TensorType(ir.DataType.INT64) + ) + + # Create a processing node for input2 + add_node = ir.Node("", "Add", inputs=[input2, input2]) + add_node.outputs[0].name = "processed_input2" + add_node.outputs[0].shape = input2.shape + add_node.outputs[0].type = input2.type + + graph = ir.Graph( + inputs=[input1, input2, input3], + outputs=[input1, add_node.outputs[0], input3], # input1 and input3 are direct + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = output_fix.OutputFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify output order is preserved (by checking names) + output_names = [output.name for output in result.model.graph.outputs] + self.assertEqual(output_names, ["input1", "processed_input2", "input3"]) + + # Verify first and third outputs are now Identity outputs + self.assertEqual(result.model.graph.outputs[0].producer().op_type, "Identity") + self.assertEqual(result.model.graph.outputs[1].producer().op_type, "Add") + self.assertEqual(result.model.graph.outputs[2].producer().op_type, "Identity") + + def test_name_collision_avoided(self): + """Test: Verify that renaming original inputs doesn't cause name collisions.""" + input_value = ir.val( + "my_value", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create a node that produces a value with name that could collide + add_node = ir.Node("", "Add", inputs=[input_value, input_value]) + add_node.outputs[0].name = "my_value_orig" # This name could collide + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[input_value], # Direct input as output + nodes=[add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = output_fix.OutputFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify no assertion errors or issues occurred + # (The implementation should handle this gracefully) + self.assertEqual(len(list(result.model.graph)), 2) # Add + Identity + + def test_mixed_outputs_with_initializer_and_input(self): + """Test: Handle case where outputs include both inputs and computed values.""" + input_value = ir.val( + "input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Create a Constant node (not from input) + const_node = ir.Node("", "Constant", inputs=[]) + const_node.outputs[0].name = "constant" + const_node.outputs[0].shape = ir.Shape([2, 2]) + const_node.outputs[0].type = ir.TensorType(ir.DataType.FLOAT) + + # Create an Add node + add_node = ir.Node("", "Add", inputs=[input_value, const_node.outputs[0]]) + add_node.outputs[0].name = "sum" + add_node.outputs[0].shape = input_value.shape + add_node.outputs[0].type = input_value.type + + graph = ir.Graph( + inputs=[input_value], + outputs=[ + input_value, + const_node.outputs[0], + add_node.outputs[0], + ], # Mixed outputs + nodes=[const_node, add_node], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = output_fix.OutputFixPass() + result = pass_instance(model) + + # Verify the pass was applied (only for the direct input) + self.assertTrue(result.modified) + + # Verify only one Identity was added (for the direct input) + identity_nodes = [n for n in result.model.graph if n.op_type == "Identity"] + self.assertEqual(len(identity_nodes), 1) + + # Verify the first output is now an Identity node output + self.assertEqual(result.model.graph.outputs[0].producer().op_type, "Identity") + # The second output should still be from Constant + self.assertEqual(result.model.graph.outputs[1].producer().op_type, "Constant") + # The third output should still be from Add + self.assertEqual(result.model.graph.outputs[2].producer().op_type, "Add") + + def test_shape_metadata_types_preserved(self): + """Test: Various shape types (dynamic, symbolic) are preserved correctly.""" + # Input with concrete shape + input_concrete = ir.val( + "concrete", + shape=ir.Shape([2, 3, 4]), + type=ir.TensorType(ir.DataType.FLOAT), + ) + + # Input with dynamic shape + input_dynamic = ir.val( + "dynamic", + shape=ir.Shape([None, 3, None]), + type=ir.TensorType(ir.DataType.FLOAT), + ) + + # Input with symbolic shape + input_symbolic = ir.val( + "symbolic", + shape=ir.Shape(["batch", "seq", "hidden"]), + type=ir.TensorType(ir.DataType.FLOAT), + ) + + graph = ir.Graph( + inputs=[input_concrete, input_dynamic, input_symbolic], + outputs=[input_concrete, input_dynamic, input_symbolic], + nodes=[], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = output_fix.OutputFixPass() + result = pass_instance(model) + + # Verify shapes are preserved + outputs = result.model.graph.outputs + self.assertEqual(outputs[0].shape, ir.Shape([2, 3, 4])) + self.assertEqual(outputs[1].shape, ir.Shape([None, 3, None])) + self.assertEqual(outputs[2].shape, ir.Shape(["batch", "seq", "hidden"])) + + def test_loop_subgraph_with_direct_input_output(self): + """Test: Add Identity in Loop node subgraphs.""" + # Create loop body with direct input->output + iter_num = ir.val( + "iter_num", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.INT64) + ) + cond_in = ir.val("cond_in", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.BOOL)) + loop_var = ir.val( + "loop_var", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + + # Loop body with direct input->output for both cond_in and loop_var + loop_body = ir.Graph( + inputs=[iter_num, cond_in, loop_var], + outputs=[cond_in, loop_var], # Both directly passed through + nodes=[], + name="loop_body", + ) + + # Create main graph with Loop node + main_input = ir.val( + "main_input", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + max_trip_count = ir.val( + "max_trip_count", + shape=ir.Shape([]), + type=ir.TensorType(ir.DataType.INT64), + ) + condition = ir.val( + "condition", shape=ir.Shape([]), type=ir.TensorType(ir.DataType.BOOL) + ) + + loop_node = ir.Node("", "Loop", inputs=[max_trip_count, condition, main_input]) + loop_node.attributes["body"] = ir.AttrGraph("body", loop_body) + loop_node.outputs[0].name = "loop_output" + loop_node.outputs[0].shape = main_input.shape + loop_node.outputs[0].type = main_input.type + + main_graph = ir.Graph( + inputs=[main_input, max_trip_count, condition], + outputs=[loop_node.outputs[0]], + nodes=[loop_node], + name="main_graph", + ) + + model = ir.Model(main_graph, ir_version=10) + + # Run the pass + pass_instance = output_fix.OutputFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify Identity was added in loop body for direct pass-throughs + loop_node_after = next(iter(result.model.graph)) + loop_body_after = loop_node_after.attributes["body"].value + loop_body_nodes = list(loop_body_after) + + # Should have two Identity nodes (one for cond_in, one for loop_var) + identity_nodes = [n for n in loop_body_nodes if n.op_type == "Identity"] + self.assertEqual(len(identity_nodes), 2) + + # Verify both outputs are now from Identity nodes + self.assertEqual(loop_body_after.outputs[0].producer().op_type, "Identity") + self.assertEqual(loop_body_after.outputs[1].producer().op_type, "Identity") + if __name__ == "__main__": unittest.main() From c4b8c79aa2ed36abcd5c03fe0dd4e5e24c826624 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 15:54:30 -0800 Subject: [PATCH 08/16] Update src/onnx_ir/passes/common/output_fix_test.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/output_fix_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onnx_ir/passes/common/output_fix_test.py b/src/onnx_ir/passes/common/output_fix_test.py index e1a6ac80..7b977163 100644 --- a/src/onnx_ir/passes/common/output_fix_test.py +++ b/src/onnx_ir/passes/common/output_fix_test.py @@ -1,6 +1,6 @@ # Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 -"""Tests for the identity fix pass.""" +"""Tests for the output fix pass.""" from __future__ import annotations From ed5baf71a794711674fce9acd583f9f48f130f59 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 15:54:37 -0800 Subject: [PATCH 09/16] Update src/onnx_ir/passes/common/output_fix.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/output_fix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onnx_ir/passes/common/output_fix.py b/src/onnx_ir/passes/common/output_fix.py index 1175f8f7..25d4d26e 100644 --- a/src/onnx_ir/passes/common/output_fix.py +++ b/src/onnx_ir/passes/common/output_fix.py @@ -32,7 +32,7 @@ class OutputFixPass(ir.passes.InPlacePass): """ def call(self, model: ir.Model) -> ir.passes.PassResult: - """Main entry point for the identity fix pass.""" + """Main entry point for the output fix pass.""" modified = False # Process the main graph From 74c56ffc3dafb3bf09e5327c26855f3ad972210d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 15:54:45 -0800 Subject: [PATCH 10/16] Update src/onnx_ir/passes/common/output_fix.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/output_fix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onnx_ir/passes/common/output_fix.py b/src/onnx_ir/passes/common/output_fix.py index 25d4d26e..ac097bd6 100644 --- a/src/onnx_ir/passes/common/output_fix.py +++ b/src/onnx_ir/passes/common/output_fix.py @@ -1,6 +1,6 @@ # Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 -"""Identity fix pass for adding Identity nodes when graph inputs are directly used as outputs.""" +"""Output fix pass for adding Identity nodes when graph inputs are directly used as outputs.""" from __future__ import annotations From 5576b95f8754db96f56c188d021b373aad0c702d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 16:32:49 -0800 Subject: [PATCH 11/16] Update Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/output_fix.py | 96 +++++++++++++++----- src/onnx_ir/passes/common/output_fix_test.py | 61 +++++++++++++ 2 files changed, 133 insertions(+), 24 deletions(-) diff --git a/src/onnx_ir/passes/common/output_fix.py b/src/onnx_ir/passes/common/output_fix.py index ac097bd6..6b2845f1 100644 --- a/src/onnx_ir/passes/common/output_fix.py +++ b/src/onnx_ir/passes/common/output_fix.py @@ -8,6 +8,7 @@ "OutputFixPass", ] +import collections import logging import onnx_ir as ir @@ -18,10 +19,15 @@ class OutputFixPass(ir.passes.InPlacePass): """Pass for adding Identity nodes when graph inputs are directly used as outputs. - This pass adds Identity nodes according to the following rule: + This pass adds Identity nodes according to the following rules: If a graph input is directly used as a graph output (without any intermediate nodes), - insert an Identity node between them. This turns an invalid ONNX graph into a valid one. + insert an Identity node between them. + + If a value is used multiple times as graph outputs, insert Identity nodes for each usage + except the first one. + + This turns an invalid ONNX graph into a valid one. Example transformation: Before: input -> (direct connection) -> output @@ -36,49 +42,91 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: modified = False # Process the main graph - if self._process_graph(model.graph): + if _alias_multi_used_outputs(model.graph): + modified = True + if _alias_direct_outputs(model.graph): modified = True # Process functions for function in model.functions.values(): - if self._process_graph(function): + if _alias_multi_used_outputs(function): + modified = True + if _alias_direct_outputs(function): modified = True return ir.passes.PassResult(model, modified=modified) - def _process_graph(self, graph_like: ir.Graph | ir.Function) -> bool: - """Process a single graph or function, returning True if modified.""" - modified = False - for graph in (graph_like, *graph_like.subgraphs()): - # Check each output to see if it's directly a graph input - outputs_to_fix: list[tuple[ir.Value, int]] = [] - for i, output in enumerate(graph.outputs): - if output.is_graph_input(): - outputs_to_fix.append((output, i)) +def _alias_multi_used_outputs(graph_like: ir.Graph | ir.Function) -> bool: + """Insert Identity nodes for graph outputs that are used multiple times.""" + modified = False - # Add Identity nodes for each output that needs fixing - for output, index in outputs_to_fix: + for graph in (graph_like, *graph_like.subgraphs()): + # Count usage of each output + output_usages = collections.Counter(graph.outputs) + seen: set[ir.Value] = set() + + # Add Identity nodes for outputs used multiple times + for i, output in enumerate(graph.outputs): + if output not in seen: + # Skip the first occurrence + seen.add(output) + continue + if output_usages[output] > 1: # Create an Identity node identity_node = ir.node("Identity", inputs=[output]) identity_output = identity_node.outputs[0] # Copy metadata from the original output - identity_output.name = output.name + identity_output.name = f"{output.name}_alias_{i}" identity_output.shape = output.shape identity_output.type = output.type identity_output.metadata_props.update(output.metadata_props) identity_output.doc_string = output.doc_string - # Create a unique name for the old output to avoid name conflicts - # TODO: Use a better unique naming strategy if needed - output.name = f"{output.name}_orig" - # Add the node to the graph graph.append(identity_node) - graph.outputs[index] = identity_output - - logger.debug("Added Identity node for graph input '%s' used as output", output) + graph.outputs[i] = identity_output + logger.debug( + "Added Identity node for graph output '%s' used multiple times", output + ) modified = True + return modified + + +def _alias_direct_outputs(graph_like: ir.Graph | ir.Function) -> bool: + """Insert Identity nodes for graph inputs used directly as outputs.""" + modified = False + + for graph in (graph_like, *graph_like.subgraphs()): + # Check each output to see if it's directly a graph input + outputs_to_fix: list[tuple[ir.Value, int]] = [] + for i, output in enumerate(graph.outputs): + if output.is_graph_input(): + outputs_to_fix.append((output, i)) + + # Add Identity nodes for each output that needs fixing + for output, index in outputs_to_fix: + # Create an Identity node + identity_node = ir.node("Identity", inputs=[output]) + identity_output = identity_node.outputs[0] + + # Copy metadata from the original output + identity_output.name = output.name + identity_output.shape = output.shape + identity_output.type = output.type + identity_output.metadata_props.update(output.metadata_props) + identity_output.doc_string = output.doc_string + + # Create a unique name for the old output to avoid name conflicts + # TODO: Use a better unique naming strategy if needed + output.name = f"{output.name}_orig" + + # Add the node to the graph + graph.append(identity_node) + graph.outputs[index] = identity_output + + logger.debug("Added Identity node for graph input '%s' used as output", output) + modified = True - return modified + return modified diff --git a/src/onnx_ir/passes/common/output_fix_test.py b/src/onnx_ir/passes/common/output_fix_test.py index 7b977163..808c6ff1 100644 --- a/src/onnx_ir/passes/common/output_fix_test.py +++ b/src/onnx_ir/passes/common/output_fix_test.py @@ -459,6 +459,67 @@ def test_same_input_used_multiple_times_as_output(self): self.assertEqual(len(result.model.graph.outputs), 2) self.assertIsNot(result.model.graph.outputs[0], result.model.graph.outputs[1]) + # Verify that output names are unique + # First occurrence keeps original name, second gets _alias suffix + output_names = [output.name for output in result.model.graph.outputs] + self.assertEqual(output_names[0], "input") + self.assertEqual(output_names[1], "input_alias_1") + # Ensure names are unique + self.assertEqual(len(set(output_names)), 2) + + def test_mixed_multiple_usage_and_single_usage(self): + """Test: Mix of inputs used once and multiple times as outputs.""" + input1 = ir.val( + "input1", shape=ir.Shape([2, 2]), type=ir.TensorType(ir.DataType.FLOAT) + ) + input2 = ir.val( + "input2", shape=ir.Shape([3, 3]), type=ir.TensorType(ir.DataType.INT32) + ) + input3 = ir.val( + "input3", shape=ir.Shape([4, 4]), type=ir.TensorType(ir.DataType.INT64) + ) + + # Create outputs: input1 (once), input2 (three times), input3 (twice) + graph = ir.Graph( + inputs=[input1, input2, input3], + outputs=[input1, input2, input3, input2, input3, input2], + nodes=[], + name="test_graph", + ) + + model = ir.Model(graph, ir_version=10) + + # Run the pass + pass_instance = output_fix.OutputFixPass() + result = pass_instance(model) + + # Verify the pass was applied + self.assertTrue(result.modified) + + # Verify Identity nodes were added + # All outputs are graph inputs, so _alias_direct_outputs adds Identity for each + # Additionally, multi-used outputs get extra Identity nodes from _alias_multi_used_outputs + # input1: 1 Identity (direct output) + # input2: 1 (direct) + 2 (for 2nd & 3rd occurrence) = 3 Identity nodes + # input3: 1 (direct) + 1 (for 2nd occurrence) = 2 Identity nodes + # Total: 1 + 3 + 2 = 6 Identity nodes + nodes = list(result.model.graph) + identity_nodes = [n for n in nodes if n.op_type == "Identity"] + self.assertEqual(len(identity_nodes), 6) + + # Verify output names + output_names = [output.name for output in result.model.graph.outputs] + # First occurrences keep original names: input1, input2, input3 + # Subsequent occurrences of multi-used values get _alias suffixes + self.assertEqual(output_names[0], "input1") # input1 first occurrence + self.assertEqual(output_names[1], "input2") # input2 first occurrence + self.assertEqual(output_names[2], "input3") # input3 first occurrence + self.assertEqual(output_names[3], "input2_alias_3") # input2 second + self.assertEqual(output_names[4], "input3_alias_4") # input3 second + self.assertEqual(output_names[5], "input2_alias_5") # input2 third + # Verify all names are unique + self.assertEqual(len(set(output_names)), 6) + def test_nested_subgraphs(self): """Test: Handle nested subgraphs (subgraph within subgraph).""" # Create innermost graph with direct input->output From da8c261b7f6a31b59480b937a73c8c001911ac2f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 16:35:57 -0800 Subject: [PATCH 12/16] docs Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/output_fix.py | 31 ++++++++++++++----------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/onnx_ir/passes/common/output_fix.py b/src/onnx_ir/passes/common/output_fix.py index 6b2845f1..4cee43ee 100644 --- a/src/onnx_ir/passes/common/output_fix.py +++ b/src/onnx_ir/passes/common/output_fix.py @@ -17,24 +17,27 @@ class OutputFixPass(ir.passes.InPlacePass): - """Pass for adding Identity nodes when graph inputs are directly used as outputs. + """Pass for adding Identity nodes to fix invalid output configurations. This pass adds Identity nodes according to the following rules: - If a graph input is directly used as a graph output (without any intermediate nodes), - insert an Identity node between them. + - If a graph input is directly used as a graph output (without any intermediate nodes), + insert an Identity node between them. The ONNX specification does not allow a graph + input to be directly used as a graph output without any processing nodes in between. + - If a value is used multiple times as graph outputs, insert Identity nodes for each + duplicate usage (keeping the first usage unchanged). This ensures each output value + is unique, as required by the ONNX specification. - If a value is used multiple times as graph outputs, insert Identity nodes for each usage - except the first one. + This pass processes both the main graph and all subgraphs (e.g., in control flow operators). - This turns an invalid ONNX graph into a valid one. + Example transformations: + Direct input-to-output: + Before: input -> (direct connection) -> output + After: input -> Identity -> output - Example transformation: - Before: input -> (direct connection) -> output - After: input -> Identity -> output - - This is required because ONNX specification does not allow a graph input to be - directly used as a graph output without any processing nodes in between. + Duplicate outputs: + Before: value -> [output1, output2] + After: value -> output1, value -> Identity -> output2 """ def call(self, model: ir.Model) -> ir.passes.PassResult: @@ -78,6 +81,7 @@ def _alias_multi_used_outputs(graph_like: ir.Graph | ir.Function) -> bool: identity_output = identity_node.outputs[0] # Copy metadata from the original output + # TODO: Use a better unique naming strategy if needed identity_output.name = f"{output.name}_alias_{i}" identity_output.shape = output.shape identity_output.type = output.type @@ -112,13 +116,14 @@ def _alias_direct_outputs(graph_like: ir.Graph | ir.Function) -> bool: identity_output = identity_node.outputs[0] # Copy metadata from the original output + # Preserve the original output name identity_output.name = output.name identity_output.shape = output.shape identity_output.type = output.type identity_output.metadata_props.update(output.metadata_props) identity_output.doc_string = output.doc_string - # Create a unique name for the old output to avoid name conflicts + # Create a new name for the old output # TODO: Use a better unique naming strategy if needed output.name = f"{output.name}_orig" From fe5a5744a6c2bd51e3ecb60710ed0f22162d257b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 16:39:21 -0800 Subject: [PATCH 13/16] Apply suggestion from @justinchuby Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/onnx_ir/passes/common/__init__.py b/src/onnx_ir/passes/common/__init__.py index 78f819be..884cf0ef 100644 --- a/src/onnx_ir/passes/common/__init__.py +++ b/src/onnx_ir/passes/common/__init__.py @@ -44,9 +44,7 @@ from onnx_ir.passes.common.inliner import InlinePass from onnx_ir.passes.common.naming import NameFixPass from onnx_ir.passes.common.onnx_checker import CheckerPass -from onnx_ir.passes.common.output_fix import ( - OutputFixPass, -) +from onnx_ir.passes.common.output_fix import OutputFixPass from onnx_ir.passes.common.shape_inference import ShapeInferencePass from onnx_ir.passes.common.topological_sort import TopologicalSortPass from onnx_ir.passes.common.unused_removal import ( From bc9e6507c6cf78a4202731240d69ffded1fa72ed Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 16:42:19 -0800 Subject: [PATCH 14/16] docs Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/output_fix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/onnx_ir/passes/common/output_fix.py b/src/onnx_ir/passes/common/output_fix.py index 4cee43ee..d4431090 100644 --- a/src/onnx_ir/passes/common/output_fix.py +++ b/src/onnx_ir/passes/common/output_fix.py @@ -61,7 +61,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: def _alias_multi_used_outputs(graph_like: ir.Graph | ir.Function) -> bool: - """Insert Identity nodes for graph outputs that are used multiple times.""" + """Insert Identity nodes for values that appear in the graph output list multiple times.""" modified = False for graph in (graph_like, *graph_like.subgraphs()): From 791cdec456910fdd33b6eadb82d5a0e369bceb84 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 12 Dec 2025 16:44:02 -0800 Subject: [PATCH 15/16] lint Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/onnx_ir/passes/common/__init__.py b/src/onnx_ir/passes/common/__init__.py index 884cf0ef..44554802 100644 --- a/src/onnx_ir/passes/common/__init__.py +++ b/src/onnx_ir/passes/common/__init__.py @@ -34,9 +34,7 @@ LiftSubgraphInitializersToMainGraphPass, RemoveInitializersFromInputsPass, ) -from onnx_ir.passes.common.identity_elimination import ( - IdentityEliminationPass, -) +from onnx_ir.passes.common.identity_elimination import IdentityEliminationPass from onnx_ir.passes.common.initializer_deduplication import ( DeduplicateHashedInitializersPass, DeduplicateInitializersPass, From 9e6ab238d28ea9e242db6d4b7b34fbb5a97d11d6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 16 Dec 2025 17:05:47 -0800 Subject: [PATCH 16/16] Update Signed-off-by: Justin Chu --- src/onnx_ir/passes/common/output_fix.py | 50 +++++++++++++------------ 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/onnx_ir/passes/common/output_fix.py b/src/onnx_ir/passes/common/output_fix.py index d4431090..21b698c7 100644 --- a/src/onnx_ir/passes/common/output_fix.py +++ b/src/onnx_ir/passes/common/output_fix.py @@ -1,6 +1,12 @@ # Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 -"""Output fix pass for adding Identity nodes when graph inputs are directly used as outputs.""" +"""Output fix pass for adding Identity nodes. + +- Graph inputs are directly used as outputs (without any intermediate nodes). +- A value is used multiple times as a graph output (ensuring each output is unique). + +This ensures compliance with the ONNX specification for valid output configurations. +""" from __future__ import annotations @@ -8,7 +14,6 @@ "OutputFixPass", ] -import collections import logging import onnx_ir as ir @@ -66,7 +71,6 @@ def _alias_multi_used_outputs(graph_like: ir.Graph | ir.Function) -> bool: for graph in (graph_like, *graph_like.subgraphs()): # Count usage of each output - output_usages = collections.Counter(graph.outputs) seen: set[ir.Value] = set() # Add Identity nodes for outputs used multiple times @@ -75,26 +79,26 @@ def _alias_multi_used_outputs(graph_like: ir.Graph | ir.Function) -> bool: # Skip the first occurrence seen.add(output) continue - if output_usages[output] > 1: - # Create an Identity node - identity_node = ir.node("Identity", inputs=[output]) - identity_output = identity_node.outputs[0] - - # Copy metadata from the original output - # TODO: Use a better unique naming strategy if needed - identity_output.name = f"{output.name}_alias_{i}" - identity_output.shape = output.shape - identity_output.type = output.type - identity_output.metadata_props.update(output.metadata_props) - identity_output.doc_string = output.doc_string - - # Add the node to the graph - graph.append(identity_node) - graph.outputs[i] = identity_output - logger.debug( - "Added Identity node for graph output '%s' used multiple times", output - ) - modified = True + + # Create an Identity node + identity_node = ir.node("Identity", inputs=[output]) + identity_output = identity_node.outputs[0] + + # Copy metadata from the original output + # TODO: Use a better unique naming strategy if needed + identity_output.name = f"{output.name}_alias_{i}" + identity_output.shape = output.shape + identity_output.type = output.type + identity_output.metadata_props.update(output.metadata_props) + identity_output.doc_string = output.doc_string + + # Add the node to the graph + graph.append(identity_node) + graph.outputs[i] = identity_output + logger.debug( + "Added Identity node for graph output '%s' used multiple times", output + ) + modified = True return modified