diff --git a/python/ark/module.py b/python/ark/module.py index b5744f10..c33623a1 100644 --- a/python/ark/module.py +++ b/python/ark/module.py @@ -4,12 +4,11 @@ import logging import numpy as np from typing import Any, Dict, Union -from .tensor import Parameter +from .tensor import Tensor, Parameter from .torch import torch, _no_torch from .runtime import Runtime from .model import Model from .data_type import DataType -from .ops import placeholder class Module: @@ -36,10 +35,7 @@ def __setattr__(self, __name: str, __value: Any) -> None: elif isinstance(__value, Parameter): self.register_parameter(__name, __value) elif not _no_torch and isinstance(__value, torch.nn.Parameter): - shape, dtype = list(__value.shape), DataType.from_torch( - __value.dtype - ) - __value = Parameter(placeholder(shape, dtype, data=__value), True) + __value = Parameter(__value) self.register_parameter(__name, __value) super().__setattr__(__name, __value) @@ -147,16 +143,14 @@ def forward(ctx, ark_module, *args, **kwargs): input_requires_grad = 0 for arg in args: if isinstance(arg, torch.Tensor): - shape, dtype = list(arg.shape), DataType.from_torch(arg.dtype) - input_args.append(placeholder(shape, dtype, data=arg)) + input_args.append(Tensor.from_torch(arg)) if arg.requires_grad: input_requires_grad += 1 else: input_args.append(arg) for k, v in kwargs.items(): if isinstance(v, torch.Tensor): - shape, dtype = list(arg.shape), DataType.from_torch(arg.dtype) - input_kwargs[k] = placeholder(shape, dtype, data=v) + input_kwargs[k] = Tensor.from_torch(v) if v.requires_grad: input_requires_grad += 1 else: @@ -178,12 +172,7 @@ def backward(ctx, *grad_outputs): PyTorch parameters. """ Model.reset() - # i think we should support placeholder initialization - # with just pytorch tensor - ark_grad_outputs = [] - for grad in grad_outputs: - shape, dtype = list(grad.shape), DataType.from_torch(grad.dtype) - ark_grad_outputs.append(placeholder(shape, dtype, data=grad)) + ark_grad_outputs = [Tensor.from_torch(grad) for grad in grad_outputs] grads = ctx.ark_module.backward(*ark_grad_outputs) grad_inputs, grad_weights = ( grads[: ctx.num_inp_grad], diff --git a/python/ark/tensor.py b/python/ark/tensor.py index a09b0af6..7b503cf3 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import numpy as np -from typing import Callable, Iterable, List, Union, Type +from typing import Callable, Iterable, List, Union, Type, Dict from ._ark_core import _Dims, _Tensor, _NullTensor from .torch import torch, _no_torch @@ -22,6 +22,9 @@ class Dims(_Dims): class Tensor: + + _tensor_grads: Dict[int, "Tensor"] = {} + def __init__( self, _tensor: _Tensor, @@ -38,6 +41,8 @@ def __init__( self._tensor = _tensor self.initializer: Initializer = initializer self.requires_grad = requires_grad + if self.requires_grad: + Tensor._tensor_grads[self._tensor.id()] = self def __hash__(self): return self._tensor.id() @@ -186,6 +191,8 @@ def to_torch(self) -> torch.Tensor: torch_view = torch.utils.dlpack.from_dlpack(dl_capsule) # Keep dl_capsule alive not to free the memory torch_view.__ark_buffer__ = dl_capsule + if self.requires_grad: + torch_view.requires_grad_(True) return torch_view @staticmethod @@ -205,7 +212,8 @@ def from_torch(tensor: torch.Tensor) -> "Tensor": shape=list(tensor.shape), dtype=DataType.from_torch(tensor.dtype), data=tensor.data_ptr(), - ) + ), + requires_grad=tensor.requires_grad ) # Share ownership of the memory with the torch tensor ark_tensor.__torch_buffer__ = tensor @@ -259,37 +267,43 @@ def initialize(self) -> "Tensor": self.copy(data) return self + def requires_grad_(self, requires_grad: bool = True) -> "Tensor": + """ + Sets the `requires_grad` attribute in-place for the tensor. + If `requires_grad` is True, the tensor will be tracked for gradient + updates. + """ + self.requires_grad = requires_grad + if requires_grad: + Tensor._tensor_grads[self._tensor.id()] = self + elif self._tensor.id() in Tensor._tensor_grads: + del Tensor._tensor_grads[self._tensor.id()] + return self + -class Parameter(Tensor): +class Parameter(Tensor, torch.nn.Parameter): """ A tensor as a parameter. """ def __init__( self, - tensor: _Tensor, - from_torch: bool, + tensor: Union[_Tensor, "torch.nn.Parameter"], ): """ Initializes a new instance of the Parameter class. - Args: - _tensor (_ark_core._Tensor): The underlying _Tensor object. - from_torch: Indicates if the Parameter is tied to a torch.nn.Paramter """ - if not _no_torch and from_torch: - _tensor = tensor._tensor + if not _no_torch and isinstance(tensor, torch.nn.Parameter): + ark_tensor = Tensor.from_torch(tensor) + self._tensor = ark_tensor._tensor + ark_tensor.requires_grad_(True) self.torch_param = tensor self.staged_tensor = None - Tensor.__init__( - self, - _tensor, - requires_grad=tensor.requires_grad, - ) elif isinstance(tensor, _Tensor): _tensor = tensor self.torch_param = None self.staged_tensor = None - Tensor.__init__(self, _tensor, requires_grad=False) + Tensor.__init__(self, _tensor, requires_grad=True) else: raise TypeError( "tensor must be an ARK tensor or a torch.nn.Parameter" diff --git a/python/ark/torch/tracer.py b/python/ark/torch/tracer.py index eb73d4e4..d62fdb8a 100644 --- a/python/ark/torch/tracer.py +++ b/python/ark/torch/tracer.py @@ -3,11 +3,12 @@ try: import torch + from torch import fx except ImportError: raise ImportError("torch is required to use this module") import logging -from typing import List, Dict, Optional, Callable, Any +from typing import List, Dict, Optional, Callable, Union, Any from ..planner import Planner, Plan from ..tensor import Tensor @@ -188,22 +189,29 @@ def handle_aten_mse_loss_backward( "aten::mse_loss_backward": handle_aten_mse_loss_backward, } - class Tracer: def __init__(self): self.tensors: Dict[str, Tensor] = {} self.params: Optional[List[torch.nn.Parameter]] = None self.params_idx: int = 0 - self.inputs_fw: List[Tensor] = [] - self.inputs_bw: List[Tensor] = [] - self.outputs_fw: List[Tensor] = [] - self.outputs_bw: List[Tensor] = [] - self.plan_fw: Optional[Plan] = None - self.plan_bw: Optional[Plan] = None + self.inputs_fw: List[List[Tensor]] = [] + self.fw_plan_tns = {} + self.inputs_bw: List[List[Tensor]] = [] + self.bw_plan_tns = {} + # List of output lists, one for each ARK plan + self.outputs_fw: List[List[Tensor]] = [] + self.outputs_bw: List[List[Tensor]] = [] + self.curr_plan_idx: int = -1 + self.plan_fw: List[Optional[Plan]] = [] + self.plan_bw: List[Optional[Plan]] = [] self.device: Optional[torch.device] = None self.failed: bool = False self.launched_fw: bool = False self.launched_bw: bool = False + self.execution_segments_fw = [] + self.ark_outputs: Dict[int, Tensor] = {} + self.execution_segments_bw = [] + self.intermediate_results = {} def __call__(self, target: Callable) -> Callable: is_module = issubclass(target, torch.nn.Module) @@ -216,26 +224,36 @@ def __call__(self, target: Callable) -> Callable: target.forward_torch = target.forward def forward_wrapper(instance: torch.nn.Module, *args, **kwargs) -> Any: - if self.plan_fw is None: + if self.plan_fw == []: return instance.forward_torch(*args, **kwargs) - rt = Runtime.get_runtime() - if not self.launched_fw: - rt.launch( - plan=self.plan_fw, - device_id=self.device.index, - loop_mode=False, - ) - self.launched_fw = True - self.launched_bw = False + input_data = args + for i, backend in enumerate(self.plan_fw): + if isinstance(backend, Plan): + # use ARK + rt = Runtime.get_runtime() + rt.launch( + plan=backend, + device_id=self.device.index, + loop_mode=False, + ) + ph_map = {ph: data for ph, data in zip(self.inputs_fw[i], input_data)} + rt.run(tensor_mappings=ph_map) + # all outputs (including intermediates): tuple(t.to_torch() for t in self.outputs_fw[i]) + input_data = self.outputs_fw[i][-1].to_torch() + else: + # use pytorch + input_data = [backend(input_data)] + if i != len(self.execution_segments_fw) - 1: + self.inputs_fw.append([Tensor.from_torch(input_data[0])]) + - ph_map = {ph: data for ph, data in zip(self.inputs_fw, args)} - rt.run(tensor_mappings=ph_map) # TODO: how to get the output tensor(s)? - return self.outputs_fw[0] + return input_data def backward_wrapper(instance: torch.nn.Module, *args, **kwargs): - if self.plan_bw is None: + if self.plan_bw == []: return instance.forward_torch(*args, **kwargs) + rt = Runtime.get_runtime() if not self.launched_bw: rt.launch( @@ -270,47 +288,116 @@ def call(*args, **kwargs): target.backward_ark = backward_wrapper target.__call__ = call_wrapper return target + + def partition_graph(self, gm: torch.fx.GraphModule, is_fw: bool): + if is_fw: + exe_segment = self.execution_segments_fw + else: + exe_segment = self.execution_segments_bw + current_segment = [] + backend = "ARK" + for node in gm.graph.nodes: + if node.op == "call_function": + target_name = node.target.name() + if target_name in _REGISTRY_FUNCTION_HANDLER: + if backend == "PyTorch": + # End the PyTorch segment and start a new ARK segment + if current_segment: + exe_segment.append((backend, current_segment)) + current_segment = [] + backend = "ARK" + else: + if backend == "ARK": + # End the ARK segment and start a new PyTorch segment + if current_segment: + exe_segment.append((backend, current_segment)) + current_segment = [] + backend = "PyTorch" + + current_segment.append(node) + + if current_segment: + exe_segment.append((backend, current_segment)) def autograd_trace_( - self, gm: torch.nn.Module, _: List[torch.Tensor] + self, gm: torch.nn.Module, forward_inputs: List[torch.Tensor] ) -> Callable: def fw_compiler(gm: torch.fx.GraphModule, _): - logging.info("==== FW Starts ====") + print("==== FW Starts ====") return self.autograd_trace_impl_(gm, _, True) def bw_compiler(gm: torch.fx.GraphModule, _): - logging.info("==== BW Starts ====") + print("==== BW Starts ====") return self.autograd_trace_impl_(gm, _, False) return torch._dynamo.backends.common.aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler - )(gm, _) + )(gm, forward_inputs) def autograd_trace_impl_( self, gm: torch.fx.GraphModule, _: List[torch.Tensor], is_fw: bool ) -> Callable: - + self.partition_graph(gm, is_fw) + if is_fw: + self.curr_plan_idx = -1 + self.plan_fw = [] + exe_seg = self.execution_segments_fw + else: + self.curr_plan_idx = -1 + self.plan_bw = [] + exe_seg = self.execution_segments_bw + + print("gm: ", gm) def run(args) -> Any: - Model.reset() if not self.failed: - for node in gm.graph.nodes: - logging.info(node.format_node(), node.meta) - if not self.handle_node_(node, is_fw): - print(f"Failed to handle node {node.format_node()}") - self.failed = True - break - if not self.failed: - Model.set_device_id(self.device.index) - if is_fw: - self.plan_fw = Planner(self.device.index).plan() - else: - self.plan_bw = Planner(self.device.index).plan() + for backend, op_seq in exe_seg: + if backend == "ARK": + Model.reset() + self.curr_plan_idx += 1 + curr_outputs, curr_inputs = [], [] + for node in op_seq: + self.intermediate_results[node.name] = node + if not self.handle_node_(node, is_fw, curr_outputs, curr_inputs): + self.failed = True + break + if not self.failed: + Model.set_device_id(self.device.index) + if is_fw: + self.inputs_fw.append(curr_inputs) + self.outputs_fw.append(curr_outputs) + self.plan_fw.append(Planner(self.device.index).plan()) + else: + self.inputs_bw.append(curr_inputs) + self.outputs_bw.append(curr_outputs) + self.plan_bw.append(Planner(self.device.index).plan()) + else: # PyTorch + # we have our op list stored in ops + self.curr_plan_idx += 1 + torch_module = self.construct_torch_module(gm, op_seq) + if is_fw: + + self.plan_fw.append(torch_module) + else: + + self.plan_bw.append(torch_module) + # handle intermediate outputs computed by Torch + if self.curr_plan_idx != len(exe_seg) -1: + for node in op_seq: + meta = node.meta["tensor_meta"] + data = 0 + t = ops.placeholder( + shape=meta.shape, + dtype=ops.DataType.from_torch(meta.dtype), + name=node.name, + data=data, + ) + self.tensors[node.name] = t + return torch.fx.Interpreter(gm).boxed_run(args) - run._boxed_call = True return run - def handle_node_(self, node: torch.fx.node.Node, is_fw: bool) -> bool: + def handle_node_(self, node: torch.fx.node.Node, is_fw: bool, curr_outputs, curr_inputs) -> bool: if node.op == "placeholder": t = self.tensors.get(node.name, None) if t is not None: @@ -347,34 +434,67 @@ def handle_node_(self, node: torch.fx.node.Node, is_fw: bool) -> bool: ) self.tensors[node.name] = t if data == 0: - if is_fw: - self.inputs_fw.append(t) - else: - self.inputs_bw.append(t) + curr_inputs.append(t) elif node.op == "output": - outputs_list = self.outputs_fw if is_fw else self.outputs_bw - if outputs_list: - raise ValueError("Multiple output nodes are unexpected") for out in node.args[0]: if isinstance(out, torch.fx.node.Node): if out.name not in self.tensors: raise ValueError(f"Output tensor {out.name} not found") - outputs_list.append(self.tensors[out.name]) + tns = self.tensors[out.name] + curr_outputs.append(self.tensors[out.name]) else: - outputs_list.append(out) + curr_outputs.append(out) elif node.op == "call_function": target_name = node.target.name() if target_name not in _REGISTRY_FUNCTION_HANDLER: + # should never happen now due to partitioning before logging.warning( f"Unsupported function {target_name}. Usage: {node.format_node()}" ) return False t = _REGISTRY_FUNCTION_HANDLER[target_name](node, self.tensors) self.tensors[node.name] = t + # Append to outputs_fw if this is an intermediate output for ARK + # if self.curr_plan_idx != len(self.execution_segments) - 1: + curr_outputs.append(t) else: raise ValueError(f"Unexpected node {node.format_node()}") return True - + + def construct_torch_module(self, original_gm, op_seq): + graph = fx.Graph() + env = self.intermediate_results + def create_node(node): + if node.op == 'placeholder': + if node.name in self.intermediate_results: + return self.intermediate_results[node.name] + return graph.placeholder(node.target, type_expr=node.type) + elif node.op == 'get_attr': + return graph.get_attr(node.target) + elif node.op == 'call_function': + args = tuple(env[arg.name] if isinstance(arg, fx.Node) else arg for arg in node.args) + kwargs = {k: env[v.name] if isinstance(v, fx.Node) else v for k, v in node.kwargs.items()} + return graph.call_function(node.target, args, kwargs) + elif node.op == 'output': + args = tuple(env[arg.name] if isinstance(arg, fx.Node) else arg for arg in node.args[0]) + return graph.output(tuple(args)) + else: + raise ValueError(f"Unsupported node operation: {node.op}") + for node in op_seq: + if node.op == 'call_function' and node.args and isinstance(node.args[0], fx.Node): + input_name = node.args[0].name + if input_name in self.intermediate_results: + placeholder_node = graph.placeholder(input_name) + env[input_name] = placeholder_node + for node in op_seq: + env[node.name] = create_node(node) + + # TODO: Support multiple output for intermediae activations + if op_seq[-1].op != 'output': + last_node = env[op_seq[-1].name] + graph.output(last_node) + torch_module = fx.GraphModule(original_gm, graph) + return torch_module def tracer(target: Callable): - return Tracer()(target) + return Tracer()(target) \ No newline at end of file