Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions python/ark/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import logging
import numpy as np
from typing import Any, Dict, Union
from .tensor import Parameter
from .tensor import Tensor, Parameter
from .torch import torch, _no_torch
from .runtime import Runtime
from .model import Model
from .data_type import DataType
from .ops import placeholder


class Module:
Expand All @@ -36,10 +35,7 @@ def __setattr__(self, __name: str, __value: Any) -> None:
elif isinstance(__value, Parameter):
self.register_parameter(__name, __value)
elif not _no_torch and isinstance(__value, torch.nn.Parameter):
shape, dtype = list(__value.shape), DataType.from_torch(
__value.dtype
)
__value = Parameter(placeholder(shape, dtype, data=__value), True)
__value = Parameter(__value)
self.register_parameter(__name, __value)
super().__setattr__(__name, __value)

Expand Down Expand Up @@ -147,16 +143,14 @@ def forward(ctx, ark_module, *args, **kwargs):
input_requires_grad = 0
for arg in args:
if isinstance(arg, torch.Tensor):
shape, dtype = list(arg.shape), DataType.from_torch(arg.dtype)
input_args.append(placeholder(shape, dtype, data=arg))
input_args.append(Tensor.from_torch(arg))
if arg.requires_grad:
input_requires_grad += 1
else:
input_args.append(arg)
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
shape, dtype = list(arg.shape), DataType.from_torch(arg.dtype)
input_kwargs[k] = placeholder(shape, dtype, data=v)
input_kwargs[k] = Tensor.from_torch(v)
if v.requires_grad:
input_requires_grad += 1
else:
Expand All @@ -178,12 +172,7 @@ def backward(ctx, *grad_outputs):
PyTorch parameters.
"""
Model.reset()
# i think we should support placeholder initialization
# with just pytorch tensor
ark_grad_outputs = []
for grad in grad_outputs:
shape, dtype = list(grad.shape), DataType.from_torch(grad.dtype)
ark_grad_outputs.append(placeholder(shape, dtype, data=grad))
ark_grad_outputs = [Tensor.from_torch(grad) for grad in grad_outputs]
grads = ctx.ark_module.backward(*ark_grad_outputs)
grad_inputs, grad_weights = (
grads[: ctx.num_inp_grad],
Expand Down
46 changes: 30 additions & 16 deletions python/ark/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

import numpy as np
from typing import Callable, Iterable, List, Union, Type
from typing import Callable, Iterable, List, Union, Type, Dict

from ._ark_core import _Dims, _Tensor, _NullTensor
from .torch import torch, _no_torch
Expand All @@ -22,6 +22,9 @@ class Dims(_Dims):


class Tensor:

_tensor_grads: Dict[int, "Tensor"] = {}

def __init__(
self,
_tensor: _Tensor,
Expand All @@ -38,6 +41,8 @@ def __init__(
self._tensor = _tensor
self.initializer: Initializer = initializer
self.requires_grad = requires_grad
if self.requires_grad:
Tensor._tensor_grads[self._tensor.id()] = self

def __hash__(self):
return self._tensor.id()
Expand Down Expand Up @@ -186,6 +191,8 @@ def to_torch(self) -> torch.Tensor:
torch_view = torch.utils.dlpack.from_dlpack(dl_capsule)
# Keep dl_capsule alive not to free the memory
torch_view.__ark_buffer__ = dl_capsule
if self.requires_grad:
torch_view.requires_grad_(True)
return torch_view

@staticmethod
Expand All @@ -205,7 +212,8 @@ def from_torch(tensor: torch.Tensor) -> "Tensor":
shape=list(tensor.shape),
dtype=DataType.from_torch(tensor.dtype),
data=tensor.data_ptr(),
)
),
requires_grad=tensor.requires_grad
)
# Share ownership of the memory with the torch tensor
ark_tensor.__torch_buffer__ = tensor
Expand Down Expand Up @@ -259,37 +267,43 @@ def initialize(self) -> "Tensor":
self.copy(data)
return self

def requires_grad_(self, requires_grad: bool = True) -> "Tensor":
"""
Sets the `requires_grad` attribute in-place for the tensor.
If `requires_grad` is True, the tensor will be tracked for gradient
updates.
"""
self.requires_grad = requires_grad
if requires_grad:
Tensor._tensor_grads[self._tensor.id()] = self
elif self._tensor.id() in Tensor._tensor_grads:
del Tensor._tensor_grads[self._tensor.id()]
return self


class Parameter(Tensor):
class Parameter(Tensor, torch.nn.Parameter):
"""
A tensor as a parameter.
"""

def __init__(
self,
tensor: _Tensor,
from_torch: bool,
tensor: Union[_Tensor, "torch.nn.Parameter"],
):
"""
Initializes a new instance of the Parameter class.
Args:
_tensor (_ark_core._Tensor): The underlying _Tensor object.
from_torch: Indicates if the Parameter is tied to a torch.nn.Paramter
"""
if not _no_torch and from_torch:
_tensor = tensor._tensor
if not _no_torch and isinstance(tensor, torch.nn.Parameter):
ark_tensor = Tensor.from_torch(tensor)
self._tensor = ark_tensor._tensor
ark_tensor.requires_grad_(True)
self.torch_param = tensor
self.staged_tensor = None
Tensor.__init__(
self,
_tensor,
requires_grad=tensor.requires_grad,
)
elif isinstance(tensor, _Tensor):
_tensor = tensor
self.torch_param = None
self.staged_tensor = None
Tensor.__init__(self, _tensor, requires_grad=False)
Tensor.__init__(self, _tensor, requires_grad=True)
else:
raise TypeError(
"tensor must be an ARK tensor or a torch.nn.Parameter"
Expand Down
Loading