From 0ec6ed6b169f4ce2d920c090c99e69e8765e2192 Mon Sep 17 00:00:00 2001 From: Dayoung Lee Date: Fri, 1 Aug 2025 17:01:00 +0900 Subject: [PATCH] DRAFT: [serialize] Enable dynamic shape in copy operator --- example_dynamic_onnx.py | 30 ++++++++++ test/modules/op/copy.py | 22 ++++++++ test/pt2_to_circle_test/builder.py | 1 + tico/serialize/circle_serializer.py | 3 +- tico/serialize/operators/op_copy.py | 88 ++++++++++++++--------------- tico/utils/convert.py | 10 +++- tico/utils/validate_args_kwargs.py | 6 ++ 7 files changed, 111 insertions(+), 49 deletions(-) create mode 100644 example_dynamic_onnx.py diff --git a/example_dynamic_onnx.py b/example_dynamic_onnx.py new file mode 100644 index 00000000..0fcb5750 --- /dev/null +++ b/example_dynamic_onnx.py @@ -0,0 +1,30 @@ +import torch +from torch.export import Dim +class SimpleCopyWithBroadcastToDynamicShape(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, dst, src): + dst.copy_(src) + return dst + + def get_example_inputs(self): + return (torch.randn(5, 5), torch.randn(1, 5)), {} + + def get_dynamic_shapes(self): + dim = Dim("dim", min=1, max=128) + dynamic_shapes = { + "dst": {0: dim}, + "src": {}, + } + return dynamic_shapes + +model = SimpleCopyWithBroadcastToDynamicShape() + +ep = torch.export.export( + model, + args=(torch.randn(5, 5), torch.randn(1, 5)), + dynamic_shapes={"dst": {0: Dim("dim", min=1, max=128)}, "src": {}} +) + +breakpoint() diff --git a/test/modules/op/copy.py b/test/modules/op/copy.py index f773636c..c02568b6 100644 --- a/test/modules/op/copy.py +++ b/test/modules/op/copy.py @@ -15,6 +15,8 @@ import torch from test.modules.base import TestModuleBase +from torch.export import Dim +from test.utils.tag import use_onert class SimpleCopy(TestModuleBase): @@ -39,3 +41,23 @@ def forward(self, dst, src): def get_example_inputs(self): return (torch.randn(5, 5), torch.randn(1, 5)), {} + +@use_onert +class SimpleCopyWithBroadcastToDynamicShape(TestModuleBase): + def __init__(self): + super().__init__() + + def forward(self, dst, src): + dst.copy_(src) + return dst + + def get_example_inputs(self): + return (torch.randn(5, 5), torch.randn(1, 5)), {} + + def get_dynamic_shapes(self): + dim = Dim("dim", min=1, max=128) + dynamic_shapes = { + "dst": {0: dim}, + "src": {}, + } + return dynamic_shapes \ No newline at end of file diff --git a/test/pt2_to_circle_test/builder.py b/test/pt2_to_circle_test/builder.py index b52c3ad6..44fff3cc 100644 --- a/test/pt2_to_circle_test/builder.py +++ b/test/pt2_to_circle_test/builder.py @@ -104,6 +104,7 @@ def _run( assert ( self.use_onert ), "Dynamic shapes are only supported with onert runtime. Please set 'use_onert' to True." + dynamic_shapes = self.nnmodule.get_dynamic_shapes() compile_config: Optional[CompileConfigBase] = None if hasattr(self.nnmodule, "get_compile_config"): diff --git a/tico/serialize/circle_serializer.py b/tico/serialize/circle_serializer.py index 5dd697dc..a89afe11 100644 --- a/tico/serialize/circle_serializer.py +++ b/tico/serialize/circle_serializer.py @@ -84,7 +84,7 @@ def build_circle(ep: ExportedProgram) -> bytes: graph.add_output(user_output) logger.debug(f"Registered output: {user_output}") - + # Export operators logger.debug("---------------Export operators--------------") op_codes: Dict[OpCode, int] = {} @@ -146,6 +146,7 @@ def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None: raise RuntimeError( f"Only support dense tensors (node layout: {node_val.layout})" ) + graph.add_tensor_from_node(node) logger.debug(f"call_function: {node.name} tensor exported.") diff --git a/tico/serialize/operators/op_copy.py b/tico/serialize/operators/op_copy.py index 309f6363..cbbacc04 100644 --- a/tico/serialize/operators/op_copy.py +++ b/tico/serialize/operators/op_copy.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Dict, List, Optional, TYPE_CHECKING, Union - +from copy import deepcopy if TYPE_CHECKING: import torch._ops import torch.fx @@ -59,9 +59,8 @@ def check_to_do_broadcast( src: List[int], src_sig: Optional[List[int]], ) -> bool: - assert dst_sig is None - assert src_sig is None - return dst != src + exactly_same = (dst_sig == src_sig) and (dst == src) + return not exactly_same def define_broadcast_to_node( self, @@ -98,54 +97,43 @@ def define_node( self, node: torch.fx.Node, ) -> circle.Operator.OperatorT: - if len(node.args) == 3: - raise NotYetSupportedError("'non_blocking' is not supported yet.") - - assert len(node.args) == 2, len(node.args) - - args = CopyArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + args = CopyArgs(*node.args, **node.kwargs) dst = args.dst src = args.src - # To connect 'dst' to Reshape node in the graph, 'dst' must be converted to Shape op. - dst_tensor: circle.Tensor.TensorT = self.graph.get_tensor(dst) - dst_shape: List[int] = dst_tensor.shape - dst_shape_signature: Optional[List[int]] = dst_tensor.shapeSignature - - if dst_shape_signature is not None: - # TODO: support dynamic shape - raise NotYetSupportedError("Dynamic shape is not supported yet.") - - dst_shape_tensor = torch.as_tensor(dst_shape, dtype=torch.int32) - - dst_shape_shape = [len(dst_shape)] - dst_name: str = dst.name - - shape_output = self.graph.add_tensor_from_scratch( - prefix=f"{dst_name}_shape_output", - shape=dst_shape_shape, - shape_signature=None, - dtype=circle.TensorType.TensorType.INT32, - source_node=node, - ) - - shape_operator = self.define_shape_node([dst], [shape_output]) - self.graph.add_operator(shape_operator) - src_tensor: circle.Tensor.TensorT = self.graph.get_tensor(src) src_shape: List[int] = src_tensor.shape src_shape_signature: Optional[List[int]] = src_tensor.shapeSignature - - if src_shape_signature is not None: - # TODO: support dynamic shape - raise NotYetSupportedError("Dynamic shape is not supported yet.") - + + dst_tensor: circle.Tensor.TensorT = self.graph.get_tensor(dst) + dst_shape: List[int] = dst_tensor.shape + dst_shape_signature: Optional[List[int]] = dst_tensor.shapeSignature + # The src tensor must be broadcastable with the dst tensor. do_broadcast = self.check_to_do_broadcast( dst_shape, dst_shape_signature, src_shape, src_shape_signature ) - if do_broadcast: - # create braodcastTo output tensor + + if not do_broadcast: + # To connect 'dst' to Reshape node in the graph, 'dst' must be converted to Shape op. + dst_shape_tensor = torch.as_tensor(dst_shape, dtype=torch.int32) + + dst_shape_shape = [len(dst_shape)] + dst_name: str = dst.name + + shape_output = self.graph.add_tensor_from_scratch( + prefix=f"{dst_name}_shape_output", + shape=dst_shape_shape, + shape_signature=None, + dtype=circle.TensorType.TensorType.INT32, + source_node=node, + ) + + shape_operator = self.define_shape_node([dst], [shape_output]) + self.graph.add_operator(shape_operator) + inputs = [src, shape_output] + else: + # create broadcastTo output tensor src_name: str = src.name src_type: int = src_tensor.type @@ -159,15 +147,21 @@ def define_node( ) ) + dst_shape_merged = deepcopy(dst_shape) + if dst_shape_signature is not None: + for idx, sig in enumerate(dst_shape_signature): + if sig == -1: + dst_shape_merged[idx] = -1 + + dst_shape_tensor = torch.as_tensor(dst_shape_merged, dtype=torch.int32) broadcast_to_operator: circle.Operator.OperatorT = ( self.define_broadcast_to_node( - [src_tensor, dst_shape_tensor], [broadcast_to_output] + [src_tensor, dst_shape_tensor], [node] ) ) - self.graph.add_operator(broadcast_to_operator) - inputs: List = [broadcast_to_output, shape_output] - else: - inputs = [src, shape_output] + # self.graph.add_operator(broadcast_to_operator) + # inputs: List = [broadcast_to_output, shape_output] + return broadcast_to_operator outputs = [node] op_index = get_op_index( diff --git a/tico/utils/convert.py b/tico/utils/convert.py index ec039803..9926b9fc 100644 --- a/tico/utils/convert.py +++ b/tico/utils/convert.py @@ -197,6 +197,8 @@ def convert_exported_module_to_circle( logger.debug("Input ExportedProgram (must be core aten)") logger.debug(exported_program) + nodes = list(exported_program.graph.nodes) + print(nodes[0].meta['val']) # PRE-EDGE PASSES # # Here are the passes that run before to_edge() conversion. @@ -221,6 +223,8 @@ def convert_exported_module_to_circle( # CompositeImplicitAutograd and have functional schema are safe to not decompose. exported_program = traced_run_decompositions(exported_program) + nodes = list(exported_program.graph.nodes) + print(nodes[0].meta['val']) # TODO Distinguish legalize and optimize circle_legalize = PassManager( passes=[ @@ -259,7 +263,9 @@ def convert_exported_module_to_circle( ] ) circle_legalize.run(exported_program) - + + nodes = list(exported_program.graph.nodes) + print(nodes[0].meta['val']) # After this stage, ExportedProgram invariant is broken, i.e., # graph can have a constant torch.tensor not lifted to a placeholder circle_legalize = PassManager( @@ -270,6 +276,8 @@ def convert_exported_module_to_circle( ) circle_legalize.run(exported_program) + nodes = list(exported_program.graph.nodes) + print(nodes[0].meta['val']) # TODO Give an option to enable quantiztion to user enable_quantization = has_quantization_ops(exported_program.graph) if enable_quantization: diff --git a/tico/utils/validate_args_kwargs.py b/tico/utils/validate_args_kwargs.py index 03f5e5b4..75e8086e 100644 --- a/tico/utils/validate_args_kwargs.py +++ b/tico/utils/validate_args_kwargs.py @@ -296,6 +296,12 @@ class CopyArgs: dst: torch.fx.Node src: torch.fx.Node + non_blocking: bool = False + + def __post_init__(self): + if self.non_blocking is True: + raise NotImplementedError("non_blocking option is not supported yet.") + @enforce_type