From 6375f70a0ace8596e3e5a99c9745304d4a653821 Mon Sep 17 00:00:00 2001 From: noli Date: Tue, 27 Aug 2024 16:22:23 +0000 Subject: [PATCH 1/6] add ark tensor gradient tracking/updates --- python/ark/module.py | 21 +++++--------------- python/ark/tensor.py | 46 +++++++++++++++++++++++++++++--------------- 2 files changed, 35 insertions(+), 32 deletions(-) diff --git a/python/ark/module.py b/python/ark/module.py index b5744f10..c33623a1 100644 --- a/python/ark/module.py +++ b/python/ark/module.py @@ -4,12 +4,11 @@ import logging import numpy as np from typing import Any, Dict, Union -from .tensor import Parameter +from .tensor import Tensor, Parameter from .torch import torch, _no_torch from .runtime import Runtime from .model import Model from .data_type import DataType -from .ops import placeholder class Module: @@ -36,10 +35,7 @@ def __setattr__(self, __name: str, __value: Any) -> None: elif isinstance(__value, Parameter): self.register_parameter(__name, __value) elif not _no_torch and isinstance(__value, torch.nn.Parameter): - shape, dtype = list(__value.shape), DataType.from_torch( - __value.dtype - ) - __value = Parameter(placeholder(shape, dtype, data=__value), True) + __value = Parameter(__value) self.register_parameter(__name, __value) super().__setattr__(__name, __value) @@ -147,16 +143,14 @@ def forward(ctx, ark_module, *args, **kwargs): input_requires_grad = 0 for arg in args: if isinstance(arg, torch.Tensor): - shape, dtype = list(arg.shape), DataType.from_torch(arg.dtype) - input_args.append(placeholder(shape, dtype, data=arg)) + input_args.append(Tensor.from_torch(arg)) if arg.requires_grad: input_requires_grad += 1 else: input_args.append(arg) for k, v in kwargs.items(): if isinstance(v, torch.Tensor): - shape, dtype = list(arg.shape), DataType.from_torch(arg.dtype) - input_kwargs[k] = placeholder(shape, dtype, data=v) + input_kwargs[k] = Tensor.from_torch(v) if v.requires_grad: input_requires_grad += 1 else: @@ -178,12 +172,7 @@ def backward(ctx, *grad_outputs): PyTorch parameters. """ Model.reset() - # i think we should support placeholder initialization - # with just pytorch tensor - ark_grad_outputs = [] - for grad in grad_outputs: - shape, dtype = list(grad.shape), DataType.from_torch(grad.dtype) - ark_grad_outputs.append(placeholder(shape, dtype, data=grad)) + ark_grad_outputs = [Tensor.from_torch(grad) for grad in grad_outputs] grads = ctx.ark_module.backward(*ark_grad_outputs) grad_inputs, grad_weights = ( grads[: ctx.num_inp_grad], diff --git a/python/ark/tensor.py b/python/ark/tensor.py index a09b0af6..7b503cf3 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import numpy as np -from typing import Callable, Iterable, List, Union, Type +from typing import Callable, Iterable, List, Union, Type, Dict from ._ark_core import _Dims, _Tensor, _NullTensor from .torch import torch, _no_torch @@ -22,6 +22,9 @@ class Dims(_Dims): class Tensor: + + _tensor_grads: Dict[int, "Tensor"] = {} + def __init__( self, _tensor: _Tensor, @@ -38,6 +41,8 @@ def __init__( self._tensor = _tensor self.initializer: Initializer = initializer self.requires_grad = requires_grad + if self.requires_grad: + Tensor._tensor_grads[self._tensor.id()] = self def __hash__(self): return self._tensor.id() @@ -186,6 +191,8 @@ def to_torch(self) -> torch.Tensor: torch_view = torch.utils.dlpack.from_dlpack(dl_capsule) # Keep dl_capsule alive not to free the memory torch_view.__ark_buffer__ = dl_capsule + if self.requires_grad: + torch_view.requires_grad_(True) return torch_view @staticmethod @@ -205,7 +212,8 @@ def from_torch(tensor: torch.Tensor) -> "Tensor": shape=list(tensor.shape), dtype=DataType.from_torch(tensor.dtype), data=tensor.data_ptr(), - ) + ), + requires_grad=tensor.requires_grad ) # Share ownership of the memory with the torch tensor ark_tensor.__torch_buffer__ = tensor @@ -259,37 +267,43 @@ def initialize(self) -> "Tensor": self.copy(data) return self + def requires_grad_(self, requires_grad: bool = True) -> "Tensor": + """ + Sets the `requires_grad` attribute in-place for the tensor. + If `requires_grad` is True, the tensor will be tracked for gradient + updates. + """ + self.requires_grad = requires_grad + if requires_grad: + Tensor._tensor_grads[self._tensor.id()] = self + elif self._tensor.id() in Tensor._tensor_grads: + del Tensor._tensor_grads[self._tensor.id()] + return self + -class Parameter(Tensor): +class Parameter(Tensor, torch.nn.Parameter): """ A tensor as a parameter. """ def __init__( self, - tensor: _Tensor, - from_torch: bool, + tensor: Union[_Tensor, "torch.nn.Parameter"], ): """ Initializes a new instance of the Parameter class. - Args: - _tensor (_ark_core._Tensor): The underlying _Tensor object. - from_torch: Indicates if the Parameter is tied to a torch.nn.Paramter """ - if not _no_torch and from_torch: - _tensor = tensor._tensor + if not _no_torch and isinstance(tensor, torch.nn.Parameter): + ark_tensor = Tensor.from_torch(tensor) + self._tensor = ark_tensor._tensor + ark_tensor.requires_grad_(True) self.torch_param = tensor self.staged_tensor = None - Tensor.__init__( - self, - _tensor, - requires_grad=tensor.requires_grad, - ) elif isinstance(tensor, _Tensor): _tensor = tensor self.torch_param = None self.staged_tensor = None - Tensor.__init__(self, _tensor, requires_grad=False) + Tensor.__init__(self, _tensor, requires_grad=True) else: raise TypeError( "tensor must be an ARK tensor or a torch.nn.Parameter" From 2aa11550fdb90538e347fd0a1eb35885a1a0071d Mon Sep 17 00:00:00 2001 From: noli Date: Tue, 27 Aug 2024 23:45:27 +0000 Subject: [PATCH 2/6] small fix for aot autograd input handling --- python/ark/torch/tracer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ark/torch/tracer.py b/python/ark/torch/tracer.py index 9570fd97..65e3fded 100644 --- a/python/ark/torch/tracer.py +++ b/python/ark/torch/tracer.py @@ -247,7 +247,7 @@ def call(*args, **kwargs): return cls def autograd_trace_( - self, gm: torch.nn.Module, _: List[torch.Tensor] + self, gm: torch.nn.Module, forward_inputs: List[torch.Tensor] ) -> Callable: for _, param in gm.named_parameters(remove_duplicate=False): self.params.append(param) @@ -264,7 +264,7 @@ def bw_compiler(gm: torch.fx.GraphModule, _): return torch._dynamo.backends.common.aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler - )(gm, _) + )(gm, forward_inputs) def autograd_trace_impl_( self, gm: torch.fx.GraphModule, _: List[torch.Tensor], is_fw: bool From ce834495e030ae2b427fa34b8df5e2ae6fa2a837 Mon Sep 17 00:00:00 2001 From: noli Date: Sat, 31 Aug 2024 16:27:59 +0000 Subject: [PATCH 3/6] construct pytorch module to use as fallback forward/backward from recorded ops during tracing --- python/ark/torch/tracer.py | 40 ++++++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/python/ark/torch/tracer.py b/python/ark/torch/tracer.py index 42a1f1f8..d527b0a5 100644 --- a/python/ark/torch/tracer.py +++ b/python/ark/torch/tracer.py @@ -3,11 +3,12 @@ try: import torch + from torch import fx except ImportError: raise ImportError("torch is required to use this module") import logging -from typing import List, Dict, Optional, Callable, Any +from typing import List, Dict, Optional, Callable, Union, Any from ..planner import Planner, Plan from ..tensor import Tensor @@ -204,6 +205,10 @@ def __init__(self): self.failed: bool = False self.launched_fw: bool = False self.launched_bw: bool = False + self.execution_backend: List[Union[torch.nn.Module, Plan]] = [] + self.forward_ops = [] + self.backward_ops = [] + self.torch_module = None def __call__(self, target: Callable) -> Callable: is_module = issubclass(target, torch.nn.Module) @@ -293,8 +298,9 @@ def autograd_trace_impl_( def run(args) -> Any: Model.reset() if not self.failed: + op_list = self.forward_ops if is_fw else self.backward_ops for node in gm.graph.nodes: - logging.info(node.format_node(), node.meta) + op_list.append(node) if not self.handle_node_(node, is_fw): print(f"Failed to handle node {node.format_node()}") self.failed = True @@ -305,8 +311,13 @@ def run(args) -> Any: self.plan_fw = Planner(self.device.index).plan() else: self.plan_bw = Planner(self.device.index).plan() + if is_fw: + t1 = self.construct_torch_module(gm, self.forward_ops) + print("FORWARD TORCH: ", t1.code) + else: + t2 = self.construct_torch_module(gm, self.backward_ops) + print("BKACWARD TOCH: ", t2.code) return torch.fx.Interpreter(gm).boxed_run(args) - run._boxed_call = True return run @@ -374,7 +385,28 @@ def handle_node_(self, node: torch.fx.node.Node, is_fw: bool) -> bool: else: raise ValueError(f"Unexpected node {node.format_node()}") return True - + + def construct_torch_module(self, original_gm, op_seq): + graph = fx.Graph() + env = {} + def create_node(node): + if node.op == 'placeholder': + return graph.placeholder(node.target, type_expr=node.type) + elif node.op == 'get_attr': + return graph.get_attr(node.target) + elif node.op == 'call_function': + args = tuple(env[arg.name] if isinstance(arg, fx.Node) else arg for arg in node.args) + kwargs = {k: env[v.name] if isinstance(v, fx.Node) else v for k, v in node.kwargs.items()} + return graph.call_function(node.target, args, kwargs) + elif node.op == 'output': + args = tuple(env[arg.name] if isinstance(arg, fx.Node) else arg for arg in node.args[0]) + return graph.output(tuple(args)) + else: + raise ValueError(f"Unsupported node operation: {node.op}") + for node in op_seq: + env[node.name] = create_node(node) + torch_module = fx.GraphModule(original_gm, graph) + return torch_module def tracer(target: Callable): return Tracer()(target) From 53e5ef4c4f16c3f9185e2d85ff293b0422528e2e Mon Sep 17 00:00:00 2001 From: noli Date: Sat, 31 Aug 2024 18:53:54 +0000 Subject: [PATCH 4/6] wip adds graph partitioning --- python/ark/torch/tracer.py | 123 ++++++++++++++++++++++++------------- 1 file changed, 82 insertions(+), 41 deletions(-) diff --git a/python/ark/torch/tracer.py b/python/ark/torch/tracer.py index d527b0a5..b0d9ca37 100644 --- a/python/ark/torch/tracer.py +++ b/python/ark/torch/tracer.py @@ -189,7 +189,6 @@ def handle_aten_mse_loss_backward( "aten::mse_loss_backward": handle_aten_mse_loss_backward, } - class Tracer: def __init__(self): self.tensors: Dict[str, Tensor] = {} @@ -199,16 +198,14 @@ def __init__(self): self.inputs_bw: List[Tensor] = [] self.outputs_fw: List[Tensor] = [] self.outputs_bw: List[Tensor] = [] - self.plan_fw: Optional[Plan] = None - self.plan_bw: Optional[Plan] = None + self.plan_fw: List[Optional[Plan]] = [] + self.plan_bw: List[Optional[Plan]] = [] self.device: Optional[torch.device] = None self.failed: bool = False self.launched_fw: bool = False self.launched_bw: bool = False - self.execution_backend: List[Union[torch.nn.Module, Plan]] = [] - self.forward_ops = [] - self.backward_ops = [] - self.torch_module = None + self.execution_segments = [] + self.intermediate_results = {} def __call__(self, target: Callable) -> Callable: is_module = issubclass(target, torch.nn.Module) @@ -221,26 +218,35 @@ def __call__(self, target: Callable) -> Callable: target.forward_torch = target.forward def forward_wrapper(instance: torch.nn.Module, *args, **kwargs) -> Any: - if self.plan_fw is None: + if self.plan_fw == []: return instance.forward_torch(*args, **kwargs) - rt = Runtime.get_runtime() - if not self.launched_fw: - rt.launch( - plan=self.plan_fw, - device_id=self.device.index, - loop_mode=False, - ) - self.launched_fw = True - self.launched_bw = False + input_data = args + for i, backend in enumerate(self.plan_fw): + if isinstance(backend, Plan): + # use ARK + rt = Runtime.get_runtime() + if not self.launched_fw: + rt.launch( + plan=self.plan_fw[i], + device_id=self.device.index, + loop_mode=False, + ) + self.launched_fw = True + self.launched_bw = False - ph_map = {ph: data for ph, data in zip(self.inputs_fw, args)} - rt.run(tensor_mappings=ph_map) + ph_map = {ph: data for ph, data in zip(self.inputs_fw, args)} + rt.run(tensor_mappings=ph_map) + input_data = self.outputs_fw[0] + else: + # use pytorch + input_data = backend(*input_data) # TODO: how to get the output tensor(s)? - return self.outputs_fw[0] + return input_data def backward_wrapper(instance: torch.nn.Module, *args, **kwargs): - if self.plan_bw is None: + if self.plan_bw == []: return instance.forward_torch(*args, **kwargs) + rt = Runtime.get_runtime() if not self.launched_bw: rt.launch( @@ -275,6 +281,32 @@ def call(*args, **kwargs): target.backward_ark = backward_wrapper target.__call__ = call_wrapper return target + + def partition_graph(self, gm: torch.fx.GraphModule): + current_segment = [] + backend = "ARK" + for node in gm.graph.nodes: + if node.op == "call_function": + target_name = node.target.name() + if target_name in _REGISTRY_FUNCTION_HANDLER: + if backend == "PyTorch": + # End the PyTorch segment and start a new ARK segment + if current_segment: + self.execution_segments.append((backend, current_segment)) + current_segment = [] + backend = "ARK" + else: + if backend == "ARK": + # End the ARK segment and start a new PyTorch segment + if current_segment: + self.execution_segments.append((backend, current_segment)) + current_segment = [] + backend = "PyTorch" + + current_segment.append(node) + + if current_segment: + self.execution_segments.append((backend, current_segment)) def autograd_trace_( self, gm: torch.nn.Module, forward_inputs: List[torch.Tensor] @@ -294,29 +326,34 @@ def bw_compiler(gm: torch.fx.GraphModule, _): def autograd_trace_impl_( self, gm: torch.fx.GraphModule, _: List[torch.Tensor], is_fw: bool ) -> Callable: + self.partition_graph(gm) def run(args) -> Any: Model.reset() if not self.failed: - op_list = self.forward_ops if is_fw else self.backward_ops - for node in gm.graph.nodes: - op_list.append(node) - if not self.handle_node_(node, is_fw): - print(f"Failed to handle node {node.format_node()}") - self.failed = True - break - if not self.failed: - Model.set_device_id(self.device.index) - if is_fw: - self.plan_fw = Planner(self.device.index).plan() - else: - self.plan_bw = Planner(self.device.index).plan() - if is_fw: - t1 = self.construct_torch_module(gm, self.forward_ops) - print("FORWARD TORCH: ", t1.code) - else: - t2 = self.construct_torch_module(gm, self.backward_ops) - print("BKACWARD TOCH: ", t2.code) + intermediate_results = {} + + for backend, ops in self.execution_segments: + if backend == "ARK": + Model.reset() + for node in ops: + self.intermediate_results[node.name] = node + if not self.handle_node_(node, is_fw): + self.failed = True + break + if not self.failed: + Model.set_device_id(self.device.index) + if is_fw: + self.plan_fw.append(Planner(self.device.index).plan()) + else: + self.plan_bw.append(Planner(self.device_index).plan()) + for node in ops: + intermediate_results[node.name] = self.tensors[node.name] + else: # PyTorch + # we have our op list stored in ops + torch_module = self.construct_torch_module(gm, ops) + self.plan_fw.append(torch_module) + self.plan_bw.append(torch_module) return torch.fx.Interpreter(gm).boxed_run(args) run._boxed_call = True return run @@ -376,6 +413,7 @@ def handle_node_(self, node: torch.fx.node.Node, is_fw: bool) -> bool: elif node.op == "call_function": target_name = node.target.name() if target_name not in _REGISTRY_FUNCTION_HANDLER: + # should never happen now due to partitioning before logging.warning( f"Unsupported function {target_name}. Usage: {node.format_node()}" ) @@ -388,9 +426,12 @@ def handle_node_(self, node: torch.fx.node.Node, is_fw: bool) -> bool: def construct_torch_module(self, original_gm, op_seq): graph = fx.Graph() - env = {} + env = self.intermediate_results + print("INTERM: ", self.intermediate_results) def create_node(node): if node.op == 'placeholder': + if node.name in self.intermediate_results: + return self.intermediate_results[node.name] return graph.placeholder(node.target, type_expr=node.type) elif node.op == 'get_attr': return graph.get_attr(node.target) From 9be6d992b32303fb527772a2ee0566ce35ee5291 Mon Sep 17 00:00:00 2001 From: noli Date: Sun, 1 Sep 2024 03:02:09 +0000 Subject: [PATCH 5/6] wip --- python/ark/torch/tracer.py | 81 +++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 24 deletions(-) diff --git a/python/ark/torch/tracer.py b/python/ark/torch/tracer.py index b0d9ca37..447b81f2 100644 --- a/python/ark/torch/tracer.py +++ b/python/ark/torch/tracer.py @@ -195,9 +195,13 @@ def __init__(self): self.params: Optional[List[torch.nn.Parameter]] = None self.params_idx: int = 0 self.inputs_fw: List[Tensor] = [] + self.fw_plan_tns = {} self.inputs_bw: List[Tensor] = [] - self.outputs_fw: List[Tensor] = [] - self.outputs_bw: List[Tensor] = [] + self.bw_plan_tns = {} + # List of output lists, one for each ARK plan + self.outputs_fw: List[List[Tensor]] = [] + self.outputs_bw: List[List[Tensor]] = [] + self.curr_plan_idx: int = -1 self.plan_fw: List[Optional[Plan]] = [] self.plan_bw: List[Optional[Plan]] = [] self.device: Optional[torch.device] = None @@ -227,19 +231,22 @@ def forward_wrapper(instance: torch.nn.Module, *args, **kwargs) -> Any: rt = Runtime.get_runtime() if not self.launched_fw: rt.launch( - plan=self.plan_fw[i], + plan=backend, device_id=self.device.index, loop_mode=False, ) self.launched_fw = True self.launched_bw = False - ph_map = {ph: data for ph, data in zip(self.inputs_fw, args)} + ph_map = {ph: data for ph, data in zip(self.inputs_fw, input_data)} rt.run(tensor_mappings=ph_map) - input_data = self.outputs_fw[0] + print(self.outputs_fw) + # all outputs (including intermediates): tuple(t.to_torch() for t in self.outputs_fw[i]) + input_data = self.outputs_fw[i][-1].to_torch() else: # use pytorch - input_data = backend(*input_data) + input_data = Tensor.from_torch(backend(input_data)) + # TODO: how to get the output tensor(s)? return input_data @@ -327,38 +334,54 @@ def autograd_trace_impl_( self, gm: torch.fx.GraphModule, _: List[torch.Tensor], is_fw: bool ) -> Callable: self.partition_graph(gm) - def run(args) -> Any: Model.reset() if not self.failed: - intermediate_results = {} - - for backend, ops in self.execution_segments: + self.curr_plan_idx = -1 + for backend, op_seq in self.execution_segments: if backend == "ARK": Model.reset() - for node in ops: + self.curr_plan_idx += 1 + curr_outputs = [] + for node in op_seq: self.intermediate_results[node.name] = node - if not self.handle_node_(node, is_fw): + if not self.handle_node_(node, is_fw, curr_outputs): self.failed = True break + if is_fw: + self.outputs_fw.append(curr_outputs) + else: + self.outputs_bw.append(curr_outputs) + curr_outputs = [] if not self.failed: + self.outputs_fw.append(curr_outputs) Model.set_device_id(self.device.index) if is_fw: self.plan_fw.append(Planner(self.device.index).plan()) else: - self.plan_bw.append(Planner(self.device_index).plan()) - for node in ops: - intermediate_results[node.name] = self.tensors[node.name] + self.plan_bw.append(Planner(self.device.index).plan()) else: # PyTorch # we have our op list stored in ops - torch_module = self.construct_torch_module(gm, ops) + torch_module = self.construct_torch_module(gm, op_seq) self.plan_fw.append(torch_module) self.plan_bw.append(torch_module) + # handle intermediate outputs computed by ARK + for node in op_seq: + meta = node.meta["tensor_meta"] + data = 0 + t = ops.placeholder( + shape=meta.shape, + dtype=ops.DataType.from_torch(meta.dtype), + name=node.name, + data=data, + ) + self.tensors[node.name] = t + self.execution_segments.clear() return torch.fx.Interpreter(gm).boxed_run(args) run._boxed_call = True return run - def handle_node_(self, node: torch.fx.node.Node, is_fw: bool) -> bool: + def handle_node_(self, node: torch.fx.node.Node, is_fw: bool, curr_outputs) -> bool: if node.op == "placeholder": t = self.tensors.get(node.name, None) if t is not None: @@ -400,16 +423,15 @@ def handle_node_(self, node: torch.fx.node.Node, is_fw: bool) -> bool: else: self.inputs_bw.append(t) elif node.op == "output": - outputs_list = self.outputs_fw if is_fw else self.outputs_bw - if outputs_list: - raise ValueError("Multiple output nodes are unexpected") + outputs = [] for out in node.args[0]: if isinstance(out, torch.fx.node.Node): if out.name not in self.tensors: raise ValueError(f"Output tensor {out.name} not found") - outputs_list.append(self.tensors[out.name]) + tns = self.tensors[out.name] + curr_outputs.append(self.tensors[out.name]) else: - outputs_list.append(out) + curr_outputs.append(out) elif node.op == "call_function": target_name = node.target.name() if target_name not in _REGISTRY_FUNCTION_HANDLER: @@ -420,6 +442,8 @@ def handle_node_(self, node: torch.fx.node.Node, is_fw: bool) -> bool: return False t = _REGISTRY_FUNCTION_HANDLER[target_name](node, self.tensors) self.tensors[node.name] = t + # Append to outputs_fw if this is an intermediate output for ARK + curr_outputs.append(t) else: raise ValueError(f"Unexpected node {node.format_node()}") return True @@ -427,7 +451,6 @@ def handle_node_(self, node: torch.fx.node.Node, is_fw: bool) -> bool: def construct_torch_module(self, original_gm, op_seq): graph = fx.Graph() env = self.intermediate_results - print("INTERM: ", self.intermediate_results) def create_node(node): if node.op == 'placeholder': if node.name in self.intermediate_results: @@ -444,10 +467,20 @@ def create_node(node): return graph.output(tuple(args)) else: raise ValueError(f"Unsupported node operation: {node.op}") + for node in op_seq: + if node.op == 'call_function' and node.args and isinstance(node.args[0], fx.Node): + input_name = node.args[0].name + if input_name in self.intermediate_results: + placeholder_node = graph.placeholder(input_name) + env[input_name] = placeholder_node for node in op_seq: env[node.name] = create_node(node) + + # TODO: Support multiple output for intermediate activations + last_node = list(env.values())[-1] + graph.output(last_node) torch_module = fx.GraphModule(original_gm, graph) return torch_module def tracer(target: Callable): - return Tracer()(target) + return Tracer()(target) \ No newline at end of file From 72f7e7ff31ce6eed7a5ec6c66c0f3227b7ee8959 Mon Sep 17 00:00:00 2001 From: noli Date: Mon, 2 Sep 2024 06:51:38 +0000 Subject: [PATCH 6/6] wip --- python/ark/torch/tracer.py | 120 +++++++++++++++++++++---------------- 1 file changed, 67 insertions(+), 53 deletions(-) diff --git a/python/ark/torch/tracer.py b/python/ark/torch/tracer.py index 447b81f2..d62fdb8a 100644 --- a/python/ark/torch/tracer.py +++ b/python/ark/torch/tracer.py @@ -194,9 +194,9 @@ def __init__(self): self.tensors: Dict[str, Tensor] = {} self.params: Optional[List[torch.nn.Parameter]] = None self.params_idx: int = 0 - self.inputs_fw: List[Tensor] = [] + self.inputs_fw: List[List[Tensor]] = [] self.fw_plan_tns = {} - self.inputs_bw: List[Tensor] = [] + self.inputs_bw: List[List[Tensor]] = [] self.bw_plan_tns = {} # List of output lists, one for each ARK plan self.outputs_fw: List[List[Tensor]] = [] @@ -208,7 +208,9 @@ def __init__(self): self.failed: bool = False self.launched_fw: bool = False self.launched_bw: bool = False - self.execution_segments = [] + self.execution_segments_fw = [] + self.ark_outputs: Dict[int, Tensor] = {} + self.execution_segments_bw = [] self.intermediate_results = {} def __call__(self, target: Callable) -> Callable: @@ -229,23 +231,21 @@ def forward_wrapper(instance: torch.nn.Module, *args, **kwargs) -> Any: if isinstance(backend, Plan): # use ARK rt = Runtime.get_runtime() - if not self.launched_fw: - rt.launch( + rt.launch( plan=backend, device_id=self.device.index, loop_mode=False, - ) - self.launched_fw = True - self.launched_bw = False - - ph_map = {ph: data for ph, data in zip(self.inputs_fw, input_data)} + ) + ph_map = {ph: data for ph, data in zip(self.inputs_fw[i], input_data)} rt.run(tensor_mappings=ph_map) - print(self.outputs_fw) # all outputs (including intermediates): tuple(t.to_torch() for t in self.outputs_fw[i]) input_data = self.outputs_fw[i][-1].to_torch() else: # use pytorch - input_data = Tensor.from_torch(backend(input_data)) + input_data = [backend(input_data)] + if i != len(self.execution_segments_fw) - 1: + self.inputs_fw.append([Tensor.from_torch(input_data[0])]) + # TODO: how to get the output tensor(s)? return input_data @@ -289,7 +289,11 @@ def call(*args, **kwargs): target.__call__ = call_wrapper return target - def partition_graph(self, gm: torch.fx.GraphModule): + def partition_graph(self, gm: torch.fx.GraphModule, is_fw: bool): + if is_fw: + exe_segment = self.execution_segments_fw + else: + exe_segment = self.execution_segments_bw current_segment = [] backend = "ARK" for node in gm.graph.nodes: @@ -299,31 +303,31 @@ def partition_graph(self, gm: torch.fx.GraphModule): if backend == "PyTorch": # End the PyTorch segment and start a new ARK segment if current_segment: - self.execution_segments.append((backend, current_segment)) + exe_segment.append((backend, current_segment)) current_segment = [] backend = "ARK" else: if backend == "ARK": # End the ARK segment and start a new PyTorch segment if current_segment: - self.execution_segments.append((backend, current_segment)) + exe_segment.append((backend, current_segment)) current_segment = [] backend = "PyTorch" current_segment.append(node) if current_segment: - self.execution_segments.append((backend, current_segment)) + exe_segment.append((backend, current_segment)) def autograd_trace_( self, gm: torch.nn.Module, forward_inputs: List[torch.Tensor] ) -> Callable: def fw_compiler(gm: torch.fx.GraphModule, _): - logging.info("==== FW Starts ====") + print("==== FW Starts ====") return self.autograd_trace_impl_(gm, _, True) def bw_compiler(gm: torch.fx.GraphModule, _): - logging.info("==== BW Starts ====") + print("==== BW Starts ====") return self.autograd_trace_impl_(gm, _, False) return torch._dynamo.backends.common.aot_autograd( @@ -333,55 +337,67 @@ def bw_compiler(gm: torch.fx.GraphModule, _): def autograd_trace_impl_( self, gm: torch.fx.GraphModule, _: List[torch.Tensor], is_fw: bool ) -> Callable: - self.partition_graph(gm) + self.partition_graph(gm, is_fw) + if is_fw: + self.curr_plan_idx = -1 + self.plan_fw = [] + exe_seg = self.execution_segments_fw + else: + self.curr_plan_idx = -1 + self.plan_bw = [] + exe_seg = self.execution_segments_bw + + print("gm: ", gm) def run(args) -> Any: - Model.reset() if not self.failed: - self.curr_plan_idx = -1 - for backend, op_seq in self.execution_segments: + for backend, op_seq in exe_seg: if backend == "ARK": Model.reset() self.curr_plan_idx += 1 - curr_outputs = [] + curr_outputs, curr_inputs = [], [] for node in op_seq: self.intermediate_results[node.name] = node - if not self.handle_node_(node, is_fw, curr_outputs): + if not self.handle_node_(node, is_fw, curr_outputs, curr_inputs): self.failed = True break - if is_fw: - self.outputs_fw.append(curr_outputs) - else: - self.outputs_bw.append(curr_outputs) - curr_outputs = [] if not self.failed: - self.outputs_fw.append(curr_outputs) Model.set_device_id(self.device.index) if is_fw: + self.inputs_fw.append(curr_inputs) + self.outputs_fw.append(curr_outputs) self.plan_fw.append(Planner(self.device.index).plan()) else: + self.inputs_bw.append(curr_inputs) + self.outputs_bw.append(curr_outputs) self.plan_bw.append(Planner(self.device.index).plan()) else: # PyTorch # we have our op list stored in ops + self.curr_plan_idx += 1 torch_module = self.construct_torch_module(gm, op_seq) - self.plan_fw.append(torch_module) - self.plan_bw.append(torch_module) - # handle intermediate outputs computed by ARK - for node in op_seq: - meta = node.meta["tensor_meta"] - data = 0 - t = ops.placeholder( - shape=meta.shape, - dtype=ops.DataType.from_torch(meta.dtype), - name=node.name, - data=data, - ) - self.tensors[node.name] = t - self.execution_segments.clear() + if is_fw: + + self.plan_fw.append(torch_module) + else: + + self.plan_bw.append(torch_module) + # handle intermediate outputs computed by Torch + if self.curr_plan_idx != len(exe_seg) -1: + for node in op_seq: + meta = node.meta["tensor_meta"] + data = 0 + t = ops.placeholder( + shape=meta.shape, + dtype=ops.DataType.from_torch(meta.dtype), + name=node.name, + data=data, + ) + self.tensors[node.name] = t + return torch.fx.Interpreter(gm).boxed_run(args) run._boxed_call = True return run - def handle_node_(self, node: torch.fx.node.Node, is_fw: bool, curr_outputs) -> bool: + def handle_node_(self, node: torch.fx.node.Node, is_fw: bool, curr_outputs, curr_inputs) -> bool: if node.op == "placeholder": t = self.tensors.get(node.name, None) if t is not None: @@ -418,12 +434,8 @@ def handle_node_(self, node: torch.fx.node.Node, is_fw: bool, curr_outputs) -> b ) self.tensors[node.name] = t if data == 0: - if is_fw: - self.inputs_fw.append(t) - else: - self.inputs_bw.append(t) + curr_inputs.append(t) elif node.op == "output": - outputs = [] for out in node.args[0]: if isinstance(out, torch.fx.node.Node): if out.name not in self.tensors: @@ -443,6 +455,7 @@ def handle_node_(self, node: torch.fx.node.Node, is_fw: bool, curr_outputs) -> b t = _REGISTRY_FUNCTION_HANDLER[target_name](node, self.tensors) self.tensors[node.name] = t # Append to outputs_fw if this is an intermediate output for ARK + # if self.curr_plan_idx != len(self.execution_segments) - 1: curr_outputs.append(t) else: raise ValueError(f"Unexpected node {node.format_node()}") @@ -476,9 +489,10 @@ def create_node(node): for node in op_seq: env[node.name] = create_node(node) - # TODO: Support multiple output for intermediate activations - last_node = list(env.values())[-1] - graph.output(last_node) + # TODO: Support multiple output for intermediae activations + if op_seq[-1].op != 'output': + last_node = env[op_seq[-1].name] + graph.output(last_node) torch_module = fx.GraphModule(original_gm, graph) return torch_module