From 89ad9186e831a6ae765583242ca065da6ce3330e Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 27 May 2024 04:35:00 +0000 Subject: [PATCH 01/54] torch wip --- python/ark/data_type.py | 33 +++++++++++++++++++++++---------- python/ark/torch_mock.py | 11 +++++++++++ 2 files changed, 34 insertions(+), 10 deletions(-) create mode 100644 python/ark/torch_mock.py diff --git a/python/ark/data_type.py b/python/ark/data_type.py index fe95d0d88..de64c1d7d 100644 --- a/python/ark/data_type.py +++ b/python/ark/data_type.py @@ -3,26 +3,29 @@ import numpy from . import _ark_core - +try: + import torch +except ImportError: + from . import torch_mock as torch _REGISTRY_DATA_TYPE = { - "fp32": {"np": numpy.float32}, - "fp16": {"np": numpy.float16}, - "bf16": {"np": None}, - "int32": {"np": numpy.int32}, - "uint32": {"np": numpy.uint32}, - "int8": {"np": numpy.int8}, - "uint8": {"np": numpy.uint8}, - "byte": {"np": numpy.ubyte}, + "fp32": {"np": numpy.float32, "torch": torch.float32}, + "fp16": {"np": numpy.float16, "torch": torch.float16}, + "bf16": {"np": None, "torch": torch.bfloat16}, + "int32": {"np": numpy.int32, "torch": torch.int32}, + "uint32": {"np": numpy.uint32, "torch": None}, + "int8": {"np": numpy.int8, "torch": torch.int8}, + "uint8": {"np": numpy.uint8, "torch": torch.uint8}, + "byte": {"np": numpy.ubyte, "torch": torch.uint8}, } - class MetaDataType(type): def __new__(cls, name, bases, attrs): new_class = super().__new__(cls, name, bases, attrs) if name in _REGISTRY_DATA_TYPE: reg = _REGISTRY_DATA_TYPE[name] new_class.to_numpy = staticmethod(lambda: reg["np"]) + new_class.to_torch = staticmethod(lambda: reg["torch"]) new_class.ctype = staticmethod( lambda: getattr(_ark_core, name.upper()) ) @@ -104,6 +107,16 @@ def to_numpy() -> numpy.dtype: """ ... + @staticmethod + def to_torch() -> torch.dtype: + """ + Return the corresponding torch data type. + + Returns: + torch.dtype: The corresponding torch data type. + """ + ... + @staticmethod def ctype() -> _ark_core._DataType: """ diff --git a/python/ark/torch_mock.py b/python/ark/torch_mock.py new file mode 100644 index 000000000..e58a3eda8 --- /dev/null +++ b/python/ark/torch_mock.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +class dtype: ... +class float32: ... +class float16: ... +class bfloat16: ... +class int32: ... +class int8: ... +class uint8: ... +class ubyte: ... From ab1998ecef18116bd92f4ea91b14c69becc66655 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Sun, 26 May 2024 21:43:10 -0700 Subject: [PATCH 02/54] Update ut-cuda.yml --- .github/workflows/ut-cuda.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ut-cuda.yml b/.github/workflows/ut-cuda.yml index e938ca877..5a78818ff 100644 --- a/.github/workflows/ut-cuda.yml +++ b/.github/workflows/ut-cuda.yml @@ -7,6 +7,8 @@ on: pull_request: branches: - main + types: + - ready_for_review jobs: UnitTest: From ece4f553f62dc2da591321be3f7d5e34bff2c80d Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 27 May 2024 07:24:41 +0000 Subject: [PATCH 03/54] torch wip --- python/ark/data_type.py | 2 ++ python/ark/module.py | 33 ++++++++++++++++++++++++++++----- python/ark/tensor.py | 35 +++++++++++++++++++++++++++++++++++ python/ark/torch_mock.py | 18 ++++++++++++++++++ 4 files changed, 83 insertions(+), 5 deletions(-) diff --git a/python/ark/data_type.py b/python/ark/data_type.py index de64c1d7d..f5ccd9e5b 100644 --- a/python/ark/data_type.py +++ b/python/ark/data_type.py @@ -3,6 +3,7 @@ import numpy from . import _ark_core + try: import torch except ImportError: @@ -19,6 +20,7 @@ "byte": {"np": numpy.ubyte, "torch": torch.uint8}, } + class MetaDataType(type): def __new__(cls, name, bases, attrs): new_class = super().__new__(cls, name, bases, attrs) diff --git a/python/ark/module.py b/python/ark/module.py index 62b941281..459beeda6 100644 --- a/python/ark/module.py +++ b/python/ark/module.py @@ -3,9 +3,14 @@ import logging import numpy as np -from typing import Any, Dict +from typing import Any, Dict, Union from .tensor import Parameter +try: + import torch +except ImportError: + from . import torch_mock as torch + class Module: """ @@ -57,7 +62,9 @@ def params_dict(self, prefix="") -> Dict[str, Parameter]: return params_dict def load_state_dict( - self, state_dict: Dict[str, np.ndarray], prefix: str = "" + self, + state_dict: Dict[str, Union[np.ndarray, torch.Tensor]], + prefix: str = "", ): """ Loads a model from a state_dict and copy the parameters to the device GPU. @@ -68,20 +75,36 @@ def load_state_dict( all_keys = set(state_dict.keys()) pd = self.params_dict(prefix) for name, param in pd.items(): - param.from_numpy(state_dict[name]) + data = state_dict.get(name, None) + if isinstance(data, np.ndarray): + param.from_numpy(data) + elif isinstance(data, torch.Tensor): + param.from_torch(data) + else: + continue all_keys.remove(name) if all_keys: logging.warning( f"{len(all_keys)} unused parameter(s) in state_dict" ) - def state_dict(self, prefix: str = "") -> Dict[str, np.ndarray]: + def state_dict( + self, prefix: str = "", mode: str = "numpy" + ) -> Dict[str, Union[np.ndarray, torch.Tensor]]: """ Copies the parameters from the device GPU to the host and saves the model to a state_dict. Must be called after the executor is launched. """ - return {k: v.to_numpy() for k, v in self.params_dict(prefix).items()} + if mode == "numpy": + return { + k: v.to_numpy() for k, v in self.params_dict(prefix).items() + } + elif mode == "torch": + return { + k: v.to_torch() for k, v in self.params_dict(prefix).items() + } + raise ValueError(f"Unsupported mode: {mode}") def forward(self, *args: Any, **kwargs: Any) -> Any: ... diff --git a/python/ark/tensor.py b/python/ark/tensor.py index 316d18566..625f82bce 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -8,6 +8,15 @@ from .data_type import DataType from .runtime import Runtime +try: + import torch + + _no_torch = False +except ImportError: + from . import torch_mock as torch + + _no_torch = True + NullTensor = _NullTensor @@ -89,6 +98,32 @@ def from_numpy(self, ndarray: np.ndarray) -> "Tensor": rt.executor.tensor_write(self._tensor, ndarray) return self + def to_torch(self, tensor: torch.Tensor = None) -> torch.Tensor: + """ """ + if _no_torch: + raise ImportError("torch is not available") + torch_type = self.dtype().to_torch() + if tensor is None: + return torch.from_numpy(self.to_numpy()) + elif tensor.shape != self.shape(): + raise ValueError("torch tensor shape does not match the tensor") + elif tensor.dtype != torch_type: + raise ValueError("torch tensor dtype does not match the tensor") + elif not tensor.is_contiguous(): + raise ValueError("torch tensor is not contiguous in memory") + elif tensor.numel() != self.nelems(): + raise ValueError("torch tensor size does not match the tensor") + tensor.copy_(torch.from_numpy(self.to_numpy())) + return tensor + + def from_torch(self, tensor: torch.Tensor) -> "Tensor": + """ """ + if _no_torch: + raise ImportError("torch is not available") + if tensor.is_cuda: + tensor = tensor.cpu() + return self.from_numpy(tensor.numpy()) + class Parameter(Tensor): """ diff --git a/python/ark/torch_mock.py b/python/ark/torch_mock.py index e58a3eda8..68333e431 100644 --- a/python/ark/torch_mock.py +++ b/python/ark/torch_mock.py @@ -1,11 +1,29 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. + class dtype: ... + + class float32: ... + + class float16: ... + + class bfloat16: ... + + class int32: ... + + class int8: ... + + class uint8: ... + + class ubyte: ... + + +class Tensor: ... From 952b7610c31288cc8851aa6466461f2ba7a2393f Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 27 May 2024 23:14:40 +0000 Subject: [PATCH 04/54] runtime module --- ark/api/planner.cpp | 4 +- examples/tutorial/torch_tutorial.py | 23 ++++++++ python/ark/__init__.py | 2 +- python/ark/data_type.py | 22 +++++++ python/ark/module.py | 71 +++++++++++++++++++++- python/ark/tensor.py | 91 ++++++++++++++++++++--------- 6 files changed, 181 insertions(+), 32 deletions(-) create mode 100644 examples/tutorial/torch_tutorial.py diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index ad5048c0e..5c9d09f2e 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -56,8 +56,8 @@ static void check_config_field(const ModelOpRef op, const Json &config, std::string DefaultPlanner::Impl::plan(bool pretty) const { const auto gpu_info = GpuManager::get_instance(gpu_id_)->info(); size_t num_sm = gpu_info.num_sm; - Json task_infos; - Json processor_groups; + Json task_infos = Json::array(); + Json processor_groups = Json::array(); size_t max_num_warps = 1; size_t max_num_processors = 1; size_t next_node_id = 0; diff --git a/examples/tutorial/torch_tutorial.py b/examples/tutorial/torch_tutorial.py new file mode 100644 index 000000000..5677d41cd --- /dev/null +++ b/examples/tutorial/torch_tutorial.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import ark +import torch + + +class ArkAddModule(ark.RuntimeModule): + def build_forward(self, x: ark.Tensor, y: ark.Tensor) -> ark.Tensor: + return ark.add(x, y) + +# ARK module for addition +module = ArkAddModule() + +# Define two torch arrays +x = torch.ones(64) * 2 +y = torch.ones(64) * 3 + +# Run the ARK module +z = module(x, y) + +# Print the result +print(z) diff --git a/python/ark/__init__.py b/python/ark/__init__.py index 92e9c39c3..2a4d164e4 100644 --- a/python/ark/__init__.py +++ b/python/ark/__init__.py @@ -38,7 +38,7 @@ def set_world_size(world_size): from .init import init from .tensor import Dims, Tensor, Parameter -from .module import Module +from .module import Module, RuntimeModule from .runtime import Runtime, DefaultPlanner from .serialize import save, load from .data_type import ( diff --git a/python/ark/data_type.py b/python/ark/data_type.py index f5ccd9e5b..8ab982106 100644 --- a/python/ark/data_type.py +++ b/python/ark/data_type.py @@ -64,6 +64,28 @@ def from_numpy(np_type: numpy.dtype) -> "DataType": f" to ark data type." ) + @staticmethod + def from_torch(torch_type: torch.dtype) -> "DataType": + """ + Return the corresponding ark data type. + + Parameters: + torch_type (torch.dtype): The torch data type. + + Returns: + DataType: The corresponding ark data type. + + Raises: + ValueError: If there is no defined conversion from torch data type to ark data type. + """ + for type_name, reg in _REGISTRY_DATA_TYPE.items(): + if reg["torch"] == torch_type: + return DataType.from_name(type_name) + raise ValueError( + f"Undefined conversion from torch data type {torch_type}" + f" to ark data type." + ) + @staticmethod def from_name(type_name: str) -> "DataType": """ diff --git a/python/ark/module.py b/python/ark/module.py index 459beeda6..b7919d2cd 100644 --- a/python/ark/module.py +++ b/python/ark/module.py @@ -3,14 +3,19 @@ import logging import numpy as np -from typing import Any, Dict, Union -from .tensor import Parameter +from typing import Any, Dict, List, Union +from .tensor import Tensor, Parameter +from .runtime import Runtime, DefaultPlanner try: import torch + + _no_torch = False except ImportError: from . import torch_mock as torch + _no_torch = True + class Module: """ @@ -109,3 +114,65 @@ def state_dict( def forward(self, *args: Any, **kwargs: Any) -> Any: ... def backward(self, *args: Any, **kwargs: Any) -> Any: ... + + def initialize(self): + for param in self.parameters.values(): + param.initialize() + for module in self.sub_modules.values(): + module.initialize() + + +def _recursive_ark_to_torch(object): + if isinstance(object, Tensor): + return object.to_torch() + if isinstance(object, dict): + return {k: _recursive_ark_to_torch(v) for k, v in object.items()} + if isinstance(object, list): + return [_recursive_ark_to_torch(v) for v in object] + return object + + +class RuntimeModule(Module): + def __init__(self): + if _no_torch: + raise ImportError("torch is not available") + super().__init__() + self.built_forward = False + self.built_backward = False + self.forward_input_tensor_args: List[Tensor] = [] + self.forward_input_tensor_kwargs: Dict[str, Tensor] = {} + self.forward_output = None + self.backward_tensor_args = [] + self.backward_tensor_kwargs = {} + + def build_forward(self, *args: Any, **kwargs: Any) -> Any: ... + + def build_backward(self, *args: Any, **kwargs: Any) -> Any: ... + + def forward(self, *args: Any, **kwargs: Any) -> Any: + if not self.built_forward: + for arg in args: + if isinstance(arg, torch.Tensor): + self.forward_input_tensor_args.append( + Tensor.from_torch(arg) + ) + for key, value in kwargs.items(): + if isinstance(value, torch.Tensor): + self.forward_input_tensor_kwargs[key] = Tensor.from_torch( + value + ) + self.forward_output = self.build_forward( + *self.forward_input_tensor_args, + **self.forward_input_tensor_kwargs, + ) + self.built_forward = True + + with Runtime.get_runtime() as rt: + rt.launch(plan=DefaultPlanner().plan()) + for arg in self.forward_input_tensor_args: + arg.initialize() + for value in self.forward_input_tensor_kwargs.values(): + value.initialize() + + rt.run() + return _recursive_ark_to_torch(self.forward_output) diff --git a/python/ark/tensor.py b/python/ark/tensor.py index 625f82bce..f264bb440 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -2,11 +2,12 @@ # Licensed under the MIT license. import numpy as np -from typing import List +from typing import Callable, List, Union, Type from _ark_core import _Dims, _Tensor, _NullTensor from .data_type import DataType from .runtime import Runtime +from .model import Model try: import torch @@ -24,14 +25,19 @@ class Dims(_Dims): pass +Initializer = Type[Callable[[], Union[torch.Tensor, np.ndarray]]] + + class Tensor: - def __init__(self, _tensor: _Tensor): + def __init__(self, _tensor: _Tensor, initializer: Initializer = None): """ Initializes a new instance of the Tensor class. Args: _tensor (_ark_core._Tensor): The underlying _Tensor object. """ self._tensor = _tensor + self.initializer: Initializer = initializer + Model.get_model().add_tensor(self) def shape(self) -> List[int]: """ @@ -80,24 +86,6 @@ def to_numpy(self, ndarray: np.ndarray = None) -> np.ndarray: rt.executor.tensor_read(self._tensor, ndarray) return ndarray - def from_numpy(self, ndarray: np.ndarray) -> "Tensor": - """ - Copies the tensor from a host numpy array to the device. - """ - rt = Runtime.get_runtime() - if not rt.launched(): - raise RuntimeError( - "Tensor is not allocated yet. `Tensor.from_numpy()` is " - "usable only after you call `Runtime.launch()`." - ) - ndarray = ndarray.astype(self.dtype().to_numpy()) - if not ndarray.flags["C_CONTIGUOUS"]: - ndarray = np.ascontiguousarray(ndarray) - if ndarray.nbytes != self.nelems() * self.dtype().element_size(): - raise ValueError("ndarray size does not match the tensor") - rt.executor.tensor_write(self._tensor, ndarray) - return self - def to_torch(self, tensor: torch.Tensor = None) -> torch.Tensor: """ """ if _no_torch: @@ -116,13 +104,62 @@ def to_torch(self, tensor: torch.Tensor = None) -> torch.Tensor: tensor.copy_(torch.from_numpy(self.to_numpy())) return tensor - def from_torch(self, tensor: torch.Tensor) -> "Tensor": - """ """ - if _no_torch: - raise ImportError("torch is not available") - if tensor.is_cuda: - tensor = tensor.cpu() - return self.from_numpy(tensor.numpy()) + @staticmethod + def from_numpy(ndarray: np.ndarray): + return Tensor( + Model.get_model().tensor( + Dims(list(ndarray.shape)), + DataType.from_numpy(ndarray.dtype).ctype(), + Dims(), + Dims(), + Dims(), + "", + ), + lambda: ndarray, + ) + + @staticmethod + def from_torch(tensor: torch.Tensor): + return Tensor( + Model.get_model().tensor( + Dims(list(tensor.shape)), + DataType.from_torch(tensor.dtype).ctype(), + Dims(), + Dims(), + Dims(), + "", + ), + lambda: tensor, + ) + + def copy(self, data: Union[np.ndarray, torch.Tensor]) -> "Tensor": + """ + Copies the tensor from a host numpy array to the device. + """ + rt = Runtime.get_runtime() + if not rt.launched(): + raise RuntimeError( + "Tensor is not allocated yet. `Tensor.from_numpy()` is " + "usable only after you call `Runtime.launch()`." + ) + if isinstance(data, torch.Tensor): + data = data.cpu().numpy() + data = data.astype(self.dtype().to_numpy()) + if not data.flags["C_CONTIGUOUS"]: + data = np.ascontiguousarray(data) + if data.nbytes != self.nelems() * self.dtype().element_size(): + raise ValueError("data size does not match the tensor") + rt.executor.tensor_write(self._tensor, data) + return self + + def initialize(self) -> "Tensor": + """ + Initializes the tensor. + """ + if self.initializer is not None: + data = self.initializer() + self.copy(data) + return self class Parameter(Tensor): From a40926812f7b02f02e1e48a981c65e21c4dadfaa Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 27 May 2024 23:20:44 +0000 Subject: [PATCH 05/54] fix --- python/ark/tensor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/ark/tensor.py b/python/ark/tensor.py index f264bb440..5168791a8 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -37,7 +37,6 @@ def __init__(self, _tensor: _Tensor, initializer: Initializer = None): """ self._tensor = _tensor self.initializer: Initializer = initializer - Model.get_model().add_tensor(self) def shape(self) -> List[int]: """ From 8e4622707b34cd4a71579bd65d7ba484e2424969 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 27 May 2024 23:52:16 +0000 Subject: [PATCH 06/54] fix --- ark/include/kernels/kernel_template.in | 5 ++++- examples/tutorial/torch_tutorial.py | 6 +++++- python/ark/module.py | 20 +++++++++++++------- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/ark/include/kernels/kernel_template.in b/ark/include/kernels/kernel_template.in index bc842ea4a..5bba320a5 100644 --- a/ark/include/kernels/kernel_template.in +++ b/ark/include/kernels/kernel_template.in @@ -59,9 +59,12 @@ void @NAME@(int *_iter) { sync_gpu<@NUM_BLOCKS@>(ARK_LOOP_SYNC_STATE); ark_loop_body(_buf, _i); } + if (threadIdx.x == 0) { + __threadfence_system(); + } + sync_gpu<@NUM_BLOCKS@>(ARK_LOOP_SYNC_STATE); if (threadIdx.x == 0 && blockIdx.x == 0) { atomicStoreRelaxed(_iter, 0); } - sync_gpu<@NUM_BLOCKS@>(ARK_LOOP_SYNC_STATE); } } diff --git a/examples/tutorial/torch_tutorial.py b/examples/tutorial/torch_tutorial.py index 5677d41cd..e9482a7cc 100644 --- a/examples/tutorial/torch_tutorial.py +++ b/examples/tutorial/torch_tutorial.py @@ -9,6 +9,7 @@ class ArkAddModule(ark.RuntimeModule): def build_forward(self, x: ark.Tensor, y: ark.Tensor) -> ark.Tensor: return ark.add(x, y) + # ARK module for addition module = ArkAddModule() @@ -19,5 +20,8 @@ def build_forward(self, x: ark.Tensor, y: ark.Tensor) -> ark.Tensor: # Run the ARK module z = module(x, y) +w = module(x, z) + # Print the result -print(z) +print(z) # 5 +print(w) # 7 diff --git a/python/ark/module.py b/python/ark/module.py index b7919d2cd..a266f522d 100644 --- a/python/ark/module.py +++ b/python/ark/module.py @@ -6,6 +6,8 @@ from typing import Any, Dict, List, Union from .tensor import Tensor, Parameter from .runtime import Runtime, DefaultPlanner +from .ops import tensor +from .data_type import DataType try: import torch @@ -154,12 +156,16 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: for arg in args: if isinstance(arg, torch.Tensor): self.forward_input_tensor_args.append( - Tensor.from_torch(arg) + tensor( + list(arg.shape), + DataType.from_torch(arg.dtype), + ) ) for key, value in kwargs.items(): if isinstance(value, torch.Tensor): - self.forward_input_tensor_kwargs[key] = Tensor.from_torch( - value + self.forward_input_tensor_kwargs[key] = tensor( + list(value.shape), + DataType.from_torch(value.dtype), ) self.forward_output = self.build_forward( *self.forward_input_tensor_args, @@ -169,10 +175,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: with Runtime.get_runtime() as rt: rt.launch(plan=DefaultPlanner().plan()) - for arg in self.forward_input_tensor_args: - arg.initialize() - for value in self.forward_input_tensor_kwargs.values(): - value.initialize() + for tns, arg in zip(self.forward_input_tensor_args, args): + tns.copy(arg) + for key, value in self.forward_input_tensor_kwargs.items(): + value.copy(kwargs[key]) rt.run() return _recursive_ark_to_torch(self.forward_output) From eee7ec2b4bb1cde335e99d780657c70e497542c9 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 28 May 2024 19:00:09 +0000 Subject: [PATCH 07/54] some fixes --- python/ark/module.py | 23 ++++++++++++++++------- python/ark/tensor.py | 28 +++++++++++++++++++++------- python/executor_py.cpp | 15 ++++++++++++++- 3 files changed, 51 insertions(+), 15 deletions(-) diff --git a/python/ark/module.py b/python/ark/module.py index a266f522d..faeeea40d 100644 --- a/python/ark/module.py +++ b/python/ark/module.py @@ -83,12 +83,9 @@ def load_state_dict( pd = self.params_dict(prefix) for name, param in pd.items(): data = state_dict.get(name, None) - if isinstance(data, np.ndarray): - param.from_numpy(data) - elif isinstance(data, torch.Tensor): - param.from_torch(data) - else: + if data is None: continue + param.copy(data) all_keys.remove(name) if all_keys: logging.warning( @@ -143,6 +140,8 @@ def __init__(self): self.built_backward = False self.forward_input_tensor_args: List[Tensor] = [] self.forward_input_tensor_kwargs: Dict[str, Tensor] = {} + self.forward_input_args = [] + self.forward_input_kwargs = {} self.forward_output = None self.backward_tensor_args = [] self.backward_tensor_kwargs = {} @@ -161,15 +160,25 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: DataType.from_torch(arg.dtype), ) ) + self.forward_input_args.append( + self.forward_input_tensor_args[-1] + ) + else: + self.forward_input_args.append(arg) for key, value in kwargs.items(): if isinstance(value, torch.Tensor): self.forward_input_tensor_kwargs[key] = tensor( list(value.shape), DataType.from_torch(value.dtype), ) + self.forward_input_kwargs[key] = ( + self.forward_input_tensor_kwargs[key] + ) + else: + self.forward_input_kwargs[key] = value self.forward_output = self.build_forward( - *self.forward_input_tensor_args, - **self.forward_input_tensor_kwargs, + *self.forward_input_args, + **self.forward_input_kwargs, ) self.built_forward = True diff --git a/python/ark/tensor.py b/python/ark/tensor.py index 5168791a8..a567264d5 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -142,13 +142,27 @@ def copy(self, data: Union[np.ndarray, torch.Tensor]) -> "Tensor": "usable only after you call `Runtime.launch()`." ) if isinstance(data, torch.Tensor): - data = data.cpu().numpy() - data = data.astype(self.dtype().to_numpy()) - if not data.flags["C_CONTIGUOUS"]: - data = np.ascontiguousarray(data) - if data.nbytes != self.nelems() * self.dtype().element_size(): - raise ValueError("data size does not match the tensor") - rt.executor.tensor_write(self._tensor, data) + if data.dtype != self.dtype().to_torch(): + raise ValueError("data dtype does not match the tensor") + if not data.is_contiguous(): + data = data.contiguous() + if data.numel() != self.nelems(): + raise ValueError("data size does not match the tensor") + rt.executor.tensor_write( + self._tensor, + data.data_ptr(), + data.numel() * data.element_size(), + ) + elif isinstance(data, np.ndarray): + if data.dtype != self.dtype().to_numpy(): + raise ValueError("data dtype does not match the tensor") + if not data.flags["C_CONTIGUOUS"]: + data = np.ascontiguousarray(data) + if data.nbytes != self.nelems() * self.dtype().element_size(): + raise ValueError("data size does not match the tensor") + rt.executor.tensor_write(self._tensor, data) + else: + raise ValueError("data must be a numpy array or a torch tensor") return self def initialize(self) -> "Tensor": diff --git a/python/executor_py.cpp b/python/executor_py.cpp index dc2840329..13a81608e 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -17,6 +17,11 @@ static void tensor_write(ark::Executor *exe, const ark::Tensor &tensor, info.size * info.itemsize); } +static void tensor_write(ark::Executor *exe, const ark::Tensor &tensor, + size_t host_address, size_t bytes) { + exe->tensor_write(tensor, reinterpret_cast(host_address), bytes); +} + static void tensor_read(ark::Executor *exe, const ark::Tensor &tensor, py::buffer host_buffer) { py::buffer_info info = host_buffer.request(); @@ -39,5 +44,13 @@ void register_executor(py::module &m) { .def("destroy", &ark::Executor::destroy) .def("destroyed", &ark::Executor::destroyed) .def("tensor_read", &tensor_read, py::arg("tensor"), py::arg("data")) - .def("tensor_write", &tensor_write, py::arg("tensor"), py::arg("data")); + .def( + "tensor_write", + py::overload_cast( + &tensor_write), + py::arg("tensor"), py::arg("data")) + .def("tensor_write", + py::overload_cast(&tensor_write), + py::arg("tensor"), py::arg("address"), py::arg("bytes")); } From 87b9b0127de668f810847d04d4c2a08178439ee0 Mon Sep 17 00:00:00 2001 From: Noli Gerawork <86308445+naturalcandy@users.noreply.github.com> Date: Tue, 18 Jun 2024 11:20:45 -0400 Subject: [PATCH 08/54] Python API Multiple Runtime Support (#216) - Introduced support for multiple Runtime instances - Added utility functions for multi-runtime management - Ensured backward compatibility with existing usage patterns of Runtime - Added unit tests for multi-runtime functionality --------- Co-authored-by: noli --- ark/api/executor.cpp | 101 +++++++++++++++++++++ ark/include/ark/executor.hpp | 6 ++ python/ark/init.py | 5 +- python/ark/ops.py | 138 ++++++++++++++++++++++------ python/ark/runtime.py | 139 +++++++++++++++++++++++------ python/ark/tensor.py | 69 ++++++++++---- python/executor_py.cpp | 30 ++++++- python/unittest/test.py | 1 + python/unittest/test_conversion.py | 93 +++++++++++++++++++ python/unittest/test_runtime.py | 121 ++++++++++++++++++++++--- 10 files changed, 610 insertions(+), 93 deletions(-) create mode 100644 python/unittest/test_conversion.py diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 198d22e51..a0711bfe8 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -3,12 +3,15 @@ #include "ark/executor.hpp" +#include + #include #include #include #include #include +#include "ark/data_type.hpp" #include "ark/model.hpp" #include "ark/planner.hpp" #include "codegen.hpp" @@ -154,6 +157,8 @@ class Executor::Impl { void tensor_read(const Tensor tensor, void *data, size_t bytes) const; void tensor_write(const Tensor tensor, const void *data, size_t bytes) const; + DLDeviceType get_device_type() const; + DLManagedTensor *get_dl_tensor(const Tensor &tensor) const; private: void init_communicator(); @@ -783,6 +788,94 @@ void Executor::Impl::tensor_write(const Tensor tensor, const void *data, copy_stream_->sync(); } +DLDeviceType Executor::Impl::get_device_type() const { +#if defined(ARK_CUDA) + return kDLCUDA; +#elif defined(ARK_ROCM) + return kDLROCM; +#else + return kDLCPU; +#endif +} + +DLDataType get_dl_dtype(const DataType &ark_data_type) { + DLDataType dl_data_type; + dl_data_type.lanes = 1; + if (ark_data_type == FP32) { + dl_data_type.code = kDLFloat; + dl_data_type.bits = 32; + } else if (ark_data_type == FP16) { + dl_data_type.code = kDLFloat; + dl_data_type.bits = 16; + } else if (ark_data_type == BF16) { + dl_data_type.code = kDLBfloat; + dl_data_type.bits = 16; + } else if (ark_data_type == INT32) { + dl_data_type.code = kDLInt; + dl_data_type.bits = 32; + } else if (ark_data_type == UINT32) { + dl_data_type.code = kDLUInt; + dl_data_type.bits = 32; + } else if (ark_data_type == INT8) { + dl_data_type.code = kDLInt; + dl_data_type.bits = 8; + } else if (ark_data_type == UINT8) { + dl_data_type.code = kDLUInt; + dl_data_type.bits = 8; + } else if (ark_data_type == BYTE) { + dl_data_type.code = kDLUInt; + dl_data_type.bits = 8; + } else { + ERR(InvalidUsageError, "Unsupported data type"); + } + return dl_data_type; +} + +DLManagedTensor *Executor::Impl::get_dl_tensor(const Tensor &tensor) const { + DLTensor dl_tensor; + dl_tensor.data = + buffer_->ref(buffer_id_to_offset_.at(tensor.ref()->buffer()->id())); + size_t offset_in_elements = + tensor.offsets().is_no_dim() ? 0 : tensor.offsets().vector()[0]; + dl_tensor.byte_offset = offset_in_elements * tensor.data_type().bytes(); + dl_tensor.device.device_type = get_device_type(); + dl_tensor.device.device_id = static_cast(gpu_id_); + dl_tensor.ndim = static_cast(tensor.shape().ndims()); + dl_tensor.dtype = get_dl_dtype(tensor.data_type()); + + dl_tensor.shape = + tensor.shape().is_no_dim() ? nullptr : new int64_t[dl_tensor.ndim]; + dl_tensor.strides = + tensor.strides().is_no_dim() ? nullptr : new int64_t[dl_tensor.ndim]; + auto shape = tensor.shape(); + if (dl_tensor.shape) { + for (int i = 0; i < dl_tensor.ndim; ++i) { + dl_tensor.shape[i] = shape[i]; + } + } + if (dl_tensor.strides) { + dl_tensor.strides[dl_tensor.ndim - 1] = 1; + for (int i = dl_tensor.ndim - 2; i >= 0; --i) { + dl_tensor.strides[i] = + dl_tensor.shape[i + 1] * dl_tensor.strides[i + 1]; + } + } + DLManagedTensor *dl_managed_tensor = new DLManagedTensor(); + dl_managed_tensor->dl_tensor = dl_tensor; + dl_managed_tensor->manager_ctx = nullptr; + dl_managed_tensor->deleter = [](DLManagedTensor *self) { + if (self->dl_tensor.shape) { + delete[] self->dl_tensor.shape; + self->dl_tensor.shape = nullptr; + } + if (self->dl_tensor.strides) { + delete[] self->dl_tensor.strides; + self->dl_tensor.strides = nullptr; + } + }; + return dl_managed_tensor; +} + Executor::Executor(int rank, int world_size, int gpu_id, const std::string &name, const std::string &plan) : impl_(std::make_unique(rank, world_size, gpu_id, name, @@ -818,6 +911,14 @@ void Executor::tensor_write(const Tensor tensor, const void *data, impl_->tensor_write(tensor, data, bytes); } +DLDeviceType Executor::get_device_type() const { + return impl_->get_device_type(); +} + +DLManagedTensor *Executor::get_dl_tensor(const Tensor &tensor) const { + return impl_->get_dl_tensor(tensor); +} + DefaultExecutor::DefaultExecutor(const Model &model, int gpu_id, const std::string &name) : Executor( diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index 4682af7d0..54c49cd29 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -4,6 +4,8 @@ #ifndef ARK_EXECUTOR_HPP #define ARK_EXECUTOR_HPP +#include + #include #include #include @@ -62,6 +64,10 @@ class Executor { void tensor_write(const Tensor tensor, const void *data, size_t bytes) const; + DLManagedTensor *get_dl_tensor(const Tensor &tensor) const; + + DLDeviceType get_device_type() const; + private: class Impl; std::unique_ptr impl_; diff --git a/python/ark/init.py b/python/ark/init.py index be71e8e02..dbf7c1569 100644 --- a/python/ark/init.py +++ b/python/ark/init.py @@ -9,7 +9,6 @@ def init(): """Initializes ARK.""" Model.reset() - if _RuntimeState.executor is not None: - if not _RuntimeState.executor.destroyed(): - _RuntimeState.executor.destroy() + if _RuntimeState.runtime: + _RuntimeState.delete_all() _ark_core.init() diff --git a/python/ark/ops.py b/python/ark/ops.py index bc1c3ed13..86b021aef 100644 --- a/python/ark/ops.py +++ b/python/ark/ops.py @@ -59,6 +59,8 @@ def add( tensor_add = ark.add(tensor1, tensor2) """ if isinstance(input, Tensor) and isinstance(other, Tensor): + if input.runtime_id != other.runtime_id: + raise ValueError("Tensors must be on the same runtime") a = input._tensor b = other._tensor elif isinstance(input, Tensor): @@ -75,7 +77,9 @@ def add( ) if output is not NullTensor: output = output._tensor - return Tensor(Model.get_model().add(a, b, output, name)) + return Tensor( + Model.get_model().add(a, b, output, name), runtime_id=input.runtime_id + ) def cast( @@ -88,7 +92,8 @@ def cast( if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().cast(input._tensor, dtype.ctype(), output, name) + Model.get_model().cast(input._tensor, dtype.ctype(), output, name), + runtime_id=input.runtime_id, ) @@ -97,10 +102,12 @@ def constant( shape: Iterable[int], dtype: DataType = fp32, name: str = "constant", + runtime_id: int = -1, ) -> Tensor: """Constant.""" return Tensor( - Model.get_model().constant(value, Dims(shape), dtype.ctype(), name) + Model.get_model().constant(value, Dims(shape), dtype.ctype(), name), + runtime_id=runtime_id, ) @@ -112,7 +119,10 @@ def copy( output = output._tensor if isinstance(input, Tensor): intput = intput._tensor - return Tensor(Model.get_model().copy(intput, output, name)) + return Tensor( + Model.get_model().copy(intput, output, name), + runtime_id=input.runtime_id, + ) def div( @@ -130,8 +140,13 @@ def div( if output is not NullTensor: output = output._tensor if isinstance(other, Tensor): + if input.runtime_id != other.runtime_id: + raise ValueError("Tensors must be on the same runtime") other = other._tensor - return Tensor(Model.get_model().div(input._tensor, other, output, name)) + return Tensor( + Model.get_model().div(input._tensor, other, output, name), + runtime_id=input.runtime_id, + ) def embedding( @@ -141,10 +156,15 @@ def embedding( name: str = "embedding", ) -> Tensor: """Embedding layer.""" + if input.runtime_id != weight.runtime_id: + raise ValueError("Tensors must be on the same runtime") if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().embedding(input._tensor, weight._tensor, output, name) + Model.get_model().embedding( + input._tensor, weight._tensor, output, name + ), + runtime_id=input.runtime_id, ) @@ -158,7 +178,10 @@ def exp( """ if output is not NullTensor: output = output._tensor - return Tensor(Model.get_model().exp(input._tensor, output, name)) + return Tensor( + Model.get_model().exp(input._tensor, output, name), + runtime_id=input.runtime_id, + ) def gelu( @@ -174,7 +197,10 @@ def gelu( """ if output is not NullTensor: output = output._tensor - return Tensor(Model.get_model().gelu(input._tensor, output, name)) + return Tensor( + Model.get_model().gelu(input._tensor, output, name), + runtime_id=input.runtime_id, + ) def identity( @@ -189,8 +215,13 @@ def identity( for dep in deps: if not isinstance(dep, Tensor): raise TypeError("All dependencies should be a tensor") + if input.runtime_id != dep.runtime_id: + raise ValueError("All tensors must be on the same runtime") dep_tensors.append(dep._tensor) - return Tensor(Model.get_model().identity(input._tensor, dep_tensors, name)) + return Tensor( + Model.get_model().identity(input._tensor, dep_tensors, name), + runtime_id=input.runtime_id, + ) def matmul( @@ -210,6 +241,8 @@ def matmul( Usage: tensor_matmul = ark.matmul(tensor1, tensor2) """ + if input.runtime_id != other.runtime_id: + raise ValueError("Tensors must be on the same runtime") if output is not NullTensor: output = output._tensor return Tensor( @@ -220,7 +253,8 @@ def matmul( transpose_input, transpose_other, name, - ) + ), + runtime_id=input.runtime_id, ) @@ -239,8 +273,13 @@ def mul( if output is not NullTensor: output = output._tensor if isinstance(other, Tensor): + if input.runtime_id != other.runtime_id: + raise ValueError("Tensors must be on the same runtime") other = other._tensor - return Tensor(Model.get_model().mul(input._tensor, other, output, name)) + return Tensor( + Model.get_model().mul(input._tensor, other, output, name), + runtime_id=input.runtime_id, + ) def noop(input: Tensor, name: str = "noop"): @@ -268,7 +307,8 @@ def reduce_max( return Tensor( Model.get_model().reduce_max( input._tensor, axis, keepdims, output, name - ) + ), + runtime_id=input.runtime_id, ) @@ -290,7 +330,8 @@ def reduce_mean( return Tensor( Model.get_model().reduce_mean( input._tensor, axis, keepdims, output, name - ) + ), + runtime_id=input.runtime_id, ) @@ -314,7 +355,8 @@ def reduce_sum( return Tensor( Model.get_model().reduce_sum( input._tensor, axis, keepdims, output, name - ) + ), + runtime_id=input.runtime_id, ) @@ -329,7 +371,10 @@ def relu( """ if output is not NullTensor: output = output._tensor - return Tensor(Model.get_model().relu(input._tensor, output, name)) + return Tensor( + Model.get_model().relu(input._tensor, output, name), + runtime_id=input.runtime_id, + ) def reshape( @@ -357,7 +402,8 @@ def reshape( if len(shape) > 4: raise ValueError("Only support tensors with up to 4 dimensions") return Tensor( - Model.get_model().reshape(input._tensor, Dims(shape), allowzero, name) + Model.get_model().reshape(input._tensor, Dims(shape), allowzero, name), + runtime_id=input.runtime_id, ) @@ -374,8 +420,11 @@ def rope( """ if output is not NullTensor: output = output._tensor + if input.runtime_id != other.runtime_id: + raise ValueError("Tensors must be on the same runtime") return Tensor( - Model.get_model().rope(input._tensor, other._tensor, output, name) + Model.get_model().rope(input._tensor, other._tensor, output, name), + runtime_id=input.runtime_id, ) @@ -389,7 +438,10 @@ def rsqrt( """ if output is not NullTensor: output = output._tensor - return Tensor(Model.get_model().rsqrt(input._tensor, output, name)) + return Tensor( + Model.get_model().rsqrt(input._tensor, output, name), + runtime_id=input.runtime_id, + ) def sharding( @@ -407,7 +459,9 @@ def sharding( _tensor_list = Model.get_model().sharding( input._tensor, axis, dim_per_shard, name ) - return [Tensor(_tensor) for _tensor in _tensor_list] + return [ + Tensor(_tensor, runtime_id=input.runtime_id) for _tensor in _tensor_list + ] def sigmoid( @@ -421,7 +475,10 @@ def sigmoid( """ if output is not NullTensor: output = output._tensor - return Tensor(Model.get_model().sigmoid(input._tensor, output, name)) + return Tensor( + Model.get_model().sigmoid(input._tensor, output, name), + runtime_id=input.runtime_id, + ) def sqrt( @@ -434,7 +491,10 @@ def sqrt( """ if output is not NullTensor: output = output._tensor - return Tensor(Model.get_model().sqrt(input._tensor, output, name)) + return Tensor( + Model.get_model().sqrt(input._tensor, output, name), + runtime_id=input.runtime_id, + ) def sub( @@ -452,8 +512,13 @@ def sub( if output is not NullTensor: output = output._tensor if isinstance(other, Tensor): + if input.runtime_id != other.runtime_id: + raise ValueError("Tensors must be on the same runtime") other = other._tensor - return Tensor(Model.get_model().sub(input._tensor, other, output, name)) + return Tensor( + Model.get_model().sub(input._tensor, other, output, name), + runtime_id=input.runtime_id, + ) def tensor( @@ -463,6 +528,7 @@ def tensor( offsets: Iterable[int] = [], padded_shape: Iterable[int] = [], name: str = "", + runtime_id: int = -1, ) -> Tensor: """ Construct a tensor with given shape and data type. @@ -470,7 +536,10 @@ def tensor( tensor = ark.tensor([1, 2, 3, 4], dtype=ark.fp32) tensor = ark.tensor([1, 2], dtype=ark.fp16) """ - return Tensor(_tensor(shape, dtype, strides, offsets, padded_shape, name)) + return Tensor( + _tensor(shape, dtype, strides, offsets, padded_shape, name), + runtime_id=runtime_id, + ) def transpose( @@ -496,7 +565,8 @@ def transpose( if len(perm) > 4: raise ValueError("Only support perm up to 4 dimensions") return Tensor( - Model.get_model().transpose(input._tensor, perm, output, name) + Model.get_model().transpose(input._tensor, perm, output, name), + runtime_id=input.runtime_id, ) @@ -515,11 +585,15 @@ def mean( def ones( - shape: Iterable[int], dtype: DataType = fp32, name: str = "ones" + shape: Iterable[int], + dtype: DataType = fp32, + name: str = "ones", + runtime_id: int = -1, ) -> Tensor: """Ones.""" return Tensor( - Model.get_model().constant(1, Dims(shape), dtype.ctype(), name) + Model.get_model().constant(1, Dims(shape), dtype.ctype(), name), + runtime_id=runtime_id, ) @@ -530,12 +604,14 @@ def parameter( offsets: Iterable[int] = [], padded_shape: Iterable[int] = [], name: str = "", + runtime_id: int = -1, ) -> Parameter: """ Construct a parameter with given shape and data type. """ return Parameter( - _tensor(shape, dtype, strides, offsets, padded_shape, name) + _tensor(shape, dtype, strides, offsets, padded_shape, name), + runtime_id=runtime_id, ) @@ -569,11 +645,15 @@ def layernorm( def zeros( - shape: Iterable[int], dtype: DataType = fp32, name: str = "zeros" + shape: Iterable[int], + dtype: DataType = fp32, + name: str = "zeros", + runtime_id: int = -1, ) -> Tensor: """Zeros.""" return Tensor( - Model.get_model().constant(0, Dims(shape), dtype.ctype(), name) + Model.get_model().constant(0, Dims(shape), dtype.ctype(), name), + runtime_id=runtime_id, ) diff --git a/python/ark/runtime.py b/python/ark/runtime.py index 7480ce7da..798eaf9d5 100644 --- a/python/ark/runtime.py +++ b/python/ark/runtime.py @@ -3,7 +3,7 @@ import logging from enum import Enum -from typing import Callable +from typing import Callable, Dict, List from _ark_core import _Executor, _DefaultPlanner from .model import Model @@ -14,8 +14,36 @@ class _RuntimeState: The _RuntimeState class is used to store the state of the model. """ - runtime = None - executor = None + runtime: Dict[int, "Runtime"] = {} + + @staticmethod + def reset_all(): + """ + Resets all runtimes. + """ + runtime_ids = list(_RuntimeState.runtime.keys()) + for runtime_id in runtime_ids: + _RuntimeState.runtime[runtime_id].reset() + + @staticmethod + def delete_all(): + """ + Deletes all runtimes. + """ + runtime_ids = list(_RuntimeState.runtime.keys()) + for runtime_id in runtime_ids: + _RuntimeState.runtime[runtime_id].reset(delete=True) + + @staticmethod + def print_runtime_states(): + """ + Print runtimes and their corresponding states. + """ + print(f"{'Runtime ID':<12} | {'Status':<20}") + print(f"{'-'*12} | {'-'*20}") + for runtime_id, runtime in _RuntimeState.runtime.items(): + runtime_id = "-1(Default)" if runtime_id == -1 else runtime_id + print(f"{runtime_id:<12} | {runtime.state:<20}") class DefaultPlanner(_DefaultPlanner): @@ -61,22 +89,48 @@ class State(Enum): LaunchedNotRunning = 1 Running = 2 + def __init__(self, runtime_id: int = -1): + self.runtime_id = runtime_id + self.executor: Executor = None + self.state: Runtime.State = Runtime.State.Init + _RuntimeState.runtime[runtime_id] = self + + def get_state(self) -> "Runtime.State": + """ + Get the runtime state. + """ + return self.state + @staticmethod - def get_runtime() -> "Runtime": + def exists(runtime_id: int) -> bool: """ - Get the runtime. + Check if a runtime exists with the given ID. """ - if _RuntimeState.runtime is None: - _RuntimeState.runtime = Runtime() - return _RuntimeState.runtime + return runtime_id in _RuntimeState.runtime - def __init__(self): - self.executor: Executor = None - self.state: Runtime.State = Runtime.State.Init - _RuntimeState.runtime = self + @staticmethod + def get_all_ids() -> List[int]: + """ + Get a list of all existing runtime IDs. + """ + return list(_RuntimeState.runtime.keys()) - def __del__(self): - self.reset() + @staticmethod + def get_runtime(runtime_id=-1) -> "Runtime": + """ + Get the runtime by ID. If runtime_id is not provided, use a default ID of -1. + If the runtime does not exist, create a new runtime with the given ID. + """ + if runtime_id not in _RuntimeState.runtime: + _RuntimeState.runtime[runtime_id] = Runtime(runtime_id) + return _RuntimeState.runtime[runtime_id] + + @staticmethod + def see_runtime_statuses() -> "Dict[int, Runtime]": + """ + Returns the runtime dictionary containing all of the runtimes. + """ + return _RuntimeState.runtime def __enter__(self): return self @@ -113,7 +167,9 @@ def launch( initialized. The executor will compile the cuda kernels and launch the ARK runtime. """ if self.launched(): - logging.warn("Runtime is already launched, skip launching") + logging.warn( + f"Runtime {self.runtime_id} is already launched, skip launching" + ) return if not plan: if not plan_path: @@ -124,19 +180,19 @@ def launch( # If the RuntimeState is init, we need to create a new executor and # compile the kernels if self.state == Runtime.State.Init: - if _RuntimeState.executor is not None: - if not _RuntimeState.executor.destroyed(): - logging.warn("Destroying an old executor") - _RuntimeState.executor.destroy() - - _RuntimeState.executor = Executor( + if self.executor is not None: + if not self.executor.destroyed(): + logging.warn( + f"Runtime {self.runtime_id}, has already been launched. Destroying the old executor" + ) + self.executor.destroy() + self.executor = Executor( rank, world_size, gpu_id, "ArkRuntime", plan, ) - self.executor = _RuntimeState.executor self.executor.compile() self.executor.launch() self.state = Runtime.State.LaunchedNotRunning @@ -146,8 +202,8 @@ def run(self, iter=1, non_blocking=False): Run the ARK program for iter iterations and wait for the kernel to finish. """ if self.state != Runtime.State.LaunchedNotRunning: - logging.error("ARK runtime is not launched") - raise RuntimeError("ARK runtime is not launched") + logging.error(f"ARK runtime {self.runtime_id} is not launched") + raise RuntimeError(f"ARK runtime {self.runtime_id} is not launched") self.state = Runtime.State.Running self.executor.run(iter) if not non_blocking: @@ -158,7 +214,9 @@ def wait(self): Wait for the kernel to finish. """ if self.state != Runtime.State.Running: - logging.warn("ARK runtime is not running, skip waiting") + logging.warn( + f"ARK runtime {self.runtime_id} is not running, skip waiting" + ) return self.executor.wait() self.state = Runtime.State.LaunchedNotRunning @@ -169,15 +227,17 @@ def stop(self) -> float: Once this is called, we need to call `launch()` again to run the model again. """ if not self.launched(): - logging.warn("ARK runtime is never launched, skip stopping") + logging.warn( + f"ARK runtime {self.runtime_id} is never launched, skip stopping" + ) return elapsed = self.executor.stop() self.state = Runtime.State.LaunchedNotRunning return elapsed - def reset(self): + def reset(self, delete=False): """ - Reset the runtime. + Reset the runtime. If delete is True, delete the runtime associated with the runtime_id. """ if self.launched(): self.stop() @@ -186,3 +246,26 @@ def reset(self): self.executor.destroy() self.executor = None self.state = Runtime.State.Init + if delete: + del _RuntimeState.runtime[self.runtime_id] + + @staticmethod + def reset_all_runtimes(): + """ + Reset all runtimes. + """ + _RuntimeState.reset_all() + + @staticmethod + def delete_all_runtimes(): + """ + Delete all runtimes. + """ + _RuntimeState.delete_all() + + @staticmethod + def print_runtime_states(): + """ + Print runtimes and their corresponding states. + """ + _RuntimeState.print_runtime_states() diff --git a/python/ark/tensor.py b/python/ark/tensor.py index a567264d5..00e266929 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -29,14 +29,22 @@ class Dims(_Dims): class Tensor: - def __init__(self, _tensor: _Tensor, initializer: Initializer = None): + def __init__( + self, + _tensor: _Tensor, + initializer: Initializer = None, + runtime_id: int = -1, + ): """ Initializes a new instance of the Tensor class. Args: _tensor (_ark_core._Tensor): The underlying _Tensor object. + intializer (Initializer): The initializer for the Tensor. + runtime_id (int): The ID of the Runtime to use. Defaults to -1, which is the default Runtime. """ self._tensor = _tensor self.initializer: Initializer = initializer + self.runtime_id = runtime_id def shape(self) -> List[int]: """ @@ -69,7 +77,7 @@ def to_numpy(self, ndarray: np.ndarray = None) -> np.ndarray: an empty numpy array without the data buffer will be returned. """ np_type = self.dtype().to_numpy() - rt = Runtime.get_runtime() + rt = Runtime.get_runtime(self.runtime_id) if not rt.launched(): return np.ndarray(self.shape(), dtype=np_type, buffer=None) if ndarray is None: @@ -85,7 +93,9 @@ def to_numpy(self, ndarray: np.ndarray = None) -> np.ndarray: rt.executor.tensor_read(self._tensor, ndarray) return ndarray - def to_torch(self, tensor: torch.Tensor = None) -> torch.Tensor: + def to_torch( + self, tensor: torch.Tensor = None, runtime_id: int = -1 + ) -> torch.Tensor: """ """ if _no_torch: raise ImportError("torch is not available") @@ -100,22 +110,42 @@ def to_torch(self, tensor: torch.Tensor = None) -> torch.Tensor: raise ValueError("torch tensor is not contiguous in memory") elif tensor.numel() != self.nelems(): raise ValueError("torch tensor size does not match the tensor") - tensor.copy_(torch.from_numpy(self.to_numpy())) + tensor.copy_(torch.from_numpy(self.to_numpy(self.runtime_id))) return tensor - @staticmethod - def from_numpy(ndarray: np.ndarray): - return Tensor( - Model.get_model().tensor( - Dims(list(ndarray.shape)), - DataType.from_numpy(ndarray.dtype).ctype(), - Dims(), - Dims(), - Dims(), - "", - ), - lambda: ndarray, - ) + def get_torch_view(self) -> torch.Tensor: + """ + Returns a torch tensor that shares the same memory with the device tensor. + """ + if _no_torch: + raise ImportError("torch is not available") + rt = Runtime.get_runtime(self.runtime_id) + if not rt.launched(): + raise RuntimeError( + "Tensor is not allocated yet. `Tensor.get_torch_view()` is " + "usable only after you call `Runtime.launch()`." + ) + dl_tensor = rt.executor.get_dl_tensor(self._tensor) + torch_view = torch.utils.dlpack.from_dlpack(dl_tensor) + return torch_view + + def from_numpy(self, ndarray: np.ndarray) -> "Tensor": + """ + Copies the tensor from a host numpy array to the device. + """ + rt = Runtime.get_runtime(self.runtime_id) + if not rt.launched(): + raise RuntimeError( + "Tensor is not allocated yet. `Tensor.from_numpy()` is " + "usable only after you call `Runtime.launch()`." + ) + ndarray = ndarray.astype(self.dtype().to_numpy()) + if not ndarray.flags["C_CONTIGUOUS"]: + ndarray = np.ascontiguousarray(ndarray) + if ndarray.nbytes != self.nelems() * self.dtype().element_size(): + raise ValueError("ndarray size does not match the tensor") + rt.executor.tensor_write(self._tensor, ndarray) + return self @staticmethod def from_torch(tensor: torch.Tensor): @@ -135,7 +165,7 @@ def copy(self, data: Union[np.ndarray, torch.Tensor]) -> "Tensor": """ Copies the tensor from a host numpy array to the device. """ - rt = Runtime.get_runtime() + rt = Runtime.get_runtime(self.runtime_id) if not rt.launched(): raise RuntimeError( "Tensor is not allocated yet. `Tensor.from_numpy()` is " @@ -180,8 +210,9 @@ class Parameter(Tensor): A tensor as a parameter. """ - def __init__(self, _tensor: _Tensor): + def __init__(self, _tensor: _Tensor, runtime_id: int = -1): """ Initializes a new instance of the Parameter class. """ super().__init__(_tensor) + self.runtime_id = runtime_id diff --git a/python/executor_py.cpp b/python/executor_py.cpp index 13a81608e..59bee5a9b 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -1,13 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +#include #include #include #include #include #include - +#include namespace py = pybind11; static void tensor_write(ark::Executor *exe, const ark::Tensor &tensor, @@ -29,6 +30,29 @@ static void tensor_read(ark::Executor *exe, const ark::Tensor &tensor, info.size * info.itemsize); } +DLManagedTensor *to_dlpack(ark::Executor &exe, const ark::Tensor &tensor) { + DLManagedTensor *dl_tensor = exe.get_dl_tensor(tensor); + return dl_tensor; +} + +void free_capsule(PyObject *capsule) { + const char *name = PyCapsule_GetName(capsule); + auto *dl_managed_tensor = + static_cast(PyCapsule_GetPointer(capsule, name)); + if (dl_managed_tensor) { + dl_managed_tensor->deleter(dl_managed_tensor); + dl_managed_tensor = nullptr; + } +} + +py::capsule to_dlpack_capsule(ark::Executor &self, const ark::Tensor &tensor) { + DLManagedTensor *dl_managed_tensor = to_dlpack(self, tensor); + const char *capsule_name = "dltensor"; + PyObject *dl_capsule = PyCapsule_New(static_cast(dl_managed_tensor), + capsule_name, free_capsule); + return py::reinterpret_steal(dl_capsule); +} + void register_executor(py::module &m) { py::class_(m, "_Executor") .def( @@ -52,5 +76,7 @@ void register_executor(py::module &m) { .def("tensor_write", py::overload_cast(&tensor_write), - py::arg("tensor"), py::arg("address"), py::arg("bytes")); + py::arg("tensor"), py::arg("address"), py::arg("bytes")) + .def("get_dl_tensor", &to_dlpack_capsule), + py::arg("tensor"); } diff --git a/python/unittest/test.py b/python/unittest/test.py index f6f9b97af..e43ff11e2 100644 --- a/python/unittest/test.py +++ b/python/unittest/test.py @@ -9,3 +9,4 @@ from test_model import * from test_runtime import * +from test_conversion import * diff --git a/python/unittest/test_conversion.py b/python/unittest/test_conversion.py new file mode 100644 index 000000000..8f941a033 --- /dev/null +++ b/python/unittest/test_conversion.py @@ -0,0 +1,93 @@ +import torch +import numpy as np +import ark + + +def initialize_tensor(dimensions, dtype): + tensor = ark.tensor(dimensions, dtype) + tensor_host = np.random.rand(*dimensions).astype(dtype.to_numpy()) + return tensor, tensor_host + + +# Test function to validate the integrity of the PyTorch view of the ARK tensor, +# including its data and attributes such as shape and data type. +def test_values_fixed_dims(num_dims: int, size: int, dtype: ark.DataType): + ark.init() + dimensions = [size] * num_dims + + input_tensor, input_tensor_host = initialize_tensor(dimensions, dtype) + other_tensor, other_tensor_host = initialize_tensor(dimensions, dtype) + output_tensor = ark.add(input_tensor, other_tensor) + + runtime = ark.Runtime() + runtime.launch() + + input_tensor.from_numpy(input_tensor_host) + other_tensor.from_numpy(other_tensor_host) + + input_view = input_tensor.get_torch_view() + other_view = other_tensor.get_torch_view() + output_view = output_tensor.get_torch_view() + + runtime.run() + + input_view_numpy = input_view.cpu().numpy() + other_view_numpy = other_view.cpu().numpy() + output_view_numpy = output_view.cpu().numpy() + + output_tensor_host = output_tensor.to_numpy() + + runtime.stop() + runtime.delete_all_runtimes() + + assert np.allclose(input_tensor_host, input_view_numpy) + assert np.allclose(other_tensor_host, other_view_numpy) + assert np.allclose(output_tensor_host, output_view_numpy) + + +# Function to check if there is a difference between two arrays at a specific index +def check_diff(input_tensor_host, input_view_numpy, value, index): + mask = np.ones(input_tensor_host.shape, dtype=bool) + mask[index] = False + if not np.allclose(input_tensor_host[mask], input_view_numpy[mask]): + print("Difference found at index: ", index) + return False + if input_view_numpy[index] != value: + print(input_view_numpy[index], value) + return False + return True + + +# Test function to check if changes to the torch views are reflected in the original tensors +def test_aliasing(dtype: ark.DataType): + ark.init() + dimensions = [4, 4] + input_tensor, input_tensor_host = initialize_tensor(dimensions, dtype) + other_tensor, other_tensor_host = initialize_tensor(dimensions, dtype) + output_tensor = ark.mul(input_tensor, other_tensor) + runtime = ark.Runtime() + runtime.launch() + input_tensor.from_numpy(input_tensor_host) + other_tensor.from_numpy(other_tensor_host) + + input_view = input_tensor.get_torch_view() + other_view = other_tensor.get_torch_view() + output_view = output_tensor.get_torch_view() + # make changes to the views + input_view[1, 1] = 20 + other_view[0, 0] = 30 + runtime.run() + output_view[3, 0] = 40 + + output_tensor_host = output_tensor.to_numpy() + input_view_numpy = input_view.cpu().numpy() + other_view_numpy = other_view.cpu().numpy() + output_view_numpy = output_view.cpu().numpy() + # Check if changes to the views are reflected in the original tensors + print(input_view_numpy) + assert check_diff(input_tensor_host, input_view_numpy, 20, (1, 1)) + assert check_diff(other_tensor_host, other_view_numpy, 30, (0, 0)) + assert check_diff(output_tensor_host, output_view_numpy, 40, (3, 0)) + + runtime.stop() + runtime.reset() diff --git a/python/unittest/test_runtime.py b/python/unittest/test_runtime.py index bd9098fe8..fd34bb96b 100644 --- a/python/unittest/test_runtime.py +++ b/python/unittest/test_runtime.py @@ -4,21 +4,20 @@ import ark import json +empty_plan = json.dumps( + { + "Rank": 0, + "WorldSize": 1, + "NumProcessors": 1, + "NumWarpsPerProcessor": 1, + "TaskInfos": [], + "ProcessorGroups": [], + } +) + def test_runtime_relaunch(): ark.init() - - empty_plan = json.dumps( - { - "Rank": 0, - "WorldSize": 1, - "NumProcessors": 1, - "NumWarpsPerProcessor": 1, - "TaskInfos": [], - "ProcessorGroups": [], - } - ) - with ark.Runtime.get_runtime() as rt: assert rt.launched() == False rt.launch(plan=empty_plan) @@ -28,3 +27,101 @@ def test_runtime_relaunch(): assert rt.launched() == False rt.launch(plan=empty_plan) assert rt.launched() == True + + +def test_multiple_runtime_launch(): + ark.init() + num_runtimes = 5 + for i in range(num_runtimes): + rt = ark.Runtime.get_runtime(i) + assert rt.launched() == False + rt.launch(gpu_id=i, plan=empty_plan) + assert rt.launched() == True + for i in range(num_runtimes): + rt = ark.Runtime.get_runtime(i) + assert rt.launched() == True + ark.Runtime.delete_all_runtimes() + + +def test_stop_runtime(): + ark.init() + rt1 = ark.Runtime.get_runtime(1) + rt1.launch(plan=empty_plan, gpu_id=1) + rt2 = ark.Runtime.get_runtime(2) + rt2.launch(plan=empty_plan, gpu_id=2) + rt1.stop() + rt1.reset() + assert rt1.state == ark.Runtime.State.Init + assert rt2.state == ark.Runtime.State.LaunchedNotRunning + ark.Runtime.delete_all_runtimes() + + +def test_reset_runtime(): + ark.init() + rt1 = ark.Runtime.get_runtime(0) + rt1.launch(plan=empty_plan, gpu_id=1) + rt2 = ark.Runtime.get_runtime(1) + rt2.launch(plan=empty_plan, gpu_id=2) + rt1.reset() + assert rt1.launched() == False + assert rt2.launched() == True + rt1.launch(plan=empty_plan) + assert rt1.launched() == True + ark.Runtime.delete_all_runtimes() + + +def test_multiple_runtimes_complex(): + ark.init() + num_runtimes = 3 + runtime_list = [ark.Runtime.get_runtime(i) for i in range(num_runtimes)] + default_runtime = ark.Runtime.get_runtime() + runtime_list.append(default_runtime) + for i, rt in enumerate(runtime_list): + rt.launch(plan=empty_plan, gpu_id=i) + assert rt.launched() == True + runtime_list[0].stop() + assert runtime_list[0].state == ark.Runtime.State.LaunchedNotRunning + for rt in runtime_list[1:]: + assert rt.launched() == True + runtime_list[1].reset() + assert runtime_list[1].state == ark.Runtime.State.Init + assert runtime_list[0].state == ark.Runtime.State.LaunchedNotRunning + assert runtime_list[2].state == ark.Runtime.State.LaunchedNotRunning + runtime_list[1].launch(plan=empty_plan, gpu_id=1) + for rt in runtime_list: + assert rt.launched() == True + ark.Runtime.delete_all_runtimes() + + +def test_runtime_state_after_reset(): + ark.init() + rt = ark.Runtime.get_runtime() + rt.launch(plan=empty_plan) + rt.reset() + assert rt.launched() == False + assert rt.running() == False + ark.Runtime.delete_all_runtimes() + + +def test_see_runtime_statuses(): + ark.init() + num_runtimes = 3 + runtimes = [ark.Runtime.get_runtime(i) for i in range(num_runtimes)] + runtime_statuses = ark.Runtime.see_runtime_statuses() + assert len(runtime_statuses) == num_runtimes + for i in range(num_runtimes): + assert i in runtime_statuses + for i, rt in enumerate(runtimes): + assert runtime_statuses[i] == rt + ark.Runtime.delete_all_runtimes() + + +def test_multiple_runtimes_init(): + ark.init() + runtimes = [ark.Runtime.get_runtime(i) for i in range(3)] + for rt in runtimes: + assert rt.state == ark.Runtime.State.Init + ark.init() + runtimes = ark.Runtime.see_runtime_statuses() + assert len(runtimes) == 0 + ark.Runtime.delete_all_runtimes() From 9a0556bde84a4dd6a76f39155d60957c9165ad52 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 18 Jun 2024 21:30:02 +0000 Subject: [PATCH 09/54] cmake dlpack --- .gitmodules | 4 ++++ ark/CMakeLists.txt | 1 + third_party/CMakeLists.txt | 13 +++++++++++++ third_party/dlpack | 1 + 4 files changed, 19 insertions(+) create mode 160000 third_party/dlpack diff --git a/.gitmodules b/.gitmodules index ced5dcf94..ec484eb61 100644 --- a/.gitmodules +++ b/.gitmodules @@ -17,3 +17,7 @@ [submodule "third_party/json"] path = third_party/json url = https://github.com/nlohmann/json + +[submodule "third_party/dlpack"] + path = third_party/dlpack + url = https://github.com/dmlc/dlpack diff --git a/ark/CMakeLists.txt b/ark/CMakeLists.txt index 4457d3c0b..ce03b65ed 100644 --- a/ark/CMakeLists.txt +++ b/ark/CMakeLists.txt @@ -17,6 +17,7 @@ set(COMMON_LIBS ARK::numa ARK::ibverbs pthread rt) target_include_directories(ark_obj PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) target_include_directories(ark_obj PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(ark_obj SYSTEM PRIVATE + ${DLPACK_INCLUDE_DIRS} ${JSON_INCLUDE_DIRS} ${MSCCLPP_INCLUDE_DIRS} ${IBVERBS_INCLUDE_DIRS} diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 75916d962..cc4b5eb5c 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -35,6 +35,19 @@ if (NOT json_POPULATED) endif() set(JSON_INCLUDE_DIRS ${json_SOURCE_DIR}/include PARENT_SCOPE) +# DLPack +FetchContent_Declare( + dlpack + GIT_REPOSITORY https://github.com/dmlc/dlpack + GIT_TAG v0.8 + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/dlpack +) +FetchContent_GetProperties(dlpack) +if (NOT dlpack_POPULATED) + FetchContent_Populate(dlpack) +endif() +set(DLPACK_INCLUDE_DIRS ${dlpack_SOURCE_DIR}/include PARENT_SCOPE) + if(USE_CUDA) # Configure CUTLASS FetchContent_Declare( diff --git a/third_party/dlpack b/third_party/dlpack new file mode 160000 index 000000000..365b823ce --- /dev/null +++ b/third_party/dlpack @@ -0,0 +1 @@ +Subproject commit 365b823cedb281cd0240ca601aba9b78771f91a3 From 75f7831b700783e899beaa15f950f125a7520d6c Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 18 Jun 2024 22:38:35 +0000 Subject: [PATCH 10/54] include dlpack for pybind --- python/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index efb9aea3e..bd25d01e6 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -20,3 +20,4 @@ file(GLOB_RECURSE BIND_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.c pybind11_add_module(ark_py ${BIND_SOURCES}) set_target_properties(ark_py PROPERTIES OUTPUT_NAME _ark_core) target_link_libraries(ark_py PRIVATE ark_static) +target_include_directories(ark_py SYSTEM PRIVATE ${DLPACK_INCLUDE_DIRS}) From 94b44f20a15c892d5a47e1597d838891ca600553 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 24 Jun 2024 23:51:22 +0000 Subject: [PATCH 11/54] support d2d copy --- ark/api/executor.cpp | 99 ++++++++++++++++++++---------- ark/include/ark/executor.hpp | 10 ++- python/ark/tensor.py | 42 +++++++++---- python/executor_py.cpp | 33 +++++++--- python/unittest/test_conversion.py | 37 ++++++++++- 5 files changed, 162 insertions(+), 59 deletions(-) diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index a0711bfe8..96e53c8cf 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -147,6 +147,8 @@ class Executor::Impl { const std::string &plan); ~Impl() = default; + int gpu_id() const { return gpu_id_; } + void compile(); void launch(int64_t max_spin_count); void run(int iter); @@ -154,9 +156,10 @@ class Executor::Impl { float stop(int64_t max_spin_count); void barrier(); - void tensor_read(const Tensor tensor, void *data, size_t bytes) const; + void tensor_read(const Tensor tensor, void *data, size_t bytes, + bool is_d2d) const; void tensor_write(const Tensor tensor, const void *data, - size_t bytes) const; + size_t bytes, bool is_d2d) const; DLDeviceType get_device_type() const; DLManagedTensor *get_dl_tensor(const Tensor &tensor) const; @@ -731,57 +734,83 @@ void Executor::Impl::barrier() { } void Executor::Impl::tensor_read(const Tensor tensor, void *data, - size_t bytes) const { + size_t bytes, bool is_d2d) const { GLOG(gpuSetDevice(gpu_id_)); size_t tensor_data_bytes = tensor.shape().nelems() * tensor.data_type().bytes(); - if (bytes < tensor_data_bytes) { - ERR(InvalidUsageError, "Data buffer (", bytes, - ") is smaller than the tensor data (", tensor_data_bytes, ")."); + if (bytes != tensor_data_bytes) { + ERR(InvalidUsageError, "Destination bytes (", bytes, + ") mismatches the tensor data bytes (", tensor_data_bytes, ")."); } - size_t tensor_bytes = - tensor.strides().nelems() * tensor.data_type().bytes(); - void *src = - buffer_->ref(buffer_id_to_offset_.at(tensor.ref()->buffer()->id())); + size_t buffer_id = tensor.ref()->buffer()->id(); + if (buffer_id_to_offset_.find(buffer_id) == buffer_id_to_offset_.end()) { + ERR(NotFoundError, "Invalid buffer ID: ", buffer_id); + } + size_t offset = buffer_id_to_offset_.at(buffer_id); + auto kind = (is_d2d) ? gpuMemcpyDeviceToDevice : gpuMemcpyDeviceToHost; + void *src = buffer_->ref(offset); if (tensor.strides() == tensor.shape()) { - GLOG(gpuMemcpyAsync(data, src, bytes, gpuMemcpyDeviceToHost, - copy_stream_->get())); - copy_stream_->sync(); + GLOG(gpuMemcpyAsync(data, src, bytes, kind, copy_stream_->get())); } else { + size_t tensor_bytes = + tensor.strides().nelems() * tensor.data_type().bytes(); std::vector tensor_host(tensor_bytes); GLOG(gpuMemcpyAsync(tensor_host.data(), src, tensor_bytes, gpuMemcpyDeviceToHost, copy_stream_->get())); copy_stream_->sync(); - tensor_to_data(tensor_host.data(), static_cast(data), - tensor.shape(), tensor.strides(), tensor.offsets(), - tensor.data_type().bytes()); + if (!is_d2d) { + tensor_to_data(tensor_host.data(), static_cast(data), + tensor.shape(), tensor.strides(), tensor.offsets(), + tensor.data_type().bytes()); + return; + } + // TODO: convert data layout on the device directly + std::vector data_host(bytes); + tensor_to_data(tensor_host.data(), data_host.data(), + tensor.shape(), tensor.strides(), tensor.offsets(), + tensor.data_type().bytes()); + GLOG(gpuMemcpyAsync(data, data_host.data(), bytes, + gpuMemcpyHostToDevice, copy_stream_->get())); } + copy_stream_->sync(); } void Executor::Impl::tensor_write(const Tensor tensor, const void *data, - size_t bytes) const { + size_t bytes, bool is_d2d) const { GLOG(gpuSetDevice(gpu_id_)); size_t tensor_data_bytes = tensor.shape().nelems() * tensor.data_type().bytes(); - if (bytes < tensor_data_bytes) { - ERR(InvalidUsageError, "Data buffer (", bytes, - ") is smaller than the tensor data (", tensor_data_bytes, ")."); + if (bytes != tensor_data_bytes) { + ERR(InvalidUsageError, "Source bytes (", bytes, + ") mismatches the tensor data bytes (", tensor_data_bytes, ")."); + } + size_t buffer_id = tensor.ref()->buffer()->id(); + if (buffer_id_to_offset_.find(buffer_id) == buffer_id_to_offset_.end()) { + ERR(NotFoundError, "Invalid buffer ID: ", buffer_id); } + size_t offset = buffer_id_to_offset_.at(buffer_id); size_t tensor_bytes = tensor.strides().nelems() * tensor.data_type().bytes(); - void *dst = - buffer_->ref(buffer_id_to_offset_.at(tensor.ref()->buffer()->id())); + auto kind = (is_d2d) ? gpuMemcpyDeviceToDevice : gpuMemcpyHostToDevice; + void *dst = buffer_->ref(offset); if (tensor.strides() == tensor.shape()) { - GLOG(gpuMemcpyAsync(dst, data, tensor_bytes, gpuMemcpyHostToDevice, - copy_stream_->get())); + GLOG(gpuMemcpyAsync(dst, data, tensor_bytes, kind, copy_stream_->get())); } else { std::vector tensor_host(tensor_bytes); - GLOG(gpuMemcpyAsync(tensor_host.data(), dst, tensor_bytes, - gpuMemcpyDeviceToHost, copy_stream_->get())); - copy_stream_->sync(); - data_to_tensor(tensor_host.data(), static_cast(data), - tensor.shape(), tensor.strides(), tensor.offsets(), - tensor.data_type().bytes()); + if (!is_d2d) { + data_to_tensor(tensor_host.data(), static_cast(data), + tensor.shape(), tensor.strides(), tensor.offsets(), + tensor.data_type().bytes()); + } else { + // TODO: convert data layout on the device directly + std::vector tmp(bytes); + GLOG(gpuMemcpyAsync(tmp.data(), data, bytes, + gpuMemcpyDeviceToHost, copy_stream_->get())); + copy_stream_->sync(); + data_to_tensor(tensor_host.data(), tmp.data(), + tensor.shape(), tensor.strides(), tensor.offsets(), + tensor.data_type().bytes()); + } GLOG(gpuMemcpyAsync(dst, tensor_host.data(), tensor_bytes, gpuMemcpyHostToDevice, copy_stream_->get())); } @@ -883,6 +912,8 @@ Executor::Executor(int rank, int world_size, int gpu_id, Executor::~Executor() = default; +int Executor::gpu_id() const { return impl_->gpu_id(); } + void Executor::compile() { impl_->compile(); } void Executor::launch(int64_t max_spin_count) { impl_->launch(max_spin_count); } @@ -902,13 +933,13 @@ void Executor::destroy() { impl_.reset(nullptr); } bool Executor::destroyed() const { return impl_.get() == nullptr; } void Executor::tensor_read(const Tensor tensor, void *data, - size_t bytes) const { - impl_->tensor_read(tensor, data, bytes); + size_t bytes, bool is_d2d) const { + impl_->tensor_read(tensor, data, bytes, is_d2d); } void Executor::tensor_write(const Tensor tensor, const void *data, - size_t bytes) const { - impl_->tensor_write(tensor, data, bytes); + size_t bytes, bool is_d2d) const { + impl_->tensor_write(tensor, data, bytes, is_d2d); } DLDeviceType Executor::get_device_type() const { diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index 54c49cd29..a5d6f0273 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -23,6 +23,9 @@ class Executor { ~Executor(); + /// Return the GPU ID. + int gpu_id() const; + /// Compile the model. This must be called before `launch()`. void compile(); @@ -59,10 +62,11 @@ class Executor { data.size() * sizeof(T)); } - void tensor_read(const Tensor tensor, void *data, size_t bytes) const; + void tensor_read(const Tensor tensor, void *data, size_t bytes, + bool is_d2d = false) const; - void tensor_write(const Tensor tensor, const void *data, - size_t bytes) const; + void tensor_write(const Tensor tensor, const void *data, size_t bytes, + bool is_d2d = false) const; DLManagedTensor *get_dl_tensor(const Tensor &tensor) const; diff --git a/python/ark/tensor.py b/python/ark/tensor.py index 00e266929..eff1bf20e 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -77,10 +77,17 @@ def to_numpy(self, ndarray: np.ndarray = None) -> np.ndarray: an empty numpy array without the data buffer will be returned. """ np_type = self.dtype().to_numpy() + if np_type is None: + raise ValueError( + f"Tensor data type {self.dtype().__name__} is not supported by numpy." + ) rt = Runtime.get_runtime(self.runtime_id) if not rt.launched(): - return np.ndarray(self.shape(), dtype=np_type, buffer=None) - if ndarray is None: + raise RuntimeError( + "Tensor is not allocated yet. `Tensor.to_numpy()` is " + "usable only after you call `Runtime.launch()`." + ) + elif ndarray is None: ndarray = np.zeros(self.shape(), dtype=np_type) elif not ndarray.flags["C_CONTIGUOUS"]: raise ValueError("ndarray is not contiguous in memory") @@ -99,9 +106,18 @@ def to_torch( """ """ if _no_torch: raise ImportError("torch is not available") + rt = Runtime.get_runtime(self.runtime_id) + if not rt.launched(): + raise RuntimeError( + "Tensor is not allocated yet. `Tensor.to_torch()` is " + "usable only after you call `Runtime.launch()`." + ) torch_type = self.dtype().to_torch() if tensor is None: - return torch.from_numpy(self.to_numpy()) + dev_name = f"cuda:{rt.executor.gpu_id()}" + tensor = torch.zeros( + self.shape(), dtype=torch_type, device=torch.device(dev_name) + ) elif tensor.shape != self.shape(): raise ValueError("torch tensor shape does not match the tensor") elif tensor.dtype != torch_type: @@ -110,7 +126,10 @@ def to_torch( raise ValueError("torch tensor is not contiguous in memory") elif tensor.numel() != self.nelems(): raise ValueError("torch tensor size does not match the tensor") - tensor.copy_(torch.from_numpy(self.to_numpy(self.runtime_id))) + tensor_bytes = self.nelems() * self.dtype().element_size() + rt.executor.tensor_read( + self._tensor, tensor.data_ptr(), tensor_bytes, True + ) return tensor def get_torch_view(self) -> torch.Tensor: @@ -163,7 +182,8 @@ def from_torch(tensor: torch.Tensor): def copy(self, data: Union[np.ndarray, torch.Tensor]) -> "Tensor": """ - Copies the tensor from a host numpy array to the device. + Copies data into this tensor. The data type may differ, + but the size must match. """ rt = Runtime.get_runtime(self.runtime_id) if not rt.launched(): @@ -171,24 +191,22 @@ def copy(self, data: Union[np.ndarray, torch.Tensor]) -> "Tensor": "Tensor is not allocated yet. `Tensor.from_numpy()` is " "usable only after you call `Runtime.launch()`." ) + tensor_bytes = self.nelems() * self.dtype().element_size() if isinstance(data, torch.Tensor): - if data.dtype != self.dtype().to_torch(): - raise ValueError("data dtype does not match the tensor") if not data.is_contiguous(): data = data.contiguous() - if data.numel() != self.nelems(): + if data.numel() * data.element_size() != tensor_bytes: raise ValueError("data size does not match the tensor") rt.executor.tensor_write( self._tensor, data.data_ptr(), - data.numel() * data.element_size(), + tensor_bytes, + data.device.type == "cuda", ) elif isinstance(data, np.ndarray): - if data.dtype != self.dtype().to_numpy(): - raise ValueError("data dtype does not match the tensor") if not data.flags["C_CONTIGUOUS"]: data = np.ascontiguousarray(data) - if data.nbytes != self.nelems() * self.dtype().element_size(): + if data.nbytes != tensor_bytes: raise ValueError("data size does not match the tensor") rt.executor.tensor_write(self._tensor, data) else: diff --git a/python/executor_py.cpp b/python/executor_py.cpp index 59bee5a9b..b6cf8a7a8 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -15,19 +15,24 @@ static void tensor_write(ark::Executor *exe, const ark::Tensor &tensor, py::buffer host_buffer) { py::buffer_info info = host_buffer.request(); exe->tensor_write(tensor, reinterpret_cast(info.ptr), - info.size * info.itemsize); + info.size * info.itemsize, false); } static void tensor_write(ark::Executor *exe, const ark::Tensor &tensor, - size_t host_address, size_t bytes) { - exe->tensor_write(tensor, reinterpret_cast(host_address), bytes); + size_t address, size_t bytes, bool is_d2d) { + exe->tensor_write(tensor, reinterpret_cast(address), bytes, is_d2d); } static void tensor_read(ark::Executor *exe, const ark::Tensor &tensor, py::buffer host_buffer) { py::buffer_info info = host_buffer.request(); exe->tensor_read(tensor, reinterpret_cast(info.ptr), - info.size * info.itemsize); + info.size * info.itemsize, false); +} + +static void tensor_read(ark::Executor *exe, const ark::Tensor &tensor, + size_t address, size_t bytes, bool is_d2d) { + exe->tensor_read(tensor, reinterpret_cast(address), bytes, is_d2d); } DLManagedTensor *to_dlpack(ark::Executor &exe, const ark::Tensor &tensor) { @@ -59,6 +64,7 @@ void register_executor(py::module &m) { py::init(), py::arg("rank"), py::arg("world_size"), py::arg("gpu_id"), py::arg("name"), py::arg("plan")) + .def("gpu_id", &ark::Executor::gpu_id) .def("compile", &ark::Executor::compile) .def("launch", &ark::Executor::launch, py::arg("max_spin_count") = -1) .def("run", &ark::Executor::run, py::arg("iter")) @@ -67,7 +73,16 @@ void register_executor(py::module &m) { .def("barrier", &ark::Executor::barrier) .def("destroy", &ark::Executor::destroy) .def("destroyed", &ark::Executor::destroyed) - .def("tensor_read", &tensor_read, py::arg("tensor"), py::arg("data")) + .def( + "tensor_read", + py::overload_cast( + &tensor_read), + py::arg("tensor"), py::arg("data")) + .def("tensor_read", + py::overload_cast(&tensor_read), + py::arg("tensor"), py::arg("address"), py::arg("bytes"), + py::arg("is_d2d")) .def( "tensor_write", py::overload_cast( @@ -75,8 +90,8 @@ void register_executor(py::module &m) { py::arg("tensor"), py::arg("data")) .def("tensor_write", py::overload_cast(&tensor_write), - py::arg("tensor"), py::arg("address"), py::arg("bytes")) - .def("get_dl_tensor", &to_dlpack_capsule), - py::arg("tensor"); + size_t, bool>(&tensor_write), + py::arg("tensor"), py::arg("address"), py::arg("bytes"), + py::arg("is_d2d")) + .def("get_dl_tensor", &to_dlpack_capsule); } diff --git a/python/unittest/test_conversion.py b/python/unittest/test_conversion.py index 8f941a033..5befa1c34 100644 --- a/python/unittest/test_conversion.py +++ b/python/unittest/test_conversion.py @@ -1,7 +1,14 @@ -import torch +import pytest import numpy as np import ark +try: + import torch + + _no_torch = False +except ImportError: + _no_torch = True + def initialize_tensor(dimensions, dtype): tensor = ark.tensor(dimensions, dtype) @@ -11,6 +18,8 @@ def initialize_tensor(dimensions, dtype): # Test function to validate the integrity of the PyTorch view of the ARK tensor, # including its data and attributes such as shape and data type. +@pytest.mark.parametrize("num_dims,size", [(1, 5), (1, 1024), (2, 5), (2, 32)]) +@pytest.mark.parametrize("dtype", [ark.fp16, ark.fp32]) def test_values_fixed_dims(num_dims: int, size: int, dtype: ark.DataType): ark.init() dimensions = [size] * num_dims @@ -59,6 +68,7 @@ def check_diff(input_tensor_host, input_view_numpy, value, index): # Test function to check if changes to the torch views are reflected in the original tensors +@pytest.mark.parametrize("dtype", [ark.fp16, ark.fp32]) def test_aliasing(dtype: ark.DataType): ark.init() dimensions = [4, 4] @@ -91,3 +101,28 @@ def test_aliasing(dtype: ark.DataType): runtime.stop() runtime.reset() + + +def test_conversion_torch(): + if _no_torch: + pytest.skip("PyTorch not available") + + dimensions = [4, 4] + + ark.init() + t = ark.constant(7, dimensions) + + with ark.Runtime() as rt: + rt.launch() + + torch_tensor = t.to_torch() + + assert torch_tensor.shape == (4, 4) + assert torch_tensor.dtype == torch.float32 + assert torch_tensor.device.type == "cuda" + assert torch.all(torch_tensor == 0) + + rt.run() + + torch_tensor = t.to_torch() + assert torch.all(torch_tensor == 7) From 20c23f34b17ecfa24d96ffa8799c3c173b468c53 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 24 Jun 2024 23:58:59 +0000 Subject: [PATCH 12/54] lint --- ark/api/executor.cpp | 46 +++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 96e53c8cf..ae3e5f499 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -158,8 +158,8 @@ class Executor::Impl { void tensor_read(const Tensor tensor, void *data, size_t bytes, bool is_d2d) const; - void tensor_write(const Tensor tensor, const void *data, - size_t bytes, bool is_d2d) const; + void tensor_write(const Tensor tensor, const void *data, size_t bytes, + bool is_d2d) const; DLDeviceType get_device_type() const; DLManagedTensor *get_dl_tensor(const Tensor &tensor) const; @@ -733,8 +733,8 @@ void Executor::Impl::barrier() { } } -void Executor::Impl::tensor_read(const Tensor tensor, void *data, - size_t bytes, bool is_d2d) const { +void Executor::Impl::tensor_read(const Tensor tensor, void *data, size_t bytes, + bool is_d2d) const { GLOG(gpuSetDevice(gpu_id_)); size_t tensor_data_bytes = tensor.shape().nelems() * tensor.data_type().bytes(); @@ -760,15 +760,15 @@ void Executor::Impl::tensor_read(const Tensor tensor, void *data, copy_stream_->sync(); if (!is_d2d) { tensor_to_data(tensor_host.data(), static_cast(data), - tensor.shape(), tensor.strides(), tensor.offsets(), - tensor.data_type().bytes()); + tensor.shape(), tensor.strides(), tensor.offsets(), + tensor.data_type().bytes()); return; } // TODO: convert data layout on the device directly std::vector data_host(bytes); - tensor_to_data(tensor_host.data(), data_host.data(), - tensor.shape(), tensor.strides(), tensor.offsets(), - tensor.data_type().bytes()); + tensor_to_data(tensor_host.data(), data_host.data(), tensor.shape(), + tensor.strides(), tensor.offsets(), + tensor.data_type().bytes()); GLOG(gpuMemcpyAsync(data, data_host.data(), bytes, gpuMemcpyHostToDevice, copy_stream_->get())); } @@ -794,22 +794,24 @@ void Executor::Impl::tensor_write(const Tensor tensor, const void *data, auto kind = (is_d2d) ? gpuMemcpyDeviceToDevice : gpuMemcpyHostToDevice; void *dst = buffer_->ref(offset); if (tensor.strides() == tensor.shape()) { - GLOG(gpuMemcpyAsync(dst, data, tensor_bytes, kind, copy_stream_->get())); + GLOG( + gpuMemcpyAsync(dst, data, tensor_bytes, kind, copy_stream_->get())); } else { std::vector tensor_host(tensor_bytes); if (!is_d2d) { - data_to_tensor(tensor_host.data(), static_cast(data), - tensor.shape(), tensor.strides(), tensor.offsets(), - tensor.data_type().bytes()); + data_to_tensor(tensor_host.data(), + static_cast(data), tensor.shape(), + tensor.strides(), tensor.offsets(), + tensor.data_type().bytes()); } else { // TODO: convert data layout on the device directly std::vector tmp(bytes); - GLOG(gpuMemcpyAsync(tmp.data(), data, bytes, - gpuMemcpyDeviceToHost, copy_stream_->get())); + GLOG(gpuMemcpyAsync(tmp.data(), data, bytes, gpuMemcpyDeviceToHost, + copy_stream_->get())); copy_stream_->sync(); - data_to_tensor(tensor_host.data(), tmp.data(), - tensor.shape(), tensor.strides(), tensor.offsets(), - tensor.data_type().bytes()); + data_to_tensor(tensor_host.data(), tmp.data(), tensor.shape(), + tensor.strides(), tensor.offsets(), + tensor.data_type().bytes()); } GLOG(gpuMemcpyAsync(dst, tensor_host.data(), tensor_bytes, gpuMemcpyHostToDevice, copy_stream_->get())); @@ -932,13 +934,13 @@ void Executor::destroy() { impl_.reset(nullptr); } bool Executor::destroyed() const { return impl_.get() == nullptr; } -void Executor::tensor_read(const Tensor tensor, void *data, - size_t bytes, bool is_d2d) const { +void Executor::tensor_read(const Tensor tensor, void *data, size_t bytes, + bool is_d2d) const { impl_->tensor_read(tensor, data, bytes, is_d2d); } -void Executor::tensor_write(const Tensor tensor, const void *data, - size_t bytes, bool is_d2d) const { +void Executor::tensor_write(const Tensor tensor, const void *data, size_t bytes, + bool is_d2d) const { impl_->tensor_write(tensor, data, bytes, is_d2d); } From ebe85604cb7249b4e0d7d6c3eed69758c4c6825f Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 25 Jun 2024 01:21:42 +0000 Subject: [PATCH 13/54] Seperate DLPack from C++ interfaces --- ark/api/executor.cpp | 127 +++++------------------------------ ark/include/ark/executor.hpp | 8 +-- python/executor_py.cpp | 90 ++++++++++++++++++++++++- 3 files changed, 106 insertions(+), 119 deletions(-) diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index ae3e5f499..ebfa7016d 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -156,12 +156,12 @@ class Executor::Impl { float stop(int64_t max_spin_count); void barrier(); + uintptr_t tensor_address(const Tensor tensor) const; + void tensor_read(const Tensor tensor, void *data, size_t bytes, bool is_d2d) const; void tensor_write(const Tensor tensor, const void *data, size_t bytes, bool is_d2d) const; - DLDeviceType get_device_type() const; - DLManagedTensor *get_dl_tensor(const Tensor &tensor) const; private: void init_communicator(); @@ -733,6 +733,15 @@ void Executor::Impl::barrier() { } } +uintptr_t Executor::Impl::tensor_address(const Tensor tensor) const { + size_t buffer_id = tensor.ref()->buffer()->id(); + if (buffer_id_to_offset_.find(buffer_id) == buffer_id_to_offset_.end()) { + ERR(NotFoundError, "Invalid buffer ID: ", buffer_id); + } + size_t offset = buffer_id_to_offset_.at(buffer_id); + return reinterpret_cast(buffer_->ref(offset)); +} + void Executor::Impl::tensor_read(const Tensor tensor, void *data, size_t bytes, bool is_d2d) const { GLOG(gpuSetDevice(gpu_id_)); @@ -742,13 +751,8 @@ void Executor::Impl::tensor_read(const Tensor tensor, void *data, size_t bytes, ERR(InvalidUsageError, "Destination bytes (", bytes, ") mismatches the tensor data bytes (", tensor_data_bytes, ")."); } - size_t buffer_id = tensor.ref()->buffer()->id(); - if (buffer_id_to_offset_.find(buffer_id) == buffer_id_to_offset_.end()) { - ERR(NotFoundError, "Invalid buffer ID: ", buffer_id); - } - size_t offset = buffer_id_to_offset_.at(buffer_id); auto kind = (is_d2d) ? gpuMemcpyDeviceToDevice : gpuMemcpyDeviceToHost; - void *src = buffer_->ref(offset); + void *src = reinterpret_cast(tensor_address(tensor)); if (tensor.strides() == tensor.shape()) { GLOG(gpuMemcpyAsync(data, src, bytes, kind, copy_stream_->get())); } else { @@ -784,15 +788,10 @@ void Executor::Impl::tensor_write(const Tensor tensor, const void *data, ERR(InvalidUsageError, "Source bytes (", bytes, ") mismatches the tensor data bytes (", tensor_data_bytes, ")."); } - size_t buffer_id = tensor.ref()->buffer()->id(); - if (buffer_id_to_offset_.find(buffer_id) == buffer_id_to_offset_.end()) { - ERR(NotFoundError, "Invalid buffer ID: ", buffer_id); - } - size_t offset = buffer_id_to_offset_.at(buffer_id); size_t tensor_bytes = tensor.strides().nelems() * tensor.data_type().bytes(); auto kind = (is_d2d) ? gpuMemcpyDeviceToDevice : gpuMemcpyHostToDevice; - void *dst = buffer_->ref(offset); + void *dst = reinterpret_cast(tensor_address(tensor)); if (tensor.strides() == tensor.shape()) { GLOG( gpuMemcpyAsync(dst, data, tensor_bytes, kind, copy_stream_->get())); @@ -819,94 +818,6 @@ void Executor::Impl::tensor_write(const Tensor tensor, const void *data, copy_stream_->sync(); } -DLDeviceType Executor::Impl::get_device_type() const { -#if defined(ARK_CUDA) - return kDLCUDA; -#elif defined(ARK_ROCM) - return kDLROCM; -#else - return kDLCPU; -#endif -} - -DLDataType get_dl_dtype(const DataType &ark_data_type) { - DLDataType dl_data_type; - dl_data_type.lanes = 1; - if (ark_data_type == FP32) { - dl_data_type.code = kDLFloat; - dl_data_type.bits = 32; - } else if (ark_data_type == FP16) { - dl_data_type.code = kDLFloat; - dl_data_type.bits = 16; - } else if (ark_data_type == BF16) { - dl_data_type.code = kDLBfloat; - dl_data_type.bits = 16; - } else if (ark_data_type == INT32) { - dl_data_type.code = kDLInt; - dl_data_type.bits = 32; - } else if (ark_data_type == UINT32) { - dl_data_type.code = kDLUInt; - dl_data_type.bits = 32; - } else if (ark_data_type == INT8) { - dl_data_type.code = kDLInt; - dl_data_type.bits = 8; - } else if (ark_data_type == UINT8) { - dl_data_type.code = kDLUInt; - dl_data_type.bits = 8; - } else if (ark_data_type == BYTE) { - dl_data_type.code = kDLUInt; - dl_data_type.bits = 8; - } else { - ERR(InvalidUsageError, "Unsupported data type"); - } - return dl_data_type; -} - -DLManagedTensor *Executor::Impl::get_dl_tensor(const Tensor &tensor) const { - DLTensor dl_tensor; - dl_tensor.data = - buffer_->ref(buffer_id_to_offset_.at(tensor.ref()->buffer()->id())); - size_t offset_in_elements = - tensor.offsets().is_no_dim() ? 0 : tensor.offsets().vector()[0]; - dl_tensor.byte_offset = offset_in_elements * tensor.data_type().bytes(); - dl_tensor.device.device_type = get_device_type(); - dl_tensor.device.device_id = static_cast(gpu_id_); - dl_tensor.ndim = static_cast(tensor.shape().ndims()); - dl_tensor.dtype = get_dl_dtype(tensor.data_type()); - - dl_tensor.shape = - tensor.shape().is_no_dim() ? nullptr : new int64_t[dl_tensor.ndim]; - dl_tensor.strides = - tensor.strides().is_no_dim() ? nullptr : new int64_t[dl_tensor.ndim]; - auto shape = tensor.shape(); - if (dl_tensor.shape) { - for (int i = 0; i < dl_tensor.ndim; ++i) { - dl_tensor.shape[i] = shape[i]; - } - } - if (dl_tensor.strides) { - dl_tensor.strides[dl_tensor.ndim - 1] = 1; - for (int i = dl_tensor.ndim - 2; i >= 0; --i) { - dl_tensor.strides[i] = - dl_tensor.shape[i + 1] * dl_tensor.strides[i + 1]; - } - } - DLManagedTensor *dl_managed_tensor = new DLManagedTensor(); - dl_managed_tensor->dl_tensor = dl_tensor; - dl_managed_tensor->manager_ctx = nullptr; - dl_managed_tensor->deleter = [](DLManagedTensor *self) { - if (self->dl_tensor.shape) { - delete[] self->dl_tensor.shape; - self->dl_tensor.shape = nullptr; - } - if (self->dl_tensor.strides) { - delete[] self->dl_tensor.strides; - self->dl_tensor.strides = nullptr; - } - }; - return dl_managed_tensor; -} - Executor::Executor(int rank, int world_size, int gpu_id, const std::string &name, const std::string &plan) : impl_(std::make_unique(rank, world_size, gpu_id, name, @@ -934,6 +845,10 @@ void Executor::destroy() { impl_.reset(nullptr); } bool Executor::destroyed() const { return impl_.get() == nullptr; } +uintptr_t Executor::tensor_address(const Tensor tensor) const { + return impl_->tensor_address(tensor); +} + void Executor::tensor_read(const Tensor tensor, void *data, size_t bytes, bool is_d2d) const { impl_->tensor_read(tensor, data, bytes, is_d2d); @@ -944,14 +859,6 @@ void Executor::tensor_write(const Tensor tensor, const void *data, size_t bytes, impl_->tensor_write(tensor, data, bytes, is_d2d); } -DLDeviceType Executor::get_device_type() const { - return impl_->get_device_type(); -} - -DLManagedTensor *Executor::get_dl_tensor(const Tensor &tensor) const { - return impl_->get_dl_tensor(tensor); -} - DefaultExecutor::DefaultExecutor(const Model &model, int gpu_id, const std::string &name) : Executor( diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index a5d6f0273..b8cdaf273 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -4,8 +4,6 @@ #ifndef ARK_EXECUTOR_HPP #define ARK_EXECUTOR_HPP -#include - #include #include #include @@ -50,6 +48,8 @@ class Executor { bool destroyed() const; + uintptr_t tensor_address(const Tensor tensor) const; + template void tensor_read(const Tensor tensor, std::vector &data) const { tensor_read(tensor, reinterpret_cast(data.data()), @@ -68,10 +68,6 @@ class Executor { void tensor_write(const Tensor tensor, const void *data, size_t bytes, bool is_d2d = false) const; - DLManagedTensor *get_dl_tensor(const Tensor &tensor) const; - - DLDeviceType get_device_type() const; - private: class Impl; std::unique_ptr impl_; diff --git a/python/executor_py.cpp b/python/executor_py.cpp index b6cf8a7a8..e5ab4f964 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace py = pybind11; static void tensor_write(ark::Executor *exe, const ark::Tensor &tensor, @@ -35,9 +36,92 @@ static void tensor_read(ark::Executor *exe, const ark::Tensor &tensor, exe->tensor_read(tensor, reinterpret_cast(address), bytes, is_d2d); } -DLManagedTensor *to_dlpack(ark::Executor &exe, const ark::Tensor &tensor) { - DLManagedTensor *dl_tensor = exe.get_dl_tensor(tensor); - return dl_tensor; +static DLDataType get_dl_dtype(const ark::DataType &ark_data_type) { + DLDataType dl_data_type; + dl_data_type.lanes = 1; + if (ark_data_type == ark::FP32) { + dl_data_type.code = kDLFloat; + dl_data_type.bits = 32; + } else if (ark_data_type == ark::FP16) { + dl_data_type.code = kDLFloat; + dl_data_type.bits = 16; + } else if (ark_data_type == ark::BF16) { + dl_data_type.code = kDLBfloat; + dl_data_type.bits = 16; + } else if (ark_data_type == ark::INT32) { + dl_data_type.code = kDLInt; + dl_data_type.bits = 32; + } else if (ark_data_type == ark::UINT32) { + dl_data_type.code = kDLUInt; + dl_data_type.bits = 32; + } else if (ark_data_type == ark::INT8) { + dl_data_type.code = kDLInt; + dl_data_type.bits = 8; + } else if (ark_data_type == ark::UINT8) { + dl_data_type.code = kDLUInt; + dl_data_type.bits = 8; + } else if (ark_data_type == ark::BYTE) { + dl_data_type.code = kDLUInt; + dl_data_type.bits = 8; + } else { + throw std::runtime_error("unexpected error"); + } + return dl_data_type; +} + +static DLDeviceType get_device_type() { +#if defined(ARK_CUDA) + return kDLCUDA; +#elif defined(ARK_ROCM) + return kDLROCM; +#else + return kDLCPU; +#endif +} + +static DLManagedTensor *to_dlpack(ark::Executor &exe, + const ark::Tensor &tensor) { + DLTensor dl_tensor; + dl_tensor.data = reinterpret_cast(exe.tensor_address(tensor)); + size_t offset_in_elements = + tensor.offsets().is_no_dim() ? 0 : tensor.offsets().vector()[0]; + dl_tensor.byte_offset = offset_in_elements * tensor.data_type().bytes(); + dl_tensor.device.device_type = get_device_type(); + dl_tensor.device.device_id = static_cast(exe.gpu_id()); + dl_tensor.ndim = static_cast(tensor.shape().ndims()); + dl_tensor.dtype = get_dl_dtype(tensor.data_type()); + + dl_tensor.shape = + tensor.shape().is_no_dim() ? nullptr : new int64_t[dl_tensor.ndim]; + dl_tensor.strides = + tensor.strides().is_no_dim() ? nullptr : new int64_t[dl_tensor.ndim]; + auto shape = tensor.shape(); + if (dl_tensor.shape) { + for (int i = 0; i < dl_tensor.ndim; ++i) { + dl_tensor.shape[i] = shape[i]; + } + } + if (dl_tensor.strides) { + dl_tensor.strides[dl_tensor.ndim - 1] = 1; + for (int i = dl_tensor.ndim - 2; i >= 0; --i) { + dl_tensor.strides[i] = + dl_tensor.shape[i + 1] * dl_tensor.strides[i + 1]; + } + } + DLManagedTensor *dl_managed_tensor = new DLManagedTensor(); + dl_managed_tensor->dl_tensor = dl_tensor; + dl_managed_tensor->manager_ctx = nullptr; + dl_managed_tensor->deleter = [](DLManagedTensor *self) { + if (self->dl_tensor.shape) { + delete[] self->dl_tensor.shape; + self->dl_tensor.shape = nullptr; + } + if (self->dl_tensor.strides) { + delete[] self->dl_tensor.strides; + self->dl_tensor.strides = nullptr; + } + }; + return dl_managed_tensor; } void free_capsule(PyObject *capsule) { From 08c9b899c22b759a6f4f194b7932f48d08eeb8f4 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 25 Jun 2024 01:30:50 +0000 Subject: [PATCH 14/54] Update workflow trigger --- .github/workflows/ut-cuda.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/ut-cuda.yml b/.github/workflows/ut-cuda.yml index 5a78818ff..918c1a4a8 100644 --- a/.github/workflows/ut-cuda.yml +++ b/.github/workflows/ut-cuda.yml @@ -7,8 +7,7 @@ on: pull_request: branches: - main - types: - - ready_for_review + types: [opened, synchronize, reopened, ready_for_review] jobs: UnitTest: From 1fa08afa36010116cdcd6d89e64db104f3fa23d1 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 25 Jun 2024 20:53:29 +0000 Subject: [PATCH 15/54] expose exceptions --- ark/api/dims.cpp | 1 - ark/include/ark.hpp | 1 + ark/{ => include/ark}/error.hpp | 15 ++++++++++----- ark/logging.h | 2 +- python/ark/__init__.py | 12 ++++++++++++ python/ark/error.py | 12 ++++++++++++ python/ark_py.cpp | 2 ++ python/error_py.cpp | 25 +++++++++++++++++++++++++ python/unittest/test_error.py | 12 ++++++++++++ 9 files changed, 75 insertions(+), 7 deletions(-) rename ark/{ => include/ark}/error.hpp (70%) create mode 100644 python/ark/error.py create mode 100644 python/error_py.cpp create mode 100644 python/unittest/test_error.py diff --git a/ark/api/dims.cpp b/ark/api/dims.cpp index a2830a060..a1f03b426 100644 --- a/ark/api/dims.cpp +++ b/ark/api/dims.cpp @@ -5,7 +5,6 @@ #include -#include "error.hpp" #include "logging.h" namespace ark { diff --git a/ark/include/ark.hpp b/ark/include/ark.hpp index a7b2f7f70..2ca796172 100644 --- a/ark/include/ark.hpp +++ b/ark/include/ark.hpp @@ -10,6 +10,7 @@ #include #include +#include #include #include #include diff --git a/ark/error.hpp b/ark/include/ark/error.hpp similarity index 70% rename from ark/error.hpp rename to ark/include/ark/error.hpp index e08acd975..78d02cab3 100644 --- a/ark/error.hpp +++ b/ark/include/ark/error.hpp @@ -1,17 +1,21 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#ifndef ARK_ERROR_HPP_ -#define ARK_ERROR_HPP_ +#ifndef ARK_ERROR_HPP +#define ARK_ERROR_HPP #include #include namespace ark { -class BaseError : public std::runtime_error { +class BaseError : public std::exception { + private: + std::string msg_; + public: - BaseError(const std::string &msg) : std::runtime_error(msg) {} + BaseError(const std::string &msg) : msg_(msg) {} + const char *what() const noexcept override { return msg_.c_str(); } }; #define REGISTER_ERROR_TYPE(_name) \ @@ -20,6 +24,7 @@ class BaseError : public std::runtime_error { _name(const std::string &msg) : BaseError(msg) {} \ }; +REGISTER_ERROR_TYPE(InternalError) REGISTER_ERROR_TYPE(InvalidUsageError) REGISTER_ERROR_TYPE(NotFoundError) REGISTER_ERROR_TYPE(ModelError) @@ -32,4 +37,4 @@ REGISTER_ERROR_TYPE(UnitTestError) } // namespace ark -#endif // ARK_ERROR_HPP_ +#endif // ARK_ERROR_HPP diff --git a/ark/logging.h b/ark/logging.h index d29793ff7..6eb8aaf91 100644 --- a/ark/logging.h +++ b/ark/logging.h @@ -8,7 +8,7 @@ #include #include -#include "error.hpp" +#include "ark/error.hpp" namespace ark { diff --git a/python/ark/__init__.py b/python/ark/__init__.py index 2a4d164e4..3d162c3e4 100644 --- a/python/ark/__init__.py +++ b/python/ark/__init__.py @@ -91,3 +91,15 @@ def set_world_size(world_size): ones, zeros, ) +from .error import ( + InternalError, + InvalidUsageError, + NotFoundError, + ModelError, + SchedulerError, + ExecutorError, + SystemError, + GpuError, + RuntimeError, +) + diff --git a/python/ark/error.py b/python/ark/error.py new file mode 100644 index 000000000..d3ac3aee8 --- /dev/null +++ b/python/ark/error.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from _ark_core import _InternalError as InternalError +from _ark_core import _InvalidUsageError as InvalidUsageError +from _ark_core import _NotFoundError as NotFoundError +from _ark_core import _ModelError as ModelError +from _ark_core import _SchedulerError as SchedulerError +from _ark_core import _ExecutorError as ExecutorError +from _ark_core import _SystemError as SystemError +from _ark_core import _GpuError as GpuError +from _ark_core import _RuntimeError as RuntimeError diff --git a/python/ark_py.cpp b/python/ark_py.cpp index 35c3b21c3..1bc4255d6 100644 --- a/python/ark_py.cpp +++ b/python/ark_py.cpp @@ -9,6 +9,7 @@ namespace py = pybind11; extern void register_data_type(py::module &m); extern void register_dims(py::module &m); +extern void register_error(py::module &m); extern void register_executor(py::module &m); extern void register_init(py::module &m); extern void register_model_graph(py::module &m); @@ -23,6 +24,7 @@ PYBIND11_MODULE(_ark_core, m) { register_data_type(m); register_dims(m); + register_error(m); register_executor(m); register_init(m); register_model_graph(m); diff --git a/python/error_py.cpp b/python/error_py.cpp new file mode 100644 index 000000000..863d8423d --- /dev/null +++ b/python/error_py.cpp @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include + +#include + +namespace py = pybind11; + +#define REGISTER_ERROR_PY(_name) \ + py::register_exception(m, "_" #_name) + +void register_error(py::module &m) { + REGISTER_ERROR_PY(InternalError); + REGISTER_ERROR_PY(InvalidUsageError); + REGISTER_ERROR_PY(NotFoundError); + REGISTER_ERROR_PY(ModelError); + REGISTER_ERROR_PY(SchedulerError); + REGISTER_ERROR_PY(ExecutorError); + REGISTER_ERROR_PY(SystemError); + REGISTER_ERROR_PY(GpuError); + REGISTER_ERROR_PY(RuntimeError); +} diff --git a/python/unittest/test_error.py b/python/unittest/test_error.py new file mode 100644 index 000000000..c063c05c5 --- /dev/null +++ b/python/unittest/test_error.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import ark + + +def test_error(): + ark.init() + try: + ark.tensor([0]) + except Exception as e: + assert isinstance(e, ark.InvalidUsageError) From 59caff1eddb0a01c4f7bdf6e082b96d22e10ad6e Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 26 Jun 2024 23:25:35 +0000 Subject: [PATCH 16/54] Build python module by default --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index ee1e3566e..9ba2f2c55 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,6 +17,7 @@ option(USE_CUDA "Use NVIDIA/CUDA." OFF) option(USE_ROCM "Use AMD/ROCm." OFF) option(BYPASS_GPU_CHECK "Bypass GPU check." OFF) option(BUILD_TESTS "Build unit tests." ON) +option(BUILD_PYTHON "Build Python module." ON) if(BYPASS_GPU_CHECK) if(USE_CUDA) From efb2c78145cab0832971205911320264bbe74870 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Sat, 29 Jun 2024 03:51:19 +0000 Subject: [PATCH 17/54] revert --- ark/include/kernels/kernel_template.in | 1 + 1 file changed, 1 insertion(+) diff --git a/ark/include/kernels/kernel_template.in b/ark/include/kernels/kernel_template.in index 876e6a1b4..ea1862920 100644 --- a/ark/include/kernels/kernel_template.in +++ b/ark/include/kernels/kernel_template.in @@ -64,5 +64,6 @@ void @NAME@(char *_buf, int *_iter) { if (threadIdx.x == 0 && blockIdx.x == 0) { atomicStoreRelaxed(_iter, 0); } + sync_gpu<@NUM_BLOCKS@>(ARK_LOOP_SYNC_STATE); } } From 8975f9d4a0574f0421e79f6dd49e7443e7244606 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Sat, 29 Jun 2024 04:03:20 +0000 Subject: [PATCH 18/54] Do not use `sys.path` for importing `_ark_core` --- python/ark/__init__.py | 5 +---- python/ark/error.py | 18 +++++++++--------- python/ark/init.py | 2 +- python/ark/model.py | 2 +- python/ark/runtime.py | 2 +- python/ark/tensor.py | 2 +- 6 files changed, 14 insertions(+), 17 deletions(-) diff --git a/python/ark/__init__.py b/python/ark/__init__.py index 3d162c3e4..031afc7ba 100644 --- a/python/ark/__init__.py +++ b/python/ark/__init__.py @@ -7,9 +7,7 @@ if os.environ.get("ARK_ROOT", None) is None: os.environ["ARK_ROOT"] = os.path.abspath(os.path.dirname(__file__)) -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - -import _ark_core +from . import _ark_core from .model import Model @@ -102,4 +100,3 @@ def set_world_size(world_size): GpuError, RuntimeError, ) - diff --git a/python/ark/error.py b/python/ark/error.py index d3ac3aee8..40f7391ac 100644 --- a/python/ark/error.py +++ b/python/ark/error.py @@ -1,12 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from _ark_core import _InternalError as InternalError -from _ark_core import _InvalidUsageError as InvalidUsageError -from _ark_core import _NotFoundError as NotFoundError -from _ark_core import _ModelError as ModelError -from _ark_core import _SchedulerError as SchedulerError -from _ark_core import _ExecutorError as ExecutorError -from _ark_core import _SystemError as SystemError -from _ark_core import _GpuError as GpuError -from _ark_core import _RuntimeError as RuntimeError +from ._ark_core import _InternalError as InternalError +from ._ark_core import _InvalidUsageError as InvalidUsageError +from ._ark_core import _NotFoundError as NotFoundError +from ._ark_core import _ModelError as ModelError +from ._ark_core import _SchedulerError as SchedulerError +from ._ark_core import _ExecutorError as ExecutorError +from ._ark_core import _SystemError as SystemError +from ._ark_core import _GpuError as GpuError +from ._ark_core import _RuntimeError as RuntimeError diff --git a/python/ark/init.py b/python/ark/init.py index dbf7c1569..32f530791 100644 --- a/python/ark/init.py +++ b/python/ark/init.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import _ark_core +from . import _ark_core from .model import Model from .runtime import _RuntimeState diff --git a/python/ark/model.py b/python/ark/model.py index e6208fc16..87af88f49 100644 --- a/python/ark/model.py +++ b/python/ark/model.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. from typing import NewType -from _ark_core import _Model +from ._ark_core import _Model _ModelState = NewType("_ModelState", None) diff --git a/python/ark/runtime.py b/python/ark/runtime.py index 798eaf9d5..efae6ab3c 100644 --- a/python/ark/runtime.py +++ b/python/ark/runtime.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Callable, Dict, List -from _ark_core import _Executor, _DefaultPlanner +from ._ark_core import _Executor, _DefaultPlanner from .model import Model diff --git a/python/ark/tensor.py b/python/ark/tensor.py index eff1bf20e..ac2886960 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -4,7 +4,7 @@ import numpy as np from typing import Callable, List, Union, Type -from _ark_core import _Dims, _Tensor, _NullTensor +from ._ark_core import _Dims, _Tensor, _NullTensor from .data_type import DataType from .runtime import Runtime from .model import Model From 153837ba60497413d70c90fed945eaa037c84a29 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 2 Jul 2024 04:09:10 +0000 Subject: [PATCH 19/54] wip --- ark/api/executor.cpp | 51 +- ark/codegen.cpp | 3 +- ark/include/ark/executor.hpp | 7 +- ark/include/kernels/common/broadcast.h | 4 +- ark/model/model_json.cpp | 11 +- ark/model/model_json.hpp | 2 +- ark/model/model_op.cpp | 5 +- ark/ops/ops_all_reduce_test.cpp | 2 +- ark/ops/ops_arithmetic_test.cpp | 48 +- ark/ops/ops_embedding_test.cpp | 2 +- ark/ops/ops_matmul.cpp | 30 +- ark/ops/ops_test_common.cpp | 10 +- ark/ops/ops_test_common.hpp | 6 +- examples/llama/README.md | 4 +- examples/llama/model_test.py | 88 +- plan_gpu0.json | 2504 ++++++++++++++++++++++++ python/ark/__init__.py | 1 + python/ark/profiler.py | 30 + python/executor_py.cpp | 1 + 19 files changed, 2706 insertions(+), 103 deletions(-) create mode 100644 plan_gpu0.json create mode 100644 python/ark/profiler.py diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index b052040ef..4af9df7c0 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -143,11 +143,13 @@ static size_t tensor_stride_bytes(const Json &tensor) { class Executor::Impl { public: - Impl(int rank, int world_size, int gpu_id, const std::string &name, - const std::string &plan); + Impl(int rank, int world_size, int gpu_id, const std::string &name); ~Impl() = default; + void init(const std::string &plan); + int gpu_id() const { return gpu_id_; } + std::string plan() const { return plan_json_.dump_pretty(); } void compile(); void launch(int64_t max_spin_count); @@ -173,11 +175,13 @@ class Executor::Impl { const int rank_; const int world_size_; int gpu_id_; + std::string name_; bool is_launched_ = false; bool is_recording_ = false; float elapsed_msec_ = -1; + PlanJson plan_json_; std::map buffer_id_to_offset_; size_t total_bytes_; std::shared_ptr codegen_; @@ -199,8 +203,8 @@ class Executor::Impl { }; Executor::Impl::Impl(int rank, int world_size, int gpu_id, - const std::string &name, const std::string &plan) - : rank_(rank), world_size_(world_size), gpu_id_(gpu_id) { + const std::string &name) + : rank_(rank), world_size_(world_size), gpu_id_(gpu_id), name_(name) { if (rank < 0 || rank >= world_size) { ERR(InvalidUsageError, "Invalid rank ", rank, " with world size ", world_size); @@ -211,17 +215,18 @@ Executor::Impl::Impl(int rank, int world_size, int gpu_id, if (world_size_ > 1) { init_communicator(); } +} - Json plan_json; +void Executor::Impl::init(const std::string &plan) { auto &plan_path = get_env().enforce_plan_path; if (!plan_path.empty()) { LOG(INFO, "Enforce executor plan path: ", plan_path); - plan_json = Json::parse(read_file(plan_path)); + plan_json_ = Json::parse(read_file(plan_path)); } else { - plan_json = Json::parse(plan); + plan_json_ = Json::parse(plan); } - buffer_id_to_offset_ = init_buffers(plan_json); + buffer_id_to_offset_ = init_buffers(plan_json_); std::string buffer_id_to_offset_str; for (const auto &kv : buffer_id_to_offset_) { @@ -230,7 +235,7 @@ Executor::Impl::Impl(int rank, int world_size, int gpu_id, } codegen_ = - std::make_shared(plan_json, buffer_id_to_offset_, name); + std::make_shared(plan_json_, buffer_id_to_offset_, name_); auto gpu_manager = GpuManager::get_instance(gpu_id_); timer_begin_ = gpu_manager->create_event(); @@ -249,13 +254,13 @@ Executor::Impl::Impl(int rank, int world_size, int gpu_id, static_cast(gpu_manager->info().smem_block_total); if (world_size_ > 1) { - auto remote_ranks = init_remote_ranks(plan_json); + auto remote_ranks = init_remote_ranks(plan_json_); init_channels(remote_ranks); } kernel_ = std::shared_ptr(new GpuKernel( gpu_id_, codegen_->code(), {threads_per_block, 1, 1}, {num_sm, 1, 1}, - std::max(smem_block_total, size_t(4)), name, + std::max(smem_block_total, size_t(4)), name_, {std::pair{buffer_->ref(), sizeof(buffer_->ref())}, std::pair{flag, sizeof(flag)}})); } @@ -812,13 +817,18 @@ void Executor::Impl::tensor_write(const Tensor tensor, const void *data, Executor::Executor(int rank, int world_size, int gpu_id, const std::string &name, const std::string &plan) - : impl_(std::make_unique(rank, world_size, gpu_id, name, - plan)) {} + : impl_(std::make_unique(rank, world_size, gpu_id, name)) { + if (!plan.empty()) { + impl_->init(plan); + } +} Executor::~Executor() = default; int Executor::gpu_id() const { return impl_->gpu_id(); } +std::string Executor::plan() const { return impl_->plan(); } + void Executor::compile() { impl_->compile(); } void Executor::launch(int64_t max_spin_count) { impl_->launch(max_spin_count); } @@ -852,14 +862,17 @@ void Executor::tensor_write(const Tensor tensor, const void *data, size_t bytes, } DefaultExecutor::DefaultExecutor(const Model &model, int gpu_id, - const std::string &name) + const std::vector& config_rules, + const std::string& name) : Executor( model.rank(), model.world_size(), (gpu_id < 0) ? (model.rank() % get_env().num_ranks_per_host) : gpu_id, - name, - DefaultPlanner(model, (gpu_id < 0) ? (model.rank() % - get_env().num_ranks_per_host) - : gpu_id) - .plan()) {} + name, "") { + DefaultPlanner planner(model, impl_->gpu_id()); + for (const auto &rule : config_rules) { + planner.install_config_rule(rule); + } + impl_->init(planner.plan()); +} } // namespace ark diff --git a/ark/codegen.cpp b/ark/codegen.cpp index cd6206284..09ff28dd3 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -305,7 +305,8 @@ std::string CodeGenerator::Impl::resource_group( n_slots = total_warps / num_warps_per_task; } if (n_slots == 0) { - ERR(SchedulerError, "not enough resources for task group"); + ERR(SchedulerError, "not enough resources for task group: ", + tg.dump()); } size_t task_b = *task_range.begin(); diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index b8cdaf273..2473e1b14 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -5,6 +5,7 @@ #define ARK_EXECUTOR_HPP #include +#include #include #include #include @@ -24,6 +25,9 @@ class Executor { /// Return the GPU ID. int gpu_id() const; + /// Return the plan string. + std::string plan() const; + /// Compile the model. This must be called before `launch()`. void compile(); @@ -68,7 +72,7 @@ class Executor { void tensor_write(const Tensor tensor, const void *data, size_t bytes, bool is_d2d = false) const; - private: + protected: class Impl; std::unique_ptr impl_; }; @@ -78,6 +82,7 @@ class Model; class DefaultExecutor : public Executor { public: DefaultExecutor(const Model &model, int gpu_id = -1, + const std::vector& config_rules = {}, const std::string &name = "DefaultExecutor"); }; diff --git a/ark/include/kernels/common/broadcast.h b/ark/include/kernels/common/broadcast.h index 97b12e004..858938613 100644 --- a/ark/include/kernels/common/broadcast.h +++ b/ark/include/kernels/common/broadcast.h @@ -186,9 +186,9 @@ struct Broadcast2Intrinsic { (BroadcastInput0 && BroadcastInput1) ? OutNelemPerThread : BroadcastInput0 - ? math::gcd::value + ? math::gcd::value : BroadcastInput1 - ? math::gcd::value + ? math::gcd::value : math::gcd::value>::value; diff --git a/ark/model/model_json.cpp b/ark/model/model_json.cpp index 0057ef0aa..97ce71967 100644 --- a/ark/model/model_json.cpp +++ b/ark/model/model_json.cpp @@ -272,7 +272,16 @@ static void verify_format_plan(const Json &json) { } } -PlanJson::PlanJson(const Json &json) : Json(json) { verify_format_plan(*this); } +PlanJson::PlanJson(const Json &json) + : Json((json != nullptr) ? json + : Json{{"Rank", 0}, + {"WorldSize", 1}, + {"NumProcessors", 1}, + {"NumWarpsPerProcessor", 1}, + {"TaskInfos", Json::array()}, + {"ProcessorGroups", Json::array()}}) { + verify_format_plan(*this); +} static std::stringstream &dump_pretty_plan(const Json &json, std::stringstream &ss, int indent, diff --git a/ark/model/model_json.hpp b/ark/model/model_json.hpp index cf5fbbce2..e42640a9a 100644 --- a/ark/model/model_json.hpp +++ b/ark/model/model_json.hpp @@ -18,7 +18,7 @@ class ModelJson : public Json { class PlanJson : public Json { public: - PlanJson(const Json &json); + PlanJson(const Json &json = nullptr); std::string dump_pretty(int indent = 0, int indent_step = 2) const; }; diff --git a/ark/model/model_op.cpp b/ark/model/model_op.cpp index 6cdba5d02..b5a0645c8 100644 --- a/ark/model/model_op.cpp +++ b/ark/model/model_op.cpp @@ -202,8 +202,11 @@ std::shared_ptr ModelOp::deserialize(const Json &serialized) { } else if (!serialized.contains("Args")) { ERR(InvalidUsageError, "ModelOp deserialization failed: missing Args"); } + // Run `ModelOpT::from_name` before `construct()` to ensure all operators + // are registered. + auto op_type = ModelOpT::from_name(serialized["Type"]); auto ret = model_op_factory()->construct(serialized["Type"]); - ret->type_ = ModelOpT::from_name(serialized["Type"]); + ret->type_ = op_type; ret->name_ = serialized["Name"]; ret->is_virtual_ = serialized["IsVirtual"]; for (const auto &t : serialized["ReadTensors"]) { diff --git a/ark/ops/ops_all_reduce_test.cpp b/ark/ops/ops_all_reduce_test.cpp index 9e2c6f675..54c6426fa 100644 --- a/ark/ops/ops_all_reduce_test.cpp +++ b/ark/ops/ops_all_reduce_test.cpp @@ -94,7 +94,7 @@ void test_all_reduce_internal(ark::DimType nelem) { auto result = ark::op_test("all_reduce", m, {ones}, {output}, baseline_all_reduce, - {ones_vec.data()}, false, gpu_id, NumGpus); + {ones_vec.data()}); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); return ark::unittest::SUCCESS; diff --git a/ark/ops/ops_arithmetic_test.cpp b/ark/ops/ops_arithmetic_test.cpp index 3fdc5ac7e..c7c18b603 100644 --- a/ark/ops/ops_arithmetic_test.cpp +++ b/ark/ops/ops_arithmetic_test.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT license. #include "ops_test_common.hpp" +#include "model/model_json.hpp" template void baseline_add(std::vector &outputs, @@ -142,12 +143,25 @@ ark::unittest::State test_add_fp32() { ark::unittest::State test_add_fp16() { ark::Model m; - ark::Tensor t0 = m.tensor({8192}, ark::FP16); - ark::Tensor t1 = m.tensor({8192}, ark::FP16); + ark::Tensor t0 = m.tensor({32, 2048, 2048}, ark::FP16); + ark::Tensor t1 = m.tensor({32, 2048, 2048}, ark::FP16); ark::Tensor out = m.add(t0, t1); auto result = - ark::op_test("add_fp16", m, {t0, t1}, {out}, baseline_add); + ark::op_test("add_fp16", m, {t0, t1}, {out}, baseline_add, {}, + { + ark::DefaultPlanner::ConfigRule([](const std::string op_str, const std::string) { + auto op = ark::Json::parse(op_str); + ark::Json config; + if (op.at("Type") == "Add") { + config["NumWarps"] = 4; + config["SramBytes"] = 0; + config["Tile"] = {128, 256}; + config["NumTasks"] = 4096; + } + return config.dump(); + }) + }); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); return ark::unittest::SUCCESS; @@ -416,20 +430,20 @@ ark::unittest::State test_div_invalid() { int main() { ark::init(); - UNITTEST(test_add_fp32); + // UNITTEST(test_add_fp32); UNITTEST(test_add_fp16); - UNITTEST(test_add_bf16); - UNITTEST(test_add_overwrite); - UNITTEST(test_add_broadcast); - UNITTEST(test_add_invalid); - UNITTEST(test_sub_fp32); - UNITTEST(test_sub_invalid); - UNITTEST(test_mul_fp32); - UNITTEST(test_mul_fp16); - UNITTEST(test_mul_overwrite); - UNITTEST(test_mul_broadcast); - UNITTEST(test_mul_invalid); - UNITTEST(test_div_fp32); - UNITTEST(test_div_invalid); + // UNITTEST(test_add_bf16); + // UNITTEST(test_add_overwrite); + // UNITTEST(test_add_broadcast); + // UNITTEST(test_add_invalid); + // UNITTEST(test_sub_fp32); + // UNITTEST(test_sub_invalid); + // UNITTEST(test_mul_fp32); + // UNITTEST(test_mul_fp16); + // UNITTEST(test_mul_overwrite); + // UNITTEST(test_mul_broadcast); + // UNITTEST(test_mul_invalid); + // UNITTEST(test_div_fp32); + // UNITTEST(test_div_invalid); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_embedding_test.cpp b/ark/ops/ops_embedding_test.cpp index 822973106..4f9df046a 100644 --- a/ark/ops/ops_embedding_test.cpp +++ b/ark/ops/ops_embedding_test.cpp @@ -80,7 +80,7 @@ ark::unittest::State test_embedding() { } auto result = ark::op_test("embedding_" + type_str, m, {ti, tw}, {to}, baseline_embedding, - {ti_data.data(), tw_data.data()}, true); + {ti_data.data(), tw_data.data()}); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); return ark::unittest::SUCCESS; diff --git a/ark/ops/ops_matmul.cpp b/ark/ops/ops_matmul.cpp index b259f99c8..b4553a4ed 100644 --- a/ark/ops/ops_matmul.cpp +++ b/ark/ops/ops_matmul.cpp @@ -189,45 +189,55 @@ std::vector ModelOpMatmul::impl_args([ } static const Json get_default_config(const ArchRef arch, - const ModelDataType &data_type) { + const ModelDataType &data_type, + const Dims &mnk) { + if (data_type != FP32.ref() && data_type != FP16.ref() && + data_type != BF16.ref()) { + ERR(InvalidUsageError, + "Unsupported data type: ", data_type->type_name()); + } + if (!arch->belongs_to(ARCH_CUDA) && !arch->belongs_to(ARCH_ROCM)) { + ERR(InvalidUsageError, "Unsupported architecture: ", arch->name()); + } + DimType tm = (mnk[0] > mnk[1]) ? 256 : 128; + DimType tn = (mnk[0] > mnk[1]) ? 128 : 256; if (arch->belongs_to(ARCH_CUDA_80) && data_type == FP32.ref()) { return {{"NumWarps", 8}, {"SramBytes", 147456}, - {"TileShapeMNK", {128, 256, 32}}}; + {"TileShapeMNK", {tm, tn, 32}}}; } else if (arch->belongs_to(ARCH_CUDA_80) && data_type == FP16.ref()) { return {{"NumWarps", 8}, {"SramBytes", 147456}, - {"TileShapeMNK", {128, 256, 64}}}; + {"TileShapeMNK", {tm, tn, 64}}}; } else if (arch->belongs_to(ARCH_CUDA_80) && data_type == BF16.ref()) { return {{"NumWarps", 8}, {"SramBytes", 147456}, - {"TileShapeMNK", {128, 256, 64}}}; + {"TileShapeMNK", {tm, tn, 64}}}; } else if (arch->belongs_to(ARCH_ROCM_942) && data_type == FP32.ref()) { return {{"NumWarps", 4}, {"SramBytes", 24672}, - {"TileShapeMNK", {128, 256, 16}}}; + {"TileShapeMNK", {tm, tn, 16}}}; } else if (arch->belongs_to(ARCH_ROCM_942) && data_type == FP16.ref()) { return {{"NumWarps", 4}, {"SramBytes", 24672}, - {"TileShapeMNK", {128, 256, 32}}}; + {"TileShapeMNK", {tm, tn, 32}}}; } else if (arch->belongs_to(ARCH_ROCM_942) && data_type == BF16.ref()) { return {{"NumWarps", 4}, {"SramBytes", 24672}, - {"TileShapeMNK", {128, 256, 32}}}; + {"TileShapeMNK", {tm, tn, 32}}}; } - ERR(InvalidUsageError, "Unsupported arch and data type: ", arch->name(), - " and ", data_type->type_name()); + ERR(InternalError, "Unexpected error"); return {}; } Json ModelOpMatmul::default_config(const ArchRef arch) const { auto result = result_tensors_[0]; - Json config = get_default_config(arch, result->data_type()); check_fields_args(args_, {"TransposeInput", "TransposeOther"}); Dims mnk = calc_problem_size(read_tensors_[0]->padded_shape(), read_tensors_[1]->padded_shape(), args_.at("TransposeInput").value(), args_.at("TransposeOther").value()); + Json config = get_default_config(arch, result->data_type(), mnk); size_t tile_x = config.at("TileShapeMNK")[0]; size_t tile_y = config.at("TileShapeMNK")[1]; if (mnk[0] % tile_x != 0 || mnk[1] % tile_y != 0) { diff --git a/ark/ops/ops_test_common.cpp b/ark/ops/ops_test_common.cpp index 50317fba7..ad2c208b6 100644 --- a/ark/ops/ops_test_common.cpp +++ b/ark/ops/ops_test_common.cpp @@ -36,8 +36,9 @@ OpsTestResult op_test(const std::string &test_name_prefix, const Model &model, const std::vector &outputs, OpsTestBaseline baseline, const std::vector &inputs_data, - bool print_on_error, int rank, int world_size) { - DefaultExecutor exe(model); + const std::vector& config_rules, + bool print_on_error) { + DefaultExecutor exe(model, -1, config_rules); exe.compile(); std::vector>> inputs_data_storages; @@ -133,7 +134,7 @@ OpsTestResult op_test(const std::string &test_name_prefix, const Model &model, for (auto t : gt) { gt_ptrs.push_back(t->data()); } - baseline(gt_ptrs, output_shapes, inputs_data_refs, input_shapes, rank); + baseline(gt_ptrs, output_shapes, inputs_data_refs, input_shapes, model.rank()); std::stringstream test_name; test_name << test_name_prefix; @@ -147,6 +148,7 @@ OpsTestResult op_test(const std::string &test_name_prefix, const Model &model, OpsTestResult result; result.test_name = test_name.str(); + result.plan = exe.plan(); // Compare results with the ground truth. for (size_t i = 0; i < outputs.size(); i++) { @@ -187,7 +189,7 @@ OpsTestResult op_test(const std::string &test_name_prefix, const Model &model, GLOG(gpuDeviceSynchronize()); // Throughput test. - if (world_size > 1) { + if (model.world_size() > 1) { // For multi-GPU, we need to make sure that all GPUs run the same // number of iterations. Rather than doing allgather, we just // use a magic number here. diff --git a/ark/ops/ops_test_common.hpp b/ark/ops/ops_test_common.hpp index 01e97dbb1..a32d9b748 100644 --- a/ark/ops/ops_test_common.hpp +++ b/ark/ops/ops_test_common.hpp @@ -10,6 +10,7 @@ #include "ark/model.hpp" #include "ark/model_ref.hpp" +#include "ark/planner.hpp" #include "ark/random.hpp" #include "bfloat16.h" #include "half.h" @@ -133,6 +134,7 @@ TensorCompareResult tensor_compare(T *ground_truth, T *res, Dims shape, struct OpsTestResult { std::string test_name; + std::string plan; int iter; float msec_per_iter; std::vector mse; @@ -170,8 +172,8 @@ OpsTestResult op_test(const std::string &test_name_prefix, const Model &model, const std::vector &outputs, OpsTestBaseline baseline, const std::vector &inputs_data = {}, - bool print_on_error = false, int rank = 0, - int world_size = 1); + const std::vector& config_rules = {}, + bool print_on_error = false); OpsTestGpuMem to_gpu(void *host_ptr, size_t size); diff --git a/examples/llama/README.md b/examples/llama/README.md index 090dd1de3..1fe040ae0 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -29,10 +29,10 @@ Llama2 examples over ARK. 4. Download Llama2 model weights and tokenizer weights. * The model and tokenizer should be compatible with the [official PyTorch implementation](https://github.com/facebookresearch/llama/blob/main/llama). -5. Run the model accuracy test. `--pth_path` is the path to the model weights file (`consolidated.00.pth`). +5. Run the model accuracy test. `--ckpt_dir` is the directory where the model weight files are at (e.g., `consolidated.00.pth`). ```bash - python3 model_test.py --pth_path=/path/to/model/weights.pth + python3 model_test.py --ckpt_dir=/directory/of/model/weights ``` 6. Test text generation. `--pth_path` is the path to the model weights file (`consolidated.00.pth`), `--tok_path` is the path to the tokenizer weights file (`tokenizer.model`), and `--params_path` is the path to the model parameters (`params.json`). diff --git a/examples/llama/model_test.py b/examples/llama/model_test.py index 737d3ec8b..585341640 100644 --- a/examples/llama/model_test.py +++ b/examples/llama/model_test.py @@ -58,30 +58,34 @@ def run_ark( ] output = module(*module_inputs) - runtime = ark.Runtime() - # Prefer num_warps_per_sm = 16 for nvidia and 8 for amd - runtime.launch(num_warps_per_sm=8) + with ark.Runtime() as rt: + rt.launch(plan_path="/mnt/changhohwang/ark/plan_gpu0.json") - # Load model parameters - if state_dict: - module.load_state_dict(state_dict) + # Load model parameters + if state_dict: + print("Loading state_dict") + module.load_state_dict(state_dict) + print("Loading state_dict done") - # Load input data into tensors - tensors = [i for i in module_inputs if isinstance(i, ark.Tensor)] - tensor_data = [i for i in inputs if isinstance(i, np.ndarray)] - for tensor, ndarray in zip(tensors, tensor_data): - tensor.from_numpy(ndarray) + # Load input data into tensors + tensors = [i for i in module_inputs if isinstance(i, ark.Tensor)] + tensor_data = [i for i in inputs if isinstance(i, np.ndarray)] + for tensor, ndarray in zip(tensors, tensor_data): + tensor.from_numpy(ndarray) - start_time = time.time() + start_time = time.time() - # Run the model - runtime.run(iter=iterations) + # Run the model + print("Run:", iterations) - end_time = time.time() + rt.run(iter=iterations) + print("Run done") - if isinstance(output, list) or isinstance(output, tuple): - outputs = [o.to_numpy() for o in output] - outputs = [output.to_numpy()] + end_time = time.time() + + if isinstance(output, list) or isinstance(output, tuple): + outputs = [o.to_numpy() for o in output] + outputs = [output.to_numpy()] return RunResults(outputs=outputs, runtime=end_time - start_time) @@ -160,7 +164,9 @@ def test_module( else: prefix = module_name_prefix + "." if module_name_prefix else "" # Load the state_dict from the given path + print("Loading ckpt:", ckpt_path) state_dict_pt = torch.load(ckpt_path) + print("Loading ckpt done") state_dict_pt = { k[len(prefix) :]: v for k, v in state_dict_pt.items() @@ -182,6 +188,7 @@ def test_module( rank=rank, world_size=world_size, ) + print("Run ARK done") if not test_thru_ark_only: # PyTorch module @@ -195,6 +202,7 @@ def test_module( inputs_pt, iterations=test_thru_iterations if test_thru else 1, ) + print("Run PyTorch done") if test_thru: print( @@ -447,26 +455,26 @@ def test_transformer_block( ) output = module(feature_tensor, 0, freqs_cis_ark_tensor, None) - ark.Model.get_model().create_nodes() - print(ark.Model.get_model().serialize()) - - # test_module( - # module_class_ark=model_ark.TransformerBlock, - # module_args_ark=[ - # 0, - # args, - # ark.DataType.from_numpy(dtype), - # rank, - # world_size, - # ], - # inputs_ark=[feature, 0, freqs_cis_ark, None], - # module_class_pt=model_pt.TransformerBlock, - # module_args_pt=[0, args], - # inputs_pt=[feature.astype(dtype), 0, freqs_cis, None], - # module_name_prefix="layers.0", - # rank=rank, - # world_size=world_size, - # ) + # print(ark.Model.get_model().serialize()) + + test_module( + module_class_ark=model_ark.TransformerBlock, + module_args_ark=[ + 0, + args, + ark.DataType.from_numpy(dtype), + rank, + world_size, + ], + inputs_ark=[feature, 0, freqs_cis_ark, None], + module_class_pt=model_pt.TransformerBlock, + module_args_pt=[0, args], + inputs_pt=[feature.astype(dtype), 0, freqs_cis, None], + module_name_prefix="layers.0", + rank=rank, + world_size=world_size, + test_thru=True, + ) def test_transformer( @@ -570,7 +578,7 @@ def worker( # Configurations args = ModelArgs7B() batch_size = 1 - seq_len = 512 + seq_len = 2048 dtype = np.float16 world_size = ngpus @@ -578,7 +586,7 @@ def worker( args.vocab_size = 32000 # Reduce max_seq_len due to OOM from the PyTorch model - args.max_seq_len = 512 + args.max_seq_len = 2048 # Verify the configurations assert batch_size <= args.max_batch_size diff --git a/plan_gpu0.json b/plan_gpu0.json new file mode 100644 index 000000000..49b6bdd98 --- /dev/null +++ b/plan_gpu0.json @@ -0,0 +1,2504 @@ +{ + "Rank": 0, + "WorldSize": 1, + "NumProcessors": 304, + "NumWarpsPerProcessor": 4, + "TaskInfos": [ + { + "Id": 0, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul", + "IsVirtual": false, + "ReadTensors": [ + {"Id":4,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":0,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":6,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":7,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 1, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Rope", + "Name": "rope", + "IsVirtual": false, + "ReadTensors": [ + {"Id":12,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":5,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":15,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":16,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [32,128], + "NumTasks": 2048 + } + } + ] + }, + { + "Id": 2, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Transpose", + "Name": "transpose", + "IsVirtual": false, + "ReadTensors": [ + {"Id":16,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":19,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":20,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Permutation": {"DIMS":[0,2,1,3]} + }, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [8,128], + "NumTasks": 8192 + } + } + ] + }, + { + "Id": 3, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":4,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":1,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":8,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":9,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 4, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Rope", + "Name": "rope_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":13,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":5,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":17,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":18,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [32,128], + "NumTasks": 2048 + } + } + ] + }, + { + "Id": 5, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Transpose", + "Name": "transpose_2", + "IsVirtual": false, + "ReadTensors": [ + {"Id":18,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":23,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":24,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Permutation": {"DIMS":[0,2,3,1]} + }, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [8,128], + "NumTasks": 8192 + } + } + ] + }, + { + "Id": 6, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_2", + "IsVirtual": false, + "ReadTensors": [ + {"Id":4,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":2,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":10,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":11,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 7, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Transpose", + "Name": "transpose_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":14,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":21,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":22,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Permutation": {"DIMS":[0,2,1,3]} + }, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [8,128], + "NumTasks": 8192 + } + } + ] + }, + { + "Id": 8, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_3", + "IsVirtual": false, + "ReadTensors": [ + {"Id":20,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":24,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":25,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":26,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":false} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 4096 + } + } + ] + }, + { + "Id": 9, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "ScalarMul", + "Name": "mul", + "IsVirtual": false, + "ReadTensors": [ + {"Id":26,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":27,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Factor": {"FLOAT":0.0883883461356163} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 4096 + } + } + ] + }, + { + "Id": 10, + "NumWarps": 1, + "SramBytes": 256, + "Ops": [ + { + "Type": "ReduceMax", + "Name": "reduce_max", + "IsVirtual": false, + "ReadTensors": [ + {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":29,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":30,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Axis": {"INT":3}, + "KeepDim": {"BOOL":true} + }, + "Config": { + "NumWarps": 1, + "ImplType": "WarpWise", + "SramBytes": 256, + "NumTasks": 65536 + } + } + ] + }, + { + "Id": 11, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Sub", + "Name": "sub", + "IsVirtual": false, + "ReadTensors": [ + {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":30,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":31,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 4096 + } + } + ] + }, + { + "Id": 12, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Exp", + "Name": "exp", + "IsVirtual": false, + "ReadTensors": [ + {"Id":31,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":31,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 4096 + } + } + ] + }, + { + "Id": 13, + "NumWarps": 1, + "SramBytes": 256, + "Ops": [ + { + "Type": "ReduceSum", + "Name": "reduce_sum", + "IsVirtual": false, + "ReadTensors": [ + {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":33,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":34,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Axis": {"INT":3}, + "KeepDim": {"BOOL":true} + }, + "Config": { + "NumWarps": 1, + "ImplType": "WarpWise", + "SramBytes": 256, + "NumTasks": 65536 + } + } + ] + }, + { + "Id": 14, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Div", + "Name": "div", + "IsVirtual": false, + "ReadTensors": [ + {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":34,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":35,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 4096 + } + } + ] + }, + { + "Id": 15, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_4", + "IsVirtual": false, + "ReadTensors": [ + {"Id":35,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":22,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":36,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":37,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":false} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [256,128,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 16, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Transpose", + "Name": "transpose_3", + "IsVirtual": false, + "ReadTensors": [ + {"Id":37,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":38,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":39,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Permutation": {"DIMS":[0,2,1,3]} + }, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [8,128], + "NumTasks": 8192 + } + } + ] + }, + { + "Id": 17, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_5", + "IsVirtual": false, + "ReadTensors": [ + {"Id":40,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":3,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":41,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":20,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":42,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":20,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 18, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Cast", + "Name": "cast", + "IsVirtual": false, + "ReadTensors": [ + {"Id":52,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":30,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":54,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 19, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":56,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":57,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 20, + "NumWarps": 1, + "SramBytes": 256, + "Ops": [ + { + "Type": "ReduceMean", + "Name": "reduce_mean", + "IsVirtual": false, + "ReadTensors": [ + {"Id":57,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":58,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":59,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Axis": {"INT":2}, + "KeepDim": {"BOOL":true} + }, + "Config": { + "NumWarps": 1, + "ImplType": "WarpWise", + "SramBytes": 256, + "NumTasks": 2048 + } + } + ] + }, + { + "Id": 21, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Rsqrt", + "Name": "rsqrt", + "IsVirtual": false, + "ReadTensors": [ + {"Id":59,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":60,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":61,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [64,1], + "NumTasks": 32 + } + } + ] + }, + { + "Id": 22, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_2", + "IsVirtual": false, + "ReadTensors": [ + {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":61,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":62,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":63,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 23, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_3", + "IsVirtual": false, + "ReadTensors": [ + {"Id":63,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":50,"DataType":"FP32","Shape":[1,1,4096],"Strides":[1,1,4096],"Offsets":[0,0,0],"PaddedShape":[1,1,4096],"Buffer":{"Id":28,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":63,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":64,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 24, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Cast", + "Name": "cast_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":64,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":65,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 25, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_6", + "IsVirtual": false, + "ReadTensors": [ + {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":43,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":21,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":67,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":68,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 26, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Rope", + "Name": "rope_2", + "IsVirtual": false, + "ReadTensors": [ + {"Id":73,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":53,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":31,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":76,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":77,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 27, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Transpose", + "Name": "transpose_4", + "IsVirtual": false, + "ReadTensors": [ + {"Id":77,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":80,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":81,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Permutation": {"DIMS":[0,2,1,3]} + }, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [8,128], + "NumTasks": 8192 + } + } + ] + }, + { + "Id": 28, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_7", + "IsVirtual": false, + "ReadTensors": [ + {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":44,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":22,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":69,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":70,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 29, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Rope", + "Name": "rope_3", + "IsVirtual": false, + "ReadTensors": [ + {"Id":74,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":53,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":31,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":78,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":79,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 30, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Transpose", + "Name": "transpose_6", + "IsVirtual": false, + "ReadTensors": [ + {"Id":79,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":84,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":85,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Permutation": {"DIMS":[0,2,3,1]} + }, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [8,8], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 31, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_8", + "IsVirtual": false, + "ReadTensors": [ + {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":45,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":23,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":71,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":72,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 32, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Transpose", + "Name": "transpose_5", + "IsVirtual": false, + "ReadTensors": [ + {"Id":75,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":82,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":83,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Permutation": {"DIMS":[0,2,1,3]} + }, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [8,128], + "NumTasks": 8192 + } + } + ] + }, + { + "Id": 33, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_9", + "IsVirtual": false, + "ReadTensors": [ + {"Id":81,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":85,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":86,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":46,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":87,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":46,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":false} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 4096 + } + } + ] + }, + { + "Id": 34, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "ScalarMul", + "Name": "mul_4", + "IsVirtual": false, + "ReadTensors": [ + {"Id":87,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":46,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":88,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Factor": {"FLOAT":0.0883883461356163} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 4096 + } + } + ] + }, + { + "Id": 35, + "NumWarps": 1, + "SramBytes": 256, + "Ops": [ + { + "Type": "ReduceMax", + "Name": "reduce_max_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":90,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":48,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":91,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":48,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Axis": {"INT":3}, + "KeepDim": {"BOOL":true} + }, + "Config": { + "NumWarps": 1, + "ImplType": "WarpWise", + "SramBytes": 256, + "NumTasks": 65536 + } + } + ] + }, + { + "Id": 36, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Sub", + "Name": "sub_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":91,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":48,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":92,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 4096 + } + } + ] + }, + { + "Id": 37, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Exp", + "Name": "exp_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":92,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":92,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 4096 + } + } + ] + }, + { + "Id": 38, + "NumWarps": 1, + "SramBytes": 256, + "Ops": [ + { + "Type": "ReduceSum", + "Name": "reduce_sum_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":94,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":49,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":95,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":49,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Axis": {"INT":3}, + "KeepDim": {"BOOL":true} + }, + "Config": { + "NumWarps": 1, + "ImplType": "WarpWise", + "SramBytes": 256, + "NumTasks": 65536 + } + } + ] + }, + { + "Id": 39, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Div", + "Name": "div_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":95,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":49,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":96,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 4096 + } + } + ] + }, + { + "Id": 40, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_10", + "IsVirtual": false, + "ReadTensors": [ + {"Id":96,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":83,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":97,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":50,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":98,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":50,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":false} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [256,128,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 41, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Transpose", + "Name": "transpose_7", + "IsVirtual": false, + "ReadTensors": [ + {"Id":98,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":50,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":99,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":51,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":100,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":51,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Permutation": {"DIMS":[0,2,1,3]} + }, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [8,128], + "NumTasks": 8192 + } + } + ] + }, + { + "Id": 42, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_11", + "IsVirtual": false, + "ReadTensors": [ + {"Id":101,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":51,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":46,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":102,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":52,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":103,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":52,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 43, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Add", + "Name": "add", + "IsVirtual": false, + "ReadTensors": [ + {"Id":52,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":30,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":103,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":52,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":104,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":105,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 44, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Cast", + "Name": "cast_2", + "IsVirtual": false, + "ReadTensors": [ + {"Id":105,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":106,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 45, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_5", + "IsVirtual": false, + "ReadTensors": [ + {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":108,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":55,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":109,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":55,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 46, + "NumWarps": 1, + "SramBytes": 256, + "Ops": [ + { + "Type": "ReduceMean", + "Name": "reduce_mean_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":109,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":55,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":110,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":56,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":111,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":56,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "Axis": {"INT":2}, + "KeepDim": {"BOOL":true} + }, + "Config": { + "NumWarps": 1, + "ImplType": "WarpWise", + "SramBytes": 256, + "NumTasks": 2048 + } + } + ] + }, + { + "Id": 47, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Rsqrt", + "Name": "rsqrt_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":111,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":56,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":112,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":57,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":113,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":57,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [64,1], + "NumTasks": 32 + } + } + ] + }, + { + "Id": 48, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_6", + "IsVirtual": false, + "ReadTensors": [ + {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":113,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":57,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":114,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":115,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 49, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_7", + "IsVirtual": false, + "ReadTensors": [ + {"Id":115,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":51,"DataType":"FP32","Shape":[1,1,4096],"Strides":[1,1,4096],"Offsets":[0,0,0],"PaddedShape":[1,1,4096],"Buffer":{"Id":29,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":115,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":116,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 50, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Cast", + "Name": "cast_3", + "IsVirtual": false, + "ReadTensors": [ + {"Id":116,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":117,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":118,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 51, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_12", + "IsVirtual": false, + "ReadTensors": [ + {"Id":118,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":47,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":25,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":119,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":120,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 688 + } + } + ] + }, + { + "Id": 52, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Sigmoid", + "Name": "sigmoid", + "IsVirtual": false, + "ReadTensors": [ + {"Id":120,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":121,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":61,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":122,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":61,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 688 + } + } + ] + }, + { + "Id": 53, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_8", + "IsVirtual": false, + "ReadTensors": [ + {"Id":120,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":122,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":61,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":123,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":62,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":124,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":62,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 688 + } + } + ] + }, + { + "Id": 54, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_13", + "IsVirtual": false, + "ReadTensors": [ + {"Id":118,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":49,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":27,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":125,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":63,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":126,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":63,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 688 + } + } + ] + }, + { + "Id": 55, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_9", + "IsVirtual": false, + "ReadTensors": [ + {"Id":124,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":62,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":126,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":63,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":127,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":64,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":128,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":64,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 688 + } + } + ] + }, + { + "Id": 56, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_14", + "IsVirtual": false, + "ReadTensors": [ + {"Id":128,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":64,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":48,"DataType":"FP16","Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":129,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":65,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":130,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":65,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 57, + "NumWarps": 4, + "SramBytes": 0, + "Ops": [ + { + "Type": "Add", + "Name": "add_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":105,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[]}}, + {"Id":130,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":65,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "WriteTensors": [ + {"Id":131,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":66,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "ResultTensors": [ + {"Id":132,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":66,"Rank":-1,"SendTags":[],"RecvTags":[]}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,256], + "NumTasks": 256 + } + } + ] + } + ], + "ProcessorGroups": [ + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":0,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":1,"TaskRange":[0,2048],"Granularity":4} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":2,"TaskRange":[0,8192],"Granularity":4} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":3,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":4,"TaskRange":[0,2048],"Granularity":4} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":5,"TaskRange":[0,8192],"Granularity":4} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":6,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":7,"TaskRange":[0,8192],"Granularity":4} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":8,"TaskRange":[0,4096],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":9,"TaskRange":[0,4096],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,256], + "TaskGroups": [ + {"TaskId":10,"TaskRange":[0,65536],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":11,"TaskRange":[0,4096],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":12,"TaskRange":[0,4096],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,256], + "TaskGroups": [ + {"TaskId":13,"TaskRange":[0,65536],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":14,"TaskRange":[0,4096],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":15,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":16,"TaskRange":[0,8192],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":17,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":18,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":19,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,256], + "TaskGroups": [ + {"TaskId":20,"TaskRange":[0,2048],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,32], + "ResourceGroups": [ + { + "ProcessorRange": [0,32], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":21,"TaskRange":[0,32],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":22,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":23,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":24,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":25,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":26,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":27,"TaskRange":[0,8192],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":28,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":29,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":30,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":31,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":32,"TaskRange":[0,8192],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":33,"TaskRange":[0,4096],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":34,"TaskRange":[0,4096],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,256], + "TaskGroups": [ + {"TaskId":35,"TaskRange":[0,65536],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":36,"TaskRange":[0,4096],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":37,"TaskRange":[0,4096],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,256], + "TaskGroups": [ + {"TaskId":38,"TaskRange":[0,65536],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":39,"TaskRange":[0,4096],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":40,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":41,"TaskRange":[0,8192],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":42,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":43,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":44,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":45,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,256], + "TaskGroups": [ + {"TaskId":46,"TaskRange":[0,2048],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,32], + "ResourceGroups": [ + { + "ProcessorRange": [0,32], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":47,"TaskRange":[0,32],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":48,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":49,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":50,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":51,"TaskRange":[0,688],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":52,"TaskRange":[0,688],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":53,"TaskRange":[0,688],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":54,"TaskRange":[0,688],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":55,"TaskRange":[0,688],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":56,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":57,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + } + ] +} diff --git a/python/ark/__init__.py b/python/ark/__init__.py index 031afc7ba..f2f604be9 100644 --- a/python/ark/__init__.py +++ b/python/ark/__init__.py @@ -100,3 +100,4 @@ def set_world_size(world_size): GpuError, RuntimeError, ) +from .profiler import Profiler diff --git a/python/ark/profiler.py b/python/ark/profiler.py new file mode 100644 index 000000000..b959ceb18 --- /dev/null +++ b/python/ark/profiler.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import sys +import time +from .runtime import Runtime + + +class Profiler: + def __init__(self, plan: str): + self.plan = json.loads(plan) + + def run(self): + num_processor_groups = len(self.plan["ProcessorGroups"]) + new_plan = { + "Rank": self.plan["Rank"], "WorldSize": self.plan["WorldSize"], + "NumProcessors": self.plan["NumProcessors"], + "NumWarpsPerProcessor": self.plan["NumWarpsPerProcessor"], + "TaskInfos": self.plan["TaskInfos"], + "ProcessorGroups": [{}]} + for i in range(num_processor_groups): + new_plan["ProcessorGroups"][0] = self.plan["ProcessorGroups"][i] + with Runtime() as rt: + rt.launch(plan=json.dumps(new_plan)) + start_time = time.time() + iter = 1000 + rt.run(iter=iter) + end_time = time.time() + sys.stderr.write(f"Processor group {i} runtime: {(end_time - start_time)/iter:.6f} seconds/iter\n") diff --git a/python/executor_py.cpp b/python/executor_py.cpp index e5ab4f964..a6e5308ee 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -149,6 +149,7 @@ void register_executor(py::module &m) { py::arg("rank"), py::arg("world_size"), py::arg("gpu_id"), py::arg("name"), py::arg("plan")) .def("gpu_id", &ark::Executor::gpu_id) + .def("plan", &ark::Executor::plan) .def("compile", &ark::Executor::compile) .def("launch", &ark::Executor::launch, py::arg("max_spin_count") = -1) .def("run", &ark::Executor::run, py::arg("iter")) From ff8c4b8fc4ff178befa375ffc8ac546806fa6c4b Mon Sep 17 00:00:00 2001 From: Noli Gerawork <86308445+naturalcandy@users.noreply.github.com> Date: Tue, 2 Jul 2024 21:25:07 -0400 Subject: [PATCH 20/54] torch to ark (#217) - Adds Torch to ARK tensor conversion support - New ModelBufferManager class handles external buffer registration and simplifies buffer access during kernel initialization - Adds test cases for ARK to Torch conversion support --------- Co-authored-by: Changho Hwang --- ark/api/executor.cpp | 53 ++++++++++++++++--- ark/api/tensor.cpp | 18 ++++++- ark/codegen.cpp | 36 +++++++++---- ark/codegen.hpp | 4 +- ark/include/ark/tensor.hpp | 2 + ark/model/model_buffer.cpp | 55 ++++++++++++++++++-- ark/model/model_buffer.hpp | 15 ++++++ ark/model_buffer_manager.hpp | 58 +++++++++++++++++++++ python/ark/tensor.py | 26 +++++----- python/tensor_py.cpp | 46 ++++++++++++++++- python/unittest/test_conversion.py | 81 +++++++++++++++++++++++++++++- 11 files changed, 355 insertions(+), 39 deletions(-) create mode 100644 ark/model_buffer_manager.hpp diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 4af9df7c0..0a780bcc0 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include "ark/data_type.hpp" #include "ark/model.hpp" @@ -24,6 +25,7 @@ #include "gpu/gpu_manager.h" #include "logging.h" #include "model/model_buffer.hpp" +#include "model_buffer_manager.hpp" #include "model/model_data_type.hpp" #include "model/model_tensor.hpp" #include "utils/utils_net.hpp" @@ -234,8 +236,15 @@ void Executor::Impl::init(const std::string &plan) { std::to_string(kv.first) + ": " + std::to_string(kv.second) + ", "; } - codegen_ = - std::make_shared(plan_json_, buffer_id_to_offset_, name_); + ModelBufferManager &buffer_manager = ModelBufferManager::get_instance(); + + if (!buffer_manager.is_empty()) { + codegen_ = std::make_shared( + plan_json_, buffer_id_to_offset_, name, &buffer_manager); + } else { + codegen_ = std::make_shared(plan_json_, + buffer_id_to_offset_, name); + } auto gpu_manager = GpuManager::get_instance(gpu_id_); timer_begin_ = gpu_manager->create_event(); @@ -367,7 +376,16 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { } continue; } - buffer_id_to_offset[buf_info->buffer->id()] = offset; + if (buf_info->buffer->is_external()) { + if (buf_info->buffer->device_id() != gpu_id_) { + ERR(InvalidUsageError, + "PyTorch tensor and model execution are on different GPUs"); + } + continue; + } else { + buffer_id_to_offset[buf_info->buffer->id()] = offset; + offset += buf_info->bytes; + } for (const auto &tag_info : buf_info->buffer->send_tags()) { remote_rank_to_send_tags_and_offsets[tag_info.first] .first.push_back(tag_info.second); @@ -380,7 +398,6 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { remote_rank_to_recv_tags_and_offsets[tag_info.first] .second.push_back(offset); } - offset += buf_info->bytes; } total_bytes_ = offset; @@ -456,7 +473,11 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { bootstrap->recv(tags.data(), len * sizeof(int), remote_rank, 1); bootstrap->recv(offsets.data(), len * sizeof(size_t), remote_rank, 2); for (int i = 0; i < len; ++i) { - buffer_id_to_offset[send_tag_to_buffer_id[tags[i]]] = offsets[i]; + if (!buffer_id_to_info[send_tag_to_buffer_id[tags[i]]] + ->buffer->is_external()) { + buffer_id_to_offset[send_tag_to_buffer_id[tags[i]]] = + offsets[i]; + } } } for (auto &kv : remote_rank_to_recv_tag_to_buffer_id) { @@ -472,10 +493,13 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { bootstrap->recv(tags.data(), len * sizeof(int), remote_rank, 4); bootstrap->recv(offsets.data(), len * sizeof(size_t), remote_rank, 5); for (int i = 0; i < len; ++i) { - buffer_id_to_offset[recv_tag_to_buffer_id[tags[i]]] = offsets[i]; + if (!buffer_id_to_info[recv_tag_to_buffer_id[tags[i]]] + ->buffer->is_external()) { + buffer_id_to_offset[recv_tag_to_buffer_id[tags[i]]] = + offsets[i]; + } } } - return buffer_id_to_offset; } @@ -742,6 +766,11 @@ uintptr_t Executor::Impl::tensor_address(const Tensor tensor) const { void Executor::Impl::tensor_read(const Tensor tensor, void *data, size_t bytes, bool is_d2d) const { GLOG(gpuSetDevice(gpu_id_)); + if (tensor.ref()->buffer()->is_external()) { + ERR(InvalidUsageError, + "Reading data from a tensor preallocated by PyTorch is not " + "supported. Use PyTorch's native methods."); + } size_t tensor_data_bytes = tensor.shape().nelems() * tensor.data_type().bytes(); if (bytes != tensor_data_bytes) { @@ -779,6 +808,11 @@ void Executor::Impl::tensor_read(const Tensor tensor, void *data, size_t bytes, void Executor::Impl::tensor_write(const Tensor tensor, const void *data, size_t bytes, bool is_d2d) const { GLOG(gpuSetDevice(gpu_id_)); + if (tensor.ref()->buffer()->is_external()) { + ERR(InvalidUsageError, + "Writing data to a tensor preallocated by PyTorch is not " + "supported. Use PyTorch's native methods."); + } size_t tensor_data_bytes = tensor.shape().nelems() * tensor.data_type().bytes(); if (bytes != tensor_data_bytes) { @@ -843,7 +877,10 @@ float Executor::stop(int64_t max_spin_count) { void Executor::barrier() { impl_->barrier(); } -void Executor::destroy() { impl_.reset(nullptr); } +void Executor::destroy() { + ModelBufferManager::get_instance().clear_buffers(); + impl_.reset(nullptr); +} bool Executor::destroyed() const { return impl_.get() == nullptr; } diff --git a/ark/api/tensor.cpp b/ark/api/tensor.cpp index 4b03c3ac8..4d33bd9f1 100644 --- a/ark/api/tensor.cpp +++ b/ark/api/tensor.cpp @@ -3,11 +3,25 @@ #include "ark/tensor.hpp" +#include "model/model_buffer.hpp" #include "model/model_data_type.hpp" #include "model/model_tensor.hpp" namespace ark { +Tensor::Tensor(void* data_ptr, int32_t device_id, + const std::vector& shape, + const DataType& dtype) { + size_t external_data_size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()) * + dtype.bytes(); + auto buffer = + std::make_shared(data_ptr, external_data_size, device_id); + auto tensor = std::make_shared(dtype.ref(), buffer, Dims(shape), + Dims(shape), Dims(), Dims()); + ref_ = tensor; +} + size_t Tensor::id() const { if (ref_) { return ref_->id(); @@ -43,14 +57,14 @@ Dims Tensor::padded_shape() const { return Dims(); } -const DataType &Tensor::data_type() const { +const DataType& Tensor::data_type() const { if (ref_) { return DataType::from_name(ref_->data_type()->type_name()); } return NONE; } -std::ostream &operator<<(std::ostream &os, const Tensor &tensor) { +std::ostream& operator<<(std::ostream& os, const Tensor& tensor) { if (tensor.is_null()) { os << "null"; } else { diff --git a/ark/codegen.cpp b/ark/codegen.cpp index 09ff28dd3..a97e5e45b 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -10,6 +10,7 @@ #include "file_io.h" #include "logging.h" #include "model/model_buffer.hpp" +#include "model_buffer_manager.hpp" #include "model/model_data_type.hpp" #include "model/model_op.hpp" #include "model/model_tensor.hpp" @@ -43,7 +44,7 @@ class CodeGenerator::Impl { public: Impl(const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name); + const std::string &name, ModelBufferManager *buffer_manager); ~Impl() = default; private: @@ -64,6 +65,8 @@ class CodeGenerator::Impl { std::string sync_process_range(const Range &ranges, int state_id); + ModelBufferManager *buffer_manager_; + protected: friend class CodeGenerator; @@ -78,14 +81,18 @@ class CodeGenerator::Impl { CodeGenerator::Impl::Impl(const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name) - : buffer_id_to_offset_(buffer_id_to_offset), name_(name) { + const std::string &name, + ModelBufferManager *buffer_manager) + : buffer_id_to_offset_(buffer_id_to_offset), + name_(name), + buffer_manager_(buffer_manager) { rank_ = plan.at("Rank"); world_size_ = plan.at("WorldSize"); num_procs_ = plan.at("NumProcessors"); num_warps_per_proc_ = plan.at("NumWarpsPerProcessor"); std::stringstream definitions_ss; + for (auto &task_json : plan.at("TaskInfos")) { definitions_ss << this->def_task(task_json); } @@ -224,11 +231,19 @@ std::string CodeGenerator::Impl::def_task(const Json &task_json) { auto &arg = impl_args[i]; if (arg.type_name() == "TENSOR") { auto tns = arg.value(); - size_t buffer_offset = - buffer_id_to_offset_.at(tns->buffer()->id()); - size_t offset = buffer_offset + ModelOffset(tns).value(); - ss << "(" << tns->data_type()->type_str() << "*)&_buf[" - << offset << "]"; + if (tns->buffer()->is_external()) { + void *buf_addr = + ModelBufferManager::get_instance().get_buffer( + tns->buffer()->id()); + ss << "(" << tns->data_type()->type_str() << "*)" + << buf_addr; + } else { + size_t buffer_offset = + buffer_id_to_offset_.at(tns->buffer()->id()); + size_t offset = buffer_offset + ModelOffset(tns).value(); + ss << "(" << tns->data_type()->type_str() << "*)&_buf[" + << offset << "]"; + } } else if (arg.type_name() == "OFFSET") { auto moff = arg.value(); size_t buffer_offset = @@ -431,8 +446,9 @@ std::string CodeGenerator::Impl::sync_process_range(const Range &range, CodeGenerator::CodeGenerator( const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name) - : impl_(std::make_shared(plan, buffer_id_to_offset, name)) {} + const std::string &name, ModelBufferManager *buffer_manager) + : impl_(std::make_shared(plan, buffer_id_to_offset, name, + buffer_manager)) {} std::string CodeGenerator::code() const { return impl_->code_; } diff --git a/ark/codegen.hpp b/ark/codegen.hpp index 4f8307e7e..a2976e644 100644 --- a/ark/codegen.hpp +++ b/ark/codegen.hpp @@ -8,6 +8,7 @@ #include #include +#include "model_buffer_manager.hpp" #include "model/model_json.hpp" namespace ark { @@ -16,7 +17,8 @@ class CodeGenerator { public: CodeGenerator(const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name = "ark_kernel"); + const std::string &name = "ark_kernel", + ModelBufferManager *buffer_manager = nullptr); ~CodeGenerator() = default; diff --git a/ark/include/ark/tensor.hpp b/ark/include/ark/tensor.hpp index 747ce5fea..d13748175 100644 --- a/ark/include/ark/tensor.hpp +++ b/ark/include/ark/tensor.hpp @@ -31,6 +31,8 @@ class Tensor { Tensor(ModelTensorRef ref) : ref_(ref) {} Tensor(const Tensor &other) = default; Tensor &operator=(const Tensor &other) = default; + Tensor(void *data_ptr, int32_t device_id, const std::vector &shape, + const DataType &dtype); bool operator==(const Tensor &other) const { return ref_ == other.ref_; } bool operator!=(const Tensor &other) const { return ref_ != other.ref_; } diff --git a/ark/model/model_buffer.cpp b/ark/model/model_buffer.cpp index 4ce91b5e4..ce8f37727 100644 --- a/ark/model/model_buffer.cpp +++ b/ark/model/model_buffer.cpp @@ -4,13 +4,13 @@ #include "model_buffer.hpp" #include "logging.h" +#include "model_buffer_manager.hpp" namespace ark { -ModelBuffer::ModelBuffer(int rank) : rank_(rank) { - static size_t id = 0; - id_ = id++; -} +size_t ModelBuffer::curr_id = 0; + +ModelBuffer::ModelBuffer(int rank) : rank_(rank) { id_ = curr_id++; } ModelBuffer::ModelBuffer(size_t id, int rank, const std::vector &send_tags, @@ -24,6 +24,23 @@ ModelBuffer::ModelBuffer(size_t id, int rank, } } +ModelBuffer::ModelBuffer(void *data, size_t size, int32_t device_id) + : rank_(-1), + external_data_(data), + external_data_size_(size), + device_id_(device_id), + is_external_(true) { + id_ = curr_id++; +} + +ModelBuffer::ModelBuffer(size_t id, void *data, size_t size, int32_t device_id) + : id_(id), + rank_(-1), + external_data_(data), + external_data_size_(size), + device_id_(device_id), + is_external_(true) {} + void ModelBuffer::tag_send(int remote_rank, int tag) { send_tags_.insert(TagInfo{remote_rank, tag}); } @@ -46,6 +63,14 @@ Json ModelBuffer::serialize() const { } j["SendTags"] = send_tags; j["RecvTags"] = recv_tags; + j["IsExternal"] = is_external_; + if (is_external_) { + ModelBufferManager::get_instance().register_buffer(id_, external_data_, + external_data_size_); + j["ExternalDataSize"] = external_data_size_; + j["DeviceId"] = device_id_; + } + // external_data_ptr_ is not included in JSON return j; } @@ -62,6 +87,28 @@ std::shared_ptr ModelBuffer::deserialize(const Json &serialized) { } else if (!serialized.contains("RecvTags")) { ERR(InvalidUsageError, "ModelBuffer deserialization failed: missing RecvTags"); + } else if (!serialized.contains("IsExternal")) { + ERR(InvalidUsageError, + "ModelBuffer deserialization failed: missing IsExternal"); + } + if (serialized["IsExternal"]) { + if (!serialized.contains("ExternalDataSize")) { + ERR(InvalidUsageError, + "ModelBuffer deserialization failed: missing ExternalDataSize"); + } else if (!serialized.contains("DeviceId")) { + ERR(InvalidUsageError, + "ModelBuffer deserialization failed: missing DeviceId"); + } + void *data_ptr = + ModelBufferManager::get_instance().get_buffer(serialized["Id"]); + if (!data_ptr) { + ERR(InvalidUsageError, + "ModelBuffer deserialization failed: external buffer not found " + "in BufferManager"); + } + return std::make_shared(serialized["Id"], data_ptr, + serialized["ExternalDataSize"], + serialized["DeviceId"]); } return std::make_shared(serialized["Id"], serialized["Rank"], serialized["SendTags"], diff --git a/ark/model/model_buffer.hpp b/ark/model/model_buffer.hpp index 7ad3db206..e7f1045b2 100644 --- a/ark/model/model_buffer.hpp +++ b/ark/model/model_buffer.hpp @@ -22,6 +22,10 @@ class ModelBuffer { ModelBuffer(size_t id, int rank, const std::vector &send_tags, const std::vector &recv_tags); + // externally managed buffer + ModelBuffer(void *data, size_t size, int32_t device_id); + ModelBuffer(size_t id, void *data, size_t size, int32_t device_id); + size_t id() const { return id_; } int rank() const { return rank_; } @@ -44,11 +48,22 @@ class ModelBuffer { static std::shared_ptr deserialize(const Json &serialized); + // external buffer management + size_t external_data_size() const { return external_data_size_; } + void *external_data() const { return external_data_; } + int32_t device_id() const { return device_id_; } + bool is_external() const { return is_external_; } + private: + static size_t curr_id; size_t id_; int rank_; std::set send_tags_; std::set recv_tags_; + void *external_data_ = nullptr; + size_t external_data_size_ = 0; + int32_t device_id_; + bool is_external_ = false; }; } // namespace ark diff --git a/ark/model_buffer_manager.hpp b/ark/model_buffer_manager.hpp new file mode 100644 index 000000000..7b705f4c8 --- /dev/null +++ b/ark/model_buffer_manager.hpp @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_MODEL_BUFFER_MANAGER_HPP_ +#define ARK_MODEL_BUFFER_MANAGER_HPP_ + +#include +#include + +namespace ark { +// Manages externally allocated buffers not in the ARK memory space. +class ModelBufferManager { + public: + static ModelBufferManager& get_instance() { + static ModelBufferManager instance; + return instance; + } + + void register_buffer(size_t id, void* data, size_t size) { + buffers_[id] = std::make_tuple(data, size); + } + + void* get_buffer(size_t id) { + auto it = buffers_.find(id); + if (it != buffers_.end()) { + return std::get<0>(it->second); + } + return nullptr; + } + + size_t get_buffer_size(size_t id) { + auto it = buffers_.find(id); + if (it != buffers_.end()) { + return std::get<1>(it->second); + } + return 0; + } + + const std::unordered_map>& get_buffers() + const { + return buffers_; + } + + void clear_buffers() { buffers_.clear(); } + + bool is_empty() const { return buffers_.empty(); } + + private: + std::unordered_map> + buffers_; // Maps buffer IDs to pointers and sizes. + size_t next_compact_id_ = 0; + ModelBufferManager() {} + ModelBufferManager(const ModelBufferManager&) = delete; + ModelBufferManager& operator=(const ModelBufferManager&) = delete; +}; +} // namespace ark + +#endif // ARK_MODEL_BUFFER_MANAGER_HPP_ diff --git a/python/ark/tensor.py b/python/ark/tensor.py index ac2886960..8f26dc96e 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -167,18 +167,20 @@ def from_numpy(self, ndarray: np.ndarray) -> "Tensor": return self @staticmethod - def from_torch(tensor: torch.Tensor): - return Tensor( - Model.get_model().tensor( - Dims(list(tensor.shape)), - DataType.from_torch(tensor.dtype).ctype(), - Dims(), - Dims(), - Dims(), - "", - ), - lambda: tensor, - ) + def from_torch(tensor: torch.Tensor, runtime_id: int = -1) -> "Tensor": + """ + Returns an ARK tensor that shares the same memory with the torch tensor. + """ + if _no_torch: + raise ImportError("torch is not available") + elif not tensor.is_contiguous(): + raise ValueError("Torch tensor must be contiguous.") + elif tensor.device.type == "cpu": + raise ValueError("Torch tensor must be on a device.") + ark_dtype = DataType.from_torch(tensor.dtype) + dl_capsule = torch.utils.dlpack.to_dlpack(tensor) + ark_tensor = _Tensor(dl_capsule, ark_dtype.ctype()) + return Tensor(ark_tensor, runtime_id=runtime_id) def copy(self, data: Union[np.ndarray, torch.Tensor]) -> "Tensor": """ diff --git a/python/tensor_py.cpp b/python/tensor_py.cpp index fbd909d3d..16eb03421 100644 --- a/python/tensor_py.cpp +++ b/python/tensor_py.cpp @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +#include #include #include #include @@ -9,8 +10,51 @@ namespace py = pybind11; -void register_tensor(py::module &m) { +struct DLTensorMetadata { + void* data_ptr; + int32_t device_id; + DLDeviceType device_type; + int32_t ndim; + DLDataType dtype; + std::vector shape; + std::vector strides; + uint64_t byte_offset; +}; + +static DLTensorMetadata extractDLTensorMetadata(DLManagedTensor* dl_tensor) { + DLTensorMetadata metadata; + metadata.data_ptr = dl_tensor->dl_tensor.data; + metadata.device_id = dl_tensor->dl_tensor.device.device_id; + metadata.device_type = dl_tensor->dl_tensor.device.device_type; + metadata.ndim = dl_tensor->dl_tensor.ndim; + metadata.dtype = dl_tensor->dl_tensor.dtype; + metadata.shape.assign( + dl_tensor->dl_tensor.shape, + dl_tensor->dl_tensor.shape + dl_tensor->dl_tensor.ndim); + if (dl_tensor->dl_tensor.strides != nullptr) { + metadata.strides.assign( + dl_tensor->dl_tensor.strides, + dl_tensor->dl_tensor.strides + dl_tensor->dl_tensor.ndim); + } + metadata.byte_offset = dl_tensor->dl_tensor.byte_offset; + return metadata; +} + +void register_tensor(py::module& m) { py::class_(m, "_Tensor") + .def(py::init([](py::capsule capsule, const ark::DataType& dtype) { + DLManagedTensor* dl_tensor = (DLManagedTensor*)capsule; + if (!dl_tensor) { + throw std::runtime_error( + "Capsule does not contain a DLManagedTensor"); + } + DLTensorMetadata metadata = extractDLTensorMetadata(dl_tensor); + int32_t device_id = metadata.device_id; + void* data_ptr = metadata.data_ptr; + auto shape = metadata.shape; + + return new ark::Tensor(data_ptr, device_id, shape, dtype); + })) .def("id", &ark::Tensor::id) .def("shape", &ark::Tensor::shape, py::return_value_policy::reference) .def("strides", &ark::Tensor::strides, diff --git a/python/unittest/test_conversion.py b/python/unittest/test_conversion.py index 5befa1c34..833b88662 100644 --- a/python/unittest/test_conversion.py +++ b/python/unittest/test_conversion.py @@ -1,6 +1,7 @@ import pytest import numpy as np import ark +from typing import Callable try: import torch @@ -9,6 +10,8 @@ except ImportError: _no_torch = True +# ARK to Torch tests + def initialize_tensor(dimensions, dtype): tensor = ark.tensor(dimensions, dtype) @@ -69,7 +72,7 @@ def check_diff(input_tensor_host, input_view_numpy, value, index): # Test function to check if changes to the torch views are reflected in the original tensors @pytest.mark.parametrize("dtype", [ark.fp16, ark.fp32]) -def test_aliasing(dtype: ark.DataType): +def test_ark_to_torch_aliasing(dtype: ark.DataType): ark.init() dimensions = [4, 4] input_tensor, input_tensor_host = initialize_tensor(dimensions, dtype) @@ -126,3 +129,79 @@ def test_conversion_torch(): torch_tensor = t.to_torch() assert torch.all(torch_tensor == 7) + + +# Torch to ARK tests + +ArkBinOp = Callable[[ark.Tensor, ark.Tensor], ark.Tensor] +TorchBinOp = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] +ArkUnOp = Callable[[ark.Tensor], ark.Tensor] +TorchUnOp = Callable[[torch.Tensor], torch.Tensor] + + +# Verify the accuracy of binary operations involving ARK view tensors +@pytest.mark.parametrize( + "dtype, ark_op, torch_op, tensor_dims", + [(torch.float16, ark.add, torch.add, (2, 3))], +) +def test_bin_op(dtype, ark_op: ArkBinOp, torch_op: TorchBinOp, tensor_dims): + ark.init() + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + other_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + expected_output = torch_op(input_tensor, other_tensor).cpu().numpy() + input_ark_view = ark.Tensor.from_torch(input_tensor) + other_ark_view = ark.Tensor.from_torch(other_tensor) + output = ark_op(input_ark_view, other_ark_view) + runtime = ark.Runtime() + runtime.launch() + runtime.run() + output_host = output.to_numpy() + runtime.stop() + runtime.reset() + assert np.allclose(output_host, expected_output) + + +# Verify the accuracy of unary operations involving ARK view tensors +@pytest.mark.parametrize( + "dtype, ark_op, torch_op, tensor_dims", + [(torch.float16, ark.exp, torch.exp, (3, 3))], +) +def test_unary_op(dtype, ark_op: ArkUnOp, torch_op: TorchUnOp, tensor_dims): + ark.init() + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + expected_output = torch_op(input_tensor).cpu().numpy() + input_ark_view = ark.Tensor.from_torch(input_tensor) + output = ark_op(input_ark_view) + runtime = ark.Runtime() + runtime.launch() + runtime.run() + output_host = output.to_numpy() + runtime.stop() + runtime.reset() + assert np.allclose(output_host, expected_output) + + +# Test function to check if changes in torch tensors are reflected in ARK views +@pytest.mark.parametrize("dtype, tensor_dims", [(torch.float16, (64, 64))]) +def test_torch_to_ark_aliasing(dtype, tensor_dims): + ark.init() + # Initialize a PyTorch tensor + input_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + other_tensor = torch.randn(tensor_dims, dtype=dtype, device="cuda:0") + + input_ark_view = ark.Tensor.from_torch(input_tensor) + other_ark_view = ark.Tensor.from_torch(other_tensor) + + output = ark.add(input_ark_view, other_ark_view) + # Perform in place operations + input_tensor += other_tensor + other_tensor += input_tensor + expected_output = (input_tensor + other_tensor).cpu().numpy() + + runtime = ark.Runtime() + runtime.launch() + runtime.run() + output_host = output.to_numpy() + runtime.stop() + runtime.reset() + assert np.allclose(output_host, expected_output) From fe35541e02029b0d9a8da4cbdccf2565cbf516b0 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 3 Jul 2024 06:51:43 +0000 Subject: [PATCH 21/54] wip --- ark/api/executor.cpp | 22 +- ark/api/planner.cpp | 1 + ark/codegen.cpp | 16 +- ark/codegen.hpp | 3 +- ark/model/model_json.cpp | 14 +- ark/model_buffer_manager.hpp | 5 +- cmake/Utils.cmake | 2 +- docs/plan_file.md | 18 + examples/llama/model_test.py | 2 +- examples/tutorial/default_plan.json | 115 +++--- examples/tutorial/model.json | 46 +-- examples/tutorial/plan.json | 63 ++-- examples/tutorial/plan_1_larger_tile.json | 47 +-- examples/tutorial/plan_2_split_k.json | 63 ++-- examples/tutorial/plan_3_overwrite.json | 63 ++-- examples/tutorial/plan_tutorial.py | 4 +- plan_gpu0.json | 415 +++++++++++----------- python/ark/__init__.py | 3 +- python/ark/planner.py | 184 ++++++++++ python/ark/profiler.py | 30 +- python/ark/runtime.py | 52 +-- python/unittest/test_runtime.py | 27 +- 22 files changed, 686 insertions(+), 509 deletions(-) create mode 100644 python/ark/planner.py diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 0a780bcc0..20b162b16 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -228,6 +228,16 @@ void Executor::Impl::init(const std::string &plan) { plan_json_ = Json::parse(plan); } + auto gpu_manager = GpuManager::get_instance(gpu_id_); + + if (!gpu_manager->info().arch->belongs_to( + Arch::from_name(plan_json_.at("Architecture")))) { + LOG(WARN, "Architecture name of the plan `", + plan_json_.at("Architecture").get(), + "` is not compatible with the GPU architecture `", + gpu_manager->info().arch->name(), "`."); + } + buffer_id_to_offset_ = init_buffers(plan_json_); std::string buffer_id_to_offset_str; @@ -236,17 +246,9 @@ void Executor::Impl::init(const std::string &plan) { std::to_string(kv.first) + ": " + std::to_string(kv.second) + ", "; } - ModelBufferManager &buffer_manager = ModelBufferManager::get_instance(); + codegen_ = std::make_shared(plan_json_, buffer_id_to_offset_, + name_); - if (!buffer_manager.is_empty()) { - codegen_ = std::make_shared( - plan_json_, buffer_id_to_offset_, name, &buffer_manager); - } else { - codegen_ = std::make_shared(plan_json_, - buffer_id_to_offset_, name); - } - - auto gpu_manager = GpuManager::get_instance(gpu_id_); timer_begin_ = gpu_manager->create_event(); timer_end_ = gpu_manager->create_event(); buffer_ = gpu_manager->malloc(total_bytes_, 65536); diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index 5c9d09f2e..14e1b7b41 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -119,6 +119,7 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const { Json plan; plan["Rank"] = model_.rank(); plan["WorldSize"] = model_.world_size(); + plan["Architecture"] = gpu_info.arch->name(); plan["NumProcessors"] = max_num_processors; plan["NumWarpsPerProcessor"] = max_num_warps; plan["TaskInfos"] = task_infos; diff --git a/ark/codegen.cpp b/ark/codegen.cpp index a97e5e45b..55327329a 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -44,7 +44,7 @@ class CodeGenerator::Impl { public: Impl(const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name, ModelBufferManager *buffer_manager); + const std::string &name); ~Impl() = default; private: @@ -65,8 +65,6 @@ class CodeGenerator::Impl { std::string sync_process_range(const Range &ranges, int state_id); - ModelBufferManager *buffer_manager_; - protected: friend class CodeGenerator; @@ -81,11 +79,8 @@ class CodeGenerator::Impl { CodeGenerator::Impl::Impl(const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name, - ModelBufferManager *buffer_manager) - : buffer_id_to_offset_(buffer_id_to_offset), - name_(name), - buffer_manager_(buffer_manager) { + const std::string &name) + : buffer_id_to_offset_(buffer_id_to_offset), name_(name) { rank_ = plan.at("Rank"); world_size_ = plan.at("WorldSize"); num_procs_ = plan.at("NumProcessors"); @@ -446,9 +441,8 @@ std::string CodeGenerator::Impl::sync_process_range(const Range &range, CodeGenerator::CodeGenerator( const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name, ModelBufferManager *buffer_manager) - : impl_(std::make_shared(plan, buffer_id_to_offset, name, - buffer_manager)) {} + const std::string &name) + : impl_(std::make_shared(plan, buffer_id_to_offset, name)) {} std::string CodeGenerator::code() const { return impl_->code_; } diff --git a/ark/codegen.hpp b/ark/codegen.hpp index a2976e644..1ed8ec9f2 100644 --- a/ark/codegen.hpp +++ b/ark/codegen.hpp @@ -17,8 +17,7 @@ class CodeGenerator { public: CodeGenerator(const PlanJson &plan, const std::map &buffer_id_to_offset, - const std::string &name = "ark_kernel", - ModelBufferManager *buffer_manager = nullptr); + const std::string &name = "ark_kernel"); ~CodeGenerator() = default; diff --git a/ark/model/model_json.cpp b/ark/model/model_json.cpp index 97ce71967..86eb843e2 100644 --- a/ark/model/model_json.cpp +++ b/ark/model/model_json.cpp @@ -250,9 +250,13 @@ static void verify_format_processor_group(const Json &json) { } static void verify_format_plan(const Json &json) { - const std::vector required_fields = { - "Rank", "WorldSize", "NumProcessors", "NumWarpsPerProcessor", - "TaskInfos", "ProcessorGroups"}; + const std::vector required_fields = {"Rank", + "WorldSize", + "Architecture", + "NumProcessors", + "NumWarpsPerProcessor", + "TaskInfos", + "ProcessorGroups"}; for (const auto &field : required_fields) { if (!json.contains(field)) { ERR(NotFoundError, "PlanJson: " + field + " not found"); @@ -276,6 +280,7 @@ PlanJson::PlanJson(const Json &json) : Json((json != nullptr) ? json : Json{{"Rank", 0}, {"WorldSize", 1}, + {"Architecture", "ANY"}, {"NumProcessors", 1}, {"NumWarpsPerProcessor", 1}, {"TaskInfos", Json::array()}, @@ -292,6 +297,9 @@ static std::stringstream &dump_pretty_plan(const Json &json, dump_pretty_item(json.at("WorldSize"), "WorldSize", ss, indent + indent_step) << ",\n"; + dump_pretty_item(json.at("Architecture"), "Architecture", ss, + indent + indent_step) + << ",\n"; dump_pretty_item(json.at("NumProcessors"), "NumProcessors", ss, indent + indent_step) << ",\n"; diff --git a/ark/model_buffer_manager.hpp b/ark/model_buffer_manager.hpp index 7b705f4c8..4baaec7fe 100644 --- a/ark/model_buffer_manager.hpp +++ b/ark/model_buffer_manager.hpp @@ -46,9 +46,8 @@ class ModelBufferManager { bool is_empty() const { return buffers_.empty(); } private: - std::unordered_map> - buffers_; // Maps buffer IDs to pointers and sizes. - size_t next_compact_id_ = 0; + // Maps buffer IDs to pointers and sizes. + std::unordered_map> buffers_; ModelBufferManager() {} ModelBufferManager(const ModelBufferManager&) = delete; ModelBufferManager& operator=(const ModelBufferManager&) = delete; diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake index 9bb83fb42..b1fd1b132 100644 --- a/cmake/Utils.cmake +++ b/cmake/Utils.cmake @@ -14,7 +14,7 @@ if(GIT_CLANG_FORMAT) COMMAND ${GIT_CLANG_FORMAT} --style=file --diff || true ) add_custom_target(cpplint-autofix - COMMAND ${GIT_CLANG_FORMAT} --style=file || true + COMMAND ${GIT_CLANG_FORMAT} --style=file --extensions cc,cpp,h,hpp,cu,in,hip || true ) else() message(STATUS "git-clang-format not found.") diff --git a/docs/plan_file.md b/docs/plan_file.md index 90a4537a2..c06ccc35d 100644 --- a/docs/plan_file.md +++ b/docs/plan_file.md @@ -6,6 +6,7 @@ See an example plan file: [Example 1](../examples/tutorial/default_plan.json) - Rank (Int) - WorldSize (Int) + - Architecture (String) - NumProcessors (Int) - NumWarpsPerProcessor (Int) - TaskInfos (Array of TaskInfo) @@ -42,6 +43,23 @@ See an example plan file: [Example 1](../examples/tutorial/default_plan.json) `ProcessorRange`, `WarpRange`, `SramRange`, and `TaskRange` are in the "range" format, i.e., `[Begin, End, Step]` that indicates an arithmetic integer sequence with a common difference of `Step`, starting from `Begin` and ends before `End` (does not include `End`). They alternatively can be in the format `[Begin, End]` that assumes `Step` is 1. +## Architecture + +A name that refers to the hardware architecture where the plan is supposed to run over. The following names are currently supported. + +- `ANY`: compatible with all architectures. + +- NVIDIA Family + - `CUDA`: compatible with all supported NVIDIA architectures. + - `CUDA_70`: compatible with NVIDIA Volta architecture. + - `CUDA_80`: compatible with NVIDIA Ampere architecture. + - `CUDA_90`: compatible with NVIDIA Hopper architecture. + +- AMD Family + - `ROCM`: compatible with all supported AMD architectures. + - `ROCM_90A`: compatible with AMD CDNA 2 (GFX90A) architecture. + - `ROCM_942`: compatible with AMD CDNA 3 (GFX942) architecture. + ## TaskInfo A `TaskInfo` object describes a sequential set of operators. The followings describe each field of `TaskInfo`. diff --git a/examples/llama/model_test.py b/examples/llama/model_test.py index 585341640..71485be45 100644 --- a/examples/llama/model_test.py +++ b/examples/llama/model_test.py @@ -59,7 +59,7 @@ def run_ark( output = module(*module_inputs) with ark.Runtime() as rt: - rt.launch(plan_path="/mnt/changhohwang/ark/plan_gpu0.json") + rt.launch(ark.Plan.from_file("/mnt/changhohwang/ark/plan_gpu0.json")) # Load model parameters if state_dict: diff --git a/examples/tutorial/default_plan.json b/examples/tutorial/default_plan.json index c6b4be243..bb774a5b8 100644 --- a/examples/tutorial/default_plan.json +++ b/examples/tutorial/default_plan.json @@ -1,36 +1,37 @@ { "Rank": 0, "WorldSize": 1, - "NumProcessors": 108, - "NumWarpsPerProcessor": 8, + "Architecture": "ROCM_942", + "NumProcessors": 304, + "NumWarpsPerProcessor": 4, "TaskInfos": [ { "Id": 0, - "NumWarps": 8, - "SramBytes": 147456, + "NumWarps": 4, + "SramBytes": 24672, "Ops": [ { "Type": "Matmul", "Name": "matmul", "IsVirtual": false, "ReadTensors": [ - {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":1,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":1,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":4,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":4,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, "TransposeOther": {"BOOL":true} }, "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], "NumTasks": 172 } } @@ -46,13 +47,13 @@ "Name": "sigmoid", "IsVirtual": false, "ReadTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":6,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":6,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -74,14 +75,14 @@ "Name": "mul", "IsVirtual": false, "ReadTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":8,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":8,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -95,31 +96,31 @@ }, { "Id": 3, - "NumWarps": 8, - "SramBytes": 147456, + "NumWarps": 4, + "SramBytes": 24672, "Ops": [ { "Type": "Matmul", "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":3,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":3,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":10,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":10,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, "TransposeOther": {"BOOL":true} }, "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], "NumTasks": 172 } } @@ -135,14 +136,14 @@ "Name": "mul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":12,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":12,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -156,31 +157,31 @@ }, { "Id": 5, - "NumWarps": 8, - "SramBytes": 147456, + "NumWarps": 4, + "SramBytes": 24672, "Ops": [ { "Type": "Matmul", "Name": "matmul_2", "IsVirtual": false, "ReadTensors": [ - {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":2,"DataType":"FP16","Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008],"Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":2,"DataType":"FP16","Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008],"Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":14,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":14,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":15,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":15,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, "TransposeOther": {"BOOL":true} }, "Config": { - "NumWarps": 8, - "SramBytes": 147456, - "TileShapeMNK": [128,256,64], + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], "NumTasks": 64 } } @@ -189,12 +190,12 @@ ], "ProcessorGroups": [ { - "ProcessorRange": [0,108], + "ProcessorRange": [0,172], "ResourceGroups": [ { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,147456], + "ProcessorRange": [0,172], + "WarpRange": [0,4], + "SramRange": [0,24672], "TaskGroups": [ {"TaskId":0,"TaskRange":[0,172],"Granularity":1} ] @@ -202,10 +203,10 @@ ] }, { - "ProcessorRange": [0,108], + "ProcessorRange": [0,304], "ResourceGroups": [ { - "ProcessorRange": [0,108], + "ProcessorRange": [0,304], "WarpRange": [0,1], "SramRange": [0,0], "TaskGroups": [ @@ -215,10 +216,10 @@ ] }, { - "ProcessorRange": [0,108], + "ProcessorRange": [0,304], "ResourceGroups": [ { - "ProcessorRange": [0,108], + "ProcessorRange": [0,304], "WarpRange": [0,1], "SramRange": [0,0], "TaskGroups": [ @@ -228,12 +229,12 @@ ] }, { - "ProcessorRange": [0,108], + "ProcessorRange": [0,172], "ResourceGroups": [ { - "ProcessorRange": [0,108], - "WarpRange": [0,8], - "SramRange": [0,147456], + "ProcessorRange": [0,172], + "WarpRange": [0,4], + "SramRange": [0,24672], "TaskGroups": [ {"TaskId":3,"TaskRange":[0,172],"Granularity":1} ] @@ -241,10 +242,10 @@ ] }, { - "ProcessorRange": [0,108], + "ProcessorRange": [0,304], "ResourceGroups": [ { - "ProcessorRange": [0,108], + "ProcessorRange": [0,304], "WarpRange": [0,1], "SramRange": [0,0], "TaskGroups": [ @@ -258,8 +259,8 @@ "ResourceGroups": [ { "ProcessorRange": [0,64], - "WarpRange": [0,8], - "SramRange": [0,147456], + "WarpRange": [0,4], + "SramRange": [0,24672], "TaskGroups": [ {"TaskId":5,"TaskRange":[0,64],"Granularity":1} ] @@ -267,4 +268,4 @@ ] } ] -} +} \ No newline at end of file diff --git a/examples/tutorial/model.json b/examples/tutorial/model.json index 1bc9233a5..a6ba8e8be 100644 --- a/examples/tutorial/model.json +++ b/examples/tutorial/model.json @@ -12,14 +12,14 @@ "Name": "matmul", "IsVirtual": false, "ReadTensors": [ - {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":1,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":1,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":4,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":4,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -31,13 +31,13 @@ "Name": "sigmoid", "IsVirtual": false, "ReadTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":6,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":6,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {} }, @@ -46,14 +46,14 @@ "Name": "mul", "IsVirtual": false, "ReadTensors": [ - {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":5,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":7,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":8,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":8,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {} } @@ -69,14 +69,14 @@ "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":3,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":0,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":3,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":10,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":10,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -95,14 +95,14 @@ "Name": "mul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":9,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":11,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":12,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":12,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {} }, @@ -111,14 +111,14 @@ "Name": "matmul_2", "IsVirtual": false, "ReadTensors": [ - {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":2,"DataType":"FP16","Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008],"Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":13,"DataType":"FP16","Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":2,"DataType":"FP16","Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008],"Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":14,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":14,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":15,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":15,"DataType":"FP16","Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, diff --git a/examples/tutorial/plan.json b/examples/tutorial/plan.json index c0854e505..335c27549 100644 --- a/examples/tutorial/plan.json +++ b/examples/tutorial/plan.json @@ -1,6 +1,7 @@ { "Rank": 0, "WorldSize": 1, + "Architecture": "CUDA_80", "NumProcessors": 108, "NumWarpsPerProcessor": 8, "TaskInfos": [ @@ -14,14 +15,14 @@ "Name": "matmul", "IsVirtual": false, "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":1,"DataType":"FP16","Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} + {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, + {"Id":1,"DataType":"FP16","Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} ], "WriteTensors": [ - {"Id":4,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":4,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -46,13 +47,13 @@ "Name": "sigmoid", "IsVirtual": false, "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "WriteTensors": [ - {"Id":6,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":6,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":7,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":7,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": {}, "Config": { @@ -74,14 +75,14 @@ "Name": "mul", "IsVirtual": false, "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":7,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, + {"Id":7,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "WriteTensors": [ - {"Id":8,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":8,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":9,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": {}, "Config": { @@ -103,14 +104,14 @@ "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":3,"DataType":"FP16","Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} + {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, + {"Id":3,"DataType":"FP16","Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} ], "WriteTensors": [ - {"Id":10,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":10,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -135,14 +136,14 @@ "Name": "mul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":9,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, + {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "WriteTensors": [ - {"Id":12,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":12,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": {}, "Config": { @@ -164,14 +165,14 @@ "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":16,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,8320],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,8320]}, - {"Id":17,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,8320],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,8320]} + {"Id":16,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,8320],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,8320]}, + {"Id":17,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[4096,8320],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,8320]} ], "WriteTensors": [ - {"Id":14,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":14,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "ResultTensors": [ - {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -196,14 +197,14 @@ "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":18,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,2688],"Strides":[1,512,11008],"Offsets":[0,0,8320],"PaddedShape":[1,512,2688]}, - {"Id":19,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,2688],"Strides":[4096,11008],"Offsets":[0,8320],"PaddedShape":[4096,2688]} + {"Id":18,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,2688],"Strides":[1,512,11008],"Offsets":[0,0,8320],"PaddedShape":[1,512,2688]}, + {"Id":19,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[4096,2688],"Strides":[4096,11008],"Offsets":[0,8320],"PaddedShape":[4096,2688]} ], "WriteTensors": [ - {"Id":20,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":20,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "ResultTensors": [ - {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -228,14 +229,14 @@ "Name": "add_1", "IsVirtual": false, "ReadTensors": [ - {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, + {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "WriteTensors": [ - {"Id":23,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":23,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "ResultTensors": [ - {"Id":15,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":15,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "Args": {}, "Config": { diff --git a/examples/tutorial/plan_1_larger_tile.json b/examples/tutorial/plan_1_larger_tile.json index 3a3f66530..04d2e9d60 100644 --- a/examples/tutorial/plan_1_larger_tile.json +++ b/examples/tutorial/plan_1_larger_tile.json @@ -1,6 +1,7 @@ { "Rank": 0, "WorldSize": 1, + "Architecture": "CUDA_80", "NumProcessors": 108, "NumWarpsPerProcessor": 8, "TaskInfos": [ @@ -14,14 +15,14 @@ "Name": "matmul", "IsVirtual": false, "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":1,"DataType":"FP16","Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} + {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, + {"Id":1,"DataType":"FP16","Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} ], "WriteTensors": [ - {"Id":4,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":4,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -46,13 +47,13 @@ "Name": "sigmoid", "IsVirtual": false, "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "WriteTensors": [ - {"Id":6,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":6,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":7,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":7,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": {}, "Config": { @@ -74,14 +75,14 @@ "Name": "mul", "IsVirtual": false, "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":7,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, + {"Id":7,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "WriteTensors": [ - {"Id":8,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":8,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":9,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": {}, "Config": { @@ -103,14 +104,14 @@ "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":3,"DataType":"FP16","Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} + {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, + {"Id":3,"DataType":"FP16","Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} ], "WriteTensors": [ - {"Id":10,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":10,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -135,14 +136,14 @@ "Name": "mul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":9,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, + {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "WriteTensors": [ - {"Id":12,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":12,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": {}, "Config": { @@ -164,14 +165,14 @@ "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":2,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008]} + {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, + {"Id":2,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008]} ], "WriteTensors": [ - {"Id":14,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":14,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "ResultTensors": [ - {"Id":15,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":15,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "Args": { "TransposeInput": {"BOOL":false}, diff --git a/examples/tutorial/plan_2_split_k.json b/examples/tutorial/plan_2_split_k.json index 493515d8c..837944171 100644 --- a/examples/tutorial/plan_2_split_k.json +++ b/examples/tutorial/plan_2_split_k.json @@ -1,6 +1,7 @@ { "Rank": 0, "WorldSize": 1, + "Architecture": "CUDA_80", "NumProcessors": 108, "NumWarpsPerProcessor": 8, "TaskInfos": [ @@ -14,14 +15,14 @@ "Name": "matmul", "IsVirtual": false, "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":1,"DataType":"FP16","Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} + {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, + {"Id":1,"DataType":"FP16","Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} ], "WriteTensors": [ - {"Id":4,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":4,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -46,13 +47,13 @@ "Name": "sigmoid", "IsVirtual": false, "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "WriteTensors": [ - {"Id":6,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":6,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":7,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":7,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": {}, "Config": { @@ -74,14 +75,14 @@ "Name": "mul", "IsVirtual": false, "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":7,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, + {"Id":7,"DataType":"FP16","Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "WriteTensors": [ - {"Id":8,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":8,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":9,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": {}, "Config": { @@ -103,14 +104,14 @@ "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":3,"DataType":"FP16","Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} + {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, + {"Id":3,"DataType":"FP16","Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} ], "WriteTensors": [ - {"Id":10,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":10,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -135,14 +136,14 @@ "Name": "mul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":9,"DataType":"FP16","Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, + {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "WriteTensors": [ - {"Id":12,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":12,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": {}, "Config": { @@ -164,14 +165,14 @@ "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":16,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,8320],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,8320]}, - {"Id":17,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,8320],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,8320]} + {"Id":16,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,8320],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,8320]}, + {"Id":17,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[4096,8320],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,8320]} ], "WriteTensors": [ - {"Id":14,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":14,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "ResultTensors": [ - {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -196,14 +197,14 @@ "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":18,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,2688],"Strides":[1,512,11008],"Offsets":[0,0,8320],"PaddedShape":[1,512,2688]}, - {"Id":19,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,2688],"Strides":[4096,11008],"Offsets":[0,8320],"PaddedShape":[4096,2688]} + {"Id":18,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,2688],"Strides":[1,512,11008],"Offsets":[0,0,8320],"PaddedShape":[1,512,2688]}, + {"Id":19,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[4096,2688],"Strides":[4096,11008],"Offsets":[0,8320],"PaddedShape":[4096,2688]} ], "WriteTensors": [ - {"Id":20,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":20,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "ResultTensors": [ - {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -228,14 +229,14 @@ "Name": "add_1", "IsVirtual": false, "ReadTensors": [ - {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, + {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "WriteTensors": [ - {"Id":23,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":23,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "ResultTensors": [ - {"Id":15,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":15,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "Args": {}, "Config": { diff --git a/examples/tutorial/plan_3_overwrite.json b/examples/tutorial/plan_3_overwrite.json index c0854e505..335c27549 100644 --- a/examples/tutorial/plan_3_overwrite.json +++ b/examples/tutorial/plan_3_overwrite.json @@ -1,6 +1,7 @@ { "Rank": 0, "WorldSize": 1, + "Architecture": "CUDA_80", "NumProcessors": 108, "NumWarpsPerProcessor": 8, "TaskInfos": [ @@ -14,14 +15,14 @@ "Name": "matmul", "IsVirtual": false, "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":1,"DataType":"FP16","Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} + {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, + {"Id":1,"DataType":"FP16","Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} ], "WriteTensors": [ - {"Id":4,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":4,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -46,13 +47,13 @@ "Name": "sigmoid", "IsVirtual": false, "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "WriteTensors": [ - {"Id":6,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":6,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":7,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":7,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": {}, "Config": { @@ -74,14 +75,14 @@ "Name": "mul", "IsVirtual": false, "ReadTensors": [ - {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":7,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":5,"DataType":"FP16","Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, + {"Id":7,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "WriteTensors": [ - {"Id":8,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":8,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":9,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": {}, "Config": { @@ -103,14 +104,14 @@ "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":3,"DataType":"FP16","Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} + {"Id":0,"DataType":"FP16","Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, + {"Id":3,"DataType":"FP16","Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096]} ], "WriteTensors": [ - {"Id":10,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":10,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -135,14 +136,14 @@ "Name": "mul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":9,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, - {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":9,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]}, + {"Id":11,"DataType":"FP16","Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "WriteTensors": [ - {"Id":12,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":12,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "ResultTensors": [ - {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} + {"Id":13,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,11008],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,11008]} ], "Args": {}, "Config": { @@ -164,14 +165,14 @@ "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":16,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,8320],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,8320]}, - {"Id":17,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,8320],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,8320]} + {"Id":16,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,8320],"Strides":[1,512,11008],"Offsets":[0,0,0],"PaddedShape":[1,512,8320]}, + {"Id":17,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[4096,8320],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,8320]} ], "WriteTensors": [ - {"Id":14,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":14,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "ResultTensors": [ - {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -196,14 +197,14 @@ "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":18,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,2688],"Strides":[1,512,11008],"Offsets":[0,0,8320],"PaddedShape":[1,512,2688]}, - {"Id":19,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[4096,2688],"Strides":[4096,11008],"Offsets":[0,8320],"PaddedShape":[4096,2688]} + {"Id":18,"DataType":"FP16","Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,2688],"Strides":[1,512,11008],"Offsets":[0,0,8320],"PaddedShape":[1,512,2688]}, + {"Id":19,"DataType":"FP16","Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[4096,2688],"Strides":[4096,11008],"Offsets":[0,8320],"PaddedShape":[4096,2688]} ], "WriteTensors": [ - {"Id":20,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":20,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "ResultTensors": [ - {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -228,14 +229,14 @@ "Name": "add_1", "IsVirtual": false, "ReadTensors": [ - {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, - {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":22,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]}, + {"Id":21,"DataType":"FP16","Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "WriteTensors": [ - {"Id":23,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":23,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "ResultTensors": [ - {"Id":15,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} + {"Id":15,"DataType":"FP16","Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false},"Shape":[1,512,4096],"Strides":[1,512,4096],"Offsets":[0,0,0],"PaddedShape":[1,512,4096]} ], "Args": {}, "Config": { diff --git a/examples/tutorial/plan_tutorial.py b/examples/tutorial/plan_tutorial.py index 056523e15..989f29c5e 100644 --- a/examples/tutorial/plan_tutorial.py +++ b/examples/tutorial/plan_tutorial.py @@ -339,7 +339,7 @@ def main(plan_path: str): plan = planner.plan() with open("default_plan.json", "w") as f: - f.write(plan) + f.write(str(plan)) rt.launch(plan=plan) # Initialize @@ -364,7 +364,7 @@ def main(plan_path: str): print(f"File {plan_path} does not exist. Exiting...") return with ark.Runtime.get_runtime() as rt: - rt.launch(plan_path=plan_path) + rt.launch(plan=ark.Plan.from_file(plan_path)) # Initialize InputModule.initialize() diff --git a/plan_gpu0.json b/plan_gpu0.json index 49b6bdd98..63c1943e3 100644 --- a/plan_gpu0.json +++ b/plan_gpu0.json @@ -1,6 +1,7 @@ { "Rank": 0, "WorldSize": 1, + "Architecture": "ROCM_942", "NumProcessors": 304, "NumWarpsPerProcessor": 4, "TaskInfos": [ @@ -14,14 +15,14 @@ "Name": "matmul", "IsVirtual": false, "ReadTensors": [ - {"Id":4,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":0,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":4,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":0,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":6,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":6,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":7,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":7,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -46,14 +47,14 @@ "Name": "rope", "IsVirtual": false, "ReadTensors": [ - {"Id":12,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":5,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":12,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":5,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":15,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":15,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":16,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":16,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -75,13 +76,13 @@ "Name": "transpose", "IsVirtual": false, "ReadTensors": [ - {"Id":16,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":16,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":19,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":19,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":20,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":20,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Permutation": {"DIMS":[0,2,1,3]} @@ -105,14 +106,14 @@ "Name": "matmul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":4,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":1,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":4,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":1,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":8,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":8,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":9,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":9,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -137,14 +138,14 @@ "Name": "rope_1", "IsVirtual": false, "ReadTensors": [ - {"Id":13,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":5,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":13,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":5,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":17,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":17,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":18,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":18,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -166,13 +167,13 @@ "Name": "transpose_2", "IsVirtual": false, "ReadTensors": [ - {"Id":18,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":18,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":23,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":23,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":24,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":24,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Permutation": {"DIMS":[0,2,3,1]} @@ -196,14 +197,14 @@ "Name": "matmul_2", "IsVirtual": false, "ReadTensors": [ - {"Id":4,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":2,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":4,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":2,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":10,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":10,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":11,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":11,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -228,13 +229,13 @@ "Name": "transpose_1", "IsVirtual": false, "ReadTensors": [ - {"Id":14,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":14,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":21,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":21,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":22,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":22,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Permutation": {"DIMS":[0,2,1,3]} @@ -258,14 +259,14 @@ "Name": "matmul_3", "IsVirtual": false, "ReadTensors": [ - {"Id":20,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":24,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":20,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":24,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":25,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":25,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":26,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":26,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -290,13 +291,13 @@ "Name": "mul", "IsVirtual": false, "ReadTensors": [ - {"Id":26,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":26,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":27,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":27,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Factor": {"FLOAT":0.0883883461356163} @@ -320,13 +321,13 @@ "Name": "reduce_max", "IsVirtual": false, "ReadTensors": [ - {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":29,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":29,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":30,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":30,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Axis": {"INT":3}, @@ -351,14 +352,14 @@ "Name": "sub", "IsVirtual": false, "ReadTensors": [ - {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":30,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":30,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":31,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":31,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -380,13 +381,13 @@ "Name": "exp", "IsVirtual": false, "ReadTensors": [ - {"Id":31,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":31,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":31,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":31,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -408,13 +409,13 @@ "Name": "reduce_sum", "IsVirtual": false, "ReadTensors": [ - {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":33,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":33,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":34,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":34,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Axis": {"INT":3}, @@ -439,14 +440,14 @@ "Name": "div", "IsVirtual": false, "ReadTensors": [ - {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":34,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":34,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":35,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":35,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -468,14 +469,14 @@ "Name": "matmul_4", "IsVirtual": false, "ReadTensors": [ - {"Id":35,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":22,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":35,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":22,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":36,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":36,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":37,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":37,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -500,13 +501,13 @@ "Name": "transpose_3", "IsVirtual": false, "ReadTensors": [ - {"Id":37,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":37,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":38,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":38,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":39,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":39,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Permutation": {"DIMS":[0,2,1,3]} @@ -530,14 +531,14 @@ "Name": "matmul_5", "IsVirtual": false, "ReadTensors": [ - {"Id":40,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":3,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":40,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":3,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":41,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":20,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":41,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":20,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":42,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":20,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":42,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":20,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -562,13 +563,13 @@ "Name": "cast", "IsVirtual": false, "ReadTensors": [ - {"Id":52,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":30,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":52,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":30,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":54,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":54,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -590,14 +591,14 @@ "Name": "mul_1", "IsVirtual": false, "ReadTensors": [ - {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":56,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":56,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":57,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":57,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -619,13 +620,13 @@ "Name": "reduce_mean", "IsVirtual": false, "ReadTensors": [ - {"Id":57,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":57,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":58,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":58,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":59,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":59,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Axis": {"INT":2}, @@ -650,13 +651,13 @@ "Name": "rsqrt", "IsVirtual": false, "ReadTensors": [ - {"Id":59,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":59,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":60,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":60,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":61,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":61,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -678,14 +679,14 @@ "Name": "mul_2", "IsVirtual": false, "ReadTensors": [ - {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":61,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":61,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":62,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":62,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":63,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":63,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -707,14 +708,14 @@ "Name": "mul_3", "IsVirtual": false, "ReadTensors": [ - {"Id":63,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":50,"DataType":"FP32","Shape":[1,1,4096],"Strides":[1,1,4096],"Offsets":[0,0,0],"PaddedShape":[1,1,4096],"Buffer":{"Id":28,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":63,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":50,"DataType":"FP32","Shape":[1,1,4096],"Strides":[1,1,4096],"Offsets":[0,0,0],"PaddedShape":[1,1,4096],"Buffer":{"Id":28,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":63,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":63,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":64,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":64,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -736,13 +737,13 @@ "Name": "cast_1", "IsVirtual": false, "ReadTensors": [ - {"Id":64,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":64,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":65,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":65,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -764,14 +765,14 @@ "Name": "matmul_6", "IsVirtual": false, "ReadTensors": [ - {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":43,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":21,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":43,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":21,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":67,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":67,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":68,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":68,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -796,14 +797,14 @@ "Name": "rope_2", "IsVirtual": false, "ReadTensors": [ - {"Id":73,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":53,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":31,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":73,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":53,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":31,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":76,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":76,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":77,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":77,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -825,13 +826,13 @@ "Name": "transpose_4", "IsVirtual": false, "ReadTensors": [ - {"Id":77,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":77,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":80,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":80,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":81,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":81,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Permutation": {"DIMS":[0,2,1,3]} @@ -855,14 +856,14 @@ "Name": "matmul_7", "IsVirtual": false, "ReadTensors": [ - {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":44,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":22,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":44,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":22,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":69,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":69,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":70,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":70,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -887,14 +888,14 @@ "Name": "rope_3", "IsVirtual": false, "ReadTensors": [ - {"Id":74,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":53,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":31,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":74,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":53,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":31,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":78,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":78,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":79,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":79,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -916,13 +917,13 @@ "Name": "transpose_6", "IsVirtual": false, "ReadTensors": [ - {"Id":79,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":79,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":84,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":84,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":85,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":85,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Permutation": {"DIMS":[0,2,3,1]} @@ -946,14 +947,14 @@ "Name": "matmul_8", "IsVirtual": false, "ReadTensors": [ - {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":45,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":23,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":45,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":23,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":71,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":71,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":72,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":72,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -978,13 +979,13 @@ "Name": "transpose_5", "IsVirtual": false, "ReadTensors": [ - {"Id":75,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":75,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":82,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":82,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":83,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":83,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Permutation": {"DIMS":[0,2,1,3]} @@ -1008,14 +1009,14 @@ "Name": "matmul_9", "IsVirtual": false, "ReadTensors": [ - {"Id":81,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":85,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":81,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":85,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":86,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":46,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":86,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":46,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":87,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":46,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":87,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":46,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -1040,13 +1041,13 @@ "Name": "mul_4", "IsVirtual": false, "ReadTensors": [ - {"Id":87,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":46,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":87,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":46,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":88,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":88,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Factor": {"FLOAT":0.0883883461356163} @@ -1070,13 +1071,13 @@ "Name": "reduce_max_1", "IsVirtual": false, "ReadTensors": [ - {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":90,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":48,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":90,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":48,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":91,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":48,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":91,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":48,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Axis": {"INT":3}, @@ -1101,14 +1102,14 @@ "Name": "sub_1", "IsVirtual": false, "ReadTensors": [ - {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":91,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":48,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":91,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":48,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":92,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":92,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -1130,13 +1131,13 @@ "Name": "exp_1", "IsVirtual": false, "ReadTensors": [ - {"Id":92,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":92,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":92,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":92,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -1158,13 +1159,13 @@ "Name": "reduce_sum_1", "IsVirtual": false, "ReadTensors": [ - {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":94,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":49,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":94,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":49,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":95,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":49,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":95,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":49,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Axis": {"INT":3}, @@ -1189,14 +1190,14 @@ "Name": "div_1", "IsVirtual": false, "ReadTensors": [ - {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":95,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":49,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":95,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":49,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":96,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":96,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -1218,14 +1219,14 @@ "Name": "matmul_10", "IsVirtual": false, "ReadTensors": [ - {"Id":96,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":83,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":96,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":83,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":97,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":50,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":97,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":50,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":98,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":50,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":98,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":50,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -1250,13 +1251,13 @@ "Name": "transpose_7", "IsVirtual": false, "ReadTensors": [ - {"Id":98,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":50,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":98,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":50,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":99,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":51,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":99,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":51,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":100,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":51,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":100,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":51,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Permutation": {"DIMS":[0,2,1,3]} @@ -1280,14 +1281,14 @@ "Name": "matmul_11", "IsVirtual": false, "ReadTensors": [ - {"Id":101,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":51,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":46,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":101,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":51,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":46,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":102,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":52,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":102,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":52,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":103,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":52,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":103,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":52,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -1312,14 +1313,14 @@ "Name": "add", "IsVirtual": false, "ReadTensors": [ - {"Id":52,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":30,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":103,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":52,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":52,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":30,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":103,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":52,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":104,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":104,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":105,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":105,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -1341,13 +1342,13 @@ "Name": "cast_2", "IsVirtual": false, "ReadTensors": [ - {"Id":105,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":105,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":106,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":106,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -1369,14 +1370,14 @@ "Name": "mul_5", "IsVirtual": false, "ReadTensors": [ - {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":108,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":55,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":108,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":55,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":109,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":55,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":109,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":55,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -1398,13 +1399,13 @@ "Name": "reduce_mean_1", "IsVirtual": false, "ReadTensors": [ - {"Id":109,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":55,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":109,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":55,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":110,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":56,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":110,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":56,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":111,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":56,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":111,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":56,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "Axis": {"INT":2}, @@ -1429,13 +1430,13 @@ "Name": "rsqrt_1", "IsVirtual": false, "ReadTensors": [ - {"Id":111,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":56,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":111,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":56,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":112,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":57,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":112,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":57,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":113,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":57,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":113,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":57,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -1457,14 +1458,14 @@ "Name": "mul_6", "IsVirtual": false, "ReadTensors": [ - {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":113,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":57,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":113,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":57,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":114,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":114,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":115,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":115,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -1486,14 +1487,14 @@ "Name": "mul_7", "IsVirtual": false, "ReadTensors": [ - {"Id":115,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":51,"DataType":"FP32","Shape":[1,1,4096],"Strides":[1,1,4096],"Offsets":[0,0,0],"PaddedShape":[1,1,4096],"Buffer":{"Id":29,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":115,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":51,"DataType":"FP32","Shape":[1,1,4096],"Strides":[1,1,4096],"Offsets":[0,0,0],"PaddedShape":[1,1,4096],"Buffer":{"Id":29,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":115,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":115,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":116,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":116,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -1515,13 +1516,13 @@ "Name": "cast_3", "IsVirtual": false, "ReadTensors": [ - {"Id":116,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":116,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":117,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":117,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":118,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":118,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -1543,14 +1544,14 @@ "Name": "matmul_12", "IsVirtual": false, "ReadTensors": [ - {"Id":118,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":47,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":25,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":118,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":47,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":25,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":119,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":119,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":120,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":120,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -1575,13 +1576,13 @@ "Name": "sigmoid", "IsVirtual": false, "ReadTensors": [ - {"Id":120,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":120,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":121,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":61,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":121,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":61,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":122,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":61,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":122,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":61,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -1603,14 +1604,14 @@ "Name": "mul_8", "IsVirtual": false, "ReadTensors": [ - {"Id":120,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":122,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":61,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":120,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":122,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":61,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":123,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":62,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":123,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":62,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":124,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":62,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":124,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":62,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -1632,14 +1633,14 @@ "Name": "matmul_13", "IsVirtual": false, "ReadTensors": [ - {"Id":118,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":49,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":27,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":118,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":49,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":27,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":125,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":63,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":125,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":63,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":126,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":63,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":126,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":63,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -1664,14 +1665,14 @@ "Name": "mul_9", "IsVirtual": false, "ReadTensors": [ - {"Id":124,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":62,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":126,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":63,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":124,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":62,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":126,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":63,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":127,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":64,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":127,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":64,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":128,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":64,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":128,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":64,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { @@ -1693,14 +1694,14 @@ "Name": "matmul_14", "IsVirtual": false, "ReadTensors": [ - {"Id":128,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":64,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":48,"DataType":"FP16","Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":128,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":64,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":48,"DataType":"FP16","Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":129,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":65,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":129,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":65,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":130,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":65,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":130,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":65,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { "TransposeInput": {"BOOL":false}, @@ -1725,14 +1726,14 @@ "Name": "add_1", "IsVirtual": false, "ReadTensors": [ - {"Id":105,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[]}}, - {"Id":130,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":65,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":105,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":130,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":65,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":131,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":66,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":131,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":66,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":132,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":66,"Rank":-1,"SendTags":[],"RecvTags":[]}} + {"Id":132,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":66,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": {}, "Config": { diff --git a/python/ark/__init__.py b/python/ark/__init__.py index f2f604be9..e96972906 100644 --- a/python/ark/__init__.py +++ b/python/ark/__init__.py @@ -37,7 +37,7 @@ def set_world_size(world_size): from .init import init from .tensor import Dims, Tensor, Parameter from .module import Module, RuntimeModule -from .runtime import Runtime, DefaultPlanner +from .runtime import Runtime from .serialize import save, load from .data_type import ( DataType, @@ -100,4 +100,5 @@ def set_world_size(world_size): GpuError, RuntimeError, ) +from .planner import DefaultPlanner, Plan from .profiler import Profiler diff --git a/python/ark/planner.py b/python/ark/planner.py new file mode 100644 index 000000000..8814896d2 --- /dev/null +++ b/python/ark/planner.py @@ -0,0 +1,184 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import copy +import json +from typing import Callable, Dict, List, Any + +from ._ark_core import _DefaultPlanner +from .model import Model + + +def idnt(indent): + return " " * indent + + +def dquote(s): + return '"' + s + '"' + + +def denser_json_obj(obj, key, level, indent, indent_step, ret=""): + if len(obj) == 0: + if key: + return ret + idnt(indent) + dquote(key) + ": {}" + else: + return ret + idnt(indent) + "{}" + ret += idnt(indent) + if key: + ret += dquote(key) + ": {\n" + else: + ret += "{\n" + num_item = len(obj) + for k, v in obj.items(): + is_obj_or_arr = isinstance(v, dict) or isinstance(v, list) + is_num_arr = isinstance(v, list) and v and isinstance(v[0], int) + if level <= 0 or not is_obj_or_arr or is_num_arr: + ret += ( + idnt(indent + indent_step) + + dquote(k) + + ": " + + json.dumps(v, separators=(",", ":")) + ) + elif isinstance(v, dict): + ret += denser_json_obj( + v, k, level - 1, indent + indent_step, indent_step + ) + elif isinstance(v, list): + ret += denser_json_arr( + v, k, level - 1, indent + indent_step, indent_step + ) + num_item -= 1 + if num_item > 0: + ret += ",\n" + else: + ret += "\n" + ret += idnt(indent) + "}" + return ret + + +def denser_json_arr(obj, key, level, indent, indent_step, ret=""): + if len(obj) == 0: + if key: + return ret + idnt(indent) + dquote(key) + ": []" + else: + return ret + idnt(indent) + "[]" + ret += idnt(indent) + if key: + ret += dquote(key) + ": [\n" + else: + ret += "[\n" + num_item = len(obj) + for v in obj: + is_obj_or_arr = isinstance(v, dict) or isinstance(v, list) + is_num_arr = ( + isinstance(v, list) + and v + and (isinstance(v[0], int) or isinstance(v[0], float)) + ) + if level <= 0 or not is_obj_or_arr or is_num_arr: + ret += idnt(indent + indent_step) + json.dumps( + v, separators=(",", ":") + ) + elif isinstance(v, dict): + ret += denser_json_obj( + v, "", level - 1, indent + indent_step, indent_step + ) + elif isinstance(v, list): + ret += denser_json_arr( + v, "", level - 1, indent + indent_step, indent_step + ) + num_item -= 1 + if num_item > 0: + ret += ",\n" + else: + ret += "\n" + ret += idnt(indent) + "]" + return ret + + +def denser_json(obj, level, indent_step=2): + if isinstance(obj, dict): + return denser_json_obj(obj, "", level, 0, indent_step, "") + elif isinstance(obj, list): + return denser_json_arr(obj, "", level, 0, indent_step, "") + return json.dumps(obj, indent=indent_step) + + +class Plan: + def __init__(self, plan: Dict[str, Any]): + if plan is None: + plan = {} + plan["Rank"] = 0 + plan["WorldSize"] = 1 + plan["Architecture"] = "ANY" + plan["NumProcessors"] = 1 + plan["NumWarpsPerProcessor"] = 1 + plan["TaskInfos"] = [] + plan["ProcessorGroups"] = [] + else: + plan = copy.deepcopy(plan) + self.plan = plan + + def __str__(self) -> str: + return denser_json(self.plan, 5) + + @property + def rank(self) -> int: + return self.plan["Rank"] + + @property + def world_size(self) -> int: + return self.plan["WorldSize"] + + @property + def architecture(self) -> str: + return self.plan["Architecture"] + + @property + def num_processors(self) -> int: + return self.plan["NumProcessors"] + + @property + def num_warps_per_processor(self) -> int: + return self.plan["NumWarpsPerProcessor"] + + @property + def task_infos(self) -> List[Dict[str, Any]]: + return self.plan["TaskInfos"] + + @property + def processor_groups(self) -> List[Dict[str, Any]]: + return self.plan["ProcessorGroups"] + + @staticmethod + def from_str(plan_str: str) -> "Plan": + plan = json.loads(plan_str) + return Plan(plan) + + @staticmethod + def from_file(file_path: str) -> "Plan": + with open(file_path, "r") as f: + plan = json.load(f) + return Plan(plan) + + +class DefaultPlanner(_DefaultPlanner): + def __init__(self, device_id: int = 0): + compressed = Model.get_model().compress() + super().__init__(compressed, device_id) + + def install_config_rule(self, rule: Callable[[str, str], str]): + """ + Install a configuration rule. + + Args: + rule: A function that takes an operator description and a target + architecture name and returns a configuration description. + """ + super().install_config_rule(rule) + + def plan(self) -> Plan: + """ + Generate an execution plan. + """ + return Plan.from_str(super().plan(pretty=False)) diff --git a/python/ark/profiler.py b/python/ark/profiler.py index b959ceb18..feb78e0de 100644 --- a/python/ark/profiler.py +++ b/python/ark/profiler.py @@ -1,30 +1,36 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import json import sys import time + from .runtime import Runtime +from .planner import Plan class Profiler: - def __init__(self, plan: str): - self.plan = json.loads(plan) + def __init__(self, plan: Plan): + self.plan = plan def run(self): - num_processor_groups = len(self.plan["ProcessorGroups"]) + num_processor_groups = len(self.plan.processor_groups) new_plan = { - "Rank": self.plan["Rank"], "WorldSize": self.plan["WorldSize"], - "NumProcessors": self.plan["NumProcessors"], - "NumWarpsPerProcessor": self.plan["NumWarpsPerProcessor"], - "TaskInfos": self.plan["TaskInfos"], - "ProcessorGroups": [{}]} + "Rank": self.plan.rank, + "WorldSize": self.plan.world_size, + "Architecture": self.plan.architecture, + "NumProcessors": self.plan.num_processors, + "NumWarpsPerProcessor": self.plan.num_warps_per_processor, + "TaskInfos": self.plan.task_infos, + "ProcessorGroups": [None], + } for i in range(num_processor_groups): - new_plan["ProcessorGroups"][0] = self.plan["ProcessorGroups"][i] + new_plan["ProcessorGroups"][0] = self.plan.processor_groups[i] with Runtime() as rt: - rt.launch(plan=json.dumps(new_plan)) + rt.launch(plan=str(new_plan)) start_time = time.time() iter = 1000 rt.run(iter=iter) end_time = time.time() - sys.stderr.write(f"Processor group {i} runtime: {(end_time - start_time)/iter:.6f} seconds/iter\n") + sys.stderr.write( + f"Processor group {i} runtime: {(end_time - start_time)/iter:.6f} seconds/iter\n" + ) diff --git a/python/ark/runtime.py b/python/ark/runtime.py index efae6ab3c..40bfaaa63 100644 --- a/python/ark/runtime.py +++ b/python/ark/runtime.py @@ -3,10 +3,11 @@ import logging from enum import Enum -from typing import Callable, Dict, List +from typing import Dict, List -from ._ark_core import _Executor, _DefaultPlanner +from ._ark_core import _Executor from .model import Model +from .planner import DefaultPlanner, Plan class _RuntimeState: @@ -46,33 +47,9 @@ def print_runtime_states(): print(f"{runtime_id:<12} | {runtime.state:<20}") -class DefaultPlanner(_DefaultPlanner): - def __init__(self, gpu_id: int = 0): - compressed = Model.get_model().compress() - super().__init__(compressed, gpu_id) - - def install_config_rule(self, rule: Callable[[str, str], str]): - """ - Install a configuration rule. - - Args: - rule: A function that takes an operator description and a target - architecture name and returns a configuration description. - """ - super().install_config_rule(rule) - - def plan(self, pretty: bool = True) -> str: - """ - Generate an execution plan. - - Args: - pretty: Whether to generate a pretty plan. - """ - return super().plan(pretty) - - class Executor(_Executor): - pass + def __init__(self, plan: Plan, device_id: int, name: str): + super().__init__(plan.rank, plan.world_size, device_id, name, str(plan)) class Runtime: @@ -155,11 +132,8 @@ def running(self) -> bool: def launch( self, - rank: int = 0, - world_size: int = 1, - gpu_id: int = 0, - plan: str = "", - plan_path: str = "", + plan: Plan = None, + device_id: int = 0, ): """ Create an executor and schedule the ARK model. The scheduler will generate @@ -172,11 +146,7 @@ def launch( ) return if not plan: - if not plan_path: - plan = DefaultPlanner(gpu_id).plan() - else: - with open(plan_path, "r") as f: - plan = f.read() + plan = DefaultPlanner(device_id).plan() # If the RuntimeState is init, we need to create a new executor and # compile the kernels if self.state == Runtime.State.Init: @@ -187,11 +157,9 @@ def launch( ) self.executor.destroy() self.executor = Executor( - rank, - world_size, - gpu_id, - "ArkRuntime", plan, + device_id, + "ArkRuntime", ) self.executor.compile() self.executor.launch() diff --git a/python/unittest/test_runtime.py b/python/unittest/test_runtime.py index fd34bb96b..b075c64ea 100644 --- a/python/unittest/test_runtime.py +++ b/python/unittest/test_runtime.py @@ -2,18 +2,9 @@ # Licensed under the MIT license. import ark -import json -empty_plan = json.dumps( - { - "Rank": 0, - "WorldSize": 1, - "NumProcessors": 1, - "NumWarpsPerProcessor": 1, - "TaskInfos": [], - "ProcessorGroups": [], - } -) + +empty_plan = ark.Plan(None) def test_runtime_relaunch(): @@ -35,7 +26,7 @@ def test_multiple_runtime_launch(): for i in range(num_runtimes): rt = ark.Runtime.get_runtime(i) assert rt.launched() == False - rt.launch(gpu_id=i, plan=empty_plan) + rt.launch(plan=empty_plan, device_id=i) assert rt.launched() == True for i in range(num_runtimes): rt = ark.Runtime.get_runtime(i) @@ -46,9 +37,9 @@ def test_multiple_runtime_launch(): def test_stop_runtime(): ark.init() rt1 = ark.Runtime.get_runtime(1) - rt1.launch(plan=empty_plan, gpu_id=1) + rt1.launch(plan=empty_plan, device_id=1) rt2 = ark.Runtime.get_runtime(2) - rt2.launch(plan=empty_plan, gpu_id=2) + rt2.launch(plan=empty_plan, device_id=2) rt1.stop() rt1.reset() assert rt1.state == ark.Runtime.State.Init @@ -59,9 +50,9 @@ def test_stop_runtime(): def test_reset_runtime(): ark.init() rt1 = ark.Runtime.get_runtime(0) - rt1.launch(plan=empty_plan, gpu_id=1) + rt1.launch(plan=empty_plan, device_id=1) rt2 = ark.Runtime.get_runtime(1) - rt2.launch(plan=empty_plan, gpu_id=2) + rt2.launch(plan=empty_plan, device_id=2) rt1.reset() assert rt1.launched() == False assert rt2.launched() == True @@ -77,7 +68,7 @@ def test_multiple_runtimes_complex(): default_runtime = ark.Runtime.get_runtime() runtime_list.append(default_runtime) for i, rt in enumerate(runtime_list): - rt.launch(plan=empty_plan, gpu_id=i) + rt.launch(plan=empty_plan, device_id=i) assert rt.launched() == True runtime_list[0].stop() assert runtime_list[0].state == ark.Runtime.State.LaunchedNotRunning @@ -87,7 +78,7 @@ def test_multiple_runtimes_complex(): assert runtime_list[1].state == ark.Runtime.State.Init assert runtime_list[0].state == ark.Runtime.State.LaunchedNotRunning assert runtime_list[2].state == ark.Runtime.State.LaunchedNotRunning - runtime_list[1].launch(plan=empty_plan, gpu_id=1) + runtime_list[1].launch(plan=empty_plan, device_id=1) for rt in runtime_list: assert rt.launched() == True ark.Runtime.delete_all_runtimes() From 0cb10b92c601306d537eb3de6259cf73e59b33df Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 3 Jul 2024 07:58:34 +0000 Subject: [PATCH 22/54] fix a reduction perf bug --- ark/include/kernels/reduce.h | 18 +++++++++--------- plan_gpu0.json | 36 ++++++++++++++++++------------------ python/ark/profiler.py | 24 +++++++++++++++--------- 3 files changed, 42 insertions(+), 36 deletions(-) diff --git a/ark/include/kernels/reduce.h b/ark/include/kernels/reduce.h index 30c8b7831..3d0b4e008 100644 --- a/ark/include/kernels/reduce.h +++ b/ark/include/kernels/reduce.h @@ -53,7 +53,7 @@ DEVICE bf16 warpReduce(bf16 val) { template DEVICE DataType warpsReduce(DataType val, int tid, int smem_per_warp) { val = warpReduce(val); - if (LanesNum > Arch::ThreadsPerWarp) { + if constexpr (LanesNum > Arch::ThreadsPerWarp) { ReduceSharedStorage *shared = UnitOp::template shared_memory>( smem_per_warp); @@ -351,8 +351,8 @@ struct WwiseReduce { /// @param in Input tensor. /// @param uop_idx Index of the unit operator. template - static DEVICE void runW(DataType *out, DataType *in, int uop_idx, - int smem_per_warp) { + static DEVICE void run(DataType *out, DataType *in, int uop_idx, + int smem_per_warp) { using ShapeChecker = ReduceShapeChecker; constexpr int NelemPerThread = @@ -450,8 +450,8 @@ template ::runW(out, in, uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeSum, Axis>::run(out, in, uop_idx, + smem_per_warp); } template ::runW(out, in, uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeMean, Axis>::run(out, in, uop_idx, + smem_per_warp); } template ::runW(out, in, uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeMax, Axis>::run(out, in, uop_idx, + smem_per_warp); } } // namespace ark diff --git a/plan_gpu0.json b/plan_gpu0.json index 63c1943e3..99e2da8fa 100644 --- a/plan_gpu0.json +++ b/plan_gpu0.json @@ -314,7 +314,7 @@ { "Id": 10, "NumWarps": 1, - "SramBytes": 256, + "SramBytes": 0, "Ops": [ { "Type": "ReduceMax", @@ -336,7 +336,7 @@ "Config": { "NumWarps": 1, "ImplType": "WarpWise", - "SramBytes": 256, + "SramBytes": 0, "NumTasks": 65536 } } @@ -402,7 +402,7 @@ { "Id": 13, "NumWarps": 1, - "SramBytes": 256, + "SramBytes": 0, "Ops": [ { "Type": "ReduceSum", @@ -424,7 +424,7 @@ "Config": { "NumWarps": 1, "ImplType": "WarpWise", - "SramBytes": 256, + "SramBytes": 0, "NumTasks": 65536 } } @@ -613,7 +613,7 @@ { "Id": 20, "NumWarps": 1, - "SramBytes": 256, + "SramBytes": 0, "Ops": [ { "Type": "ReduceMean", @@ -635,7 +635,7 @@ "Config": { "NumWarps": 1, "ImplType": "WarpWise", - "SramBytes": 256, + "SramBytes": 0, "NumTasks": 2048 } } @@ -1064,7 +1064,7 @@ { "Id": 35, "NumWarps": 1, - "SramBytes": 256, + "SramBytes": 0, "Ops": [ { "Type": "ReduceMax", @@ -1086,7 +1086,7 @@ "Config": { "NumWarps": 1, "ImplType": "WarpWise", - "SramBytes": 256, + "SramBytes": 0, "NumTasks": 65536 } } @@ -1152,7 +1152,7 @@ { "Id": 38, "NumWarps": 1, - "SramBytes": 256, + "SramBytes": 0, "Ops": [ { "Type": "ReduceSum", @@ -1174,7 +1174,7 @@ "Config": { "NumWarps": 1, "ImplType": "WarpWise", - "SramBytes": 256, + "SramBytes": 0, "NumTasks": 65536 } } @@ -1392,7 +1392,7 @@ { "Id": 46, "NumWarps": 1, - "SramBytes": 256, + "SramBytes": 0, "Ops": [ { "Type": "ReduceMean", @@ -1414,7 +1414,7 @@ "Config": { "NumWarps": 1, "ImplType": "WarpWise", - "SramBytes": 256, + "SramBytes": 0, "NumTasks": 2048 } } @@ -1883,7 +1883,7 @@ { "ProcessorRange": [0,304], "WarpRange": [0,4], - "SramRange": [0,256], + "SramRange": [0,0], "TaskGroups": [ {"TaskId":10,"TaskRange":[0,65536],"Granularity":1} ] @@ -1922,7 +1922,7 @@ { "ProcessorRange": [0,304], "WarpRange": [0,4], - "SramRange": [0,256], + "SramRange": [0,0], "TaskGroups": [ {"TaskId":13,"TaskRange":[0,65536],"Granularity":1} ] @@ -2013,7 +2013,7 @@ { "ProcessorRange": [0,304], "WarpRange": [0,4], - "SramRange": [0,256], + "SramRange": [0,0], "TaskGroups": [ {"TaskId":20,"TaskRange":[0,2048],"Granularity":1} ] @@ -2208,7 +2208,7 @@ { "ProcessorRange": [0,304], "WarpRange": [0,4], - "SramRange": [0,256], + "SramRange": [0,0], "TaskGroups": [ {"TaskId":35,"TaskRange":[0,65536],"Granularity":1} ] @@ -2247,7 +2247,7 @@ { "ProcessorRange": [0,304], "WarpRange": [0,4], - "SramRange": [0,256], + "SramRange": [0,0], "TaskGroups": [ {"TaskId":38,"TaskRange":[0,65536],"Granularity":1} ] @@ -2351,7 +2351,7 @@ { "ProcessorRange": [0,304], "WarpRange": [0,4], - "SramRange": [0,256], + "SramRange": [0,0], "TaskGroups": [ {"TaskId":46,"TaskRange":[0,2048],"Granularity":1} ] diff --git a/python/ark/profiler.py b/python/ark/profiler.py index feb78e0de..529a0d506 100644 --- a/python/ark/profiler.py +++ b/python/ark/profiler.py @@ -8,11 +8,22 @@ from .planner import Plan +def timeit(plan: Plan): + with Runtime() as rt: + rt.launch(plan=plan) + start_time = time.time() + iter = 1000 + rt.run(iter=iter) + end_time = time.time() + return (end_time - start_time) / iter + + class Profiler: def __init__(self, plan: Plan): self.plan = plan def run(self): + sys.stderr.write(f"End-to-end: {timeit(self.plan):.6f} seconds/iter\n") num_processor_groups = len(self.plan.processor_groups) new_plan = { "Rank": self.plan.rank, @@ -25,12 +36,7 @@ def run(self): } for i in range(num_processor_groups): new_plan["ProcessorGroups"][0] = self.plan.processor_groups[i] - with Runtime() as rt: - rt.launch(plan=str(new_plan)) - start_time = time.time() - iter = 1000 - rt.run(iter=iter) - end_time = time.time() - sys.stderr.write( - f"Processor group {i} runtime: {(end_time - start_time)/iter:.6f} seconds/iter\n" - ) + lat_per_iter = timeit(Plan(new_plan)) + sys.stderr.write( + f"Processor group {i}: {lat_per_iter:.6f} seconds/iter\n" + ) From 0fde9c5dc486ba1edb20235115575d360558ece9 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 4 Jul 2024 07:17:32 +0000 Subject: [PATCH 23/54] optimize --- ark/include/kernels/common/sync.h | 12 +-- ark/ops/ops_broadcast.cpp | 4 +- examples/llama/model_test.py | 2 +- plan_gpu0.json | 172 ++++++++---------------------- 4 files changed, 51 insertions(+), 139 deletions(-) diff --git a/ark/include/kernels/common/sync.h b/ark/include/kernels/common/sync.h index 85f7639c9..f47625600 100644 --- a/ark/include/kernels/common/sync.h +++ b/ark/include/kernels/common/sync.h @@ -106,25 +106,19 @@ DEVICE void sync_warps() { static_assert(Arch::ThreadsPerWarp == 64, ""); if constexpr (NumWarps == 1) { __builtin_amdgcn_wave_barrier(); - } else if constexpr (NumWarps == 16) { - __syncthreads(); } else { static_assert(ARK_SMEM_RESERVED_BYTES >= sizeof(sync::WarpGroupState), ""); - int lane_id = threadIdx.x & 63; - if (lane_id == 0) { + if ((threadIdx.x & 63) == 0) { constexpr int MaxOldCnt = NumWarps - 1; - int warp_id = threadIdx.x >> 6; - int group_id = warp_id / NumWarps; + int group_id = (threadIdx.x >> 6) / NumWarps; sync::WarpGroupState *state = reinterpret_cast(_ARK_SMEM); unsigned int tmp = state->is_inc_flag[group_id] ^ 1; if (atomicInc(&state->cnt[group_id], MaxOldCnt) == MaxOldCnt) { state->flag[group_id] = tmp; } else { - while (atomicAdd(&state->flag[group_id], 0) != tmp) - __builtin_amdgcn_s_sleep(1); - __asm__ __volatile__("s_wakeup"); + while (atomicAdd(&state->flag[group_id], 0) != tmp); } state->is_inc_flag[group_id] = tmp; } diff --git a/ark/ops/ops_broadcast.cpp b/ark/ops/ops_broadcast.cpp index 3985a0500..f20e8c4dc 100644 --- a/ark/ops/ops_broadcast.cpp +++ b/ark/ops/ops_broadcast.cpp @@ -27,8 +27,8 @@ ModelOpBroadcast1::ModelOpBroadcast1(const std::string &type_name, std::string ModelOpBroadcast1::impl_name(const Json &config) const { check_fields_config(config, {"NumWarps", "Tile"}); int num_warps = config.at("NumWarps"); - auto &tile_shape = config.at("Tile"); - Dims unit_out_dims{tile_shape[0], tile_shape[1]}; + const auto& tile_shape = config.at("Tile").get>(); + Dims unit_out_dims(tile_shape); return function_name_string( pascal_to_snake(type()->type_name()), diff --git a/examples/llama/model_test.py b/examples/llama/model_test.py index 71485be45..053015c04 100644 --- a/examples/llama/model_test.py +++ b/examples/llama/model_test.py @@ -473,7 +473,7 @@ def test_transformer_block( module_name_prefix="layers.0", rank=rank, world_size=world_size, - test_thru=True, + test_thru=False, ) diff --git a/plan_gpu0.json b/plan_gpu0.json index 99e2da8fa..cad05f774 100644 --- a/plan_gpu0.json +++ b/plan_gpu0.json @@ -31,7 +31,7 @@ "Config": { "NumWarps": 4, "SramBytes": 24672, - "TileShapeMNK": [128,256,32], + "TileShapeMNK": [256,128,32], "NumTasks": 256 } } @@ -39,7 +39,7 @@ }, { "Id": 1, - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, "Ops": [ { @@ -58,17 +58,17 @@ ], "Args": {}, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [32,128], - "NumTasks": 2048 + "Tile": [256,1,128], + "NumTasks": 256 } } ] }, { "Id": 2, - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, "Ops": [ { @@ -88,10 +88,10 @@ "Permutation": {"DIMS":[0,2,1,3]} }, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [8,128], - "NumTasks": 8192 + "Tile": [256,128], + "NumTasks": 256 } } ] @@ -122,7 +122,7 @@ "Config": { "NumWarps": 4, "SramBytes": 24672, - "TileShapeMNK": [128,256,32], + "TileShapeMNK": [256,128,32], "NumTasks": 256 } } @@ -130,7 +130,7 @@ }, { "Id": 4, - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, "Ops": [ { @@ -149,17 +149,17 @@ ], "Args": {}, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [32,128], - "NumTasks": 2048 + "Tile": [256,1,128], + "NumTasks": 256 } } ] }, { "Id": 5, - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, "Ops": [ { @@ -170,19 +170,19 @@ {"Id":18,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":23,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + {"Id":23,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":24,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + {"Id":24,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { - "Permutation": {"DIMS":[0,2,3,1]} + "Permutation": {"DIMS":[0,2,1,3]} }, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [8,128], - "NumTasks": 8192 + "Tile": [256,128], + "NumTasks": 256 } } ] @@ -213,7 +213,7 @@ "Config": { "NumWarps": 4, "SramBytes": 24672, - "TileShapeMNK": [128,256,32], + "TileShapeMNK": [256,128,32], "NumTasks": 256 } } @@ -221,7 +221,7 @@ }, { "Id": 7, - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, "Ops": [ { @@ -241,10 +241,10 @@ "Permutation": {"DIMS":[0,2,1,3]} }, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [8,128], - "NumTasks": 8192 + "Tile": [256,128], + "NumTasks": 256 } } ] @@ -260,7 +260,7 @@ "IsVirtual": false, "ReadTensors": [ {"Id":20,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":24,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + {"Id":24,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ {"Id":25,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} @@ -270,12 +270,12 @@ ], "Args": { "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":false} + "TransposeOther": {"BOOL":true} }, "Config": { "NumWarps": 4, "SramBytes": 24672, - "TileShapeMNK": [128,256,32], + "TileShapeMNK": [256,128,32], "NumTasks": 4096 } } @@ -305,7 +305,7 @@ "Config": { "NumWarps": 4, "SramBytes": 0, - "Tile": [128,256], + "Tile": [256,128], "NumTasks": 4096 } } @@ -1747,119 +1747,36 @@ } ], "ProcessorGroups": [ - { - "ProcessorRange": [0,256], - "ResourceGroups": [ - { - "ProcessorRange": [0,256], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":0,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, { "ProcessorRange": [0,304], "ResourceGroups": [ { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":1,"TaskRange":[0,2048],"Granularity":4} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":2,"TaskRange":[0,8192],"Granularity":4} - ] - } - ] - }, - { - "ProcessorRange": [0,256], - "ResourceGroups": [ - { - "ProcessorRange": [0,256], + "ProcessorRange": [0,86], "WarpRange": [0,4], "SramRange": [0,24672], "TaskGroups": [ - {"TaskId":3,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":4,"TaskRange":[0,2048],"Granularity":4} + {"TaskId":0,"TaskRange":[0,256],"Granularity":1}, + {"TaskId":1,"TaskRange":[0,256],"Granularity":1}, + {"TaskId":2,"TaskRange":[0,256],"Granularity":1} ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ + }, { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":5,"TaskRange":[0,8192],"Granularity":4} - ] - } - ] - }, - { - "ProcessorRange": [0,256], - "ResourceGroups": [ - { - "ProcessorRange": [0,256], + "ProcessorRange": [86,172], "WarpRange": [0,4], "SramRange": [0,24672], "TaskGroups": [ - {"TaskId":6,"TaskRange":[0,256],"Granularity":1} + {"TaskId":3,"TaskRange":[0,256],"Granularity":1}, + {"TaskId":4,"TaskRange":[0,256],"Granularity":1}, + {"TaskId":5,"TaskRange":[0,256],"Granularity":1} ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":7,"TaskRange":[0,8192],"Granularity":4} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ + }, { - "ProcessorRange": [0,304], + "ProcessorRange": [172,258], "WarpRange": [0,4], "SramRange": [0,24672], "TaskGroups": [ - {"TaskId":8,"TaskRange":[0,4096],"Granularity":1} + {"TaskId":6,"TaskRange":[0,256],"Granularity":1}, + {"TaskId":7,"TaskRange":[0,256],"Granularity":1} ] } ] @@ -1870,8 +1787,9 @@ { "ProcessorRange": [0,304], "WarpRange": [0,4], - "SramRange": [0,0], + "SramRange": [0,24672], "TaskGroups": [ + {"TaskId":8,"TaskRange":[0,4096],"Granularity":1}, {"TaskId":9,"TaskRange":[0,4096],"Granularity":1} ] } From c4be6d1bf7b7fcacdd11dd3efad7b4170461ce41 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Fri, 5 Jul 2024 00:14:05 +0000 Subject: [PATCH 24/54] wip --- ark/codegen.cpp | 6 +- arkprof.py | 4 + examples/llama/model_test.py | 23 +- examples/llama/plan_llama2_7b_b1_s2048.json | 1723 +++++++++++++++++++ python/ark/profiler.py | 12 +- 5 files changed, 1751 insertions(+), 17 deletions(-) create mode 100644 arkprof.py create mode 100644 examples/llama/plan_llama2_7b_b1_s2048.json diff --git a/ark/codegen.cpp b/ark/codegen.cpp index 55327329a..587bcae59 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -298,10 +298,14 @@ std::string CodeGenerator::Impl::resource_group( size_t proc_b = *rg_proc_range.begin(); size_t proc_e = *rg_proc_range.end(); size_t proc_s = rg_proc_range.step(); + std::map task_infos_map; + for (auto &task_info : task_infos) { + task_infos_map[task_info.at("Id").get()] = task_info; + } std::stringstream ss; for (auto &tg : rg_json["TaskGroups"]) { size_t task_id = tg["TaskId"]; - auto &task_info = task_infos[task_id]; + auto &task_info = task_infos_map.at(task_id); Range task_range(tg["TaskRange"][0], tg["TaskRange"][1]); size_t task_gran = tg["Granularity"]; size_t num_warps_per_task = task_info["NumWarps"]; diff --git a/arkprof.py b/arkprof.py new file mode 100644 index 000000000..782bba560 --- /dev/null +++ b/arkprof.py @@ -0,0 +1,4 @@ +import ark +import sys + +ark.Profiler(ark.Plan.from_file(sys.argv[1])).run(iter=1000, profile_processor_groups=False) diff --git a/examples/llama/model_test.py b/examples/llama/model_test.py index 053015c04..19c680854 100644 --- a/examples/llama/model_test.py +++ b/examples/llama/model_test.py @@ -59,7 +59,8 @@ def run_ark( output = module(*module_inputs) with ark.Runtime() as rt: - rt.launch(ark.Plan.from_file("/mnt/changhohwang/ark/plan_gpu0.json")) + plan = ark.Plan.from_file("plan_llama2_7b_b1_s2048.json") + rt.launch(plan) # Load model parameters if state_dict: @@ -438,22 +439,22 @@ def test_transformer_block( low=-1, high=1, size=(batch_size, seq_len, args.dim) ).astype(dtype) - module = model_ark.Attention( - args, ark.DataType.from_numpy(dtype), rank, world_size - ) + # module = model_ark.Attention( + # args, ark.DataType.from_numpy(dtype), rank, world_size + # ) # module_inputs = [ # ark.tensor(list(i.shape), ark.DataType.from_numpy(i.dtype)) # if isinstance(i, np.ndarray) # else i # for i in inputs # ] - feature_tensor = ark.tensor( - list(feature.shape), ark.DataType.from_numpy(feature.dtype) - ) - freqs_cis_ark_tensor = ark.tensor( - list(freqs_cis_ark.shape), ark.DataType.from_numpy(freqs_cis_ark.dtype) - ) - output = module(feature_tensor, 0, freqs_cis_ark_tensor, None) + # feature_tensor = ark.tensor( + # list(feature.shape), ark.DataType.from_numpy(feature.dtype) + # ) + # freqs_cis_ark_tensor = ark.tensor( + # list(freqs_cis_ark.shape), ark.DataType.from_numpy(freqs_cis_ark.dtype) + # ) + # output = module(feature_tensor, 0, freqs_cis_ark_tensor, None) # print(ark.Model.get_model().serialize()) diff --git a/examples/llama/plan_llama2_7b_b1_s2048.json b/examples/llama/plan_llama2_7b_b1_s2048.json new file mode 100644 index 000000000..d0e46d228 --- /dev/null +++ b/examples/llama/plan_llama2_7b_b1_s2048.json @@ -0,0 +1,1723 @@ +{ + "Rank": 0, + "WorldSize": 1, + "Architecture": "ROCM_942", + "NumProcessors": 304, + "NumWarpsPerProcessor": 4, + "TaskInfos": [ + { + "Id": 0, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Cast", + "Name": "cast", + "IsVirtual": false, + "ReadTensors": [ + {"Id":9,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":11,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":12,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 1, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul", + "IsVirtual": false, + "ReadTensors": [ + {"Id":12,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":12,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":13,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":14,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 2, + "NumWarps": 1, + "SramBytes": 256, + "Ops": [ + { + "Type": "ReduceMean", + "Name": "reduce_mean", + "IsVirtual": false, + "ReadTensors": [ + {"Id":14,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":15,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":16,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "Axis": {"INT":2}, + "KeepDim": {"BOOL":true} + }, + "Config": { + "NumWarps": 1, + "ImplType": "WarpWise", + "SramBytes": 256, + "NumTasks": 2048 + } + } + ] + }, + { + "Id": 3, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Rsqrt", + "Name": "rsqrt", + "IsVirtual": false, + "ReadTensors": [ + {"Id":16,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":17,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":18,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [64,1], + "NumTasks": 32 + } + } + ] + }, + { + "Id": 4, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":12,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":18,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":19,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":20,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 5, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_2", + "IsVirtual": false, + "ReadTensors": [ + {"Id":20,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":7,"DataType":"FP32","Shape":[1,1,4096],"Strides":[1,1,4096],"Offsets":[0,0,0],"PaddedShape":[1,1,4096],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":20,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":21,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 6, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Cast", + "Name": "cast_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":21,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":22,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":23,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 7, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul", + "IsVirtual": false, + "ReadTensors": [ + {"Id":23,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":0,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":24,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":25,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 8, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Rope", + "Name": "rope", + "IsVirtual": false, + "ReadTensors": [ + {"Id":30,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":10,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":33,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":20,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":34,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":20,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 9, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Transpose", + "Name": "transpose", + "IsVirtual": false, + "ReadTensors": [ + {"Id":34,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":20,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":37,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":22,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":38,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":22,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "Permutation": {"DIMS":[0,2,1,3]} + }, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [8,8], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 10, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":23,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":1,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":26,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":27,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 11, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Rope", + "Name": "rope_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":31,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":10,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":35,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":21,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":36,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":21,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 12, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Transpose", + "Name": "transpose_2", + "IsVirtual": false, + "ReadTensors": [ + {"Id":36,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":21,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":41,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":42,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "Permutation": {"DIMS":[0,2,3,1]} + }, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [8,8], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 13, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_2", + "IsVirtual": false, + "ReadTensors": [ + {"Id":23,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":2,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":28,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":29,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 14, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Transpose", + "Name": "transpose_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":32,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":39,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":23,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":40,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":23,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "Permutation": {"DIMS":[0,2,1,3]} + }, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [8,8], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 15, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_3", + "IsVirtual": false, + "ReadTensors": [ + {"Id":38,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":22,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":42,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":43,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":25,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":44,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":25,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":false} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 4096 + } + } + ] + }, + { + "Id": 16, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "ScalarMul", + "Name": "mul_3", + "IsVirtual": false, + "ReadTensors": [ + {"Id":44,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":25,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":45,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":46,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "Factor": {"FLOAT":0.0883883461356163} + }, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 2097152 + } + } + ] + }, + { + "Id": 17, + "NumWarps": 1, + "SramBytes": 256, + "Ops": [ + { + "Type": "ReduceMax", + "Name": "reduce_max", + "IsVirtual": false, + "ReadTensors": [ + {"Id":46,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":47,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":27,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":48,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":27,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "Axis": {"INT":3}, + "KeepDim": {"BOOL":true} + }, + "Config": { + "NumWarps": 1, + "ImplType": "WarpWise", + "SramBytes": 256, + "NumTasks": 65536 + } + } + ] + }, + { + "Id": 18, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Sub", + "Name": "sub", + "IsVirtual": false, + "ReadTensors": [ + {"Id":46,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":48,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":27,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":46,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":49,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 2097152 + } + } + ] + }, + { + "Id": 19, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Exp", + "Name": "exp", + "IsVirtual": false, + "ReadTensors": [ + {"Id":49,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":49,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":50,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 2097152 + } + } + ] + }, + { + "Id": 20, + "NumWarps": 1, + "SramBytes": 256, + "Ops": [ + { + "Type": "ReduceSum", + "Name": "reduce_sum", + "IsVirtual": false, + "ReadTensors": [ + {"Id":50,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":51,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":28,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":52,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":28,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "Axis": {"INT":3}, + "KeepDim": {"BOOL":true} + }, + "Config": { + "NumWarps": 1, + "ImplType": "WarpWise", + "SramBytes": 256, + "NumTasks": 65536 + } + } + ] + }, + { + "Id": 21, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Div", + "Name": "div", + "IsVirtual": false, + "ReadTensors": [ + {"Id":50,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":52,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":28,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":50,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":53,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 2097152 + } + } + ] + }, + { + "Id": 22, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_4", + "IsVirtual": false, + "ReadTensors": [ + {"Id":53,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":40,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":23,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":54,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":29,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":55,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":29,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":false} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [256,128,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 23, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Transpose", + "Name": "transpose_3", + "IsVirtual": false, + "ReadTensors": [ + {"Id":55,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":29,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":56,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":30,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":57,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":30,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "Permutation": {"DIMS":[0,2,1,3]} + }, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [8,8], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 24, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_5", + "IsVirtual": false, + "ReadTensors": [ + {"Id":58,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":30,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":3,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":59,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":31,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":60,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":31,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 25, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Add", + "Name": "add", + "IsVirtual": false, + "ReadTensors": [ + {"Id":9,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":60,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":31,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":61,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":62,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 26, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Cast", + "Name": "cast_2", + "IsVirtual": false, + "ReadTensors": [ + {"Id":62,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":63,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":64,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 27, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_4", + "IsVirtual": false, + "ReadTensors": [ + {"Id":64,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":64,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":65,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":66,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 28, + "NumWarps": 1, + "SramBytes": 256, + "Ops": [ + { + "Type": "ReduceMean", + "Name": "reduce_mean_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":66,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":67,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":68,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "Axis": {"INT":2}, + "KeepDim": {"BOOL":true} + }, + "Config": { + "NumWarps": 1, + "ImplType": "WarpWise", + "SramBytes": 256, + "NumTasks": 2048 + } + } + ] + }, + { + "Id": 29, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Rsqrt", + "Name": "rsqrt_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":68,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":69,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":70,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [64,1], + "NumTasks": 32 + } + } + ] + }, + { + "Id": 30, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_5", + "IsVirtual": false, + "ReadTensors": [ + {"Id":64,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":70,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":71,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":72,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 31, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_6", + "IsVirtual": false, + "ReadTensors": [ + {"Id":72,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":8,"DataType":"FP32","Shape":[1,1,4096],"Strides":[1,1,4096],"Offsets":[0,0,0],"PaddedShape":[1,1,4096],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":72,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":73,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 32, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Cast", + "Name": "cast_3", + "IsVirtual": false, + "ReadTensors": [ + {"Id":73,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":74,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":75,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + }, + { + "Id": 33, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_6", + "IsVirtual": false, + "ReadTensors": [ + {"Id":75,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":4,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":76,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":77,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 688 + } + } + ] + }, + { + "Id": 34, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Sigmoid", + "Name": "sigmoid", + "IsVirtual": false, + "ReadTensors": [ + {"Id":77,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":78,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":79,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 352256 + } + } + ] + }, + { + "Id": 35, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_7", + "IsVirtual": false, + "ReadTensors": [ + {"Id":77,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":79,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":80,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":81,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 352256 + } + } + ] + }, + { + "Id": 36, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_7", + "IsVirtual": false, + "ReadTensors": [ + {"Id":75,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":6,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":82,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":83,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 688 + } + } + ] + }, + { + "Id": 37, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Mul", + "Name": "mul_8", + "IsVirtual": false, + "ReadTensors": [ + {"Id":81,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":83,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":84,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":85,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 352256 + } + } + ] + }, + { + "Id": 38, + "NumWarps": 4, + "SramBytes": 24672, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_8", + "IsVirtual": false, + "ReadTensors": [ + {"Id":85,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":5,"DataType":"FP16","Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":86,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":87,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 24672, + "TileShapeMNK": [128,256,32], + "NumTasks": 256 + } + } + ] + }, + { + "Id": 39, + "NumWarps": 1, + "SramBytes": 0, + "Ops": [ + { + "Type": "Add", + "Name": "add_1", + "IsVirtual": false, + "ReadTensors": [ + {"Id":62,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":87,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":88,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":89,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1,64], + "NumTasks": 131072 + } + } + ] + } + ], + "ProcessorGroups": [ + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":0,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":1,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,256], + "TaskGroups": [ + {"TaskId":2,"TaskRange":[0,2048],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,32], + "ResourceGroups": [ + { + "ProcessorRange": [0,32], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":3,"TaskRange":[0,32],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":4,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":5,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":6,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":7,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":8,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":9,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":10,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":11,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":12,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":13,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":14,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":15,"TaskRange":[0,4096],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":16,"TaskRange":[0,2097152],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,256], + "TaskGroups": [ + {"TaskId":17,"TaskRange":[0,65536],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":18,"TaskRange":[0,2097152],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":19,"TaskRange":[0,2097152],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,256], + "TaskGroups": [ + {"TaskId":20,"TaskRange":[0,65536],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":21,"TaskRange":[0,2097152],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":22,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":23,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":24,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":25,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":26,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":27,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,256], + "TaskGroups": [ + {"TaskId":28,"TaskRange":[0,2048],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,32], + "ResourceGroups": [ + { + "ProcessorRange": [0,32], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":29,"TaskRange":[0,32],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":30,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":31,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":32,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":33,"TaskRange":[0,688],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":34,"TaskRange":[0,352256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":35,"TaskRange":[0,352256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":36,"TaskRange":[0,688],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":37,"TaskRange":[0,352256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,256], + "ResourceGroups": [ + { + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], + "TaskGroups": [ + {"TaskId":38,"TaskRange":[0,256],"Granularity":1} + ] + } + ] + }, + { + "ProcessorRange": [0,304], + "ResourceGroups": [ + { + "ProcessorRange": [0,304], + "WarpRange": [0,1], + "SramRange": [0,0], + "TaskGroups": [ + {"TaskId":39,"TaskRange":[0,131072],"Granularity":1} + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/python/ark/profiler.py b/python/ark/profiler.py index 529a0d506..56233247c 100644 --- a/python/ark/profiler.py +++ b/python/ark/profiler.py @@ -8,11 +8,10 @@ from .planner import Plan -def timeit(plan: Plan): +def timeit(plan: Plan, iter: int): with Runtime() as rt: rt.launch(plan=plan) start_time = time.time() - iter = 1000 rt.run(iter=iter) end_time = time.time() return (end_time - start_time) / iter @@ -22,8 +21,11 @@ class Profiler: def __init__(self, plan: Plan): self.plan = plan - def run(self): - sys.stderr.write(f"End-to-end: {timeit(self.plan):.6f} seconds/iter\n") + def run(self, iter: int = 1000, profile_processor_groups: bool = False): + sys.stderr.write(f"End-to-end: {timeit(self.plan, iter):.6f} seconds/iter\n") + + if not profile_processor_groups: + return num_processor_groups = len(self.plan.processor_groups) new_plan = { "Rank": self.plan.rank, @@ -36,7 +38,7 @@ def run(self): } for i in range(num_processor_groups): new_plan["ProcessorGroups"][0] = self.plan.processor_groups[i] - lat_per_iter = timeit(Plan(new_plan)) + lat_per_iter = timeit(Plan(new_plan), iter) sys.stderr.write( f"Processor group {i}: {lat_per_iter:.6f} seconds/iter\n" ) From cc30912486c24f71617ee2200c7429ea2e610d51 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Fri, 5 Jul 2024 07:12:49 +0000 Subject: [PATCH 25/54] optimization --- examples/llama/plan_llama2_7b_b1_s2048.json | 732 ++++---------------- 1 file changed, 126 insertions(+), 606 deletions(-) diff --git a/examples/llama/plan_llama2_7b_b1_s2048.json b/examples/llama/plan_llama2_7b_b1_s2048.json index d0e46d228..15b0de2d0 100644 --- a/examples/llama/plan_llama2_7b_b1_s2048.json +++ b/examples/llama/plan_llama2_7b_b1_s2048.json @@ -27,17 +27,10 @@ "Config": { "NumWarps": 1, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [1,4096], + "NumTasks": 2048 } - } - ] - }, - { - "Id": 1, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Mul", "Name": "mul", @@ -56,17 +49,10 @@ "Config": { "NumWarps": 1, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [1,4096], + "NumTasks": 2048 } - } - ] - }, - { - "Id": 2, - "NumWarps": 1, - "SramBytes": 256, - "Ops": [ + }, { "Type": "ReduceMean", "Name": "reduce_mean", @@ -87,7 +73,7 @@ "Config": { "NumWarps": 1, "ImplType": "WarpWise", - "SramBytes": 256, + "SramBytes": 0, "NumTasks": 2048 } } @@ -144,17 +130,10 @@ "Config": { "NumWarps": 1, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [1,4096], + "NumTasks": 2048 } - } - ] - }, - { - "Id": 5, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Mul", "Name": "mul_2", @@ -173,17 +152,10 @@ "Config": { "NumWarps": 1, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [1,4096], + "NumTasks": 2048 } - } - ] - }, - { - "Id": 6, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Cast", "Name": "cast_1", @@ -201,8 +173,8 @@ "Config": { "NumWarps": 1, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [1,4096], + "NumTasks": 2048 } } ] @@ -233,17 +205,10 @@ "Config": { "NumWarps": 4, "SramBytes": 24672, - "TileShapeMNK": [128,256,32], + "TileShapeMNK": [256,128,32], "NumTasks": 256 } - } - ] - }, - { - "Id": 8, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Rope", "Name": "rope", @@ -260,19 +225,12 @@ ], "Args": {}, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [256,1,128], + "NumTasks": 256 } - } - ] - }, - { - "Id": 9, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Transpose", "Name": "transpose", @@ -290,10 +248,10 @@ "Permutation": {"DIMS":[0,2,1,3]} }, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [8,8], - "NumTasks": 131072 + "Tile": [256,128], + "NumTasks": 256 } } ] @@ -324,17 +282,10 @@ "Config": { "NumWarps": 4, "SramBytes": 24672, - "TileShapeMNK": [128,256,32], + "TileShapeMNK": [256,128,32], "NumTasks": 256 } - } - ] - }, - { - "Id": 11, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Rope", "Name": "rope_1", @@ -351,19 +302,12 @@ ], "Args": {}, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [256,128], + "NumTasks": 256 } - } - ] - }, - { - "Id": 12, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Transpose", "Name": "transpose_2", @@ -372,19 +316,19 @@ {"Id":36,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":21,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ - {"Id":41,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + {"Id":41,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "ResultTensors": [ - {"Id":42,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + {"Id":42,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "Args": { - "Permutation": {"DIMS":[0,2,3,1]} + "Permutation": {"DIMS":[0,2,1,3]} }, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [8,8], - "NumTasks": 131072 + "Tile": [256,128], + "NumTasks": 256 } } ] @@ -415,17 +359,10 @@ "Config": { "NumWarps": 4, "SramBytes": 24672, - "TileShapeMNK": [128,256,32], + "TileShapeMNK": [256,128,32], "NumTasks": 256 } - } - ] - }, - { - "Id": 14, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Transpose", "Name": "transpose_1", @@ -443,10 +380,10 @@ "Permutation": {"DIMS":[0,2,1,3]} }, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [8,8], - "NumTasks": 131072 + "Tile": [256,128], + "NumTasks": 256 } } ] @@ -462,7 +399,7 @@ "IsVirtual": false, "ReadTensors": [ {"Id":38,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":22,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":42,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + {"Id":42,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} ], "WriteTensors": [ {"Id":43,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":25,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} @@ -472,22 +409,15 @@ ], "Args": { "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":false} + "TransposeOther": {"BOOL":true} }, "Config": { "NumWarps": 4, "SramBytes": 24672, - "TileShapeMNK": [128,256,32], + "TileShapeMNK": [256,128,32], "NumTasks": 4096 } - } - ] - }, - { - "Id": 16, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "ScalarMul", "Name": "mul_3", @@ -505,10 +435,10 @@ "Factor": {"FLOAT":0.0883883461356163} }, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 2097152 + "Tile": [256,128], + "NumTasks": 4096 } } ] @@ -516,7 +446,7 @@ { "Id": 17, "NumWarps": 1, - "SramBytes": 256, + "SramBytes": 0, "Ops": [ { "Type": "ReduceMax", @@ -538,17 +468,10 @@ "Config": { "NumWarps": 1, "ImplType": "WarpWise", - "SramBytes": 256, + "SramBytes": 0, "NumTasks": 65536 } - } - ] - }, - { - "Id": 18, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Sub", "Name": "sub", @@ -567,17 +490,10 @@ "Config": { "NumWarps": 1, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 2097152 + "Tile": [1,2048], + "NumTasks": 65536 } - } - ] - }, - { - "Id": 19, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Exp", "Name": "exp", @@ -595,17 +511,10 @@ "Config": { "NumWarps": 1, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 2097152 + "Tile": [1,2048], + "NumTasks": 65536 } - } - ] - }, - { - "Id": 20, - "NumWarps": 1, - "SramBytes": 256, - "Ops": [ + }, { "Type": "ReduceSum", "Name": "reduce_sum", @@ -626,17 +535,10 @@ "Config": { "NumWarps": 1, "ImplType": "WarpWise", - "SramBytes": 256, + "SramBytes": 0, "NumTasks": 65536 } - } - ] - }, - { - "Id": 21, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Div", "Name": "div", @@ -655,8 +557,8 @@ "Config": { "NumWarps": 1, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 2097152 + "Tile": [1,2048], + "NumTasks": 65536 } } ] @@ -690,14 +592,7 @@ "TileShapeMNK": [256,128,32], "NumTasks": 256 } - } - ] - }, - { - "Id": 23, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Transpose", "Name": "transpose_3", @@ -715,10 +610,10 @@ "Permutation": {"DIMS":[0,2,1,3]} }, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [8,8], - "NumTasks": 131072 + "Tile": [256,1,128], + "NumTasks": 256 } } ] @@ -749,17 +644,10 @@ "Config": { "NumWarps": 4, "SramBytes": 24672, - "TileShapeMNK": [128,256,32], + "TileShapeMNK": [256,128,32], "NumTasks": 256 } - } - ] - }, - { - "Id": 25, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Add", "Name": "add", @@ -776,19 +664,12 @@ ], "Args": {}, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [256,128], + "NumTasks": 256 } - } - ] - }, - { - "Id": 26, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Cast", "Name": "cast_2", @@ -804,19 +685,12 @@ ], "Args": {}, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [256,128], + "NumTasks": 256 } - } - ] - }, - { - "Id": 27, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Mul", "Name": "mul_4", @@ -833,10 +707,10 @@ ], "Args": {}, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [256,128], + "NumTasks": 256 } } ] @@ -844,7 +718,7 @@ { "Id": 28, "NumWarps": 1, - "SramBytes": 256, + "SramBytes": 0, "Ops": [ { "Type": "ReduceMean", @@ -866,7 +740,7 @@ "Config": { "NumWarps": 1, "ImplType": "WarpWise", - "SramBytes": 256, + "SramBytes": 0, "NumTasks": 2048 } } @@ -923,17 +797,10 @@ "Config": { "NumWarps": 1, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [1,4096], + "NumTasks": 2048 } - } - ] - }, - { - "Id": 31, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Mul", "Name": "mul_6", @@ -952,17 +819,10 @@ "Config": { "NumWarps": 1, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [1,4096], + "NumTasks": 2048 } - } - ] - }, - { - "Id": 32, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Cast", "Name": "cast_3", @@ -980,8 +840,8 @@ "Config": { "NumWarps": 1, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [1,4096], + "NumTasks": 2048 } } ] @@ -1012,17 +872,10 @@ "Config": { "NumWarps": 4, "SramBytes": 24672, - "TileShapeMNK": [128,256,32], + "TileShapeMNK": [256,128,32], "NumTasks": 688 } - } - ] - }, - { - "Id": 34, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Sigmoid", "Name": "sigmoid", @@ -1038,19 +891,12 @@ ], "Args": {}, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 352256 + "Tile": [256,128], + "NumTasks": 688 } - } - ] - }, - { - "Id": 35, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Mul", "Name": "mul_7", @@ -1067,10 +913,10 @@ ], "Args": {}, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 352256 + "Tile": [256,128], + "NumTasks": 688 } } ] @@ -1101,17 +947,10 @@ "Config": { "NumWarps": 4, "SramBytes": 24672, - "TileShapeMNK": [128,256,32], + "TileShapeMNK": [256,128,32], "NumTasks": 688 } - } - ] - }, - { - "Id": 37, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Mul", "Name": "mul_8", @@ -1128,10 +967,10 @@ ], "Args": {}, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 352256 + "Tile": [256,128], + "NumTasks": 688 } } ] @@ -1162,17 +1001,10 @@ "Config": { "NumWarps": 4, "SramBytes": 24672, - "TileShapeMNK": [128,256,32], + "TileShapeMNK": [256,128,32], "NumTasks": 256 } - } - ] - }, - { - "Id": 39, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ + }, { "Type": "Add", "Name": "add_1", @@ -1189,10 +1021,10 @@ ], "Args": {}, "Config": { - "NumWarps": 1, + "NumWarps": 4, "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 + "Tile": [256,128], + "NumTasks": 256 } } ] @@ -1204,23 +1036,23 @@ "ResourceGroups": [ { "ProcessorRange": [0,304], - "WarpRange": [0,1], + "WarpRange": [0,4], "SramRange": [0,0], "TaskGroups": [ - {"TaskId":0,"TaskRange":[0,131072],"Granularity":1} + {"TaskId":0,"TaskRange":[0,2048],"Granularity":4} ] } ] }, { - "ProcessorRange": [0,304], + "ProcessorRange": [0,32], "ResourceGroups": [ { - "ProcessorRange": [0,304], + "ProcessorRange": [0,32], "WarpRange": [0,1], "SramRange": [0,0], "TaskGroups": [ - {"TaskId":1,"TaskRange":[0,131072],"Granularity":1} + {"TaskId":3,"TaskRange":[0,32],"Granularity":1} ] } ] @@ -1230,101 +1062,23 @@ "ResourceGroups": [ { "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,0], "TaskGroups": [ - {"TaskId":2,"TaskRange":[0,2048],"Granularity":1} + {"TaskId":4,"TaskRange":[0,2048],"Granularity":4} ] } ] }, { - "ProcessorRange": [0,32], + "ProcessorRange": [0,256], "ResourceGroups": [ { - "ProcessorRange": [0,32], - "WarpRange": [0,1], - "SramRange": [0,0], + "ProcessorRange": [0,256], + "WarpRange": [0,4], + "SramRange": [0,24672], "TaskGroups": [ - {"TaskId":3,"TaskRange":[0,32],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":4,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":5,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":6,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,256], - "ResourceGroups": [ - { - "ProcessorRange": [0,256], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":7,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":8,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":9,"TaskRange":[0,131072],"Granularity":1} + {"TaskId":7,"TaskRange":[0,256],"Granularity":1} ] } ] @@ -1342,32 +1096,6 @@ } ] }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":11,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":12,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, { "ProcessorRange": [0,256], "ResourceGroups": [ @@ -1381,19 +1109,6 @@ } ] }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":14,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, { "ProcessorRange": [0,304], "ResourceGroups": [ @@ -1412,75 +1127,10 @@ "ResourceGroups": [ { "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":16,"TaskRange":[0,2097152],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,256], - "TaskGroups": [ - {"TaskId":17,"TaskRange":[0,65536],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":18,"TaskRange":[0,2097152],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":19,"TaskRange":[0,2097152],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,256], - "TaskGroups": [ - {"TaskId":20,"TaskRange":[0,65536],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], + "WarpRange": [0,4], "SramRange": [0,0], "TaskGroups": [ - {"TaskId":21,"TaskRange":[0,2097152],"Granularity":1} + {"TaskId":17,"TaskRange":[0,65536],"Granularity":4} ] } ] @@ -1498,19 +1148,6 @@ } ] }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":23,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, { "ProcessorRange": [0,256], "ResourceGroups": [ @@ -1529,49 +1166,10 @@ "ResourceGroups": [ { "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":25,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":26,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], + "WarpRange": [0,4], "SramRange": [0,0], "TaskGroups": [ - {"TaskId":27,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,256], - "TaskGroups": [ - {"TaskId":28,"TaskRange":[0,2048],"Granularity":1} + {"TaskId":28,"TaskRange":[0,2048],"Granularity":4} ] } ] @@ -1594,36 +1192,10 @@ "ResourceGroups": [ { "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":30,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":31,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], + "WarpRange": [0,4], "SramRange": [0,0], "TaskGroups": [ - {"TaskId":32,"TaskRange":[0,131072],"Granularity":1} + {"TaskId":30,"TaskRange":[0,2048],"Granularity":4} ] } ] @@ -1641,32 +1213,6 @@ } ] }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":34,"TaskRange":[0,352256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":35,"TaskRange":[0,352256],"Granularity":1} - ] - } - ] - }, { "ProcessorRange": [0,304], "ResourceGroups": [ @@ -1680,19 +1226,6 @@ } ] }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":37,"TaskRange":[0,352256],"Granularity":1} - ] - } - ] - }, { "ProcessorRange": [0,256], "ResourceGroups": [ @@ -1705,19 +1238,6 @@ ] } ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,1], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":39,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] } ] } \ No newline at end of file From 34a87d867669aae49b2a29056aadfed694d97b33 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 8 Jul 2024 02:10:40 +0000 Subject: [PATCH 26/54] optimize --- examples/llama/plan_llama2_7b_b1_s2048.json | 97 ++++++++++++++++----- 1 file changed, 76 insertions(+), 21 deletions(-) diff --git a/examples/llama/plan_llama2_7b_b1_s2048.json b/examples/llama/plan_llama2_7b_b1_s2048.json index 15b0de2d0..d5c9fe552 100644 --- a/examples/llama/plan_llama2_7b_b1_s2048.json +++ b/examples/llama/plan_llama2_7b_b1_s2048.json @@ -3,7 +3,7 @@ "WorldSize": 1, "Architecture": "ROCM_942", "NumProcessors": 304, - "NumWarpsPerProcessor": 4, + "NumWarpsPerProcessor": 8, "TaskInfos": [ { "Id": 0, @@ -948,7 +948,7 @@ "NumWarps": 4, "SramBytes": 24672, "TileShapeMNK": [256,128,32], - "NumTasks": 688 + "NumTasks": 602 } }, { @@ -970,7 +970,61 @@ "NumWarps": 4, "SramBytes": 0, "Tile": [256,128], - "NumTasks": 688 + "NumTasks": 602 + } + } + ] + }, + { + "Id": 37, + "NumWarps": 4, + "SramBytes": 16480, + "Ops": [ + { + "Type": "Matmul", + "Name": "matmul_7", + "IsVirtual": false, + "ReadTensors": [ + {"Id":102,"DataType":"FP16","Shape":[1,1792,4096],"Strides":[1,2048,4096],"Offsets":[0,256,0],"PaddedShape":[1,1792,4096],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":6,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":101,"DataType":"FP16","Shape":[1,1792,11008],"Strides":[1,2048,11008],"Offsets":[0,256,0],"PaddedShape":[1,1792,11008],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":100,"DataType":"FP16","Shape":[1,1792,11008],"Strides":[1,2048,11008],"Offsets":[0,256,0],"PaddedShape":[1,1792,11008],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": { + "TransposeInput": {"BOOL":false}, + "TransposeOther": {"BOOL":true} + }, + "Config": { + "NumWarps": 4, + "SramBytes": 16480, + "TileShapeMNK": [128,128,32], + "NumTasks": 172 + } + }, + { + "Type": "Mul", + "Name": "mul_8", + "IsVirtual": false, + "ReadTensors": [ + {"Id":81,"DataType":"FP16","Shape":[1,1792,11008],"Strides":[1,2048,11008],"Offsets":[0,256,0],"PaddedShape":[1,1792,11008],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, + {"Id":83,"DataType":"FP16","Shape":[1,1792,11008],"Strides":[1,2048,11008],"Offsets":[0,256,0],"PaddedShape":[1,1792,11008],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "WriteTensors": [ + {"Id":84,"DataType":"FP16","Shape":[1,1792,11008],"Strides":[1,2048,11008],"Offsets":[0,256,0],"PaddedShape":[1,1792,11008],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "ResultTensors": [ + {"Id":85,"DataType":"FP16","Shape":[1,1792,11008],"Strides":[1,2048,11008],"Offsets":[0,256,0],"PaddedShape":[1,1792,11008],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} + ], + "Args": {}, + "Config": { + "NumWarps": 4, + "SramBytes": 0, + "Tile": [128,128], + "NumTasks": 172 } } ] @@ -1036,10 +1090,10 @@ "ResourceGroups": [ { "ProcessorRange": [0,304], - "WarpRange": [0,4], + "WarpRange": [0,8], "SramRange": [0,0], "TaskGroups": [ - {"TaskId":0,"TaskRange":[0,2048],"Granularity":4} + {"TaskId":0,"TaskRange":[0,2048],"Granularity":7} ] } ] @@ -1062,10 +1116,10 @@ "ResourceGroups": [ { "ProcessorRange": [0,304], - "WarpRange": [0,4], + "WarpRange": [0,8], "SramRange": [0,0], "TaskGroups": [ - {"TaskId":4,"TaskRange":[0,2048],"Granularity":4} + {"TaskId":4,"TaskRange":[0,2048],"Granularity":7} ] } ] @@ -1114,10 +1168,10 @@ "ResourceGroups": [ { "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,24672], + "WarpRange": [0,8], + "SramRange": [0,49344], "TaskGroups": [ - {"TaskId":15,"TaskRange":[0,4096],"Granularity":1} + {"TaskId":15,"TaskRange":[0,4096],"Granularity":2} ] } ] @@ -1127,10 +1181,10 @@ "ResourceGroups": [ { "ProcessorRange": [0,304], - "WarpRange": [0,4], + "WarpRange": [0,8], "SramRange": [0,0], "TaskGroups": [ - {"TaskId":17,"TaskRange":[0,65536],"Granularity":4} + {"TaskId":17,"TaskRange":[0,65536],"Granularity":8} ] } ] @@ -1166,10 +1220,10 @@ "ResourceGroups": [ { "ProcessorRange": [0,304], - "WarpRange": [0,4], + "WarpRange": [0,8], "SramRange": [0,0], "TaskGroups": [ - {"TaskId":28,"TaskRange":[0,2048],"Granularity":4} + {"TaskId":28,"TaskRange":[0,2048],"Granularity":7} ] } ] @@ -1192,10 +1246,10 @@ "ResourceGroups": [ { "ProcessorRange": [0,304], - "WarpRange": [0,4], + "WarpRange": [0,8], "SramRange": [0,0], "TaskGroups": [ - {"TaskId":30,"TaskRange":[0,2048],"Granularity":4} + {"TaskId":30,"TaskRange":[0,2048],"Granularity":7} ] } ] @@ -1205,8 +1259,8 @@ "ResourceGroups": [ { "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,24672], + "WarpRange": [0,8], + "SramRange": [0,49344], "TaskGroups": [ {"TaskId":33,"TaskRange":[0,688],"Granularity":1} ] @@ -1218,10 +1272,11 @@ "ResourceGroups": [ { "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,24672], + "WarpRange": [0,8], + "SramRange": [0,49344], "TaskGroups": [ - {"TaskId":36,"TaskRange":[0,688],"Granularity":1} + {"TaskId":36,"TaskRange":[0,602],"Granularity":2}, + {"TaskId":37,"TaskRange":[0,172],"Granularity":1} ] } ] From 866112de65a6fd5d3c3d89d80cdc53ff27c8c36a Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 9 Jul 2024 01:07:21 +0000 Subject: [PATCH 27/54] optimize --- ark/include/kernels/common/sync.h | 3 + ark/include/kernels/reduce.h | 41 +++++++-- examples/llama/plan_llama2_7b_b1_s2048.json | 94 +-------------------- 3 files changed, 36 insertions(+), 102 deletions(-) diff --git a/ark/include/kernels/common/sync.h b/ark/include/kernels/common/sync.h index f47625600..456a32eb7 100644 --- a/ark/include/kernels/common/sync.h +++ b/ark/include/kernels/common/sync.h @@ -106,6 +106,9 @@ DEVICE void sync_warps() { static_assert(Arch::ThreadsPerWarp == 64, ""); if constexpr (NumWarps == 1) { __builtin_amdgcn_wave_barrier(); + } else if constexpr (NumWarps == ARK_WARPS_PER_BLOCK) { + // asm volatile("s_waitcnt lgkmcnt(0) \n s_barrier " ::); + __syncthreads(); } else { static_assert(ARK_SMEM_RESERVED_BYTES >= sizeof(sync::WarpGroupState), ""); diff --git a/ark/include/kernels/reduce.h b/ark/include/kernels/reduce.h index 3d0b4e008..2dd79d2c3 100644 --- a/ark/include/kernels/reduce.h +++ b/ark/include/kernels/reduce.h @@ -355,8 +355,15 @@ struct WwiseReduce { int smem_per_warp) { using ShapeChecker = ReduceShapeChecker; + constexpr int InConsecBytes = sizeof(DataType) * InShape::W; constexpr int NelemPerThread = - DefaultNelemPerThread::value; + (InConsecBytes % 16 == 0) + ? 16 / sizeof(DataType) + : (InConsecBytes % 8 == 0) + ? 8 / sizeof(DataType) + : (InConsecBytes % 4 == 0) + ? 4 / sizeof(DataType) + : (InConsecBytes % 2 == 0) ? 2 / sizeof(DataType) : 1; constexpr int NonReduceDimLength = UnitOutDims::N * UnitOutDims::C * UnitOutDims::H; @@ -397,22 +404,38 @@ struct WwiseReduce { &in[idx_in]); } - DataType finalSum; - ReduceType::template identity<1>(&finalSum); + static_assert(math::is_pow2::value, + "NelemPerThread must be power of 2"); + if constexpr (NelemPerThread > 8) { #pragma unroll - for (int i = 0; i < NelemPerThread; ++i) { - ReduceType::template reduce<1>(&finalSum, &finalSum, &reduced[i]); + for (int i = 8; i < NelemPerThread; i += 8) { + ReduceType::template reduce<8>(&reduced[0], &reduced[0], &reduced[i]); + } + ReduceType::template reduce<4>(&reduced[0], &reduced[0], &reduced[4]); + ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + } else if constexpr (NelemPerThread == 8) { + ReduceType::template reduce<4>(&reduced[0], &reduced[0], &reduced[4]); + ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + } else if constexpr (NelemPerThread == 4) { + ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + } else if constexpr (NelemPerThread == 2) { + ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); } - UnitOp::sync_threads(); + if constexpr (InShape::W % ThreadsPerRow != 0) { + UnitOp::sync_threads(); + } // final reduction on shared memory using warp shuffle. - finalSum = warpsReduce( - finalSum, tid, smem_per_warp); + reduced[0] = warpsReduce( + reduced[0], tid, smem_per_warp); // write the result to output. if (tid % ThreadsPerRow == 0) { - ReduceType::template postReduce<1>(&out[idx_out], &finalSum, + ReduceType::template postReduce<1>(&out[idx_out], &reduced[0], InShape::W); } diff --git a/examples/llama/plan_llama2_7b_b1_s2048.json b/examples/llama/plan_llama2_7b_b1_s2048.json index d5c9fe552..b0bc757dc 100644 --- a/examples/llama/plan_llama2_7b_b1_s2048.json +++ b/examples/llama/plan_llama2_7b_b1_s2048.json @@ -230,29 +230,6 @@ "Tile": [256,1,128], "NumTasks": 256 } - }, - { - "Type": "Transpose", - "Name": "transpose", - "IsVirtual": false, - "ReadTensors": [ - {"Id":34,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":20,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":37,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":22,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":38,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":22,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Permutation": {"DIMS":[0,2,1,3]} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [256,128], - "NumTasks": 256 - } } ] }, @@ -307,29 +284,6 @@ "Tile": [256,128], "NumTasks": 256 } - }, - { - "Type": "Transpose", - "Name": "transpose_2", - "IsVirtual": false, - "ReadTensors": [ - {"Id":36,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":21,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":41,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":42,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Permutation": {"DIMS":[0,2,1,3]} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [256,128], - "NumTasks": 256 - } } ] }, @@ -362,29 +316,6 @@ "TileShapeMNK": [256,128,32], "NumTasks": 256 } - }, - { - "Type": "Transpose", - "Name": "transpose_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":32,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":39,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":23,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":40,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":23,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Permutation": {"DIMS":[0,2,1,3]} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [256,128], - "NumTasks": 256 - } } ] }, @@ -592,29 +523,6 @@ "TileShapeMNK": [256,128,32], "NumTasks": 256 } - }, - { - "Type": "Transpose", - "Name": "transpose_3", - "IsVirtual": false, - "ReadTensors": [ - {"Id":55,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":29,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":56,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":30,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":57,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":30,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Permutation": {"DIMS":[0,2,1,3]} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [256,1,128], - "NumTasks": 256 - } } ] }, @@ -1184,7 +1092,7 @@ "WarpRange": [0,8], "SramRange": [0,0], "TaskGroups": [ - {"TaskId":17,"TaskRange":[0,65536],"Granularity":8} + {"TaskId":17,"TaskRange":[0,65536],"Granularity":1} ] } ] From 68e787ae377c282c9d117e6650eb112a34c54a9c Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 9 Jul 2024 20:51:44 +0000 Subject: [PATCH 28/54] fix bf16 matmul --- ark/ops/ops_matmul.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ark/ops/ops_matmul.cpp b/ark/ops/ops_matmul.cpp index b4553a4ed..a24b95d72 100644 --- a/ark/ops/ops_matmul.cpp +++ b/ark/ops/ops_matmul.cpp @@ -223,7 +223,7 @@ static const Json get_default_config(const ArchRef arch, {"TileShapeMNK", {tm, tn, 32}}}; } else if (arch->belongs_to(ARCH_ROCM_942) && data_type == BF16.ref()) { return {{"NumWarps", 4}, - {"SramBytes", 24672}, + {"SramBytes", 24624}, {"TileShapeMNK", {tm, tn, 32}}}; } ERR(InternalError, "Unexpected error"); From b18bdb2e66d30c34b21657e15bb6cf491f108544 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 10 Jul 2024 23:44:07 +0000 Subject: [PATCH 29/54] Enhance executor interfaces --- ark/api/executor.cpp | 295 +++++++++++++++++++---------- ark/gpu/gpu_event.cpp | 11 +- ark/gpu/gpu_event.h | 4 +- ark/gpu/gpu_kernel.cpp | 2 +- ark/gpu/gpu_kernel.h | 2 +- ark/gpu/gpu_manager.cpp | 18 +- ark/gpu/gpu_manager.h | 4 +- ark/include/ark/executor.hpp | 46 +++-- ark/model/model_json.cpp | 11 +- ark/model/model_json.hpp | 2 +- ark/model/model_op.cpp | 5 +- ark/ops/ops_all_reduce_test.cpp | 7 +- ark/ops/ops_communication_test.cpp | 8 +- ark/ops/ops_embedding_test.cpp | 6 +- ark/ops/ops_test_common.cpp | 20 +- ark/ops/ops_test_common.hpp | 15 +- cmake/Utils.cmake | 2 +- python/ark/runtime.py | 4 +- python/ark/tensor.py | 10 +- python/executor_py.cpp | 59 +++++- 20 files changed, 344 insertions(+), 187 deletions(-) diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 14625161f..2f50a4280 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -140,10 +140,17 @@ static size_t tensor_stride_bytes(const Json &tensor) { class Executor::Impl { public: - Impl(int rank, int world_size, int gpu_id, const std::string &name, - const std::string &plan); + Impl(int device_id, Stream stream, const std::string &name); ~Impl() = default; + void init(const PlanJson& plan); + + int device_id() const { return device_id_; } + + Stream stream() const { return reinterpret_cast(stream_raw_); } + + std::string plan() const { return plan_json_.dump_pretty(); } + void compile(); void launch(int64_t max_spin_count); void run(int iter); @@ -151,9 +158,12 @@ class Executor::Impl { float stop(int64_t max_spin_count); void barrier(); - void tensor_read(const Tensor tensor, void *data, size_t bytes) const; - void tensor_write(const Tensor tensor, const void *data, - size_t bytes) const; + uintptr_t tensor_address(const Tensor tensor) const; + + void tensor_read(const Tensor tensor, void *data, size_t bytes, + Stream stream, bool is_d2d) const; + void tensor_write(const Tensor tensor, const void *data, size_t bytes, + Stream stream, bool is_d2d) const; private: void init_communicator(); @@ -162,14 +172,18 @@ class Executor::Impl { void init_channels(const std::set &remote_ranks); protected: - const int rank_; - const int world_size_; - int gpu_id_; + int device_id_; + std::string name_; + gpuStream stream_raw_; + + int rank_; + int world_size_; bool is_launched_ = false; bool is_recording_ = false; float elapsed_msec_ = -1; + PlanJson plan_json_; std::map buffer_id_to_offset_; size_t total_bytes_; std::shared_ptr codegen_; @@ -177,8 +191,7 @@ class Executor::Impl { std::shared_ptr timer_end_; std::shared_ptr buffer_; std::shared_ptr flag_; - std::shared_ptr main_stream_; - std::shared_ptr copy_stream_; + std::shared_ptr stream_; std::shared_ptr kernel_; // For communication @@ -190,30 +203,35 @@ class Executor::Impl { rank_to_sm_channels_; }; -Executor::Impl::Impl(int rank, int world_size, int gpu_id, - const std::string &name, const std::string &plan) - : rank_(rank), world_size_(world_size), gpu_id_(gpu_id) { - if (rank < 0 || rank >= world_size) { - ERR(InvalidUsageError, "Invalid rank ", rank, " with world size ", - world_size); +Executor::Impl::Impl(int device_id, Stream stream, const std::string &name) + : device_id_(device_id), name_(name) { + if (device_id < 0) { + ERR(InvalidUsageError, "Invalid device ID ", device_id); } - if (gpu_id < 0) { - ERR(InvalidUsageError, "Invalid GPU ID ", gpu_id); + if (stream) { + stream_raw_ = reinterpret_cast(stream); + } else { + stream_ = GpuManager::get_instance(device_id_)->create_stream(); + stream_raw_ = stream_->get(); + } +} + +void Executor::Impl::init(const PlanJson &plan_json) { + plan_json_ = plan_json; + rank_ = plan_json_["Rank"].get(); + world_size_ = plan_json_["WorldSize"].get(); + + if (rank_ < 0 || rank_ >= world_size_) { + ERR(InvalidUsageError, "Invalid rank ", rank_, " with world size ", + world_size_); } if (world_size_ > 1) { init_communicator(); } - Json plan_json; - auto &plan_path = get_env().enforce_plan_path; - if (!plan_path.empty()) { - LOG(INFO, "Enforce executor plan path: ", plan_path); - plan_json = Json::parse(read_file(plan_path)); - } else { - plan_json = Json::parse(plan); - } + auto gpu_manager = GpuManager::get_instance(device_id_); - buffer_id_to_offset_ = init_buffers(plan_json); + buffer_id_to_offset_ = init_buffers(plan_json_); std::string buffer_id_to_offset_str; for (const auto &kv : buffer_id_to_offset_) { @@ -221,17 +239,14 @@ Executor::Impl::Impl(int rank, int world_size, int gpu_id, std::to_string(kv.first) + ": " + std::to_string(kv.second) + ", "; } - codegen_ = - std::make_shared(plan_json, buffer_id_to_offset_, name); + codegen_ = std::make_shared(plan_json_, buffer_id_to_offset_, + name_); - auto gpu_manager = GpuManager::get_instance(gpu_id_); timer_begin_ = gpu_manager->create_event(); timer_end_ = gpu_manager->create_event(); buffer_ = gpu_manager->malloc(total_bytes_, 65536); flag_ = gpu_manager->malloc_host( sizeof(int), gpuHostAllocMapped | gpuHostAllocWriteCombined); - main_stream_ = gpu_manager->create_stream(); - copy_stream_ = gpu_manager->create_stream(); int threads_per_block = static_cast( codegen_->num_warps_per_proc() * gpu_manager->info().threads_per_warp); @@ -241,13 +256,13 @@ Executor::Impl::Impl(int rank, int world_size, int gpu_id, static_cast(gpu_manager->info().smem_block_total); if (world_size_ > 1) { - auto remote_ranks = init_remote_ranks(plan_json); + auto remote_ranks = init_remote_ranks(plan_json_); init_channels(remote_ranks); } kernel_ = std::shared_ptr(new GpuKernel( - gpu_id_, codegen_->code(), {threads_per_block, 1, 1}, {num_sm, 1, 1}, - std::max(smem_block_total, size_t(4)), name, + device_id_, codegen_->code(), {threads_per_block, 1, 1}, {num_sm, 1, 1}, + std::max(smem_block_total, size_t(4)), name_, {std::pair{buffer_->ref(), sizeof(buffer_->ref())}, std::pair{flag, sizeof(flag)}})); } @@ -509,7 +524,7 @@ void Executor::Impl::init_channels(const std::set &remote_ranks) { mscclpp::TransportFlags all_transports = mscclpp::Transport::CudaIpc | mscclpp::Transport::Ethernet; if (!get_env().disable_ib) { - all_transports |= IBs[gpu_id_]; + all_transports |= IBs[device_id_]; } mscclpp::RegisteredMemory regmem = comm_->registerMemory(buffer_->ref(), buffer_->bytes(), all_transports); @@ -530,12 +545,12 @@ void Executor::Impl::init_channels(const std::set &remote_ranks) { if (remote_node == this_node) { add_connection(remote_rank, mscclpp::Transport::CudaIpc); if (!get_env().disable_ib) { - add_connection(remote_rank, IBs[gpu_id_]); + add_connection(remote_rank, IBs[device_id_]); } } else { add_connection(remote_rank, get_env().disable_ib ? mscclpp::Transport::Ethernet - : IBs[gpu_id_]); + : IBs[device_id_]); } comm_->sendMemoryOnSetup(regmem, remote_rank, 0); rank_to_remote_regmem_future[remote_rank] = @@ -623,22 +638,22 @@ void Executor::Impl::launch(int64_t max_spin_count) { sm_handles[i] = it2->second[0]->deviceHandle(); } } - GLOG(gpuSetDevice(gpu_id_)); + GLOG(gpuSetDevice(device_id_)); GLOG(gpuMemcpyAsync( proxy_chan_addr, proxy_handles.data(), proxy_handles.size() * sizeof(mscclpp::SimpleProxyChannel::DeviceHandle), - gpuMemcpyHostToDevice, copy_stream_->get())); + gpuMemcpyHostToDevice, stream_raw_)); GLOG(gpuMemcpyAsync( proxy_secondary_chan_addr, proxy_secondary_handles.data(), proxy_secondary_handles.size() * sizeof(mscclpp::SimpleProxyChannel::DeviceHandle), - gpuMemcpyHostToDevice, copy_stream_->get())); + gpuMemcpyHostToDevice, stream_raw_)); GLOG(gpuMemcpyAsync( sm_chan_addr, sm_handles.data(), sm_handles.size() * sizeof(mscclpp::SmChannel::DeviceHandle), - gpuMemcpyHostToDevice, copy_stream_->get())); - copy_stream_->sync(); + gpuMemcpyHostToDevice, stream_raw_)); + GLOG(gpuStreamSynchronize(stream_raw_)); } elapsed_msec_ = -1; @@ -648,7 +663,7 @@ void Executor::Impl::launch(int64_t max_spin_count) { LOG(WARN, "Ignore launching twice."); return; } - timer_begin_->record(main_stream_); + timer_begin_->record(stream_raw_); if (world_size_ > 1) { proxy_service_->startProxy(); @@ -656,8 +671,8 @@ void Executor::Impl::launch(int64_t max_spin_count) { // Initialize loop flags. atomicStoreRelaxed(flag_->ref(), 0); - kernel_->launch(main_stream_); - timer_end_->record(main_stream_); + kernel_->launch(stream_raw_); + timer_end_->record(stream_raw_); is_recording_ = true; is_launched_ = true; } @@ -677,7 +692,7 @@ void Executor::Impl::wait(int64_t max_spin_count) { continue; } // Check if the kernel encountered an error. - gpuError res = main_stream_->query(); + gpuError res = gpuStreamQuery(stream_raw_); if (res == gpuSuccess) { if (atomicLoadRelaxed(flag_->ref()) > 0) { LOG(WARN, "Stream is finished but the loop flag is still set."); @@ -699,7 +714,7 @@ void Executor::Impl::wait(int64_t max_spin_count) { float Executor::Impl::stop(int64_t max_spin_count) { this->wait(max_spin_count); atomicStoreRelaxed(flag_->ref(), -1); - main_stream_->sync(); + GLOG(gpuStreamSynchronize(stream_raw_)); if (is_recording_) { elapsed_msec_ = timer_end_->elapsed_msec(*timer_begin_); is_recording_ = false; @@ -717,71 +732,140 @@ void Executor::Impl::barrier() { } } -void Executor::Impl::tensor_read(const Tensor tensor, void *data, - size_t bytes) const { - GLOG(gpuSetDevice(gpu_id_)); +uintptr_t Executor::Impl::tensor_address(const Tensor tensor) const { + size_t buffer_id = tensor.ref()->buffer()->id(); + if (buffer_id_to_offset_.find(buffer_id) == buffer_id_to_offset_.end()) { + ERR(NotFoundError, "Invalid buffer ID: ", buffer_id); + } + size_t offset = buffer_id_to_offset_.at(buffer_id); + return reinterpret_cast(buffer_->ref(offset)); +} + +void Executor::Impl::tensor_read(const Tensor tensor, void *data, size_t bytes, + Stream stream, bool is_d2d) const { + GLOG(gpuSetDevice(device_id_)); + std::shared_ptr copy_stream; + gpuStream copy_stream_raw; + if (stream) { + copy_stream_raw = reinterpret_cast(stream); + if ((stream == stream_raw_) && is_launched_) { + LOG(WARN, + "Reading from a tensor in the same stream of the kernel " + "may cause a deadlock."); + } + } else { + copy_stream = GpuManager::get_instance(device_id_)->create_stream(); + copy_stream_raw = copy_stream->get(); + } size_t tensor_data_bytes = tensor.shape().nelems() * tensor.data_type().bytes(); - if (bytes < tensor_data_bytes) { - ERR(InvalidUsageError, "Data buffer (", bytes, - ") is smaller than the tensor data (", tensor_data_bytes, ")."); + if (bytes != tensor_data_bytes) { + ERR(InvalidUsageError, "Destination bytes (", bytes, + ") mismatches the tensor data bytes (", tensor_data_bytes, ")."); } - size_t tensor_bytes = - tensor.strides().nelems() * tensor.data_type().bytes(); - void *src = - buffer_->ref(buffer_id_to_offset_.at(tensor.ref()->buffer()->id())); + auto kind = (is_d2d) ? gpuMemcpyDeviceToDevice : gpuMemcpyDeviceToHost; + void *src = reinterpret_cast(tensor_address(tensor)); if (tensor.strides() == tensor.shape()) { - GLOG(gpuMemcpyAsync(data, src, bytes, gpuMemcpyDeviceToHost, - copy_stream_->get())); - copy_stream_->sync(); + GLOG(gpuMemcpyAsync(data, src, bytes, kind, copy_stream_raw)); } else { + size_t tensor_bytes = + tensor.strides().nelems() * tensor.data_type().bytes(); std::vector tensor_host(tensor_bytes); GLOG(gpuMemcpyAsync(tensor_host.data(), src, tensor_bytes, - gpuMemcpyDeviceToHost, copy_stream_->get())); - copy_stream_->sync(); - tensor_to_data(tensor_host.data(), static_cast(data), - tensor.shape(), tensor.strides(), tensor.offsets(), + gpuMemcpyDeviceToHost, copy_stream_raw)); + GLOG(gpuStreamSynchronize(copy_stream_raw)); + if (!is_d2d) { + tensor_to_data(tensor_host.data(), static_cast(data), + tensor.shape(), tensor.strides(), tensor.offsets(), + tensor.data_type().bytes()); + return; + } + // TODO: convert data layout on the device directly + std::vector data_host(bytes); + tensor_to_data(tensor_host.data(), data_host.data(), tensor.shape(), + tensor.strides(), tensor.offsets(), tensor.data_type().bytes()); + GLOG(gpuMemcpyAsync(data, data_host.data(), bytes, + gpuMemcpyHostToDevice, copy_stream_raw)); } + GLOG(gpuStreamSynchronize(copy_stream_raw)); } void Executor::Impl::tensor_write(const Tensor tensor, const void *data, - size_t bytes) const { - GLOG(gpuSetDevice(gpu_id_)); + size_t bytes, Stream stream, + bool is_d2d) const { + GLOG(gpuSetDevice(device_id_)); + std::shared_ptr copy_stream; + gpuStream copy_stream_raw; + if (stream) { + copy_stream_raw = reinterpret_cast(stream); + if ((stream == stream_raw_) && is_launched_) { + LOG(WARN, + "Writing to a tensor in the same stream of the kernel " + "may cause a deadlock."); + } + } else { + copy_stream = GpuManager::get_instance(device_id_)->create_stream(); + copy_stream_raw = copy_stream->get(); + } size_t tensor_data_bytes = tensor.shape().nelems() * tensor.data_type().bytes(); - if (bytes < tensor_data_bytes) { - ERR(InvalidUsageError, "Data buffer (", bytes, - ") is smaller than the tensor data (", tensor_data_bytes, ")."); + if (bytes != tensor_data_bytes) { + ERR(InvalidUsageError, "Source bytes (", bytes, + ") mismatches the tensor data bytes (", tensor_data_bytes, ")."); } size_t tensor_bytes = tensor.strides().nelems() * tensor.data_type().bytes(); - void *dst = - buffer_->ref(buffer_id_to_offset_.at(tensor.ref()->buffer()->id())); + auto kind = (is_d2d) ? gpuMemcpyDeviceToDevice : gpuMemcpyHostToDevice; + void *dst = reinterpret_cast(tensor_address(tensor)); if (tensor.strides() == tensor.shape()) { - GLOG(gpuMemcpyAsync(dst, data, tensor_bytes, gpuMemcpyHostToDevice, - copy_stream_->get())); + GLOG(gpuMemcpyAsync(dst, data, tensor_bytes, kind, copy_stream_raw)); } else { std::vector tensor_host(tensor_bytes); - GLOG(gpuMemcpyAsync(tensor_host.data(), dst, tensor_bytes, - gpuMemcpyDeviceToHost, copy_stream_->get())); - copy_stream_->sync(); - data_to_tensor(tensor_host.data(), static_cast(data), - tensor.shape(), tensor.strides(), tensor.offsets(), - tensor.data_type().bytes()); + if (!is_d2d) { + GLOG(gpuMemcpyAsync(tensor_host.data(), dst, tensor_bytes, + gpuMemcpyDeviceToHost, copy_stream_raw)); + GLOG(gpuStreamSynchronize(copy_stream_raw)); + data_to_tensor(tensor_host.data(), + static_cast(data), tensor.shape(), + tensor.strides(), tensor.offsets(), + tensor.data_type().bytes()); + } else { + // TODO: convert data layout on the device directly + std::vector tmp(bytes); + GLOG(gpuMemcpyAsync(tmp.data(), data, bytes, gpuMemcpyDeviceToHost, + copy_stream_raw)); + GLOG(gpuStreamSynchronize(copy_stream_raw)); + data_to_tensor(tensor_host.data(), tmp.data(), tensor.shape(), + tensor.strides(), tensor.offsets(), + tensor.data_type().bytes()); + } GLOG(gpuMemcpyAsync(dst, tensor_host.data(), tensor_bytes, - gpuMemcpyHostToDevice, copy_stream_->get())); + gpuMemcpyHostToDevice, copy_stream_raw)); } - copy_stream_->sync(); + GLOG(gpuStreamSynchronize(copy_stream_raw)); } -Executor::Executor(int rank, int world_size, int gpu_id, - const std::string &name, const std::string &plan) - : impl_(std::make_unique(rank, world_size, gpu_id, name, - plan)) {} +Executor::Executor(int device_id, Stream stream, const std::string &name, + const std::string &plan) + : impl_(std::make_unique(device_id, stream, name)) { + auto &plan_path = get_env().enforce_plan_path; + if (!plan_path.empty()) { + LOG(INFO, "Enforce executor plan path: ", plan_path); + impl_->init(Json::parse(read_file(plan_path))); + } else if (!plan.empty()) { + impl_->init(Json::parse(plan)); + } +} Executor::~Executor() = default; +int Executor::device_id() const { return impl_->device_id(); } + +Stream Executor::stream() const { return impl_->stream(); } + +std::string Executor::plan() const { return impl_->plan(); } + void Executor::compile() { impl_->compile(); } void Executor::launch(int64_t max_spin_count) { impl_->launch(max_spin_count); } @@ -800,25 +884,32 @@ void Executor::destroy() { impl_.reset(nullptr); } bool Executor::destroyed() const { return impl_.get() == nullptr; } -void Executor::tensor_read(const Tensor tensor, void *data, - size_t bytes) const { - impl_->tensor_read(tensor, data, bytes); +uintptr_t Executor::tensor_address(const Tensor tensor) const { + return impl_->tensor_address(tensor); } -void Executor::tensor_write(const Tensor tensor, const void *data, - size_t bytes) const { - impl_->tensor_write(tensor, data, bytes); +void Executor::tensor_read(const Tensor tensor, void *data, size_t bytes, + Stream stream, bool is_d2d) const { + impl_->tensor_read(tensor, data, bytes, stream, is_d2d); } -DefaultExecutor::DefaultExecutor(const Model &model, int gpu_id, - const std::string &name) - : Executor( - model.rank(), model.world_size(), - (gpu_id < 0) ? (model.rank() % get_env().num_ranks_per_host) : gpu_id, - name, - DefaultPlanner(model, (gpu_id < 0) ? (model.rank() % - get_env().num_ranks_per_host) - : gpu_id) - .plan()) {} +void Executor::tensor_write(const Tensor tensor, const void *data, size_t bytes, + Stream stream, bool is_d2d) const { + impl_->tensor_write(tensor, data, bytes, stream, is_d2d); +} + +DefaultExecutor::DefaultExecutor( + const Model &model, int device_id, Stream stream, + const std::vector &config_rules, + const std::string &name) + : Executor((device_id < 0) ? (model.rank() % get_env().num_ranks_per_host) + : device_id, + stream, name, "") { + DefaultPlanner planner(model, impl_->device_id()); + for (const auto &rule : config_rules) { + planner.install_config_rule(rule); + } + impl_->init(Json::parse(planner.plan())); +} } // namespace ark diff --git a/ark/gpu/gpu_event.cpp b/ark/gpu/gpu_event.cpp index 93ec3fd52..cbc45d9a6 100644 --- a/ark/gpu/gpu_event.cpp +++ b/ark/gpu/gpu_event.cpp @@ -3,7 +3,6 @@ #include "gpu/gpu_event.h" -#include "gpu/gpu.h" #include "gpu/gpu_logging.h" #include "gpu/gpu_manager.h" @@ -15,7 +14,7 @@ class GpuEvent::Impl { Impl(const Impl&) = delete; Impl& operator=(const Impl&) = delete; - void record(std::shared_ptr stream); + void record(gpuStream stream); float elapsed_msec(const GpuEvent& other) const; private: @@ -32,8 +31,8 @@ GpuEvent::Impl::Impl(bool disable_timing) { GpuEvent::Impl::~Impl() { GLOG(gpuEventDestroy(event_)); } -void GpuEvent::Impl::record(std::shared_ptr stream) { - GLOG(gpuEventRecord(event_, stream->get())); +void GpuEvent::Impl::record(gpuStream stream) { + GLOG(gpuEventRecord(event_, stream)); } float GpuEvent::Impl::elapsed_msec(const GpuEvent& other) const { @@ -45,9 +44,7 @@ float GpuEvent::Impl::elapsed_msec(const GpuEvent& other) const { GpuEvent::GpuEvent(bool disable_timing) : pimpl_(std::make_shared(disable_timing)) {} -void GpuEvent::record(std::shared_ptr stream) { - pimpl_->record(stream); -} +void GpuEvent::record(gpuStream stream) { pimpl_->record(stream); } float GpuEvent::elapsed_msec(const GpuEvent& other) const { return pimpl_->elapsed_msec(other); diff --git a/ark/gpu/gpu_event.h b/ark/gpu/gpu_event.h index 4599ecaa4..081f0203b 100644 --- a/ark/gpu/gpu_event.h +++ b/ark/gpu/gpu_event.h @@ -6,6 +6,8 @@ #include +#include "gpu/gpu.h" + namespace ark { class GpuStream; @@ -17,7 +19,7 @@ class GpuEvent { GpuEvent(const GpuEvent &) = delete; GpuEvent &operator=(const GpuEvent &) = delete; - void record(std::shared_ptr stream); + void record(gpuStream stream); float elapsed_msec(const GpuEvent &other) const; protected: diff --git a/ark/gpu/gpu_kernel.cpp b/ark/gpu/gpu_kernel.cpp index 44ff43a1d..46f467f51 100644 --- a/ark/gpu/gpu_kernel.cpp +++ b/ark/gpu/gpu_kernel.cpp @@ -68,7 +68,7 @@ void GpuKernel::compile() { dynamic_smem_size_bytes)); } -void GpuKernel::launch(std::shared_ptr stream) { +void GpuKernel::launch(gpuStream stream) { if (!this->is_compiled()) { ERR(InvalidUsageError, "Kernel is not compiled yet."); } diff --git a/ark/gpu/gpu_kernel.h b/ark/gpu/gpu_kernel.h index c3b60aec4..b3be79071 100644 --- a/ark/gpu/gpu_kernel.h +++ b/ark/gpu/gpu_kernel.h @@ -27,7 +27,7 @@ class GpuKernel { const std::string& kernel_name, std::initializer_list> args = {}); void compile(); - void launch(std::shared_ptr stream); + void launch(gpuStream stream); gpuDeviceptr get_global(const std::string& name, bool ignore_not_found = false) const; diff --git a/ark/gpu/gpu_manager.cpp b/ark/gpu/gpu_manager.cpp index 3a6d0a066..fc841fa32 100644 --- a/ark/gpu/gpu_manager.cpp +++ b/ark/gpu/gpu_manager.cpp @@ -20,11 +20,10 @@ class GpuManager::Impl { int gpu_id_; GpuManager::Info info_; - std::shared_ptr main_stream_; void launch(gpuFunction kernel, const std::array &grid_dim, const std::array &block_dim, int smem_bytes, - std::shared_ptr stream, void **params, void **extra); + gpuStream stream, void **params, void **extra); }; GpuManager::Impl::Impl(int gpu_id) : gpu_id_(gpu_id) { @@ -76,11 +75,11 @@ GpuManager::Impl::Impl(int gpu_id) : gpu_id_(gpu_id) { void GpuManager::Impl::launch(gpuFunction kernel, const std::array &grid_dim, const std::array &block_dim, - int smem_bytes, std::shared_ptr stream, - void **params, void **extra) { + int smem_bytes, gpuStream stream, void **params, + void **extra) { GLOG_DRV(gpuModuleLaunchKernel( kernel, grid_dim[0], grid_dim[1], grid_dim[2], block_dim[0], - block_dim[1], block_dim[2], smem_bytes, stream->get(), params, extra)); + block_dim[1], block_dim[2], smem_bytes, stream, params, extra)); } std::shared_ptr GpuManager::get_instance(int gpu_id) { @@ -102,9 +101,7 @@ std::shared_ptr GpuManager::get_instance(int gpu_id) { } } -GpuManager::GpuManager(int gpu_id) : pimpl_(std::make_shared(gpu_id)) { - this->pimpl_->main_stream_ = std::shared_ptr(new GpuStream()); -} +GpuManager::GpuManager(int gpu_id) : pimpl_(std::make_shared(gpu_id)) {} std::shared_ptr GpuManager::malloc(size_t bytes, size_t align, bool expose) { @@ -126,8 +123,6 @@ std::shared_ptr GpuManager::create_stream() const { return std::shared_ptr(new GpuStream()); } -int GpuManager::get_gpu_id() const { return pimpl_->gpu_id_; } - const GpuManager::Info &GpuManager::info() const { return pimpl_->info_; } void GpuManager::set_current() const { GLOG(gpuSetDevice(pimpl_->gpu_id_)); } @@ -135,8 +130,7 @@ void GpuManager::set_current() const { GLOG(gpuSetDevice(pimpl_->gpu_id_)); } void GpuManager::launch(gpuFunction function, const std::array &grid_dim, const std::array &block_dim, int smem_bytes, - std::shared_ptr stream, void **params, - void **extra) const { + gpuStream stream, void **params, void **extra) const { this->set_current(); pimpl_->launch(function, grid_dim, block_dim, smem_bytes, stream, params, extra); diff --git a/ark/gpu/gpu_manager.h b/ark/gpu/gpu_manager.h index 05014ac47..93a48cf7b 100644 --- a/ark/gpu/gpu_manager.h +++ b/ark/gpu/gpu_manager.h @@ -30,11 +30,9 @@ class GpuManager { std::shared_ptr create_event(bool disable_timing = false) const; std::shared_ptr create_stream() const; - int get_gpu_id() const; void launch(gpuFunction function, const std::array &grid_dim, const std::array &block_dim, int smem_bytes, - std::shared_ptr stream, void **params, - void **extra) const; + gpuStream stream, void **params, void **extra) const; struct Info; const Info &info() const; diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index 4682af7d0..75dc81c17 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -5,6 +5,7 @@ #define ARK_EXECUTOR_HPP #include +#include #include #include #include @@ -12,15 +13,27 @@ namespace ark { +using Stream = void *; + /// Convenience class for executing a model. class Executor { public: /// Constructor. - Executor(int rank, int world_size, int gpu_id, const std::string &name, + Executor(int device_id, Stream stream, const std::string &name, const std::string &plan); + /// Destructor. ~Executor(); + /// Return the device ID. + int device_id() const; + + /// Return the stream of the executor. + Stream stream() const; + + /// Return the plan string. + std::string plan() const; + /// Compile the model. This must be called before `launch()`. void compile(); @@ -39,30 +52,39 @@ class Executor { /// again. float stop(int64_t max_spin_count = -1); + /// Barrier for all rank executors. void barrier(); + /// Destroy the executor. void destroy(); + /// Return whether the executor is destroyed. bool destroyed() const; + /// Return the raw virtual address of the tensor. + uintptr_t tensor_address(const Tensor tensor) const; + template - void tensor_read(const Tensor tensor, std::vector &data) const { + void tensor_read(const Tensor tensor, std::vector &data, + Stream stream = nullptr) const { tensor_read(tensor, reinterpret_cast(data.data()), - data.size() * sizeof(T)); + data.size() * sizeof(T), stream); } template - void tensor_write(const Tensor tensor, const std::vector &data) const { + void tensor_write(const Tensor tensor, const std::vector &data, + Stream stream = nullptr) const { tensor_write(tensor, reinterpret_cast(data.data()), - data.size() * sizeof(T)); + data.size() * sizeof(T), stream); } - void tensor_read(const Tensor tensor, void *data, size_t bytes) const; + void tensor_read(const Tensor tensor, void *data, size_t bytes, + Stream stream = nullptr, bool is_d2d = false) const; - void tensor_write(const Tensor tensor, const void *data, - size_t bytes) const; + void tensor_write(const Tensor tensor, const void *data, size_t bytes, + Stream stream = nullptr, bool is_d2d = false) const; - private: + protected: class Impl; std::unique_ptr impl_; }; @@ -71,8 +93,10 @@ class Model; class DefaultExecutor : public Executor { public: - DefaultExecutor(const Model &model, int gpu_id = -1, - const std::string &name = "DefaultExecutor"); + DefaultExecutor( + const Model &model, int device_id = -1, Stream stream = nullptr, + const std::vector &config_rules = {}, + const std::string &name = "DefaultExecutor"); }; } // namespace ark diff --git a/ark/model/model_json.cpp b/ark/model/model_json.cpp index 0057ef0aa..97ce71967 100644 --- a/ark/model/model_json.cpp +++ b/ark/model/model_json.cpp @@ -272,7 +272,16 @@ static void verify_format_plan(const Json &json) { } } -PlanJson::PlanJson(const Json &json) : Json(json) { verify_format_plan(*this); } +PlanJson::PlanJson(const Json &json) + : Json((json != nullptr) ? json + : Json{{"Rank", 0}, + {"WorldSize", 1}, + {"NumProcessors", 1}, + {"NumWarpsPerProcessor", 1}, + {"TaskInfos", Json::array()}, + {"ProcessorGroups", Json::array()}}) { + verify_format_plan(*this); +} static std::stringstream &dump_pretty_plan(const Json &json, std::stringstream &ss, int indent, diff --git a/ark/model/model_json.hpp b/ark/model/model_json.hpp index cf5fbbce2..e42640a9a 100644 --- a/ark/model/model_json.hpp +++ b/ark/model/model_json.hpp @@ -18,7 +18,7 @@ class ModelJson : public Json { class PlanJson : public Json { public: - PlanJson(const Json &json); + PlanJson(const Json &json = nullptr); std::string dump_pretty(int indent = 0, int indent_step = 2) const; }; diff --git a/ark/model/model_op.cpp b/ark/model/model_op.cpp index 6cdba5d02..b5a0645c8 100644 --- a/ark/model/model_op.cpp +++ b/ark/model/model_op.cpp @@ -202,8 +202,11 @@ std::shared_ptr ModelOp::deserialize(const Json &serialized) { } else if (!serialized.contains("Args")) { ERR(InvalidUsageError, "ModelOp deserialization failed: missing Args"); } + // Run `ModelOpT::from_name` before `construct()` to ensure all operators + // are registered. + auto op_type = ModelOpT::from_name(serialized["Type"]); auto ret = model_op_factory()->construct(serialized["Type"]); - ret->type_ = ModelOpT::from_name(serialized["Type"]); + ret->type_ = op_type; ret->name_ = serialized["Name"]; ret->is_virtual_ = serialized["IsVirtual"]; for (const auto &t : serialized["ReadTensors"]) { diff --git a/ark/ops/ops_all_reduce_test.cpp b/ark/ops/ops_all_reduce_test.cpp index 9e2c6f675..030146680 100644 --- a/ark/ops/ops_all_reduce_test.cpp +++ b/ark/ops/ops_all_reduce_test.cpp @@ -91,10 +91,9 @@ void test_all_reduce_internal(ark::DimType nelem) { std::vector ones_vec(ones.shape().nelems(), ark::half_t(1.0f)); - auto result = - ark::op_test("all_reduce", m, {ones}, {output}, - baseline_all_reduce, - {ones_vec.data()}, false, gpu_id, NumGpus); + auto result = ark::op_test( + "all_reduce", m, {ones}, {output}, + baseline_all_reduce, {ones_vec.data()}); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); return ark::unittest::SUCCESS; diff --git a/ark/ops/ops_communication_test.cpp b/ark/ops/ops_communication_test.cpp index 2b63642e6..f01de9789 100644 --- a/ark/ops/ops_communication_test.cpp +++ b/ark/ops/ops_communication_test.cpp @@ -229,9 +229,7 @@ ark::unittest::State test_communication_send_recv_bidir_sm() { ark::Tensor tns2 = model.identity(tns2_data, {tns}); tns2 = model.recv(tns2_data, remote_gpu_id, tag); - ark::DefaultPlanner planner(model, gpu_id); - planner.install_config_rule(config_rule); - ark::Executor exe(gpu_id, 2, gpu_id, "Executor", planner.plan()); + ark::DefaultExecutor exe(model, gpu_id, nullptr, {config_rule}); exe.compile(); std::vector data(1024); @@ -275,9 +273,7 @@ ark::unittest::State test_communication_send_recv_bidir_sm() { ark::Tensor sum = model.add(tns2, tns_data); - ark::DefaultPlanner planner(model, gpu_id); - planner.install_config_rule(config_rule); - ark::Executor exe(gpu_id, 2, gpu_id, "Executor", planner.plan()); + ark::DefaultExecutor exe(model, gpu_id, nullptr, {config_rule}); exe.compile(); std::vector data(1024); diff --git a/ark/ops/ops_embedding_test.cpp b/ark/ops/ops_embedding_test.cpp index 822973106..8cc95abd2 100644 --- a/ark/ops/ops_embedding_test.cpp +++ b/ark/ops/ops_embedding_test.cpp @@ -78,9 +78,9 @@ ark::unittest::State test_embedding() { } else if (std::is_same::value) { type_str = "bf16"; } - auto result = ark::op_test("embedding_" + type_str, m, {ti, tw}, {to}, - baseline_embedding, - {ti_data.data(), tw_data.data()}, true); + auto result = + ark::op_test("embedding_" + type_str, m, {ti, tw}, {to}, + baseline_embedding, {ti_data.data(), tw_data.data()}); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); return ark::unittest::SUCCESS; diff --git a/ark/ops/ops_test_common.cpp b/ark/ops/ops_test_common.cpp index 50317fba7..60ffc9dc2 100644 --- a/ark/ops/ops_test_common.cpp +++ b/ark/ops/ops_test_common.cpp @@ -31,13 +31,13 @@ std::ostream &operator<<(std::ostream &os, const OpsTestResult &result) { return os; } -OpsTestResult op_test(const std::string &test_name_prefix, const Model &model, - const std::vector &inputs, - const std::vector &outputs, - OpsTestBaseline baseline, - const std::vector &inputs_data, - bool print_on_error, int rank, int world_size) { - DefaultExecutor exe(model); +OpsTestResult op_test( + const std::string &test_name_prefix, const Model &model, + const std::vector &inputs, const std::vector &outputs, + OpsTestBaseline baseline, const std::vector &inputs_data, + const std::vector &config_rules, + bool print_on_error) { + DefaultExecutor exe(model, -1, nullptr, config_rules); exe.compile(); std::vector>> inputs_data_storages; @@ -133,7 +133,8 @@ OpsTestResult op_test(const std::string &test_name_prefix, const Model &model, for (auto t : gt) { gt_ptrs.push_back(t->data()); } - baseline(gt_ptrs, output_shapes, inputs_data_refs, input_shapes, rank); + baseline(gt_ptrs, output_shapes, inputs_data_refs, input_shapes, + model.rank()); std::stringstream test_name; test_name << test_name_prefix; @@ -147,6 +148,7 @@ OpsTestResult op_test(const std::string &test_name_prefix, const Model &model, OpsTestResult result; result.test_name = test_name.str(); + result.plan = exe.plan(); // Compare results with the ground truth. for (size_t i = 0; i < outputs.size(); i++) { @@ -187,7 +189,7 @@ OpsTestResult op_test(const std::string &test_name_prefix, const Model &model, GLOG(gpuDeviceSynchronize()); // Throughput test. - if (world_size > 1) { + if (model.world_size() > 1) { // For multi-GPU, we need to make sure that all GPUs run the same // number of iterations. Rather than doing allgather, we just // use a magic number here. diff --git a/ark/ops/ops_test_common.hpp b/ark/ops/ops_test_common.hpp index 01e97dbb1..c5d640f3b 100644 --- a/ark/ops/ops_test_common.hpp +++ b/ark/ops/ops_test_common.hpp @@ -10,6 +10,7 @@ #include "ark/model.hpp" #include "ark/model_ref.hpp" +#include "ark/planner.hpp" #include "ark/random.hpp" #include "bfloat16.h" #include "half.h" @@ -133,6 +134,7 @@ TensorCompareResult tensor_compare(T *ground_truth, T *res, Dims shape, struct OpsTestResult { std::string test_name; + std::string plan; int iter; float msec_per_iter; std::vector mse; @@ -165,13 +167,12 @@ using OpsTestBaseline = std::function &inputs, - const std::vector &outputs, - OpsTestBaseline baseline, - const std::vector &inputs_data = {}, - bool print_on_error = false, int rank = 0, - int world_size = 1); +OpsTestResult op_test( + const std::string &test_name_prefix, const Model &model, + const std::vector &inputs, const std::vector &outputs, + OpsTestBaseline baseline, const std::vector &inputs_data = {}, + const std::vector &config_rules = {}, + bool print_on_error = false); OpsTestGpuMem to_gpu(void *host_ptr, size_t size); diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake index 9bb83fb42..855cb824b 100644 --- a/cmake/Utils.cmake +++ b/cmake/Utils.cmake @@ -14,7 +14,7 @@ if(GIT_CLANG_FORMAT) COMMAND ${GIT_CLANG_FORMAT} --style=file --diff || true ) add_custom_target(cpplint-autofix - COMMAND ${GIT_CLANG_FORMAT} --style=file || true + COMMAND ${GIT_CLANG_FORMAT} --style=file --force --extensions cc,cpp,h,hpp,cu,in,hip || true ) else() message(STATUS "git-clang-format not found.") diff --git a/python/ark/runtime.py b/python/ark/runtime.py index 7480ce7da..33db1fb5c 100644 --- a/python/ark/runtime.py +++ b/python/ark/runtime.py @@ -106,6 +106,7 @@ def launch( gpu_id: int = 0, plan: str = "", plan_path: str = "", + stream: int = 0, ): """ Create an executor and schedule the ARK model. The scheduler will generate @@ -130,9 +131,8 @@ def launch( _RuntimeState.executor.destroy() _RuntimeState.executor = Executor( - rank, - world_size, gpu_id, + stream, "ArkRuntime", plan, ) diff --git a/python/ark/tensor.py b/python/ark/tensor.py index 316d18566..d69f2aabc 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -48,7 +48,9 @@ def dtype(self) -> DataType: """ return DataType.from_ctype(self._tensor.data_type()) - def to_numpy(self, ndarray: np.ndarray = None) -> np.ndarray: + def to_numpy( + self, ndarray: np.ndarray = None, stream: int = 0 + ) -> np.ndarray: """ Copy a tensor from device to host. If `ndarray` is None, a new numpy array will be created. If the tensor is not allocated, @@ -68,10 +70,10 @@ def to_numpy(self, ndarray: np.ndarray = None) -> np.ndarray: raise ValueError("ndarray dtype does not match the tensor") elif ndarray.nbytes != self.nelems() * self.dtype().element_size(): raise ValueError("ndarray size does not match the tensor") - rt.executor.tensor_read(self._tensor, ndarray) + rt.executor.tensor_read(self._tensor, ndarray, stream) return ndarray - def from_numpy(self, ndarray: np.ndarray) -> "Tensor": + def from_numpy(self, ndarray: np.ndarray, stream: int = 0) -> "Tensor": """ Copies the tensor from a host numpy array to the device. """ @@ -86,7 +88,7 @@ def from_numpy(self, ndarray: np.ndarray) -> "Tensor": ndarray = np.ascontiguousarray(ndarray) if ndarray.nbytes != self.nelems() * self.dtype().element_size(): raise ValueError("ndarray size does not match the tensor") - rt.executor.tensor_write(self._tensor, ndarray) + rt.executor.tensor_write(self._tensor, ndarray, stream) return self diff --git a/python/executor_py.cpp b/python/executor_py.cpp index dc2840329..979cb2952 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -11,25 +11,48 @@ namespace py = pybind11; static void tensor_write(ark::Executor *exe, const ark::Tensor &tensor, - py::buffer host_buffer) { + py::buffer host_buffer, uintptr_t stream) { py::buffer_info info = host_buffer.request(); exe->tensor_write(tensor, reinterpret_cast(info.ptr), - info.size * info.itemsize); + info.size * info.itemsize, + reinterpret_cast(stream), false); +} + +static void tensor_write(ark::Executor *exe, const ark::Tensor &tensor, + size_t address, size_t bytes, uintptr_t stream, + bool is_d2d) { + exe->tensor_write(tensor, reinterpret_cast(address), bytes, + reinterpret_cast(stream), is_d2d); } static void tensor_read(ark::Executor *exe, const ark::Tensor &tensor, - py::buffer host_buffer) { + py::buffer host_buffer, uintptr_t stream) { py::buffer_info info = host_buffer.request(); exe->tensor_read(tensor, reinterpret_cast(info.ptr), - info.size * info.itemsize); + info.size * info.itemsize, + reinterpret_cast(stream), false); +} + +static void tensor_read(ark::Executor *exe, const ark::Tensor &tensor, + size_t address, size_t bytes, uintptr_t stream, + bool is_d2d) { + exe->tensor_read(tensor, reinterpret_cast(address), bytes, + reinterpret_cast(stream), is_d2d); } void register_executor(py::module &m) { py::class_(m, "_Executor") - .def( - py::init(), - py::arg("rank"), py::arg("world_size"), py::arg("gpu_id"), - py::arg("name"), py::arg("plan")) + .def(py::init([](int device_id, uintptr_t stream, + const std::string &name, const std::string &plan) { + return new ark::Executor( + device_id, reinterpret_cast(stream), name, plan); + })) + .def("device_id", &ark::Executor::device_id) + .def("stream", + [](ark::Executor *self) { + return reinterpret_cast(self->stream()); + }) + .def("plan", &ark::Executor::plan) .def("compile", &ark::Executor::compile) .def("launch", &ark::Executor::launch, py::arg("max_spin_count") = -1) .def("run", &ark::Executor::run, py::arg("iter")) @@ -38,6 +61,22 @@ void register_executor(py::module &m) { .def("barrier", &ark::Executor::barrier) .def("destroy", &ark::Executor::destroy) .def("destroyed", &ark::Executor::destroyed) - .def("tensor_read", &tensor_read, py::arg("tensor"), py::arg("data")) - .def("tensor_write", &tensor_write, py::arg("tensor"), py::arg("data")); + .def("tensor_read", + py::overload_cast(&tensor_read), + py::arg("tensor"), py::arg("data"), py::arg("stream")) + .def("tensor_read", + py::overload_cast(&tensor_read), + py::arg("tensor"), py::arg("address"), py::arg("bytes"), + py::arg("stream"), py::arg("is_d2d")) + .def("tensor_write", + py::overload_cast(&tensor_write), + py::arg("tensor"), py::arg("data"), py::arg("stream")) + .def("tensor_write", + py::overload_cast(&tensor_write), + py::arg("tensor"), py::arg("address"), py::arg("bytes"), + py::arg("stream"), py::arg("is_d2d")); } From 215469044ae49a4a453f576b2a396a5c96992aec Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 10 Jul 2024 23:53:32 +0000 Subject: [PATCH 30/54] Update lint workflow --- .github/workflows/lint.yml | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 758eaf564..a918dcede 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,11 +13,8 @@ jobs: - name: Check out Git repository uses: actions/checkout@v4 - - name: Install ClangFormat - run: sudo apt-get install -y clang-format - - - name: Run clang-format - run: clang-format -style=file -Werror --dry-run `find ark python examples -type f -name *.h -o -name *.hpp -o -name *.c -o -name *.cc -o -name *.cpp -o -name *.cu` + - name: Run git-clang-format + run: git-clang-format --style=file --diff - name: Set up Python uses: actions/setup-python@v4 From 705f9f86d8bf8b70005a03fd875e8cc080c99af1 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 11 Jul 2024 00:02:45 +0000 Subject: [PATCH 31/54] Optimize operators --- ark/include/kernels/common/broadcast.h | 4 +- ark/include/kernels/common/sync.h | 12 ++---- ark/include/kernels/reduce.h | 59 ++++++++++++++++++-------- ark/ops/ops_broadcast.cpp | 3 +- ark/ops/ops_matmul.cpp | 32 +++++++++----- 5 files changed, 69 insertions(+), 41 deletions(-) diff --git a/ark/include/kernels/common/broadcast.h b/ark/include/kernels/common/broadcast.h index 97b12e004..858938613 100644 --- a/ark/include/kernels/common/broadcast.h +++ b/ark/include/kernels/common/broadcast.h @@ -186,9 +186,9 @@ struct Broadcast2Intrinsic { (BroadcastInput0 && BroadcastInput1) ? OutNelemPerThread : BroadcastInput0 - ? math::gcd::value + ? math::gcd::value : BroadcastInput1 - ? math::gcd::value + ? math::gcd::value : math::gcd::value>::value; diff --git a/ark/include/kernels/common/sync.h b/ark/include/kernels/common/sync.h index 85f7639c9..cf22e357d 100644 --- a/ark/include/kernels/common/sync.h +++ b/ark/include/kernels/common/sync.h @@ -106,25 +106,21 @@ DEVICE void sync_warps() { static_assert(Arch::ThreadsPerWarp == 64, ""); if constexpr (NumWarps == 1) { __builtin_amdgcn_wave_barrier(); - } else if constexpr (NumWarps == 16) { + } else if constexpr (NumWarps == ARK_WARPS_PER_BLOCK) { __syncthreads(); } else { static_assert(ARK_SMEM_RESERVED_BYTES >= sizeof(sync::WarpGroupState), ""); - int lane_id = threadIdx.x & 63; - if (lane_id == 0) { + if ((threadIdx.x & 63) == 0) { constexpr int MaxOldCnt = NumWarps - 1; - int warp_id = threadIdx.x >> 6; - int group_id = warp_id / NumWarps; + int group_id = (threadIdx.x >> 6) / NumWarps; sync::WarpGroupState *state = reinterpret_cast(_ARK_SMEM); unsigned int tmp = state->is_inc_flag[group_id] ^ 1; if (atomicInc(&state->cnt[group_id], MaxOldCnt) == MaxOldCnt) { state->flag[group_id] = tmp; } else { - while (atomicAdd(&state->flag[group_id], 0) != tmp) - __builtin_amdgcn_s_sleep(1); - __asm__ __volatile__("s_wakeup"); + while (atomicAdd(&state->flag[group_id], 0) != tmp); } state->is_inc_flag[group_id] = tmp; } diff --git a/ark/include/kernels/reduce.h b/ark/include/kernels/reduce.h index 30c8b7831..2dd79d2c3 100644 --- a/ark/include/kernels/reduce.h +++ b/ark/include/kernels/reduce.h @@ -53,7 +53,7 @@ DEVICE bf16 warpReduce(bf16 val) { template DEVICE DataType warpsReduce(DataType val, int tid, int smem_per_warp) { val = warpReduce(val); - if (LanesNum > Arch::ThreadsPerWarp) { + if constexpr (LanesNum > Arch::ThreadsPerWarp) { ReduceSharedStorage *shared = UnitOp::template shared_memory>( smem_per_warp); @@ -351,12 +351,19 @@ struct WwiseReduce { /// @param in Input tensor. /// @param uop_idx Index of the unit operator. template - static DEVICE void runW(DataType *out, DataType *in, int uop_idx, - int smem_per_warp) { + static DEVICE void run(DataType *out, DataType *in, int uop_idx, + int smem_per_warp) { using ShapeChecker = ReduceShapeChecker; + constexpr int InConsecBytes = sizeof(DataType) * InShape::W; constexpr int NelemPerThread = - DefaultNelemPerThread::value; + (InConsecBytes % 16 == 0) + ? 16 / sizeof(DataType) + : (InConsecBytes % 8 == 0) + ? 8 / sizeof(DataType) + : (InConsecBytes % 4 == 0) + ? 4 / sizeof(DataType) + : (InConsecBytes % 2 == 0) ? 2 / sizeof(DataType) : 1; constexpr int NonReduceDimLength = UnitOutDims::N * UnitOutDims::C * UnitOutDims::H; @@ -397,22 +404,38 @@ struct WwiseReduce { &in[idx_in]); } - DataType finalSum; - ReduceType::template identity<1>(&finalSum); + static_assert(math::is_pow2::value, + "NelemPerThread must be power of 2"); + if constexpr (NelemPerThread > 8) { #pragma unroll - for (int i = 0; i < NelemPerThread; ++i) { - ReduceType::template reduce<1>(&finalSum, &finalSum, &reduced[i]); + for (int i = 8; i < NelemPerThread; i += 8) { + ReduceType::template reduce<8>(&reduced[0], &reduced[0], &reduced[i]); + } + ReduceType::template reduce<4>(&reduced[0], &reduced[0], &reduced[4]); + ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + } else if constexpr (NelemPerThread == 8) { + ReduceType::template reduce<4>(&reduced[0], &reduced[0], &reduced[4]); + ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + } else if constexpr (NelemPerThread == 4) { + ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + } else if constexpr (NelemPerThread == 2) { + ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); } - UnitOp::sync_threads(); + if constexpr (InShape::W % ThreadsPerRow != 0) { + UnitOp::sync_threads(); + } // final reduction on shared memory using warp shuffle. - finalSum = warpsReduce( - finalSum, tid, smem_per_warp); + reduced[0] = warpsReduce( + reduced[0], tid, smem_per_warp); // write the result to output. if (tid % ThreadsPerRow == 0) { - ReduceType::template postReduce<1>(&out[idx_out], &finalSum, + ReduceType::template postReduce<1>(&out[idx_out], &reduced[0], InShape::W); } @@ -450,8 +473,8 @@ template ::runW(out, in, uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeSum, Axis>::run(out, in, uop_idx, + smem_per_warp); } template ::runW(out, in, uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeMean, Axis>::run(out, in, uop_idx, + smem_per_warp); } template ::runW(out, in, uop_idx, - smem_per_warp); + SmemBytes, ReduceTypeMax, Axis>::run(out, in, uop_idx, + smem_per_warp); } } // namespace ark diff --git a/ark/ops/ops_broadcast.cpp b/ark/ops/ops_broadcast.cpp index 3985a0500..e5559fc32 100644 --- a/ark/ops/ops_broadcast.cpp +++ b/ark/ops/ops_broadcast.cpp @@ -27,8 +27,7 @@ ModelOpBroadcast1::ModelOpBroadcast1(const std::string &type_name, std::string ModelOpBroadcast1::impl_name(const Json &config) const { check_fields_config(config, {"NumWarps", "Tile"}); int num_warps = config.at("NumWarps"); - auto &tile_shape = config.at("Tile"); - Dims unit_out_dims{tile_shape[0], tile_shape[1]}; + Dims unit_out_dims(config.at("Tile").get>()); return function_name_string( pascal_to_snake(type()->type_name()), diff --git a/ark/ops/ops_matmul.cpp b/ark/ops/ops_matmul.cpp index b259f99c8..a24b95d72 100644 --- a/ark/ops/ops_matmul.cpp +++ b/ark/ops/ops_matmul.cpp @@ -189,45 +189,55 @@ std::vector ModelOpMatmul::impl_args([ } static const Json get_default_config(const ArchRef arch, - const ModelDataType &data_type) { + const ModelDataType &data_type, + const Dims &mnk) { + if (data_type != FP32.ref() && data_type != FP16.ref() && + data_type != BF16.ref()) { + ERR(InvalidUsageError, + "Unsupported data type: ", data_type->type_name()); + } + if (!arch->belongs_to(ARCH_CUDA) && !arch->belongs_to(ARCH_ROCM)) { + ERR(InvalidUsageError, "Unsupported architecture: ", arch->name()); + } + DimType tm = (mnk[0] > mnk[1]) ? 256 : 128; + DimType tn = (mnk[0] > mnk[1]) ? 128 : 256; if (arch->belongs_to(ARCH_CUDA_80) && data_type == FP32.ref()) { return {{"NumWarps", 8}, {"SramBytes", 147456}, - {"TileShapeMNK", {128, 256, 32}}}; + {"TileShapeMNK", {tm, tn, 32}}}; } else if (arch->belongs_to(ARCH_CUDA_80) && data_type == FP16.ref()) { return {{"NumWarps", 8}, {"SramBytes", 147456}, - {"TileShapeMNK", {128, 256, 64}}}; + {"TileShapeMNK", {tm, tn, 64}}}; } else if (arch->belongs_to(ARCH_CUDA_80) && data_type == BF16.ref()) { return {{"NumWarps", 8}, {"SramBytes", 147456}, - {"TileShapeMNK", {128, 256, 64}}}; + {"TileShapeMNK", {tm, tn, 64}}}; } else if (arch->belongs_to(ARCH_ROCM_942) && data_type == FP32.ref()) { return {{"NumWarps", 4}, {"SramBytes", 24672}, - {"TileShapeMNK", {128, 256, 16}}}; + {"TileShapeMNK", {tm, tn, 16}}}; } else if (arch->belongs_to(ARCH_ROCM_942) && data_type == FP16.ref()) { return {{"NumWarps", 4}, {"SramBytes", 24672}, - {"TileShapeMNK", {128, 256, 32}}}; + {"TileShapeMNK", {tm, tn, 32}}}; } else if (arch->belongs_to(ARCH_ROCM_942) && data_type == BF16.ref()) { return {{"NumWarps", 4}, - {"SramBytes", 24672}, - {"TileShapeMNK", {128, 256, 32}}}; + {"SramBytes", 24624}, + {"TileShapeMNK", {tm, tn, 32}}}; } - ERR(InvalidUsageError, "Unsupported arch and data type: ", arch->name(), - " and ", data_type->type_name()); + ERR(InternalError, "Unexpected error"); return {}; } Json ModelOpMatmul::default_config(const ArchRef arch) const { auto result = result_tensors_[0]; - Json config = get_default_config(arch, result->data_type()); check_fields_args(args_, {"TransposeInput", "TransposeOther"}); Dims mnk = calc_problem_size(read_tensors_[0]->padded_shape(), read_tensors_[1]->padded_shape(), args_.at("TransposeInput").value(), args_.at("TransposeOther").value()); + Json config = get_default_config(arch, result->data_type(), mnk); size_t tile_x = config.at("TileShapeMNK")[0]; size_t tile_y = config.at("TileShapeMNK")[1]; if (mnk[0] % tile_x != 0 || mnk[1] % tile_y != 0) { From a3114e45eea5d8c7929915e7ca1b1f9cc6ef1591 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 11 Jul 2024 00:04:40 +0000 Subject: [PATCH 32/54] fix --- ark/error.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ark/error.hpp b/ark/error.hpp index e08acd975..5ad21824b 100644 --- a/ark/error.hpp +++ b/ark/error.hpp @@ -20,6 +20,7 @@ class BaseError : public std::runtime_error { _name(const std::string &msg) : BaseError(msg) {} \ }; +REGISTER_ERROR_TYPE(InternalError) REGISTER_ERROR_TYPE(InvalidUsageError) REGISTER_ERROR_TYPE(NotFoundError) REGISTER_ERROR_TYPE(ModelError) From 6116424e2a692a3cec2eb749565f1ae03637e5e6 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 11 Jul 2024 00:28:47 +0000 Subject: [PATCH 33/54] delete an unused file --- plan_gpu0.json | 2423 ------------------------------------------------ 1 file changed, 2423 deletions(-) delete mode 100644 plan_gpu0.json diff --git a/plan_gpu0.json b/plan_gpu0.json deleted file mode 100644 index cad05f774..000000000 --- a/plan_gpu0.json +++ /dev/null @@ -1,2423 +0,0 @@ -{ - "Rank": 0, - "WorldSize": 1, - "Architecture": "ROCM_942", - "NumProcessors": 304, - "NumWarpsPerProcessor": 4, - "TaskInfos": [ - { - "Id": 0, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":4,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":0,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":0,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":6,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":7,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [256,128,32], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 1, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Rope", - "Name": "rope", - "IsVirtual": false, - "ReadTensors": [ - {"Id":12,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":6,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":5,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":15,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":16,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [256,1,128], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 2, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Transpose", - "Name": "transpose", - "IsVirtual": false, - "ReadTensors": [ - {"Id":16,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":9,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":19,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":20,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Permutation": {"DIMS":[0,2,1,3]} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [256,128], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 3, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":4,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":1,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":1,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":8,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":9,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [256,128,32], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 4, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Rope", - "Name": "rope_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":13,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":7,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":5,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":5,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":17,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":18,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [256,1,128], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 5, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Transpose", - "Name": "transpose_2", - "IsVirtual": false, - "ReadTensors": [ - {"Id":18,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":10,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":23,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":24,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Permutation": {"DIMS":[0,2,1,3]} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [256,128], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 6, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_2", - "IsVirtual": false, - "ReadTensors": [ - {"Id":4,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":4,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":2,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":2,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":10,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":11,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [256,128,32], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 7, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Transpose", - "Name": "transpose_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":14,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":8,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":21,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":22,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Permutation": {"DIMS":[0,2,1,3]} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [256,128], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 8, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_3", - "IsVirtual": false, - "ReadTensors": [ - {"Id":20,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":11,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":24,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":13,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":25,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":26,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [256,128,32], - "NumTasks": 4096 - } - } - ] - }, - { - "Id": 9, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "ScalarMul", - "Name": "mul", - "IsVirtual": false, - "ReadTensors": [ - {"Id":26,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":14,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":27,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Factor": {"FLOAT":0.0883883461356163} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [256,128], - "NumTasks": 4096 - } - } - ] - }, - { - "Id": 10, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "ReduceMax", - "Name": "reduce_max", - "IsVirtual": false, - "ReadTensors": [ - {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":29,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":30,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Axis": {"INT":3}, - "KeepDim": {"BOOL":true} - }, - "Config": { - "NumWarps": 1, - "ImplType": "WarpWise", - "SramBytes": 0, - "NumTasks": 65536 - } - } - ] - }, - { - "Id": 11, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Sub", - "Name": "sub", - "IsVirtual": false, - "ReadTensors": [ - {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":30,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":16,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":28,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":31,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 4096 - } - } - ] - }, - { - "Id": 12, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Exp", - "Name": "exp", - "IsVirtual": false, - "ReadTensors": [ - {"Id":31,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":31,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 4096 - } - } - ] - }, - { - "Id": 13, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "ReduceSum", - "Name": "reduce_sum", - "IsVirtual": false, - "ReadTensors": [ - {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":33,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":34,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Axis": {"INT":3}, - "KeepDim": {"BOOL":true} - }, - "Config": { - "NumWarps": 1, - "ImplType": "WarpWise", - "SramBytes": 0, - "NumTasks": 65536 - } - } - ] - }, - { - "Id": 14, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Div", - "Name": "div", - "IsVirtual": false, - "ReadTensors": [ - {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":34,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":17,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":32,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":35,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 4096 - } - } - ] - }, - { - "Id": 15, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_4", - "IsVirtual": false, - "ReadTensors": [ - {"Id":35,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":15,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":22,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":12,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":36,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":37,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":false} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [256,128,32], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 16, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "Transpose", - "Name": "transpose_3", - "IsVirtual": false, - "ReadTensors": [ - {"Id":37,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":18,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":38,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":39,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Permutation": {"DIMS":[0,2,1,3]} - }, - "Config": { - "NumWarps": 1, - "SramBytes": 0, - "Tile": [8,128], - "NumTasks": 8192 - } - } - ] - }, - { - "Id": 17, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_5", - "IsVirtual": false, - "ReadTensors": [ - {"Id":40,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":19,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":3,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":3,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":41,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":20,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":42,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":20,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [128,256,32], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 18, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Cast", - "Name": "cast", - "IsVirtual": false, - "ReadTensors": [ - {"Id":52,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":30,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":54,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 19, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":56,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":57,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 20, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "ReduceMean", - "Name": "reduce_mean", - "IsVirtual": false, - "ReadTensors": [ - {"Id":57,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":33,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":58,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":59,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Axis": {"INT":2}, - "KeepDim": {"BOOL":true} - }, - "Config": { - "NumWarps": 1, - "ImplType": "WarpWise", - "SramBytes": 0, - "NumTasks": 2048 - } - } - ] - }, - { - "Id": 21, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "Rsqrt", - "Name": "rsqrt", - "IsVirtual": false, - "ReadTensors": [ - {"Id":59,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":34,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":60,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":61,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 1, - "SramBytes": 0, - "Tile": [64,1], - "NumTasks": 32 - } - } - ] - }, - { - "Id": 22, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul_2", - "IsVirtual": false, - "ReadTensors": [ - {"Id":55,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":32,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":61,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":35,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":62,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":63,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 23, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul_3", - "IsVirtual": false, - "ReadTensors": [ - {"Id":63,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":50,"DataType":"FP32","Shape":[1,1,4096],"Strides":[1,1,4096],"Offsets":[0,0,0],"PaddedShape":[1,1,4096],"Buffer":{"Id":28,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":63,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":64,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 24, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Cast", - "Name": "cast_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":64,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":36,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":65,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 25, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_6", - "IsVirtual": false, - "ReadTensors": [ - {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":43,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":21,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":67,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":68,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [128,256,32], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 26, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "Rope", - "Name": "rope_2", - "IsVirtual": false, - "ReadTensors": [ - {"Id":73,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":38,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":53,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":31,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":76,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":77,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 1, - "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 - } - } - ] - }, - { - "Id": 27, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "Transpose", - "Name": "transpose_4", - "IsVirtual": false, - "ReadTensors": [ - {"Id":77,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":41,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":80,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":81,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Permutation": {"DIMS":[0,2,1,3]} - }, - "Config": { - "NumWarps": 1, - "SramBytes": 0, - "Tile": [8,128], - "NumTasks": 8192 - } - } - ] - }, - { - "Id": 28, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_7", - "IsVirtual": false, - "ReadTensors": [ - {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":44,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":22,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":69,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":70,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [128,256,32], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 29, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "Rope", - "Name": "rope_3", - "IsVirtual": false, - "ReadTensors": [ - {"Id":74,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":39,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":53,"DataType":"FP16","Shape":[1,2048,1,128],"Strides":[1,2048,1,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,1,128],"Buffer":{"Id":31,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":78,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":79,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 1, - "SramBytes": 0, - "Tile": [1,64], - "NumTasks": 131072 - } - } - ] - }, - { - "Id": 30, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "Transpose", - "Name": "transpose_6", - "IsVirtual": false, - "ReadTensors": [ - {"Id":79,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":42,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":84,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":85,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Permutation": {"DIMS":[0,2,3,1]} - }, - "Config": { - "NumWarps": 1, - "SramBytes": 0, - "Tile": [8,8], - "NumTasks": 131072 - } - } - ] - }, - { - "Id": 31, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_8", - "IsVirtual": false, - "ReadTensors": [ - {"Id":66,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":37,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":45,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":23,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":71,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":72,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [128,256,32], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 32, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "Transpose", - "Name": "transpose_5", - "IsVirtual": false, - "ReadTensors": [ - {"Id":75,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":40,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":82,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":83,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Permutation": {"DIMS":[0,2,1,3]} - }, - "Config": { - "NumWarps": 1, - "SramBytes": 0, - "Tile": [8,128], - "NumTasks": 8192 - } - } - ] - }, - { - "Id": 33, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_9", - "IsVirtual": false, - "ReadTensors": [ - {"Id":81,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":43,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":85,"DataType":"FP16","Shape":[1,32,128,2048],"Strides":[1,32,128,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,128,2048],"Buffer":{"Id":45,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":86,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":46,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":87,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":46,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":false} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [128,256,32], - "NumTasks": 4096 - } - } - ] - }, - { - "Id": 34, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "ScalarMul", - "Name": "mul_4", - "IsVirtual": false, - "ReadTensors": [ - {"Id":87,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":46,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":88,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Factor": {"FLOAT":0.0883883461356163} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 4096 - } - } - ] - }, - { - "Id": 35, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "ReduceMax", - "Name": "reduce_max_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":90,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":48,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":91,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":48,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Axis": {"INT":3}, - "KeepDim": {"BOOL":true} - }, - "Config": { - "NumWarps": 1, - "ImplType": "WarpWise", - "SramBytes": 0, - "NumTasks": 65536 - } - } - ] - }, - { - "Id": 36, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Sub", - "Name": "sub_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":91,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":48,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":89,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":92,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 4096 - } - } - ] - }, - { - "Id": 37, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Exp", - "Name": "exp_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":92,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":92,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 4096 - } - } - ] - }, - { - "Id": 38, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "ReduceSum", - "Name": "reduce_sum_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":94,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":49,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":95,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":49,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Axis": {"INT":3}, - "KeepDim": {"BOOL":true} - }, - "Config": { - "NumWarps": 1, - "ImplType": "WarpWise", - "SramBytes": 0, - "NumTasks": 65536 - } - } - ] - }, - { - "Id": 39, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Div", - "Name": "div_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":95,"DataType":"FP16","Shape":[1,32,2048,1],"Strides":[1,32,2048,1],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,1],"Buffer":{"Id":49,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":93,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":96,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 4096 - } - } - ] - }, - { - "Id": 40, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_10", - "IsVirtual": false, - "ReadTensors": [ - {"Id":96,"DataType":"FP16","Shape":[1,32,2048,2048],"Strides":[1,32,2048,2048],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,2048],"Buffer":{"Id":47,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":83,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":44,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":97,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":50,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":98,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":50,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":false} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [256,128,32], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 41, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "Transpose", - "Name": "transpose_7", - "IsVirtual": false, - "ReadTensors": [ - {"Id":98,"DataType":"FP16","Shape":[1,32,2048,128],"Strides":[1,32,2048,128],"Offsets":[0,0,0,0],"PaddedShape":[1,32,2048,128],"Buffer":{"Id":50,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":99,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":51,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":100,"DataType":"FP16","Shape":[1,2048,32,128],"Strides":[1,2048,32,128],"Offsets":[0,0,0,0],"PaddedShape":[1,2048,32,128],"Buffer":{"Id":51,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Permutation": {"DIMS":[0,2,1,3]} - }, - "Config": { - "NumWarps": 1, - "SramBytes": 0, - "Tile": [8,128], - "NumTasks": 8192 - } - } - ] - }, - { - "Id": 42, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_11", - "IsVirtual": false, - "ReadTensors": [ - {"Id":101,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":51,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":46,"DataType":"FP16","Shape":[4096,4096],"Strides":[4096,4096],"Offsets":[0,0],"PaddedShape":[4096,4096],"Buffer":{"Id":24,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":102,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":52,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":103,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":52,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [128,256,32], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 43, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Add", - "Name": "add", - "IsVirtual": false, - "ReadTensors": [ - {"Id":52,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":30,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":103,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":52,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":104,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":105,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 44, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Cast", - "Name": "cast_2", - "IsVirtual": false, - "ReadTensors": [ - {"Id":105,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":106,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 45, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul_5", - "IsVirtual": false, - "ReadTensors": [ - {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":108,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":55,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":109,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":55,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 1, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 46, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "ReduceMean", - "Name": "reduce_mean_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":109,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":55,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":110,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":56,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":111,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":56,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "Axis": {"INT":2}, - "KeepDim": {"BOOL":true} - }, - "Config": { - "NumWarps": 1, - "ImplType": "WarpWise", - "SramBytes": 0, - "NumTasks": 2048 - } - } - ] - }, - { - "Id": 47, - "NumWarps": 1, - "SramBytes": 0, - "Ops": [ - { - "Type": "Rsqrt", - "Name": "rsqrt_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":111,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":56,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":112,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":57,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":113,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":57,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 1, - "SramBytes": 0, - "Tile": [64,1], - "NumTasks": 32 - } - } - ] - }, - { - "Id": 48, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul_6", - "IsVirtual": false, - "ReadTensors": [ - {"Id":107,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":54,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":113,"DataType":"FP32","Shape":[1,2048,1],"Strides":[1,2048,1],"Offsets":[0,0,0],"PaddedShape":[1,2048,1],"Buffer":{"Id":57,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":114,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":115,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 49, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul_7", - "IsVirtual": false, - "ReadTensors": [ - {"Id":115,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":51,"DataType":"FP32","Shape":[1,1,4096],"Strides":[1,1,4096],"Offsets":[0,0,0],"PaddedShape":[1,1,4096],"Buffer":{"Id":29,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":115,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":116,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 50, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Cast", - "Name": "cast_3", - "IsVirtual": false, - "ReadTensors": [ - {"Id":116,"DataType":"FP32","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":58,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":117,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":118,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 51, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_12", - "IsVirtual": false, - "ReadTensors": [ - {"Id":118,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":47,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":25,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":119,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":120,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [128,256,32], - "NumTasks": 688 - } - } - ] - }, - { - "Id": 52, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Sigmoid", - "Name": "sigmoid", - "IsVirtual": false, - "ReadTensors": [ - {"Id":120,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":121,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":61,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":122,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":61,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 688 - } - } - ] - }, - { - "Id": 53, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul_8", - "IsVirtual": false, - "ReadTensors": [ - {"Id":120,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":60,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":122,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":61,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":123,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":62,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":124,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":62,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 688 - } - } - ] - }, - { - "Id": 54, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_13", - "IsVirtual": false, - "ReadTensors": [ - {"Id":118,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":59,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":49,"DataType":"FP16","Shape":[11008,4096],"Strides":[11008,4096],"Offsets":[0,0],"PaddedShape":[11008,4096],"Buffer":{"Id":27,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":125,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":63,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":126,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":63,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [128,256,32], - "NumTasks": 688 - } - } - ] - }, - { - "Id": 55, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Mul", - "Name": "mul_9", - "IsVirtual": false, - "ReadTensors": [ - {"Id":124,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":62,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":126,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":63,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":127,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":64,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":128,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":64,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 688 - } - } - ] - }, - { - "Id": 56, - "NumWarps": 4, - "SramBytes": 24672, - "Ops": [ - { - "Type": "Matmul", - "Name": "matmul_14", - "IsVirtual": false, - "ReadTensors": [ - {"Id":128,"DataType":"FP16","Shape":[1,2048,11008],"Strides":[1,2048,11008],"Offsets":[0,0,0],"PaddedShape":[1,2048,11008],"Buffer":{"Id":64,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":48,"DataType":"FP16","Shape":[4096,11008],"Strides":[4096,11008],"Offsets":[0,0],"PaddedShape":[4096,11008],"Buffer":{"Id":26,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":129,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":65,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":130,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":65,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": { - "TransposeInput": {"BOOL":false}, - "TransposeOther": {"BOOL":true} - }, - "Config": { - "NumWarps": 4, - "SramBytes": 24672, - "TileShapeMNK": [128,256,32], - "NumTasks": 256 - } - } - ] - }, - { - "Id": 57, - "NumWarps": 4, - "SramBytes": 0, - "Ops": [ - { - "Type": "Add", - "Name": "add_1", - "IsVirtual": false, - "ReadTensors": [ - {"Id":105,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":53,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}}, - {"Id":130,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":65,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "WriteTensors": [ - {"Id":131,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":66,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "ResultTensors": [ - {"Id":132,"DataType":"FP16","Shape":[1,2048,4096],"Strides":[1,2048,4096],"Offsets":[0,0,0],"PaddedShape":[1,2048,4096],"Buffer":{"Id":66,"Rank":-1,"SendTags":[],"RecvTags":[],"IsExternal":false}} - ], - "Args": {}, - "Config": { - "NumWarps": 4, - "SramBytes": 0, - "Tile": [128,256], - "NumTasks": 256 - } - } - ] - } - ], - "ProcessorGroups": [ - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,86], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":0,"TaskRange":[0,256],"Granularity":1}, - {"TaskId":1,"TaskRange":[0,256],"Granularity":1}, - {"TaskId":2,"TaskRange":[0,256],"Granularity":1} - ] - }, - { - "ProcessorRange": [86,172], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":3,"TaskRange":[0,256],"Granularity":1}, - {"TaskId":4,"TaskRange":[0,256],"Granularity":1}, - {"TaskId":5,"TaskRange":[0,256],"Granularity":1} - ] - }, - { - "ProcessorRange": [172,258], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":6,"TaskRange":[0,256],"Granularity":1}, - {"TaskId":7,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":8,"TaskRange":[0,4096],"Granularity":1}, - {"TaskId":9,"TaskRange":[0,4096],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":10,"TaskRange":[0,65536],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":11,"TaskRange":[0,4096],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":12,"TaskRange":[0,4096],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":13,"TaskRange":[0,65536],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":14,"TaskRange":[0,4096],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,256], - "ResourceGroups": [ - { - "ProcessorRange": [0,256], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":15,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":16,"TaskRange":[0,8192],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,256], - "ResourceGroups": [ - { - "ProcessorRange": [0,256], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":17,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":18,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":19,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":20,"TaskRange":[0,2048],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,32], - "ResourceGroups": [ - { - "ProcessorRange": [0,32], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":21,"TaskRange":[0,32],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":22,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":23,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":24,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,256], - "ResourceGroups": [ - { - "ProcessorRange": [0,256], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":25,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":26,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":27,"TaskRange":[0,8192],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,256], - "ResourceGroups": [ - { - "ProcessorRange": [0,256], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":28,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":29,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":30,"TaskRange":[0,131072],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,256], - "ResourceGroups": [ - { - "ProcessorRange": [0,256], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":31,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":32,"TaskRange":[0,8192],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":33,"TaskRange":[0,4096],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":34,"TaskRange":[0,4096],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":35,"TaskRange":[0,65536],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":36,"TaskRange":[0,4096],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":37,"TaskRange":[0,4096],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":38,"TaskRange":[0,65536],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":39,"TaskRange":[0,4096],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,256], - "ResourceGroups": [ - { - "ProcessorRange": [0,256], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":40,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":41,"TaskRange":[0,8192],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,256], - "ResourceGroups": [ - { - "ProcessorRange": [0,256], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":42,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":43,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":44,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":45,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":46,"TaskRange":[0,2048],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,32], - "ResourceGroups": [ - { - "ProcessorRange": [0,32], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":47,"TaskRange":[0,32],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":48,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":49,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":50,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":51,"TaskRange":[0,688],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":52,"TaskRange":[0,688],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":53,"TaskRange":[0,688],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":54,"TaskRange":[0,688],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":55,"TaskRange":[0,688],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,256], - "ResourceGroups": [ - { - "ProcessorRange": [0,256], - "WarpRange": [0,4], - "SramRange": [0,24672], - "TaskGroups": [ - {"TaskId":56,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - }, - { - "ProcessorRange": [0,304], - "ResourceGroups": [ - { - "ProcessorRange": [0,304], - "WarpRange": [0,4], - "SramRange": [0,0], - "TaskGroups": [ - {"TaskId":57,"TaskRange":[0,256],"Granularity":1} - ] - } - ] - } - ] -} From 67e3b2601f00997d6debe8f9dd3e7c633ceee08b Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Thu, 11 Jul 2024 01:44:53 +0000 Subject: [PATCH 34/54] update test --- ark/ops/ops_scalar_test.cpp | 43 +++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/ark/ops/ops_scalar_test.cpp b/ark/ops/ops_scalar_test.cpp index 9e9e635b8..6ae0022f0 100644 --- a/ark/ops/ops_scalar_test.cpp +++ b/ark/ops/ops_scalar_test.cpp @@ -263,31 +263,28 @@ ark::unittest::State test_scalar_mul_fp16_offset() { { ark::Model m; ark::Tensor buf = m.tensor({1024}, ark::FP16); - ark::Tensor tns = m.refer(buf, {2}, {1024}, {3}); - ark::Tensor out = m.mul(tns, 2, tns); - - ark::DefaultExecutor exe(m); - exe.compile(); + ark::Tensor tns = m.refer(buf, {2}, {1024}, {6}); + ark::Tensor doubled = m.mul(tns, 2, tns); + ark::Tensor out = m.identity(buf, {doubled}); std::vector data(1024, ark::half_t(2)); - exe.tensor_write(buf, data); - - exe.launch(); - exe.run(1); - exe.stop(); - - data.clear(); - data.resize(1024); - - exe.tensor_read(buf, data); - - for (size_t i = 0; i < data.size(); ++i) { - if (i == 3 || i == 4) { - UNITTEST_EQ(data[i], 4); - } else { - UNITTEST_EQ(data[i], 2); - } - } + auto result = ark::op_test( + "scalar_mul_fp16_offset", m, {buf}, {out}, + [](std::vector &outputs, const std::vector &, + const std::vector &, const std::vector &, + int) { + ark::half_t *out = static_cast(outputs[0]); + for (size_t i = 0; i < 1024; ++i) { + if (i == 6 || i == 7) { + out[i] = 4; + } else { + out[i] = 2; + } + } + }, + {data.data()}); + UNITTEST_LOG(result); + UNITTEST_EQ(result.max_diff[0], 0.0f); } return ark::unittest::SUCCESS; } From e1f178bd3c7bbb0023e1ffc3eceee72564116d10 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Fri, 12 Jul 2024 04:37:51 +0000 Subject: [PATCH 35/54] fix merge & updates --- ark/api/executor.cpp | 3 +-- python/ark/runtime.py | 8 ++++---- python/ark/tensor.py | 17 ++++++++++------- python/executor_py.cpp | 2 +- python/unittest/unittest_common.py | 22 ++++++++++++++++++++++ 5 files changed, 38 insertions(+), 14 deletions(-) create mode 100644 python/unittest/unittest_common.py diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 1af298e89..ad6cb8550 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -233,7 +233,6 @@ void Executor::Impl::init(const PlanJson &plan_json) { if (world_size_ > 1) { init_communicator(); } -} auto gpu_manager = GpuManager::get_instance(device_id_); @@ -384,7 +383,7 @@ std::map Executor::Impl::init_buffers(const Json &plan_json) { continue; } if (buf_info->buffer->is_external()) { - if (buf_info->buffer->device_id() != gpu_id_) { + if (buf_info->buffer->device_id() != device_id_) { ERR(InvalidUsageError, "PyTorch tensor and model execution are on different GPUs"); } diff --git a/python/ark/runtime.py b/python/ark/runtime.py index 93acb6bf8..1e56fe1ca 100644 --- a/python/ark/runtime.py +++ b/python/ark/runtime.py @@ -142,7 +142,7 @@ def launch( initialized. The executor will compile the cuda kernels and launch the ARK runtime. """ if self.launched(): - logging.warn( + logging.warning( f"Runtime {self.runtime_id} is already launched, skip launching" ) return @@ -153,7 +153,7 @@ def launch( if self.state == Runtime.State.Init: if self.executor is not None: if not self.executor.destroyed(): - logging.warn( + logging.warning( f"Runtime {self.runtime_id}, has already been launched. Destroying the old executor" ) self.executor.destroy() @@ -184,7 +184,7 @@ def wait(self): Wait for the kernel to finish. """ if self.state != Runtime.State.Running: - logging.warn( + logging.warning( f"ARK runtime {self.runtime_id} is not running, skip waiting" ) return @@ -197,7 +197,7 @@ def stop(self) -> float: Once this is called, we need to call `launch()` again to run the model again. """ if not self.launched(): - logging.warn( + logging.warning( f"ARK runtime {self.runtime_id} is never launched, skip stopping" ) return diff --git a/python/ark/tensor.py b/python/ark/tensor.py index e377cf852..335020769 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -103,7 +103,7 @@ def to_numpy( return ndarray def to_torch( - self, tensor: torch.Tensor = None, runtime_id: int = -1 + self, tensor: torch.Tensor = None, stream: int = 0 ) -> torch.Tensor: """ """ if _no_torch: @@ -116,21 +116,24 @@ def to_torch( ) torch_type = self.dtype().to_torch() if tensor is None: - dev_name = f"cuda:{rt.executor.gpu_id()}" + dev_name = f"cuda:{rt.executor.device_id()}" tensor = torch.zeros( self.shape(), dtype=torch_type, device=torch.device(dev_name) ) - elif tensor.shape != self.shape(): - raise ValueError("torch tensor shape does not match the tensor") + elif list(tensor.shape) != self.shape(): + raise ValueError(f"torch tensor shape {list(tensor.shape)} " + f"does not match the tensor {self.shape()}") elif tensor.dtype != torch_type: - raise ValueError("torch tensor dtype does not match the tensor") + raise ValueError(f"torch tensor dtype {tensor.dtype} " + f"does not match the tensor {torch_type}") elif not tensor.is_contiguous(): raise ValueError("torch tensor is not contiguous in memory") elif tensor.numel() != self.nelems(): - raise ValueError("torch tensor size does not match the tensor") + raise ValueError(f"torch tensor size {tensor.numel()} " + f"does not match the tensor {self.nelems()}") tensor_bytes = self.nelems() * self.dtype().element_size() rt.executor.tensor_read( - self._tensor, tensor.data_ptr(), tensor_bytes, True + self._tensor, tensor.data_ptr(), tensor_bytes, stream, True ) return tensor diff --git a/python/executor_py.cpp b/python/executor_py.cpp index fffbb2c30..8455fa585 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -93,7 +93,7 @@ static DLManagedTensor *to_dlpack(ark::Executor &exe, tensor.offsets().is_no_dim() ? 0 : tensor.offsets().vector()[0]; dl_tensor.byte_offset = offset_in_elements * tensor.data_type().bytes(); dl_tensor.device.device_type = get_device_type(); - dl_tensor.device.device_id = static_cast(exe.gpu_id()); + dl_tensor.device.device_id = static_cast(exe.device_id()); dl_tensor.ndim = static_cast(tensor.shape().ndims()); dl_tensor.dtype = get_dl_dtype(tensor.data_type()); diff --git a/python/unittest/unittest_common.py b/python/unittest/unittest_common.py new file mode 100644 index 000000000..9548410b5 --- /dev/null +++ b/python/unittest/unittest_common.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest +import ark + + +def pytest_ark(need_torch: bool = False): + """ + Decorator for ARK unit tests. + """ + def decorator(test_func): + if need_torch: + try: + import torch + except ImportError: + return pytest.mark.skip(reason="torch is not installed")(test_func) + def wrapper(*args, **kwargs): + ark.init() + test_func(*args, **kwargs) + return wrapper + return decorator From ce1959ecb5fb064b4e653b3cad7cf3dcba63a9d7 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Fri, 12 Jul 2024 06:49:30 +0000 Subject: [PATCH 36/54] Add `loop_mode` argument --- ark/api/executor.cpp | 116 ++++++++++++++------- ark/api/planner.cpp | 2 +- ark/codegen.cpp | 2 +- ark/gpu/{gpu.h => gpu.hpp} | 7 +- ark/gpu/gpu_compile.cpp | 4 +- ark/gpu/{gpu_compile.h => gpu_compile.hpp} | 6 +- ark/gpu/gpu_event.cpp | 6 +- ark/gpu/{gpu_event.h => gpu_event.hpp} | 8 +- ark/gpu/gpu_kernel.cpp | 33 ++---- ark/gpu/{gpu_kernel.h => gpu_kernel.hpp} | 19 ++-- ark/gpu/gpu_kernel_test.cpp | 8 +- ark/gpu/{gpu_logging.h => gpu_logging.hpp} | 8 +- ark/gpu/gpu_manager.cpp | 4 +- ark/gpu/{gpu_manager.h => gpu_manager.hpp} | 14 +-- ark/gpu/gpu_memory.cpp | 8 +- ark/gpu/{gpu_memory.h => gpu_memory.hpp} | 10 +- ark/gpu/gpu_stream.cpp | 6 +- ark/gpu/{gpu_stream.h => gpu_stream.hpp} | 8 +- ark/include/ark/executor.hpp | 4 +- ark/include/kernels/kernel_template.in | 17 ++- ark/ops/ops_matmul_test.cpp | 2 +- ark/ops/ops_test_common.cpp | 2 +- python/ark/runtime.py | 4 +- python/executor_py.cpp | 8 +- 24 files changed, 173 insertions(+), 133 deletions(-) rename ark/gpu/{gpu.h => gpu.hpp} (98%) rename ark/gpu/{gpu_compile.h => gpu_compile.hpp} (78%) rename ark/gpu/{gpu_event.h => gpu_event.hpp} (84%) rename ark/gpu/{gpu_kernel.h => gpu_kernel.hpp} (68%) rename ark/gpu/{gpu_logging.h => gpu_logging.hpp} (92%) rename ark/gpu/{gpu_manager.h => gpu_manager.hpp} (88%) rename ark/gpu/{gpu_memory.h => gpu_memory.hpp} (87%) rename ark/gpu/{gpu_stream.h => gpu_stream.hpp} (79%) diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 2f50a4280..91c8e39de 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -14,11 +14,11 @@ #include "codegen.hpp" #include "env.h" #include "file_io.h" -#include "gpu/gpu.h" -#include "gpu/gpu_event.h" -#include "gpu/gpu_kernel.h" -#include "gpu/gpu_logging.h" -#include "gpu/gpu_manager.h" +#include "gpu/gpu.hpp" +#include "gpu/gpu_event.hpp" +#include "gpu/gpu_kernel.hpp" +#include "gpu/gpu_logging.hpp" +#include "gpu/gpu_manager.hpp" #include "logging.h" #include "model/model_buffer.hpp" #include "model/model_data_type.hpp" @@ -140,7 +140,7 @@ static size_t tensor_stride_bytes(const Json &tensor) { class Executor::Impl { public: - Impl(int device_id, Stream stream, const std::string &name); + Impl(int device_id, Stream stream, const std::string &name, bool loop_mode); ~Impl() = default; void init(const PlanJson& plan); @@ -174,6 +174,8 @@ class Executor::Impl { protected: int device_id_; std::string name_; + bool loop_mode_; + gpuStream stream_raw_; int rank_; @@ -203,8 +205,9 @@ class Executor::Impl { rank_to_sm_channels_; }; -Executor::Impl::Impl(int device_id, Stream stream, const std::string &name) - : device_id_(device_id), name_(name) { +Executor::Impl::Impl(int device_id, Stream stream, const std::string &name, + bool loop_mode) + : device_id_(device_id), name_(name), loop_mode_(loop_mode) { if (device_id < 0) { ERR(InvalidUsageError, "Invalid device ID ", device_id); } @@ -251,7 +254,6 @@ void Executor::Impl::init(const PlanJson &plan_json) { int threads_per_block = static_cast( codegen_->num_warps_per_proc() * gpu_manager->info().threads_per_warp); int num_sm = static_cast(codegen_->num_procs()); - int *flag = flag_->ref(); size_t smem_block_total = static_cast(gpu_manager->info().smem_block_total); @@ -260,11 +262,19 @@ void Executor::Impl::init(const PlanJson &plan_json) { init_channels(remote_ranks); } + std::string kernel_name; + if (loop_mode_) { + kernel_name = "ark_loop_kernel"; + } else { + kernel_name = "ark_kernel"; + } + if (!name_.empty()) { + kernel_name += "_" + name_; + } + kernel_ = std::shared_ptr(new GpuKernel( device_id_, codegen_->code(), {threads_per_block, 1, 1}, {num_sm, 1, 1}, - std::max(smem_block_total, size_t(4)), name_, - {std::pair{buffer_->ref(), sizeof(buffer_->ref())}, - std::pair{flag, sizeof(flag)}})); + std::max(smem_block_total, size_t(4)), kernel_name)); } void Executor::Impl::init_communicator() { @@ -669,51 +679,76 @@ void Executor::Impl::launch(int64_t max_spin_count) { proxy_service_->startProxy(); } - // Initialize loop flags. - atomicStoreRelaxed(flag_->ref(), 0); - kernel_->launch(stream_raw_); - timer_end_->record(stream_raw_); + if (loop_mode_) { + // Initialize loop flags. + atomicStoreRelaxed(flag_->ref(), 0); + void *buf_ptr = buffer_->ref(); + void *flag_ptr = flag_->ref(); + std::vector args = {&buf_ptr, &flag_ptr}; + kernel_->launch(stream_raw_, args); + } is_recording_ = true; is_launched_ = true; } void Executor::Impl::run(int iter) { - if (iter > 0) { + if (iter <= 0) return; + if (loop_mode_) { while (atomicLoadRelaxed(flag_->ref()) > 0) { } atomicStoreRelaxed(flag_->ref(), iter); + } else { + void *buf_ptr = buffer_->ref(); + int i = 0; + std::vector args = {&buf_ptr, reinterpret_cast(&i)}; + for (; i < iter; i++) { + kernel_->launch(stream_raw_, args); + } } } void Executor::Impl::wait(int64_t max_spin_count) { int64_t cnt = max_spin_count; - while (atomicLoadRelaxed(flag_->ref()) > 0) { - if (cnt-- > 0) { - continue; - } - // Check if the kernel encountered an error. - gpuError res = gpuStreamQuery(stream_raw_); - if (res == gpuSuccess) { - if (atomicLoadRelaxed(flag_->ref()) > 0) { - LOG(WARN, "Stream is finished but the loop flag is still set."); - break; + if (loop_mode_) { + while (atomicLoadRelaxed(flag_->ref()) > 0) { + if (cnt-- > 0) { + continue; + } + // Check if the kernel encountered an error. + gpuError res = gpuStreamQuery(stream_raw_); + if (res == gpuSuccess) { + if (atomicLoadRelaxed(flag_->ref()) > 0) { + LOG(WARN, + "Stream is finished but the loop flag is still set."); + break; + } else { + LOG(WARN, + "wait() is delayed by a stream query. Regarding " + "timing measurements may be inaccurate."); + break; + } + } else if (res == gpuErrorNotReady) { + cnt = max_spin_count; } else { - LOG(WARN, - "wait() is delayed by a stream query. Regarding " - "timing measurements may be inaccurate."); - break; + GLOG(res); } - } else if (res == gpuErrorNotReady) { - cnt = max_spin_count; - } else { - GLOG(res); } + } else { + if (max_spin_count >= 0) { + LOG(WARN, "max_spin_count is ignored in non-loop mode."); + } + GLOG(gpuStreamSynchronize(stream_raw_)); } } float Executor::Impl::stop(int64_t max_spin_count) { this->wait(max_spin_count); - atomicStoreRelaxed(flag_->ref(), -1); + if (is_recording_) { + timer_end_->record(stream_raw_); + } + if (loop_mode_) { + atomicStoreRelaxed(flag_->ref(), -1); + } GLOG(gpuStreamSynchronize(stream_raw_)); if (is_recording_) { elapsed_msec_ = timer_end_->elapsed_msec(*timer_begin_); @@ -847,8 +882,9 @@ void Executor::Impl::tensor_write(const Tensor tensor, const void *data, } Executor::Executor(int device_id, Stream stream, const std::string &name, - const std::string &plan) - : impl_(std::make_unique(device_id, stream, name)) { + const std::string &plan, bool loop_mode) + : impl_(std::make_unique(device_id, stream, name, + loop_mode)) { auto &plan_path = get_env().enforce_plan_path; if (!plan_path.empty()) { LOG(INFO, "Enforce executor plan path: ", plan_path); @@ -901,10 +937,10 @@ void Executor::tensor_write(const Tensor tensor, const void *data, size_t bytes, DefaultExecutor::DefaultExecutor( const Model &model, int device_id, Stream stream, const std::vector &config_rules, - const std::string &name) + const std::string &name, bool loop_mode) : Executor((device_id < 0) ? (model.rank() % get_env().num_ranks_per_host) : device_id, - stream, name, "") { + stream, name, "", loop_mode) { DefaultPlanner planner(model, impl_->device_id()); for (const auto &rule : config_rules) { planner.install_config_rule(rule); diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index 5c9d09f2e..d7fdbf807 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -6,7 +6,7 @@ #include "ark/model.hpp" #include "env.h" #include "file_io.h" -#include "gpu/gpu_manager.h" +#include "gpu/gpu_manager.hpp" #include "model/model_json.hpp" #include "model/model_node.hpp" #include "model/model_op.hpp" diff --git a/ark/codegen.cpp b/ark/codegen.cpp index cd6206284..02a5d9ad9 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -174,7 +174,7 @@ CodeGenerator::Impl::Impl(const PlanJson &plan, {"@NUM_WARPS_PER_BLOCK@", std::to_string(num_warps_per_proc_)}, {"@DEFINITIONS@", definitions_ss.str()}, {"@BODY@", body_ss.str()}, - {"@NAME@", name_}, + {"@NAME@", (name_.empty() ? "" : "_" + name_)}, }; code_ = replace(template_code, replacements); } diff --git a/ark/gpu/gpu.h b/ark/gpu/gpu.hpp similarity index 98% rename from ark/gpu/gpu.h rename to ark/gpu/gpu.hpp index 2f1eba3ba..531d6c7ee 100644 --- a/ark/gpu/gpu.h +++ b/ark/gpu/gpu.hpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#ifndef ARK_GPU_H_ -#define ARK_GPU_H_ +#ifndef ARK_GPU_HPP_ +#define ARK_GPU_HPP_ #include @@ -125,6 +125,7 @@ ARK_GPU_DEFINE_CONSTANT_ALIAS(gpuPointerAttributeSyncMemops, // runtime API ARK_GPU_DEFINE_FUNC_ALIAS(gpuGetErrorString, cudaGetErrorString, hipGetErrorString); +ARK_GPU_DEFINE_FUNC_ALIAS(gpuGetLastError, cudaGetLastError, hipGetLastError); ARK_GPU_DEFINE_FUNC_ALIAS(gpuDeviceGetAttribute, cudaDeviceGetAttribute, hipDeviceGetAttribute); ARK_GPU_DEFINE_FUNC_ALIAS(gpuDeviceSynchronize, cudaDeviceSynchronize, @@ -183,4 +184,4 @@ ARK_GPU_DEFINE_FUNC_ALIAS(gpuPointerSetAttribute, cuPointerSetAttribute, } // namespace ark -#endif // ARK_GPU_H_ +#endif // ARK_GPU_HPP_ diff --git a/ark/gpu/gpu_compile.cpp b/ark/gpu/gpu_compile.cpp index b1c078af4..11e172f07 100644 --- a/ark/gpu/gpu_compile.cpp +++ b/ark/gpu/gpu_compile.cpp @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "gpu/gpu_compile.h" +#include "gpu/gpu_compile.hpp" #include #include @@ -22,7 +22,7 @@ #include "cpu_timer.h" #include "env.h" #include "file_io.h" -#include "gpu/gpu_logging.h" +#include "gpu/gpu_logging.hpp" #include "utils/utils_string.hpp" #define ARK_DEBUG_KERNEL 0 diff --git a/ark/gpu/gpu_compile.h b/ark/gpu/gpu_compile.hpp similarity index 78% rename from ark/gpu/gpu_compile.h rename to ark/gpu/gpu_compile.hpp index 58048e78c..8b9e1a9fd 100644 --- a/ark/gpu/gpu_compile.h +++ b/ark/gpu/gpu_compile.hpp @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#ifndef ARK_GPU_COMPILE_H_ -#define ARK_GPU_COMPILE_H_ +#ifndef ARK_GPU_COMPILE_HPP_ +#define ARK_GPU_COMPILE_HPP_ #include #include @@ -16,4 +16,4 @@ const std::string gpu_compile(const std::vector &codes, } // namespace ark -#endif // ARK_GPU_COMPILE_H_ +#endif // ARK_GPU_COMPILE_HPP_ diff --git a/ark/gpu/gpu_event.cpp b/ark/gpu/gpu_event.cpp index cbc45d9a6..06779b91a 100644 --- a/ark/gpu/gpu_event.cpp +++ b/ark/gpu/gpu_event.cpp @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "gpu/gpu_event.h" +#include "gpu/gpu_event.hpp" -#include "gpu/gpu_logging.h" -#include "gpu/gpu_manager.h" +#include "gpu/gpu_logging.hpp" +#include "gpu/gpu_manager.hpp" namespace ark { class GpuEvent::Impl { diff --git a/ark/gpu/gpu_event.h b/ark/gpu/gpu_event.hpp similarity index 84% rename from ark/gpu/gpu_event.h rename to ark/gpu/gpu_event.hpp index 081f0203b..bd2a7c952 100644 --- a/ark/gpu/gpu_event.h +++ b/ark/gpu/gpu_event.hpp @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#ifndef ARK_GPU_EVENT_H_ -#define ARK_GPU_EVENT_H_ +#ifndef ARK_GPU_EVENT_HPP_ +#define ARK_GPU_EVENT_HPP_ #include -#include "gpu/gpu.h" +#include "gpu/gpu.hpp" namespace ark { @@ -33,4 +33,4 @@ class GpuEvent { }; } // namespace ark -#endif // ARK_GPU_EVENT_H_ +#endif // ARK_GPU_EVENT_HPP_ diff --git a/ark/gpu/gpu_kernel.cpp b/ark/gpu/gpu_kernel.cpp index 46f467f51..d4412f80e 100644 --- a/ark/gpu/gpu_kernel.cpp +++ b/ark/gpu/gpu_kernel.cpp @@ -1,50 +1,38 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "gpu_kernel.h" +#include "gpu_kernel.hpp" #include #include -#include "gpu.h" -#include "gpu_compile.h" -#include "gpu_logging.h" -#include "gpu_manager.h" +#include "gpu.hpp" +#include "gpu_compile.hpp" +#include "gpu_logging.hpp" +#include "gpu_manager.hpp" namespace ark { GpuKernel::GpuKernel(int gpu_id, const std::string& code, const std::array& block_dim, const std::array& grid_dim, size_t smem_bytes, - const std::string& kernel_name, - std::initializer_list> args) { - this->init(gpu_id, code, block_dim, grid_dim, smem_bytes, kernel_name, - args); + const std::string& kernel_name) { + this->init(gpu_id, code, block_dim, grid_dim, smem_bytes, kernel_name); } void GpuKernel::init(int gpu_id, const std::string& code, const std::array& block_dim, const std::array& grid_dim, size_t smem_bytes, - const std::string& kernel_name, - std::initializer_list> args) { + const std::string& kernel_name) { gpu_manager_ = GpuManager::get_instance(gpu_id); code_ = code; block_dim_ = block_dim; grid_dim_ = grid_dim; smem_bytes_ = smem_bytes; kernel_name_ = kernel_name; - params_ptr_.resize(args.size()); - args_.resize(args.size()); if (kernel_name_.size() == 0) { ERR(InvalidUsageError, "Invalid kernel name: ", kernel_name_); } - size_t idx = 0; - for (auto& pair : args) { - args_[idx].reset(new uint8_t[pair.second]); - std::memcpy(args_[idx].get(), &(pair.first), pair.second); - params_ptr_[idx] = static_cast(args_[idx].get()); - idx++; - } } void GpuKernel::compile() { @@ -68,12 +56,13 @@ void GpuKernel::compile() { dynamic_smem_size_bytes)); } -void GpuKernel::launch(gpuStream stream) { +void GpuKernel::launch(gpuStream stream, std::vector& args) { if (!this->is_compiled()) { ERR(InvalidUsageError, "Kernel is not compiled yet."); } gpu_manager_->launch(function_, grid_dim_, block_dim_, smem_bytes_, stream, - params_ptr_.data(), nullptr); + args.data(), nullptr); + GLOG(gpuGetLastError()); } gpuDeviceptr GpuKernel::get_global(const std::string& name, diff --git a/ark/gpu/gpu_kernel.h b/ark/gpu/gpu_kernel.hpp similarity index 68% rename from ark/gpu/gpu_kernel.h rename to ark/gpu/gpu_kernel.hpp index b3be79071..5308cfead 100644 --- a/ark/gpu/gpu_kernel.h +++ b/ark/gpu/gpu_kernel.hpp @@ -1,13 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#ifndef ARK_GPU_KERNEL_H_ -#define ARK_GPU_KERNEL_H_ +#ifndef ARK_GPU_KERNEL_HPP_ +#define ARK_GPU_KERNEL_HPP_ #include #include +#include -#include "gpu_stream.h" +#include "gpu_stream.hpp" namespace ark { @@ -18,16 +19,14 @@ class GpuKernel { GpuKernel(int gpu_id, const std::string& codes, const std::array& block_dim, const std::array& grid_dim, size_t smem_bytes, - const std::string& kernel_name, - std::initializer_list> args = {}); + const std::string& kernel_name); void init(int gpu_id, const std::string& codes, const std::array& block_dim, const std::array& grid_dim, size_t smem_bytes, - const std::string& kernel_name, - std::initializer_list> args = {}); + const std::string& kernel_name); void compile(); - void launch(gpuStream stream); + void launch(gpuStream stream, std::vector& args); gpuDeviceptr get_global(const std::string& name, bool ignore_not_found = false) const; @@ -43,10 +42,8 @@ class GpuKernel { std::string bin_; gpuModule module_; gpuFunction function_ = nullptr; - std::vector params_ptr_; - std::vector> args_; }; } // namespace ark -#endif // ARK_GPU_KERNEL_H_ +#endif // ARK_GPU_KERNEL_HPP_ diff --git a/ark/gpu/gpu_kernel_test.cpp b/ark/gpu/gpu_kernel_test.cpp index 870ad7ab9..342ef9656 100644 --- a/ark/gpu/gpu_kernel_test.cpp +++ b/ark/gpu/gpu_kernel_test.cpp @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "gpu/gpu_kernel.h" +#include "gpu/gpu_kernel.hpp" #include "unittest/unittest_utils.h" @@ -9,7 +9,13 @@ const std::string void_kernel = "extern \"C\" __global__ void kernel() {}"; ark::unittest::State test_gpu_kernel() { ark::GpuKernel kernel(0, void_kernel, {1, 1, 1}, {1, 1, 1}, 0, "kernel"); + UNITTEST_TRUE(!kernel.is_compiled()); kernel.compile(); + UNITTEST_TRUE(kernel.is_compiled()); + std::vector args; + for (int i = 0; i < 10; i++) { + kernel.launch(nullptr, args); + } return ark::unittest::SUCCESS; } diff --git a/ark/gpu/gpu_logging.h b/ark/gpu/gpu_logging.hpp similarity index 92% rename from ark/gpu/gpu_logging.h rename to ark/gpu/gpu_logging.hpp index b14435b8b..5e35cc003 100644 --- a/ark/gpu/gpu_logging.h +++ b/ark/gpu/gpu_logging.hpp @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#ifndef ARK_GPU_LOGGING_H_ -#define ARK_GPU_LOGGING_H_ +#ifndef ARK_GPU_LOGGING_HPP_ +#define ARK_GPU_LOGGING_HPP_ -#include "gpu/gpu.h" +#include "gpu/gpu.hpp" #include "logging.h" #define GLOG(cmd) \ @@ -29,4 +29,4 @@ } \ } while (0) -#endif // ARK_GPU_LOGGING_H_ +#endif // ARK_GPU_LOGGING_HPP_ diff --git a/ark/gpu/gpu_manager.cpp b/ark/gpu/gpu_manager.cpp index fc841fa32..572932e35 100644 --- a/ark/gpu/gpu_manager.cpp +++ b/ark/gpu/gpu_manager.cpp @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "gpu/gpu_manager.h" +#include "gpu/gpu_manager.hpp" #include -#include "gpu/gpu_logging.h" +#include "gpu/gpu_logging.hpp" #include "utils/utils_string.hpp" namespace ark { diff --git a/ark/gpu/gpu_manager.h b/ark/gpu/gpu_manager.hpp similarity index 88% rename from ark/gpu/gpu_manager.h rename to ark/gpu/gpu_manager.hpp index 93a48cf7b..eeeda4d94 100644 --- a/ark/gpu/gpu_manager.h +++ b/ark/gpu/gpu_manager.hpp @@ -1,16 +1,16 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#ifndef ARK_GPU_MANAGER_H_ -#define ARK_GPU_MANAGER_H_ +#ifndef ARK_GPU_MANAGER_HPP_ +#define ARK_GPU_MANAGER_HPP_ #include #include "arch.hpp" -#include "gpu/gpu.h" -#include "gpu/gpu_event.h" -#include "gpu/gpu_memory.h" -#include "gpu/gpu_stream.h" +#include "gpu/gpu.hpp" +#include "gpu/gpu_event.hpp" +#include "gpu/gpu_memory.hpp" +#include "gpu/gpu_stream.hpp" namespace ark { @@ -62,4 +62,4 @@ class GpuManager { } // namespace ark -#endif // ARK_GPU_MANAGER_H_ +#endif // ARK_GPU_MANAGER_HPP_ diff --git a/ark/gpu/gpu_memory.cpp b/ark/gpu/gpu_memory.cpp index 184db457c..9a854f521 100644 --- a/ark/gpu/gpu_memory.cpp +++ b/ark/gpu/gpu_memory.cpp @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "gpu/gpu_memory.h" +#include "gpu/gpu_memory.hpp" -#include "gpu/gpu.h" -#include "gpu/gpu_logging.h" -#include "gpu/gpu_manager.h" +#include "gpu/gpu.hpp" +#include "gpu/gpu_logging.hpp" +#include "gpu/gpu_manager.hpp" namespace ark { diff --git a/ark/gpu/gpu_memory.h b/ark/gpu/gpu_memory.hpp similarity index 87% rename from ark/gpu/gpu_memory.h rename to ark/gpu/gpu_memory.hpp index cd7a6f04f..6b277d40b 100644 --- a/ark/gpu/gpu_memory.h +++ b/ark/gpu/gpu_memory.hpp @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#ifndef ARK_GPU_MEMORY_H_ -#define ARK_GPU_MEMORY_H_ +#ifndef ARK_GPU_MEMORY_HPP_ +#define ARK_GPU_MEMORY_HPP_ #include #include -#include "gpu/gpu.h" +#include "gpu/gpu.hpp" namespace ark { @@ -40,7 +40,7 @@ class GpuHostMemory { GpuHostMemory(const GpuHostMemory&) = delete; GpuHostMemory& operator=(const GpuHostMemory&) = delete; - template + template T* ref() const { return reinterpret_cast(ptr_); } @@ -54,4 +54,4 @@ class GpuHostMemory { } // namespace ark -#endif // ARK_GPU_MEMORY_H_ +#endif // ARK_GPU_MEMORY_HPP_ diff --git a/ark/gpu/gpu_stream.cpp b/ark/gpu/gpu_stream.cpp index 52502365a..17d4e21f5 100644 --- a/ark/gpu/gpu_stream.cpp +++ b/ark/gpu/gpu_stream.cpp @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "gpu/gpu_stream.h" +#include "gpu/gpu_stream.hpp" -#include "gpu/gpu_logging.h" -#include "gpu/gpu_manager.h" +#include "gpu/gpu_logging.hpp" +#include "gpu/gpu_manager.hpp" namespace ark { class GpuStream::Impl { diff --git a/ark/gpu/gpu_stream.h b/ark/gpu/gpu_stream.hpp similarity index 79% rename from ark/gpu/gpu_stream.h rename to ark/gpu/gpu_stream.hpp index e76f01827..9d8775f95 100644 --- a/ark/gpu/gpu_stream.h +++ b/ark/gpu/gpu_stream.hpp @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#ifndef ARK_GPU_STREAM_H_ -#define ARK_GPU_STREAM_H_ +#ifndef ARK_GPU_STREAM_HPP_ +#define ARK_GPU_STREAM_HPP_ #include -#include "gpu/gpu.h" +#include "gpu/gpu.hpp" namespace ark { @@ -30,4 +30,4 @@ class GpuStream { }; } // namespace ark -#endif // ARK_GPU_STREAM_H_ +#endif // ARK_GPU_STREAM_HPP_ diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index 75dc81c17..f0a108a1f 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -20,7 +20,7 @@ class Executor { public: /// Constructor. Executor(int device_id, Stream stream, const std::string &name, - const std::string &plan); + const std::string &plan, bool loop_mode = true); /// Destructor. ~Executor(); @@ -96,7 +96,7 @@ class DefaultExecutor : public Executor { DefaultExecutor( const Model &model, int device_id = -1, Stream stream = nullptr, const std::vector &config_rules = {}, - const std::string &name = "DefaultExecutor"); + const std::string &name = "DefaultExecutor", bool loop_mode = true); }; } // namespace ark diff --git a/ark/include/kernels/kernel_template.in b/ark/include/kernels/kernel_template.in index ea1862920..a8a56f141 100644 --- a/ark/include/kernels/kernel_template.in +++ b/ark/include/kernels/kernel_template.in @@ -33,12 +33,12 @@ __device__ sync::State ARK_LOOP_SYNC_STATE; @DEFINITIONS@ -__device__ void ark_loop_body(char *_buf, int _iter) { +__device__ void ark_body(char *_buf, int _iter) { @BODY@ } extern "C" __global__ __launch_bounds__(ARK_WARPS_PER_BLOCK * Arch::ThreadsPerWarp, 1) -void @NAME@(char *_buf, int *_iter) { +void ark_loop_kernel@NAME@(char *_buf, int *_iter) { int *shared_mem = (int *)_ARK_SMEM; for (int i = threadIdx.x; i < ARK_SMEM_RESERVED_BYTES / sizeof(int); i += blockDim.x) { shared_mem[i] = 0; @@ -52,10 +52,10 @@ void @NAME@(char *_buf, int *_iter) { sync_gpu<@NUM_BLOCKS@>(ARK_LOOP_SYNC_STATE); if (ARK_ITER < 0) return; - ark_loop_body(_buf, 0); + ark_body(_buf, 0); for (int _i = 1; _i < ARK_ITER; ++_i) { sync_gpu<@NUM_BLOCKS@>(ARK_LOOP_SYNC_STATE); - ark_loop_body(_buf, _i); + ark_body(_buf, _i); } if (threadIdx.x == 0) { __threadfence_system(); @@ -67,3 +67,12 @@ void @NAME@(char *_buf, int *_iter) { sync_gpu<@NUM_BLOCKS@>(ARK_LOOP_SYNC_STATE); } } + +extern "C" __global__ __launch_bounds__(ARK_WARPS_PER_BLOCK * Arch::ThreadsPerWarp, 1) +void ark_kernel@NAME@(char *_buf, int _iter) { + int *shared_mem = (int *)_ARK_SMEM; + for (int i = threadIdx.x; i < ARK_SMEM_RESERVED_BYTES / sizeof(int); i += blockDim.x) { + shared_mem[i] = 0; + } + ark_body(_buf, _iter); +} diff --git a/ark/ops/ops_matmul_test.cpp b/ark/ops/ops_matmul_test.cpp index 4304a19e2..6d09b54d6 100644 --- a/ark/ops/ops_matmul_test.cpp +++ b/ark/ops/ops_matmul_test.cpp @@ -3,7 +3,7 @@ #include -#include "gpu/gpu.h" +#include "gpu/gpu.hpp" #include "logging.h" #include "model/model_node.hpp" #include "model/model_op.hpp" diff --git a/ark/ops/ops_test_common.cpp b/ark/ops/ops_test_common.cpp index 60ffc9dc2..bec69c456 100644 --- a/ark/ops/ops_test_common.cpp +++ b/ark/ops/ops_test_common.cpp @@ -10,7 +10,7 @@ #include "ark/planner.hpp" #include "ark/random.hpp" #include "env.h" -#include "gpu/gpu_logging.h" +#include "gpu/gpu_logging.hpp" #include "logging.h" #include "model/model_data_type.hpp" #include "model/model_tensor.hpp" diff --git a/python/ark/runtime.py b/python/ark/runtime.py index 33db1fb5c..d54f85c36 100644 --- a/python/ark/runtime.py +++ b/python/ark/runtime.py @@ -101,12 +101,11 @@ def running(self) -> bool: def launch( self, - rank: int = 0, - world_size: int = 1, gpu_id: int = 0, plan: str = "", plan_path: str = "", stream: int = 0, + loop_mode: bool = True, ): """ Create an executor and schedule the ARK model. The scheduler will generate @@ -135,6 +134,7 @@ def launch( stream, "ArkRuntime", plan, + loop_mode, ) self.executor = _RuntimeState.executor self.executor.compile() diff --git a/python/executor_py.cpp b/python/executor_py.cpp index 979cb2952..e782a99fe 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -43,9 +43,11 @@ static void tensor_read(ark::Executor *exe, const ark::Tensor &tensor, void register_executor(py::module &m) { py::class_(m, "_Executor") .def(py::init([](int device_id, uintptr_t stream, - const std::string &name, const std::string &plan) { - return new ark::Executor( - device_id, reinterpret_cast(stream), name, plan); + const std::string &name, const std::string &plan, + bool loop_mode) { + return new ark::Executor(device_id, + reinterpret_cast(stream), + name, plan, loop_mode); })) .def("device_id", &ark::Executor::device_id) .def("stream", From 55755bbe2e2fbc36195f7786280689bde3170ec2 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Sun, 14 Jul 2024 14:19:35 -0700 Subject: [PATCH 37/54] do not force noinline --- ark/codegen.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ark/codegen.cpp b/ark/codegen.cpp index cd6206284..0d4b14a09 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -213,7 +213,7 @@ std::string CodeGenerator::Impl::def_task(const Json &task_json) { for (auto &op_json : task_json["Ops"]) { ss << this->def_op(op_json, task_json["Id"], op_idx++); } - ss << "__noinline__ __device__ void t" << task_json["Id"] + ss << "__device__ void t" << task_json["Id"] << "(char* _buf, int _idx, int _spw) {\n"; op_idx = 0; for (auto &op_json : task_json["Ops"]) { From b29eaaefb5b969a8e0ec8b8e3813e5e3245e7825 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Sun, 14 Jul 2024 21:25:20 +0000 Subject: [PATCH 38/54] wip --- arkprof.py | 4 +++- python/ark/profiler.py | 10 +++++----- python/ark/runtime.py | 11 +++++++++-- python/ark/tensor.py | 18 ++++++++++++------ python/unittest/unittest_common.py | 8 +++++++- 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/arkprof.py b/arkprof.py index 782bba560..9e67c2dfc 100644 --- a/arkprof.py +++ b/arkprof.py @@ -1,4 +1,6 @@ import ark import sys -ark.Profiler(ark.Plan.from_file(sys.argv[1])).run(iter=1000, profile_processor_groups=False) +ark.Profiler(ark.Plan.from_file(sys.argv[1])).run( + iter=1000, profile_processor_groups=False +) diff --git a/python/ark/profiler.py b/python/ark/profiler.py index 56233247c..c161b24e6 100644 --- a/python/ark/profiler.py +++ b/python/ark/profiler.py @@ -8,9 +8,9 @@ from .planner import Plan -def timeit(plan: Plan, iter: int): +def timeit(plan: Plan, iter: int, loop_mode: bool): with Runtime() as rt: - rt.launch(plan=plan) + rt.launch(plan=plan, loop_mode=loop_mode) start_time = time.time() rt.run(iter=iter) end_time = time.time() @@ -21,8 +21,8 @@ class Profiler: def __init__(self, plan: Plan): self.plan = plan - def run(self, iter: int = 1000, profile_processor_groups: bool = False): - sys.stderr.write(f"End-to-end: {timeit(self.plan, iter):.6f} seconds/iter\n") + def run(self, iter: int = 1000, loop_mode: bool = True, profile_processor_groups: bool = False): + sys.stderr.write(f"End-to-end: {timeit(self.plan, iter, loop_mode):.6f} seconds/iter\n") if not profile_processor_groups: return @@ -38,7 +38,7 @@ def run(self, iter: int = 1000, profile_processor_groups: bool = False): } for i in range(num_processor_groups): new_plan["ProcessorGroups"][0] = self.plan.processor_groups[i] - lat_per_iter = timeit(Plan(new_plan), iter) + lat_per_iter = timeit(Plan(new_plan), iter, loop_mode) sys.stderr.write( f"Processor group {i}: {lat_per_iter:.6f} seconds/iter\n" ) diff --git a/python/ark/runtime.py b/python/ark/runtime.py index b3dbe7887..51a5b7905 100644 --- a/python/ark/runtime.py +++ b/python/ark/runtime.py @@ -48,8 +48,15 @@ def print_runtime_states(): class Executor(_Executor): - def __init__(self, device_id: int, stream: int, name: str, plan: Plan): - super().__init__(device_id, stream, name, str(plan)) + def __init__( + self, + device_id: int, + stream: int, + name: str, + plan: Plan, + loop_mode: bool = True, + ): + super().__init__(device_id, stream, name, str(plan), loop_mode) class Runtime: diff --git a/python/ark/tensor.py b/python/ark/tensor.py index 335020769..657da1065 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -121,16 +121,22 @@ def to_torch( self.shape(), dtype=torch_type, device=torch.device(dev_name) ) elif list(tensor.shape) != self.shape(): - raise ValueError(f"torch tensor shape {list(tensor.shape)} " - f"does not match the tensor {self.shape()}") + raise ValueError( + f"torch tensor shape {list(tensor.shape)} " + f"does not match the tensor {self.shape()}" + ) elif tensor.dtype != torch_type: - raise ValueError(f"torch tensor dtype {tensor.dtype} " - f"does not match the tensor {torch_type}") + raise ValueError( + f"torch tensor dtype {tensor.dtype} " + f"does not match the tensor {torch_type}" + ) elif not tensor.is_contiguous(): raise ValueError("torch tensor is not contiguous in memory") elif tensor.numel() != self.nelems(): - raise ValueError(f"torch tensor size {tensor.numel()} " - f"does not match the tensor {self.nelems()}") + raise ValueError( + f"torch tensor size {tensor.numel()} " + f"does not match the tensor {self.nelems()}" + ) tensor_bytes = self.nelems() * self.dtype().element_size() rt.executor.tensor_read( self._tensor, tensor.data_ptr(), tensor_bytes, stream, True diff --git a/python/unittest/unittest_common.py b/python/unittest/unittest_common.py index 9548410b5..0c385e89a 100644 --- a/python/unittest/unittest_common.py +++ b/python/unittest/unittest_common.py @@ -9,14 +9,20 @@ def pytest_ark(need_torch: bool = False): """ Decorator for ARK unit tests. """ + def decorator(test_func): if need_torch: try: import torch except ImportError: - return pytest.mark.skip(reason="torch is not installed")(test_func) + return pytest.mark.skip(reason="torch is not installed")( + test_func + ) + def wrapper(*args, **kwargs): ark.init() test_func(*args, **kwargs) + return wrapper + return decorator From a7a5d46c001b143781022e2d28aaa3eee0c502b3 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Sun, 14 Jul 2024 23:56:21 +0000 Subject: [PATCH 39/54] Fix CK tile indexing --- third_party/patches/composable_kernel.patch | 89 +++++++++++++++++++-- 1 file changed, 83 insertions(+), 6 deletions(-) diff --git a/third_party/patches/composable_kernel.patch b/third_party/patches/composable_kernel.patch index 43b1afcaa..e12f19332 100644 --- a/third_party/patches/composable_kernel.patch +++ b/third_party/patches/composable_kernel.patch @@ -561,7 +561,7 @@ index 2d5dc90bf..160eef036 100644 }); diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp -index 7bb47e9d3..2b2e8c604 100644 +index 7bb47e9d3..d495c7297 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -60,7 +60,7 @@ struct BlockToCTileMap_M00_N0_M01 @@ -582,7 +582,84 @@ index 7bb47e9d3..2b2e8c604 100644 { return true; } -@@ -315,7 +315,7 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt +@@ -177,58 +177,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + +- const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; +- +- index_t idx_M00 = idx_M0 / M01_; +- index_t idx_M01 = idx_M0 % M01_; +- index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; +- +- /** +- * idxN0 +- * +- * |< mtx N >| +- * +- * NPerBlock NPerBlock NPerBlock NPerBlock +- * N_0 N_1 N_2 N_3 +- * - |-----------|-----------|-----------|-----|-----|- +- * ^ | - - 0 |/----> 2 | | | | +- * | | | / | | | | | M_0 MPerBlock +- * | M | /| | | | | | +- * |-0---|---/-|-----|-----|-----------|-----|-----|- +- * | 1 | / | | | blockid | | | +- * idxM0 | | | / | V | 5 | | | M_1 MPerBlock +- * | - V 1 | - 3 | | | | +- * |-----------|-----------|-----------|-----|-----|- +- * mtx M | | | | | | +- * | | | | | | M_2 MPerBlock +- * | | | | | | +- * |-----------|-----------|-----------|-----|-----|- +- * | | | | | | +- * | | | | | | M_3 MPerBlock +- * | | | | | | +- * |-----------|-----------|-----------|-----|-----|- +- * V | | | | | | +- * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock +- * | | | | | | +- * |-----------|-----------|-----------|-----|-----|- +- * Example: +- * assume: +- * M0 = 5 +- * N0 = 4 +- * block_1d_id = 5 +- * M01 = 2 +- * +- * idx_N0 = 1 +- * idx_M0 = 1 +- * M01_adapt = 2 +- * idx_M00 = 0 +- * idx_M01 = 1 +- * idx_N0_M01_local = 5 +- * output {1, 2} +- */ +- +- return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, +- idx_N0_M01_local / M01_adapt); ++ return make_tuple(idx_M0, idx_N0); + } + + template +@@ -297,15 +246,7 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + +- const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; +- +- index_t idx_M00 = idx_M0 / M01_; +- index_t idx_M01 = idx_M0 % M01_; +- index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; +- +- return make_tuple(idx_ksplit, +- idx_N0_M01_local % M01_adapt + idx_M00 * M01_, +- idx_N0_M01_local / M01_adapt); ++ return make_tuple(idx_ksplit, idx_M0, idx_N0); + } + + template +@@ -315,7 +256,7 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt return true; // always valid provided that user gets grid size from CalculateGridSize() } @@ -591,7 +668,7 @@ index 7bb47e9d3..2b2e8c604 100644 private: index_t M01_; -@@ -373,7 +373,7 @@ struct BlockToCTileMap_M00_N00_M01_N01 +@@ -373,7 +314,7 @@ struct BlockToCTileMap_M00_N00_M01_N01 return true; } @@ -600,7 +677,7 @@ index 7bb47e9d3..2b2e8c604 100644 { if constexpr(DeviceCTileIndexCheck) return true; // validity check moved to kernel -@@ -485,7 +485,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 +@@ -485,7 +426,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 return true; } @@ -609,7 +686,7 @@ index 7bb47e9d3..2b2e8c604 100644 { if constexpr(DeviceCTileIndexCheck) return true; // validity check moved to kernel -@@ -609,7 +609,7 @@ struct OffsettedBlockToCTileMap +@@ -609,7 +550,7 @@ struct OffsettedBlockToCTileMap } template @@ -618,7 +695,7 @@ index 7bb47e9d3..2b2e8c604 100644 { return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); } -@@ -666,7 +666,7 @@ struct BlockToCTileMap_3DGrid_KSplit +@@ -666,7 +607,7 @@ struct BlockToCTileMap_3DGrid_KSplit } template From 9c19a5ec8543863d159c96f05c007b63943c2566 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 29 Jul 2024 02:56:23 +0000 Subject: [PATCH 40/54] wip --- .vscode/settings.json | 2 - ark/api/context_manager.cpp | 42 +++++++++ ark/api/context_manager_test.cpp | 54 +++++++++++ ark/api/model.cpp | 4 +- ark/api/model_graph.cpp | 4 +- ark/api/model_test.cpp | 24 ++--- ark/api/planner.cpp | 4 +- ark/include/ark.hpp | 1 + ark/include/ark/context_manager.hpp | 24 +++++ ark/include/ark/model.hpp | 64 +++++++------ ark/include/ark/model_graph.hpp | 3 +- ark/model/model_graph_impl.cpp | 40 ++++++++- ark/model/model_graph_impl.hpp | 36 +++++++- ark/model/model_node.hpp | 3 + ark/model/model_op.cpp | 11 +++ ark/model/model_op.hpp | 9 +- ark/ops/ops_arithmetic.cpp | 20 +++-- ark/ops/ops_cast.cpp | 10 +-- ark/ops/ops_communication.cpp | 14 +-- ark/ops/ops_copy.cpp | 5 +- ark/ops/ops_embedding.cpp | 4 +- ark/ops/ops_identity.cpp | 2 +- ark/ops/ops_math.cpp | 31 ++++--- ark/ops/ops_matmul.cpp | 6 +- ark/ops/ops_noop.cpp | 2 +- ark/ops/ops_reduce.cpp | 12 +-- ark/ops/ops_refer.cpp | 2 +- ark/ops/ops_reshape.cpp | 4 +- ark/ops/ops_rope.cpp | 5 +- ark/ops/ops_scalar.cpp | 31 ++++--- ark/ops/ops_tensor.cpp | 2 +- ark/ops/ops_transpose.cpp | 5 +- arkprof.py | 1 + examples/tutorial/context_tutorial.py | 117 ++++++++++++++++++++++++ python/ark/__init__.py | 2 +- python/ark/context_manager.py | 24 +++++ python/ark/ops.py | 125 ++++++++++++++++++++------ python/ark/profiler.py | 11 ++- python/ark_py.cpp | 2 + python/context_manager_py.cpp | 15 ++++ python/model_py.cpp | 86 ++++++++++-------- 41 files changed, 676 insertions(+), 187 deletions(-) create mode 100644 ark/api/context_manager.cpp create mode 100644 ark/api/context_manager_test.cpp create mode 100644 ark/include/ark/context_manager.hpp create mode 100644 examples/tutorial/context_tutorial.py create mode 100644 python/ark/context_manager.py create mode 100644 python/context_manager_py.cpp diff --git a/.vscode/settings.json b/.vscode/settings.json index 640196a66..00260f078 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,8 +3,6 @@ "cmake.environment": { "ARK_ROOT": "${workspaceFolder}/build", "ARK_IGNORE_BINARY_CACHE": "1", - "ARK_DISABLE_GRAPH_OPT": "0", - "ARK_IPC_LISTEN_PORT_BASE": "42000", // "ARK_LOG_LEVEL": "DEBUG" }, "cmake.ctestArgs": [ diff --git a/ark/api/context_manager.cpp b/ark/api/context_manager.cpp new file mode 100644 index 000000000..6d16d9e79 --- /dev/null +++ b/ark/api/context_manager.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ark/context_manager.hpp" + +#include "model/model_graph_impl.hpp" + +namespace ark { + +class ContextManager::Impl { + public: + Impl(std::shared_ptr context_stack, + const std::map& context_map); + + ~Impl(); + + private: + std::shared_ptr context_stack_; + std::vector keys_; +}; + +ContextManager::Impl::Impl( + std::shared_ptr context_stack, + const std::map& context_map) + : context_stack_(context_stack) { + for (const auto& [key, value] : context_map) { + context_stack_->push(key, value); + keys_.push_back(key); + } +} + +ContextManager::Impl::~Impl() { + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + context_stack_->pop(*it); + } +} + +ContextManager::ContextManager( + Model& model, const std::map& context_map) + : impl_(std::make_shared(model.impl_->context_stack_, context_map)) {} + +} // namespace ark diff --git a/ark/api/context_manager_test.cpp b/ark/api/context_manager_test.cpp new file mode 100644 index 000000000..ff60b43bf --- /dev/null +++ b/ark/api/context_manager_test.cpp @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ark/model.hpp" +#include "ark/context_manager.hpp" + +#include "model/model_node.hpp" +#include "unittest/unittest_utils.h" + +ark::unittest::State test_context_manager() { + ark::Model model; + ark::Tensor t0 = model.tensor({1}, ark::FP32); + ark::Tensor t1 = model.tensor({1}, ark::FP32); + ark::Tensor t2 = model.add(t0, t1); + + ark::Tensor t3; + ark::Tensor t4; + ark::Tensor t5; + { + ark::ContextManager cm0_1(model, {{"key0", "val1"}}); + t3 = model.relu(t2); + + ark::ContextManager cm1_1(model, {{"key1", "val2"}}); + t4 = model.sqrt(t3); + } + { + ark::ContextManager cm0_2(model, {{"key0", "val3"}}); + t5 = model.exp(t2); + } + + UNITTEST_TRUE(model.verify()); + + auto compressed = model.compress(false); + UNITTEST_TRUE(compressed.verify()); + + auto nodes = compressed.nodes(); + UNITTEST_EQ(nodes.size(), 4); + + UNITTEST_EQ(nodes[0]->context.size(), 0); + UNITTEST_EQ(nodes[1]->context.size(), 1); + UNITTEST_EQ(nodes[1]->context.at("key0"), "val1"); + UNITTEST_EQ(nodes[2]->context.size(), 2); + UNITTEST_EQ(nodes[2]->context.at("key0"), "val1"); + UNITTEST_EQ(nodes[2]->context.at("key1"), "val2"); + UNITTEST_EQ(nodes[3]->context.size(), 1); + UNITTEST_EQ(nodes[3]->context.at("key0"), "val3"); + + return ark::unittest::SUCCESS; +} + +int main() { + UNITTEST(test_context_manager); + return 0; +} diff --git a/ark/api/model.cpp b/ark/api/model.cpp index ab536a33c..a5a258f71 100644 --- a/ark/api/model.cpp +++ b/ark/api/model.cpp @@ -9,9 +9,9 @@ namespace ark { -Model Model::compress() const { +Model Model::compress(bool merge_nodes) const { Model model(*this); - model.compress_nodes(); + model.compress_nodes(merge_nodes); return model; } diff --git a/ark/api/model_graph.cpp b/ark/api/model_graph.cpp index b6061a34e..d11808467 100644 --- a/ark/api/model_graph.cpp +++ b/ark/api/model_graph.cpp @@ -33,7 +33,9 @@ int ModelGraph::rank() const { return impl_->rank(); } int ModelGraph::world_size() const { return impl_->world_size(); } -void ModelGraph::compress_nodes() { impl_->compress_nodes(); } +void ModelGraph::compress_nodes(bool merge_nodes) { + impl_->compress_nodes(merge_nodes); +} bool ModelGraph::compressed() const { return impl_->compressed(); } diff --git a/ark/api/model_test.cpp b/ark/api/model_test.cpp index a9d332a97..785bfcd7b 100644 --- a/ark/api/model_test.cpp +++ b/ark/api/model_test.cpp @@ -36,7 +36,7 @@ ark::unittest::State test_model_basics() { // (AddOp,) // - compressed = model.compress(); + compressed = model.compress(true); UNITTEST_TRUE(compressed.verify()); UNITTEST_TRUE(compressed.compressed()); UNITTEST_EQ(compressed.nodes().size(), 1); @@ -70,7 +70,7 @@ ark::unittest::State test_model_basics() { // (AddOp,AddOp,) // - compressed = model.compress(); + compressed = model.compress(true); UNITTEST_TRUE(compressed.verify()); UNITTEST_EQ(compressed.nodes().size(), 1); @@ -104,7 +104,7 @@ ark::unittest::State test_model_basics() { // (AddOp,AddOp,ReluOp,) // - compressed = model.compress(); + compressed = model.compress(true); UNITTEST_TRUE(compressed.verify()); UNITTEST_EQ(compressed.nodes().size(), 1); @@ -143,7 +143,7 @@ ark::unittest::State test_model_basics() { // (AddOp,AddOp,ReluOp,AddOp,) // - compressed = model.compress(); + compressed = model.compress(true); UNITTEST_TRUE(compressed.verify()); auto nodes = compressed.nodes(); @@ -190,7 +190,7 @@ ark::unittest::State test_model_basics() { // (AddOp,) --+--> (AddOp,) // - compressed = model.compress(); + compressed = model.compress(true); UNITTEST_TRUE(compressed.verify()); nodes = compressed.nodes(); @@ -250,7 +250,7 @@ ark::unittest::State test_model_basics() { // (AddOp,) // - compressed = model.compress(); + compressed = model.compress(true); UNITTEST_TRUE(compressed.verify()); nodes = compressed.nodes(); @@ -312,7 +312,7 @@ ark::unittest::State test_model_basics() { // (AddOp,) // - compressed = model.compress(); + compressed = model.compress(true); UNITTEST_TRUE(compressed.verify()); nodes = compressed.nodes(); @@ -353,7 +353,7 @@ ark::unittest::State test_model_dependent_inputs() { ark::Tensor x4 = m.mul(x2, x3); ark::Tensor y = m.add(x0, x4); - auto compressed = m.compress(); + auto compressed = m.compress(true); auto nodes = compressed.nodes(); UNITTEST_EQ(nodes.size(), 4); auto nodes_iter = nodes.begin(); @@ -399,7 +399,7 @@ ark::unittest::State test_model_noop() { UNITTEST_TRUE(model.verify()); - auto compressed = model.compress(); + auto compressed = model.compress(true); UNITTEST_TRUE(compressed.verify()); UNITTEST_EQ(compressed.nodes().size(), 0); return ark::unittest::SUCCESS; @@ -425,7 +425,7 @@ ark::unittest::State test_model_identity() { ark::Tensor t4 = model.relu(t3); UNITTEST_TRUE(model.verify()); - auto compressed = model.compress(); + auto compressed = model.compress(true); UNITTEST_TRUE(compressed.verify()); auto nodes = compressed.nodes(); UNITTEST_EQ(nodes.size(), 3); @@ -478,7 +478,7 @@ ark::unittest::State test_model_sharding() { ark::Tensor t5 = model.relu(t4); UNITTEST_TRUE(model.verify()); - auto compressed = model.compress(); + auto compressed = model.compress(true); UNITTEST_TRUE(compressed.verify()); auto nodes = compressed.nodes(); UNITTEST_EQ(nodes.size(), 4); @@ -526,7 +526,7 @@ ark::unittest::State test_model_cumulate() { UNITTEST_TRUE(model.verify()); - auto compressed = model.compress(); + auto compressed = model.compress(true); auto nodes = compressed.nodes(); UNITTEST_EQ(nodes.size(), 5); diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index f4e7fa8ee..dba149a1e 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -69,7 +69,9 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const { task_info["Id"] = next_node_id++; Json config; - if (!config_rules_.empty()) { + if (!op->config().empty()) { + config = op->config(); + } else if (!config_rules_.empty()) { const std::string op_str = op->serialize().dump(); for (auto &rule : config_rules_) { auto config_str = rule(op_str, gpu_info.arch->name()); diff --git a/ark/include/ark.hpp b/ark/include/ark.hpp index 2ca796172..e76687bce 100644 --- a/ark/include/ark.hpp +++ b/ark/include/ark.hpp @@ -8,6 +8,7 @@ #include // clang-format on +#include #include #include #include diff --git a/ark/include/ark/context_manager.hpp b/ark/include/ark/context_manager.hpp new file mode 100644 index 000000000..58271ea8c --- /dev/null +++ b/ark/include/ark/context_manager.hpp @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_CONTEXT_MANAGER_HPP +#define ARK_CONTEXT_MANAGER_HPP + +#include +#include + +namespace ark { + +class ContextManager { + public: + ContextManager(Model& model, + const std::map& context_map); + + private: + class Impl; + std::shared_ptr impl_; +}; + +} // namespace ark + +#endif // ARK_CONTEXT_MANAGER_HPP diff --git a/ark/include/ark/model.hpp b/ark/include/ark/model.hpp index 66551a037..35efe53d5 100644 --- a/ark/include/ark/model.hpp +++ b/ark/include/ark/model.hpp @@ -26,7 +26,7 @@ class Model : public ModelGraph { Model &operator=(const Model &other) = default; - Model compress() const; + Model compress(bool merge_nodes = false) const; int unique_tag(); @@ -87,23 +87,29 @@ class Model : public ModelGraph { // result in `output`. // Currently, only reduction along the last dimension is supported. Tensor reduce_sum(Tensor input, int axis, bool keepdims = true, - Tensor output = NullTensor, const std::string &name = ""); + Tensor output = NullTensor, + const std::string &config = "", + const std::string &name = ""); Tensor reduce_mean(Tensor input, int axis, bool keepdims = true, Tensor output = NullTensor, + const std::string &config = "", const std::string &name = ""); Tensor reduce_max(Tensor input, int axis, bool keepdims = true, - Tensor output = NullTensor, const std::string &name = ""); + Tensor output = NullTensor, + const std::string &config = "", + const std::string &name = ""); // Transposes the `input` tensor according to the given `permutation`. // For example, transpose(input, {0, 1 ,3, 2}) will swap the last two // dimensions of the input tensor. Currently, only 4D tensors are supported. Tensor transpose(Tensor input, const std::vector &permutation, - Tensor output = NullTensor, const std::string &name = ""); + Tensor output = NullTensor, const std::string &config = "", + const std::string &name = ""); // Performs matrix multiplication between the `input` tensor and another // `other` tensor, storing the result in `output`. Tensor matmul(Tensor input, Tensor other, Tensor output = NullTensor, bool trans_input = false, bool trans_other = false, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); // Implements the 'im2col' method for 2D convolution layers, which takes an // `input` tensor and reshapes it to a 2D matrix by extracting image patches // from the input tensor based on the provided parameters. @@ -120,72 +126,76 @@ class Model : public ModelGraph { Tensor output = NullTensor, const std::string &name = ""); // Calculates the exponential of the `input` tensor, element-wise. Tensor exp(Tensor input, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); // Calculates the square root of the `input` tensor, element-wise. Tensor sqrt(Tensor input, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); // Calculates the reverse square root of the `input` tensor, element-wise. Tensor rsqrt(Tensor input, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); // ReLU activation Tensor relu(Tensor input, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); // Copy the `input` tensor to `output` tensor Tensor copy(Tensor input, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); Tensor copy(float val, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); // Applies the Gaussian Error Linear Unit (GELU) activation function to the // `input` tensor, element-wise. GELU is a smooth approximation of the // rectifier function and is widely used in deep learning models. Tensor gelu(Tensor input, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); // Sigmoid activation Tensor sigmoid(Tensor input, Tensor output = NullTensor, + const std::string &config = "", const std::string &name = ""); // Performs rotary position embedding (RoPE) on the `input` tensor Tensor rope(Tensor input, Tensor other, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); // Performs an element-wise addition operator between the `input` tensor // and the `other` tensor Tensor add(Tensor input, Tensor other, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); Tensor add(Tensor input, float value, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); // Performs an element-wise subtraction operator between the `input` tensor // and the `other` tensor Tensor sub(Tensor input, Tensor other, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); Tensor sub(Tensor input, float value, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); // Performs an element-wise multiplication operator between the `input` // tensor and the `other` tensor, Tensor mul(Tensor input, Tensor other, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); Tensor mul(Tensor input, float value, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); // Performs an element-wise division operator between the `input` // tensor and the `other` tensor, Tensor div(Tensor input, Tensor other, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); Tensor div(Tensor input, float value, Tensor output = NullTensor, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); Tensor send(Tensor input, int remote_rank, int tag, - Tensor output = NullTensor, const std::string &name = ""); + Tensor output = NullTensor, const std::string &config = "", + const std::string &name = ""); // Blocks the execution until the corresponding 'send' operator with the // specified `id` is completed. - Tensor send_done(Tensor input, const std::string &name = ""); + Tensor send_done(Tensor input, const std::string &config = "", + const std::string &name = ""); // Receives a tensor from a source rank (@p src_rank), identified by the // `id` parameter. Blocks the execution until the corresponding 'recv' // operator is completed. Tensor recv(Tensor output, int remote_rank, int tag, - const std::string &name = ""); + const std::string &config = "", const std::string &name = ""); // Tensor put_packet(Tensor input, Tensor local_tmp_buf, Tensor recv_buf, int id, int rank, int dst_rank, size_t dst_offset, - int flag, const std::string &name = ""); + int flag, const std::string &config = "", + const std::string &name = ""); // Performs an all-reduce operator across all ranks, aggregating the input // tensors. Takes the `input` tensor, the current GPU's rank, and the // total number of ranks `rank_num`. @@ -200,10 +210,12 @@ class Model : public ModelGraph { const std::string &name = ""); /// Embedding layer. Tensor embedding(Tensor input, Tensor weight, Tensor output = NullTensor, + const std::string &config = "", const std::string &name = ""); /// Tensor type casting. Tensor cast(Tensor input, const DataType &data_type, - Tensor output = NullTensor, const std::string &name = ""); + Tensor output = NullTensor, const std::string &config = "", + const std::string &name = ""); // sync across multi devices Tensor device_sync(Tensor input, int npeers, const std::string &name = ""); diff --git a/ark/include/ark/model_graph.hpp b/ark/include/ark/model_graph.hpp index bd7c59033..f6390a2a9 100644 --- a/ark/include/ark/model_graph.hpp +++ b/ark/include/ark/model_graph.hpp @@ -25,7 +25,7 @@ class ModelGraph { int world_size() const; - void compress_nodes(); + void compress_nodes(bool merge_nodes = false); bool compressed() const; @@ -38,6 +38,7 @@ class ModelGraph { protected: friend class Model; + friend class ContextManager; class Impl; std::unique_ptr impl_; diff --git a/ark/model/model_graph_impl.cpp b/ark/model/model_graph_impl.cpp index 17410d23f..53a7fa851 100644 --- a/ark/model/model_graph_impl.cpp +++ b/ark/model/model_graph_impl.cpp @@ -17,6 +17,36 @@ namespace ark { +ModelGraphContextStack::ModelGraphContextStack(const ModelGraphContextStack &other) { + for (const auto &pair : other.storage_) { + for (const auto &value : pair.second) { + this->storage_[pair.first].push_back(value); + } + } +} + +void ModelGraphContextStack::push(const std::string &key, const std::string &value) { + this->storage_[key].push_back(std::make_shared(value)); +} + +void ModelGraphContextStack::pop(const std::string &key) { + auto it = this->storage_.find(key); + if (it == this->storage_.end() || it->second.empty()) { + ERR(ModelError, "context stack is empty"); + } + it->second.pop_back(); +} + +std::map ModelGraphContextStack::current_context() const { + std::map cur; + for (const auto &pair : this->storage_) { + if (!pair.second.empty()) { + cur[pair.first] = *pair.second.back(); + } + } + return cur; +} + ModelGraph::Impl::Impl(const ModelGraph::Impl &other) { *this = other; } ModelGraph::Impl &ModelGraph::Impl::operator=(const ModelGraph::Impl &other) { @@ -25,6 +55,7 @@ ModelGraph::Impl &ModelGraph::Impl::operator=(const ModelGraph::Impl &other) { for (const auto &node : other.nodes_) { ModelNodeRef new_node = std::make_shared(); new_node->ops = node->ops; + new_node->context = node->context; node_map.emplace(node, new_node); nodes_.push_back(new_node); } @@ -61,13 +92,16 @@ ModelGraph::Impl &ModelGraph::Impl::operator=(const ModelGraph::Impl &other) { rank_ = other.rank_; world_size_ = other.world_size_; compressed_ = other.compressed_; + context_stack_ = std::make_shared(*(other.context_stack_)); return *this; } -void ModelGraph::Impl::compress_nodes() { +void ModelGraph::Impl::compress_nodes(bool merge_nodes) { if (!compressed_) { this->recursive_remove_virtual_nodes(); - this->recursive_merge_nodes(); + if (merge_nodes) { + this->recursive_merge_nodes(); + } compressed_ = true; } } @@ -171,6 +205,8 @@ ModelNodeRef ModelGraph::Impl::add_op(ModelOpRef op) { producer->consumers.push_back(node); } + node->context = context_stack_->current_context(); + nodes_.push_back(node); return node; } diff --git a/ark/model/model_graph_impl.hpp b/ark/model/model_graph_impl.hpp index 6c109b51e..fbfc54c7e 100644 --- a/ark/model/model_graph_impl.hpp +++ b/ark/model/model_graph_impl.hpp @@ -4,6 +4,7 @@ #ifndef ARK_MODEL_GRAPH_IMPL_HPP_ #define ARK_MODEL_GRAPH_IMPL_HPP_ +#include #include #include #include @@ -18,17 +19,39 @@ namespace ark { +class ModelGraphContextStack { + private: + std::map>> storage_; + + public: + ModelGraphContextStack() = default; + + ModelGraphContextStack(const ModelGraphContextStack &other); + + ~ModelGraphContextStack() = default; + + void push(const std::string &key, const std::string &value); + + void pop(const std::string &key); + + std::map current_context() const; +}; + class ModelGraph::Impl { public: Impl(int rank, int world_size) - : rank_(rank), world_size_(world_size), compressed_(false){}; + : rank_(rank), + world_size_(world_size), + compressed_(false), + context_stack_(std::make_shared()) {}; Impl(const Impl &other); Impl &operator=(const Impl &other); template - ModelOpRef create_op(const std::string &name, Args &&... args) { + ModelOpRef create_op(const std::string &config, const std::string &name, + Args &&...args) { ModelOpRef op = std::make_shared(std::forward(args)...); std::string name_copy; if (name.empty()) { @@ -41,6 +64,7 @@ class ModelGraph::Impl { if (count > 0) { name_copy += "_" + std::to_string(count); } + op->set_config(config); op->set_name(name_copy); add_op(op); return op; @@ -50,7 +74,7 @@ class ModelGraph::Impl { int world_size() const { return world_size_; } - void compress_nodes(); + void compress_nodes(bool merge_nodes = false); bool compressed() const { return compressed_; } @@ -100,6 +124,12 @@ class ModelGraph::Impl { /// True if `compress_nodes` has been called. bool compressed_; + + protected: + friend class ContextManager; + + /// Graph context stack. + std::shared_ptr context_stack_; }; } // namespace ark diff --git a/ark/model/model_node.hpp b/ark/model/model_node.hpp index 7838ca120..c86b4d29a 100644 --- a/ark/model/model_node.hpp +++ b/ark/model/model_node.hpp @@ -26,6 +26,9 @@ class ModelNode { /// The list of @ref ModelNode that this @ref ModelNode depends on. UniqueList producers; + + /// Graph context of this node. + std::map context; }; } // namespace ark diff --git a/ark/model/model_op.cpp b/ark/model/model_op.cpp index b5a0645c8..e9689cdcb 100644 --- a/ark/model/model_op.cpp +++ b/ark/model/model_op.cpp @@ -87,6 +87,14 @@ const ModelOpType ModelOpT::from_name(const std::string &type_name) { return it->second; } +void ModelOp::set_config(const std::string &config) { + if (!config.empty()) { + config_ = Json::parse(config); + } else { + config_.clear(); + } +} + std::vector ModelOp::input_tensors() const { // input_tensors = read_tensors || write_tensors std::set input_tensors; @@ -179,6 +187,9 @@ Json ModelOp::serialize() const { for (auto &arg : args_) { j["Args"][arg.first] = arg.second.serialize(); } + if (!config_.empty()) { + j["Config"] = config_; + } return j; } diff --git a/ark/model/model_op.hpp b/ark/model/model_op.hpp index e8c220258..091a9f163 100644 --- a/ark/model/model_op.hpp +++ b/ark/model/model_op.hpp @@ -50,8 +50,8 @@ class ModelOp { return ""; } - virtual std::vector impl_args([ - [maybe_unused]] const Json &config) const { + virtual std::vector impl_args( + [[maybe_unused]] const Json &config) const { return {}; } @@ -60,10 +60,14 @@ class ModelOp { return {{"NumTasks", 0}, {"NumWarps", 0}, {"SramBytes", 0}}; } + void set_config(const std::string &config); + void set_name(const std::string &name) { name_ = name; } ModelOpType type() const { return type_; } + const Json &config() const { return config_; } + const std::string &name() const { return name_; } bool is_virtual() const { return is_virtual_; } @@ -100,6 +104,7 @@ class ModelOp { const std::vector &template_args = {}); ModelOpType type_; + Json config_; std::string name_; bool is_virtual_; std::vector read_tensors_; diff --git a/ark/ops/ops_arithmetic.cpp b/ark/ops/ops_arithmetic.cpp index aeece0d77..ef85b5d22 100644 --- a/ark/ops/ops_arithmetic.cpp +++ b/ark/ops/ops_arithmetic.cpp @@ -12,9 +12,10 @@ ModelOpAdd::ModelOpAdd(ModelTensorRef input, ModelTensorRef other, : ModelOpBroadcast2("Add", input, other, output) {} Tensor Model::add(Tensor input, Tensor other, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref_, other.ref_, output.ref_) + ->create_op(config, name, input.ref_, other.ref_, + output.ref_) ->result_tensors()[0]; } @@ -23,9 +24,10 @@ ModelOpMul::ModelOpMul(ModelTensorRef input, ModelTensorRef other, : ModelOpBroadcast2("Mul", input, other, output) {} Tensor Model::mul(Tensor input, Tensor other, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref_, other.ref_, output.ref_) + ->create_op(config, name, input.ref_, other.ref_, + output.ref_) ->result_tensors()[0]; } @@ -34,9 +36,10 @@ ModelOpSub::ModelOpSub(ModelTensorRef input, ModelTensorRef other, : ModelOpBroadcast2("Sub", input, other, output) {} Tensor Model::sub(Tensor input, Tensor other, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref_, other.ref_, output.ref_) + ->create_op(config, name, input.ref_, other.ref_, + output.ref_) ->result_tensors()[0]; } @@ -45,9 +48,10 @@ ModelOpDiv::ModelOpDiv(ModelTensorRef input, ModelTensorRef other, : ModelOpBroadcast2("Div", input, other, output) {} Tensor Model::div(Tensor input, Tensor other, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref_, other.ref_, output.ref_) + ->create_op(config, name, input.ref_, other.ref_, + output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_cast.cpp b/ark/ops/ops_cast.cpp index 9873c8367..e9527ad8c 100644 --- a/ark/ops/ops_cast.cpp +++ b/ark/ops/ops_cast.cpp @@ -105,7 +105,7 @@ ModelOpByteCast::ModelOpByteCast(ModelTensorRef input, ModelDataType data_type, } Tensor Model::cast(Tensor input, const DataType &data_type, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { check_null(input.ref()); if (output.is_null()) { if (input.data_type() == data_type) { @@ -119,14 +119,14 @@ Tensor Model::cast(Tensor input, const DataType &data_type, Tensor output, byte_cast_helper(input.ref(), data_type.ref(), new_shape, new_strides, new_offsets, new_padded_shape); return impl_ - ->create_op(name, input.ref(), data_type.ref(), - new_shape, new_strides, - new_offsets, new_padded_shape) + ->create_op( + config, name, input.ref(), data_type.ref(), new_shape, + new_strides, new_offsets, new_padded_shape) ->result_tensors()[0]; } } return impl_ - ->create_op(name, input.ref(), data_type.ref(), + ->create_op(config, name, input.ref(), data_type.ref(), output.ref()) ->result_tensors()[0]; } diff --git a/ark/ops/ops_communication.cpp b/ark/ops/ops_communication.cpp index e335f869e..4e76d2ede 100644 --- a/ark/ops/ops_communication.cpp +++ b/ark/ops/ops_communication.cpp @@ -157,23 +157,25 @@ Json ModelOpRecv::default_config([[maybe_unused]] const ArchRef arch) const { } Tensor Model::send(Tensor input, int remote_rank, int tag, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { tags_.insert(tag); return impl_ - ->create_op(name, input.ref(), remote_rank, tag, + ->create_op(config, name, input.ref(), remote_rank, tag, output.ref()) ->result_tensors()[0]; } -Tensor Model::send_done(Tensor input, const std::string &name) { - return impl_->create_op(name, input.ref()) +Tensor Model::send_done(Tensor input, const std::string &config, + const std::string &name) { + return impl_->create_op(config, name, input.ref()) ->result_tensors()[0]; } Tensor Model::recv(Tensor output, int remote_rank, int tag, - const std::string &name) { + const std::string &config, const std::string &name) { tags_.insert(tag); - return impl_->create_op(name, output.ref(), remote_rank, tag) + return impl_ + ->create_op(config, name, output.ref(), remote_rank, tag) ->result_tensors()[0]; } diff --git a/ark/ops/ops_copy.cpp b/ark/ops/ops_copy.cpp index 4f32966b8..4914c34a4 100644 --- a/ark/ops/ops_copy.cpp +++ b/ark/ops/ops_copy.cpp @@ -20,8 +20,9 @@ ModelOpCopy::ModelOpCopy(ModelTensorRef input, ModelTensorRef output) verify(); } -Tensor Model::copy(Tensor input, Tensor output, const std::string &name) { - return impl_->create_op(name, input.ref_, output.ref_) +Tensor Model::copy(Tensor input, Tensor output, const std::string &config, + const std::string &name) { + return impl_->create_op(config, name, input.ref_, output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_embedding.cpp b/ark/ops/ops_embedding.cpp index 542c0fcac..466b9a4e5 100644 --- a/ark/ops/ops_embedding.cpp +++ b/ark/ops/ops_embedding.cpp @@ -70,9 +70,9 @@ Json ModelOpEmbedding::default_config([ } Tensor Model::embedding(Tensor input, Tensor weight, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref_, weight.ref_, + ->create_op(config, name, input.ref_, weight.ref_, output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_identity.cpp b/ark/ops/ops_identity.cpp index 065cd9a52..dd398d8a5 100644 --- a/ark/ops/ops_identity.cpp +++ b/ark/ops/ops_identity.cpp @@ -31,7 +31,7 @@ Tensor Model::identity(Tensor input, const std::vector &deps, for (auto &dep : deps) { deps_ref.emplace_back(dep.ref_); } - return impl_->create_op(name, input.ref_, deps_ref) + return impl_->create_op("", name, input.ref_, deps_ref) ->result_tensors()[0]; } diff --git a/ark/ops/ops_math.cpp b/ark/ops/ops_math.cpp index 1067c561a..b2833dcca 100644 --- a/ark/ops/ops_math.cpp +++ b/ark/ops/ops_math.cpp @@ -24,48 +24,55 @@ ModelOpMath::ModelOpMath(const std::string &type_name, ModelTensorRef input, ModelOpExp::ModelOpExp(ModelTensorRef input, ModelTensorRef output) : ModelOpMath("Exp", input, output) {} -Tensor Model::exp(Tensor input, Tensor output, const std::string &name) { - return impl_->create_op(name, input.ref_, output.ref_) +Tensor Model::exp(Tensor input, Tensor output, const std::string &config, + const std::string &name) { + return impl_->create_op(config, name, input.ref_, output.ref_) ->result_tensors()[0]; } ModelOpGelu::ModelOpGelu(ModelTensorRef input, ModelTensorRef output) : ModelOpMath("Gelu", input, output) {} -Tensor Model::gelu(Tensor input, Tensor output, const std::string &name) { - return impl_->create_op(name, input.ref_, output.ref_) +Tensor Model::gelu(Tensor input, Tensor output, const std::string &config, + const std::string &name) { + return impl_->create_op(config, name, input.ref_, output.ref_) ->result_tensors()[0]; } ModelOpRelu::ModelOpRelu(ModelTensorRef input, ModelTensorRef output) : ModelOpMath("Relu", input, output) {} -Tensor Model::relu(Tensor input, Tensor output, const std::string &name) { - return impl_->create_op(name, input.ref_, output.ref_) +Tensor Model::relu(Tensor input, Tensor output, const std::string &config, + const std::string &name) { + return impl_->create_op(config, name, input.ref_, output.ref_) ->result_tensors()[0]; } ModelOpRsqrt::ModelOpRsqrt(ModelTensorRef input, ModelTensorRef output) : ModelOpMath("Rsqrt", input, output) {} -Tensor Model::rsqrt(Tensor input, Tensor output, const std::string &name) { - return impl_->create_op(name, input.ref_, output.ref_) +Tensor Model::rsqrt(Tensor input, Tensor output, const std::string &config, + const std::string &name) { + return impl_->create_op(config, name, input.ref_, output.ref_) ->result_tensors()[0]; } ModelOpSigmoid::ModelOpSigmoid(ModelTensorRef input, ModelTensorRef output) : ModelOpMath("Sigmoid", input, output) {} -Tensor Model::sigmoid(Tensor input, Tensor output, const std::string &name) { - return impl_->create_op(name, input.ref_, output.ref_) +Tensor Model::sigmoid(Tensor input, Tensor output, const std::string &config, + const std::string &name) { + return impl_ + ->create_op(config, name, input.ref_, output.ref_) ->result_tensors()[0]; } ModelOpSqrt::ModelOpSqrt(ModelTensorRef input, ModelTensorRef output) : ModelOpMath("Sqrt", input, output) {} -Tensor Model::sqrt(Tensor input, Tensor output, const std::string &name) { - return impl_->create_op(name, input.ref_, output.ref_) +Tensor Model::sqrt(Tensor input, Tensor output, const std::string &config, + const std::string &name) { + return impl_->create_op(config, name, input.ref_, output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_matmul.cpp b/ark/ops/ops_matmul.cpp index a24b95d72..1976699a1 100644 --- a/ark/ops/ops_matmul.cpp +++ b/ark/ops/ops_matmul.cpp @@ -255,10 +255,10 @@ Json ModelOpMatmul::default_config(const ArchRef arch) const { Tensor Model::matmul(Tensor input, Tensor other, Tensor output, bool trans_input, bool trans_other, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref(), other.ref(), output.ref(), - trans_input, trans_other) + ->create_op(config, name, input.ref(), other.ref(), + output.ref(), trans_input, trans_other) ->result_tensors()[0]; } diff --git a/ark/ops/ops_noop.cpp b/ark/ops/ops_noop.cpp index 894ab29be..42fe5fdf5 100644 --- a/ark/ops/ops_noop.cpp +++ b/ark/ops/ops_noop.cpp @@ -30,7 +30,7 @@ Json ModelOpNoop::default_config([[maybe_unused]] const ArchRef arch) const { } void Model::noop(Tensor input, const std::string &name) { - impl_->create_op(name, input.ref_); + impl_->create_op("", name, input.ref_); } } // namespace ark diff --git a/ark/ops/ops_reduce.cpp b/ark/ops/ops_reduce.cpp index 1c91a2f0b..dadd049d2 100644 --- a/ark/ops/ops_reduce.cpp +++ b/ark/ops/ops_reduce.cpp @@ -128,25 +128,25 @@ Json ModelOpReduce::default_config([[maybe_unused]] const ArchRef arch) const { } Tensor Model::reduce_max(Tensor input, int axis, bool keepdims, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref_, axis, keepdims, + ->create_op(config, name, input.ref_, axis, keepdims, output.ref_) ->result_tensors()[0]; } Tensor Model::reduce_mean(Tensor input, int axis, bool keepdims, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref_, axis, keepdims, + ->create_op(config, name, input.ref_, axis, keepdims, output.ref_) ->result_tensors()[0]; } Tensor Model::reduce_sum(Tensor input, int axis, bool keepdims, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref_, axis, keepdims, + ->create_op(config, name, input.ref_, axis, keepdims, output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_refer.cpp b/ark/ops/ops_refer.cpp index 782d6708c..68c61b30f 100644 --- a/ark/ops/ops_refer.cpp +++ b/ark/ops/ops_refer.cpp @@ -20,7 +20,7 @@ Tensor Model::refer(Tensor input, const Dims &shape, const Dims &strides, const Dims &offsets, const Dims &padded_shape, const std::string &name) { return impl_ - ->create_op(name, input.ref_, shape, strides, offsets, + ->create_op("", name, input.ref_, shape, strides, offsets, padded_shape) ->result_tensors()[0]; } diff --git a/ark/ops/ops_reshape.cpp b/ark/ops/ops_reshape.cpp index c4e192908..6ecbba466 100644 --- a/ark/ops/ops_reshape.cpp +++ b/ark/ops/ops_reshape.cpp @@ -199,8 +199,8 @@ Tensor Model::reshape(Tensor input, const Dims &shape, bool allowzero, reshape_helper(input.ref_, Dims{inferred_shape}, allowzero, new_shape, new_strides, new_offs); return impl_ - ->create_op(name, input.ref_, new_shape, new_strides, - new_offs) + ->create_op("", name, input.ref_, new_shape, + new_strides, new_offs) ->result_tensors()[0]; } diff --git a/ark/ops/ops_rope.cpp b/ark/ops/ops_rope.cpp index 06c1c915e..36015aae5 100644 --- a/ark/ops/ops_rope.cpp +++ b/ark/ops/ops_rope.cpp @@ -12,9 +12,10 @@ ModelOpRope::ModelOpRope(ModelTensorRef input, ModelTensorRef other, : ModelOpBroadcast2("Rope", input, other, output) {} Tensor Model::rope(Tensor input, Tensor other, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref_, other.ref_, output.ref_) + ->create_op(config, name, input.ref_, other.ref_, + output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_scalar.cpp b/ark/ops/ops_scalar.cpp index 944a7247c..b5c10f1c3 100644 --- a/ark/ops/ops_scalar.cpp +++ b/ark/ops/ops_scalar.cpp @@ -115,20 +115,21 @@ std::vector ModelOpScalarMul::impl_args([ Tensor Model::constant(float val, const Dims &shape, DataType data_type, const std::string &name) { return impl_ - ->create_op(name, val, shape, data_type.ref(), + ->create_op("", name, val, shape, data_type.ref(), nullptr) ->result_tensors()[0]; } -Tensor Model::copy(float val, Tensor output, const std::string &name) { +Tensor Model::copy(float val, Tensor output, const std::string &config, + const std::string &name) { if (output == NullTensor) { return impl_ - ->create_op(name, val, Dims{1}, FP32.ref(), - output.ref()) + ->create_op(config, name, val, Dims{1}, + FP32.ref(), output.ref()) ->result_tensors()[0]; } else { return impl_ - ->create_op(name, val, output.shape(), + ->create_op(config, name, val, output.shape(), output.data_type().ref(), output.ref()) ->result_tensors()[0]; @@ -136,30 +137,34 @@ Tensor Model::copy(float val, Tensor output, const std::string &name) { } Tensor Model::add(Tensor input, float value, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref_, value, output.ref_) + ->create_op(config, name, input.ref_, value, + output.ref_) ->result_tensors()[0]; } Tensor Model::sub(Tensor input, float value, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref_, -value, output.ref_) + ->create_op(config, name, input.ref_, -value, + output.ref_) ->result_tensors()[0]; } Tensor Model::mul(Tensor input, float value, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref_, value, output.ref_) + ->create_op(config, name, input.ref_, value, + output.ref_) ->result_tensors()[0]; } Tensor Model::div(Tensor input, float value, Tensor output, - const std::string &name) { + const std::string &config, const std::string &name) { return impl_ - ->create_op(name, input.ref_, 1 / value, output.ref_) + ->create_op(config, name, input.ref_, 1 / value, + output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_tensor.cpp b/ark/ops/ops_tensor.cpp index 0279ab311..77091fa57 100644 --- a/ark/ops/ops_tensor.cpp +++ b/ark/ops/ops_tensor.cpp @@ -27,7 +27,7 @@ Tensor Model::tensor(const Dims &shape, const DataType &data_type, const Dims &strides, const Dims &offsets, const Dims &padded_shape, const std::string &name) { return impl_ - ->create_op(name, nullptr, shape, data_type.ref(), + ->create_op("", name, nullptr, shape, data_type.ref(), strides, offsets, padded_shape) ->result_tensors()[0]; } diff --git a/ark/ops/ops_transpose.cpp b/ark/ops/ops_transpose.cpp index 3f0ed0131..f099c7fb7 100644 --- a/ark/ops/ops_transpose.cpp +++ b/ark/ops/ops_transpose.cpp @@ -124,9 +124,10 @@ Json ModelOpTranspose::default_config([ } Tensor Model::transpose(Tensor input, const std::vector &permutation, - Tensor output, const std::string &name) { + Tensor output, const std::string &config, + const std::string &name) { return impl_ - ->create_op(name, input.ref_, permutation, + ->create_op(config, name, input.ref_, permutation, output.ref_) ->result_tensors()[0]; } diff --git a/arkprof.py b/arkprof.py index 9e67c2dfc..5fb62e118 100644 --- a/arkprof.py +++ b/arkprof.py @@ -1,6 +1,7 @@ import ark import sys +ark.init() ark.Profiler(ark.Plan.from_file(sys.argv[1])).run( iter=1000, profile_processor_groups=False ) diff --git a/examples/tutorial/context_tutorial.py b/examples/tutorial/context_tutorial.py new file mode 100644 index 000000000..fb01f0a0c --- /dev/null +++ b/examples/tutorial/context_tutorial.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import ark +import time +import torch +import torch.nn.functional as F + + +class VanillaSoftmax(ark.Module): + def __init__(self): + super(Softmax, self).__init__() + + def forward(self, input): + max = ark.reduce_max(input, axis=-1) + output = ark.sub(input, max) + output = ark.exp(output) + sum = ark.reduce_sum(output, axis=-1) + output = ark.div(output, sum) + return output + + +class Softmax(ark.Module): + def __init__(self): + super(Softmax, self).__init__() + + def forward(self, input): + with ark.ContextManager( + processor_range=[0, 304], + warp_range=[0, 8], + sram_range=[0, 0], + task_id=0, + ): + max = ark.reduce_max( + input, + axis=-1, + config={ + "NumWarps": 1, + "ImplType": "WarpWise", + "SramBytes": 0, + "NumTasks": 65536, + }, + ) + output = ark.sub( + input, + max, + config={ + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1, 2048], + "NumTasks": 65536, + }, + ) + output = ark.exp( + output, + config={ + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1, 2048], + "NumTasks": 65536, + }, + ) + sum = ark.reduce_sum( + output, + axis=-1, + config={ + "NumWarps": 1, + "ImplType": "WarpWise", + "SramBytes": 0, + "NumTasks": 65536, + }, + ) + output = ark.div( + output, + sum, + config={ + "NumWarps": 1, + "SramBytes": 0, + "Tile": [1, 2048], + "NumTasks": 65536, + }, + ) + return output + + +def eval(tensor: ark.Tensor): + with ark.Runtime() as rt: + rt.launch() + rt.run() + return tensor.to_torch() + + +def perf(): + with ark.Runtime() as rt: + rt.launch() + + start = time.time() + rt.run(iter=1000) + end = time.time() + return (end - start) / 1000 + + +if __name__ == "__main__": + ark.init() + + shape = (32, 2048, 2048) + + input = torch.randn(*shape).to("cuda:0") + + output = Softmax()(ark.Tensor.from_torch(input)) + + if torch.allclose(eval(output), F.softmax(input, dim=-1), atol=1e-5): + print("Correct result") + else: + print("Incorrect result") + + print(f"Performance: {(perf() * 1e3):.3f} ms/iter") diff --git a/python/ark/__init__.py b/python/ark/__init__.py index e96972906..00370e683 100644 --- a/python/ark/__init__.py +++ b/python/ark/__init__.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import sys import os if os.environ.get("ARK_ROOT", None) is None: @@ -102,3 +101,4 @@ def set_world_size(world_size): ) from .planner import DefaultPlanner, Plan from .profiler import Profiler +from .context_manager import ContextManager diff --git a/python/ark/context_manager.py b/python/ark/context_manager.py new file mode 100644 index 000000000..443e1ca5d --- /dev/null +++ b/python/ark/context_manager.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +from .model import Model +from ._ark_core import _ContextManager + + +class ContextManager(_ContextManager): + def __init__(self, **kwargs): + context_map = {key: json.dumps(value) for key, value in kwargs.items()} + super().__init__(Model.get_model(), context_map) + + def __enter__(self) -> "ContextManager": + """ + Enter the context manager. + """ + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + """ + Exit the context manager. + """ + del self diff --git a/python/ark/ops.py b/python/ark/ops.py index 86b021aef..509e3c891 100644 --- a/python/ark/ops.py +++ b/python/ark/ops.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import List, Iterable, Union +import json +from typing import Any, Dict, List, Iterable, Union from .tensor import Dims, Tensor, Parameter, NullTensor from .data_type import DataType, fp32 @@ -12,6 +13,12 @@ def _is_list_or_tuple(obj): return isinstance(obj, list) or isinstance(obj, tuple) +def _config_to_str(config: Union[str, Dict[str, Any]]) -> str: + if isinstance(config, str): + return config + return json.dumps(config) + + def _tensor( shape: Iterable[int], dtype: DataType = fp32, @@ -50,6 +57,7 @@ def add( input: Union[Tensor, float], other: Union[Tensor, float], output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", name: str = "add", ) -> Union[Tensor, float]: """ @@ -73,12 +81,15 @@ def add( return input + other else: return Tensor( - Model.get_model().copy(input + other, output._tensor, name) + Model.get_model().copy( + input + other, output._tensor, _config_to_str(config), name + ) ) if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().add(a, b, output, name), runtime_id=input.runtime_id + Model.get_model().add(a, b, output, _config_to_str(config), name), + runtime_id=input.runtime_id, ) @@ -86,13 +97,16 @@ def cast( input: Tensor, dtype: DataType, output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", name: str = "cast", ) -> Tensor: """Type casting.""" if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().cast(input._tensor, dtype.ctype(), output, name), + Model.get_model().cast( + input._tensor, dtype.ctype(), output, _config_to_str(config), name + ), runtime_id=input.runtime_id, ) @@ -112,7 +126,10 @@ def constant( def copy( - input: Union[Tensor, float], output: Tensor = NullTensor, name: str = "copy" + input: Union[Tensor, float], + output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", + name: str = "copy", ) -> Tensor: """Data caopy.""" if output is not NullTensor: @@ -120,7 +137,7 @@ def copy( if isinstance(input, Tensor): intput = intput._tensor return Tensor( - Model.get_model().copy(intput, output, name), + Model.get_model().copy(intput, output, _config_to_str(config), name), runtime_id=input.runtime_id, ) @@ -129,6 +146,7 @@ def div( input: Tensor, other: Union[Tensor, float], output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", name: str = "div", ) -> Tensor: """ @@ -144,7 +162,9 @@ def div( raise ValueError("Tensors must be on the same runtime") other = other._tensor return Tensor( - Model.get_model().div(input._tensor, other, output, name), + Model.get_model().div( + input._tensor, other, output, _config_to_str(config), name + ), runtime_id=input.runtime_id, ) @@ -153,6 +173,7 @@ def embedding( input: Tensor, weight: Tensor, output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", name: str = "embedding", ) -> Tensor: """Embedding layer.""" @@ -162,14 +183,17 @@ def embedding( output = output._tensor return Tensor( Model.get_model().embedding( - input._tensor, weight._tensor, output, name + input._tensor, weight._tensor, output, _config_to_str(config), name ), runtime_id=input.runtime_id, ) def exp( - input: Tensor, output: Tensor = NullTensor, name: str = "exp" + input: Tensor, + output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", + name: str = "exp", ) -> Tensor: """ Calculates the exponential of the `input` tensor, element-wise. @@ -179,13 +203,18 @@ def exp( if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().exp(input._tensor, output, name), + Model.get_model().exp( + input._tensor, output, _config_to_str(config), name + ), runtime_id=input.runtime_id, ) def gelu( - input: Tensor, output: Tensor = NullTensor, name: str = "gelu" + input: Tensor, + output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", + name: str = "gelu", ) -> Tensor: """ Applies the Gaussian Error Linear Unit (GELU) activation @@ -198,7 +227,9 @@ def gelu( if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().gelu(input._tensor, output, name), + Model.get_model().gelu( + input._tensor, output, _config_to_str(config), name + ), runtime_id=input.runtime_id, ) @@ -230,6 +261,7 @@ def matmul( output: Tensor = NullTensor, transpose_input: bool = False, transpose_other: bool = False, + config: Union[str, Dict[str, Any]] = "", name: str = "matmul", ) -> Tensor: """ @@ -252,6 +284,7 @@ def matmul( output, transpose_input, transpose_other, + _config_to_str(config), name, ), runtime_id=input.runtime_id, @@ -262,6 +295,7 @@ def mul( input: Tensor, other: Union[Tensor, float], output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", name: str = "mul", ) -> Tensor: """ @@ -277,7 +311,9 @@ def mul( raise ValueError("Tensors must be on the same runtime") other = other._tensor return Tensor( - Model.get_model().mul(input._tensor, other, output, name), + Model.get_model().mul( + input._tensor, other, output, _config_to_str(config), name + ), runtime_id=input.runtime_id, ) @@ -294,6 +330,7 @@ def reduce_max( axis: int, keepdims: bool = True, output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", name: str = "reduce_max", ) -> Tensor: """ @@ -306,7 +343,7 @@ def reduce_max( output = output._tensor return Tensor( Model.get_model().reduce_max( - input._tensor, axis, keepdims, output, name + input._tensor, axis, keepdims, output, _config_to_str(config), name ), runtime_id=input.runtime_id, ) @@ -317,6 +354,7 @@ def reduce_mean( axis: int, keepdims: bool = True, output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", name: str = "reduce_mean", ) -> Tensor: """ @@ -329,7 +367,7 @@ def reduce_mean( output = output._tensor return Tensor( Model.get_model().reduce_mean( - input._tensor, axis, keepdims, output, name + input._tensor, axis, keepdims, output, _config_to_str(config), name ), runtime_id=input.runtime_id, ) @@ -340,6 +378,7 @@ def reduce_sum( axis: int, keepdims: bool = True, output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", name: str = "reduce_sum", ) -> Tensor: """ @@ -354,14 +393,17 @@ def reduce_sum( output = output._tensor return Tensor( Model.get_model().reduce_sum( - input._tensor, axis, keepdims, output, name + input._tensor, axis, keepdims, output, _config_to_str(config), name ), runtime_id=input.runtime_id, ) def relu( - input: Tensor, output: Tensor = NullTensor, name: str = "relu" + input: Tensor, + output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", + name: str = "relu", ) -> Tensor: """ Applies the ReLU activation function to the `input` tensor, @@ -372,7 +414,9 @@ def relu( if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().relu(input._tensor, output, name), + Model.get_model().relu( + input._tensor, output, _config_to_str(config), name + ), runtime_id=input.runtime_id, ) @@ -411,6 +455,7 @@ def rope( input: Tensor, other: Tensor, output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", name: str = "rope", ) -> Tensor: """ @@ -423,13 +468,18 @@ def rope( if input.runtime_id != other.runtime_id: raise ValueError("Tensors must be on the same runtime") return Tensor( - Model.get_model().rope(input._tensor, other._tensor, output, name), + Model.get_model().rope( + input._tensor, other._tensor, output, _config_to_str(config), name + ), runtime_id=input.runtime_id, ) def rsqrt( - input: Tensor, output: Tensor = NullTensor, name: str = "rsqrt" + input: Tensor, + output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", + name: str = "rsqrt", ) -> Tensor: """ Calculates the square root of the `input` tensor, element-wise. @@ -439,7 +489,9 @@ def rsqrt( if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().rsqrt(input._tensor, output, name), + Model.get_model().rsqrt( + input._tensor, output, _config_to_str(config), name + ), runtime_id=input.runtime_id, ) @@ -465,7 +517,10 @@ def sharding( def sigmoid( - input: Tensor, output: Tensor = NullTensor, name: str = "sigmoid" + input: Tensor, + output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", + name: str = "sigmoid", ) -> Tensor: """ Applies the Sigmoid activation function to the `input` tensor, @@ -476,13 +531,18 @@ def sigmoid( if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().sigmoid(input._tensor, output, name), + Model.get_model().sigmoid( + input._tensor, output, _config_to_str(config), name + ), runtime_id=input.runtime_id, ) def sqrt( - input: Tensor, output: Tensor = NullTensor, name: str = "sqrt" + input: Tensor, + output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", + name: str = "sqrt", ) -> Tensor: """ Calculates the square root of the `input` tensor, element-wise. @@ -492,7 +552,9 @@ def sqrt( if output is not NullTensor: output = output._tensor return Tensor( - Model.get_model().sqrt(input._tensor, output, name), + Model.get_model().sqrt( + input._tensor, output, _config_to_str(config), name + ), runtime_id=input.runtime_id, ) @@ -501,6 +563,7 @@ def sub( input: Tensor, other: Union[Tensor, float], output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", name: str = "sub", ) -> Tensor: """ @@ -516,7 +579,9 @@ def sub( raise ValueError("Tensors must be on the same runtime") other = other._tensor return Tensor( - Model.get_model().sub(input._tensor, other, output, name), + Model.get_model().sub( + input._tensor, other, output, _config_to_str(config), name + ), runtime_id=input.runtime_id, ) @@ -546,6 +611,7 @@ def transpose( input: Tensor, perm: Iterable[int], output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", name: str = "transpose", ) -> Tensor: """ @@ -565,7 +631,9 @@ def transpose( if len(perm) > 4: raise ValueError("Only support perm up to 4 dimensions") return Tensor( - Model.get_model().transpose(input._tensor, perm, output, name), + Model.get_model().transpose( + input._tensor, perm, output, _config_to_str(config), name + ), runtime_id=input.runtime_id, ) @@ -578,10 +646,11 @@ def mean( axis: int, keepdims: bool = True, output: Tensor = NullTensor, + config: Union[str, Dict[str, Any]] = "", name: str = "mean", ) -> Tensor: """Alias of reduce_mean.""" - return reduce_mean(input, axis, keepdims, output, name) + return reduce_mean(input, axis, keepdims, output, config, name) def ones( diff --git a/python/ark/profiler.py b/python/ark/profiler.py index c161b24e6..e47f5b7aa 100644 --- a/python/ark/profiler.py +++ b/python/ark/profiler.py @@ -21,8 +21,15 @@ class Profiler: def __init__(self, plan: Plan): self.plan = plan - def run(self, iter: int = 1000, loop_mode: bool = True, profile_processor_groups: bool = False): - sys.stderr.write(f"End-to-end: {timeit(self.plan, iter, loop_mode):.6f} seconds/iter\n") + def run( + self, + iter: int = 1000, + loop_mode: bool = True, + profile_processor_groups: bool = False, + ): + sys.stderr.write( + f"End-to-end: {timeit(self.plan, iter, loop_mode):.6f} seconds/iter\n" + ) if not profile_processor_groups: return diff --git a/python/ark_py.cpp b/python/ark_py.cpp index 1bc4255d6..7acd4ad1a 100644 --- a/python/ark_py.cpp +++ b/python/ark_py.cpp @@ -7,6 +7,7 @@ namespace py = pybind11; +extern void register_context_manager(py::module &m); extern void register_data_type(py::module &m); extern void register_dims(py::module &m); extern void register_error(py::module &m); @@ -22,6 +23,7 @@ extern void register_version(py::module &m); PYBIND11_MODULE(_ark_core, m) { m.doc() = "Bind ARK C++ APIs to Python"; + register_context_manager(m); register_data_type(m); register_dims(m); register_error(m); diff --git a/python/context_manager_py.cpp b/python/context_manager_py.cpp new file mode 100644 index 000000000..3d703a4bc --- /dev/null +++ b/python/context_manager_py.cpp @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include + +#include + +namespace py = pybind11; + +void register_context_manager(py::module &m) { + py::class_(m, "_ContextManager") + .def(py::init&>()); +} diff --git a/python/model_py.cpp b/python/model_py.cpp index 2d1e5f634..ba17251d8 100644 --- a/python/model_py.cpp +++ b/python/model_py.cpp @@ -15,97 +15,109 @@ void register_model(py::module &m) { .def(py::init(), py::arg("rank"), py::arg("world_size")) .def("rank", &ark::Model::rank) .def("world_size", &ark::Model::world_size) - .def("compress", &ark::Model::compress) + .def("compress", &ark::Model::compress, py::arg("merge_nodes") = false) .def("add", py::overload_cast(&ark::Model::add), + const std::string &, const std::string &>( + &ark::Model::add), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("add", py::overload_cast(&ark::Model::add), + const std::string &, const std::string &>( + &ark::Model::add), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("cast", &ark::Model::cast, py::arg("input"), py::arg("data_type"), - py::arg("output"), py::arg("name")) + py::arg("output"), py::arg("config"), py::arg("name")) .def("constant", &ark::Model::constant, py::arg("value"), py::arg("shape"), py::arg("data_type"), py::arg("name")) .def("copy", - py::overload_cast( - &ark::Model::copy), - py::arg("input"), py::arg("output"), py::arg("name")) + py::overload_cast(&ark::Model::copy), + py::arg("input"), py::arg("output"), py::arg("config"), + py::arg("name")) .def("copy", - py::overload_cast( - &ark::Model::copy), - py::arg("input"), py::arg("output"), py::arg("name")) + py::overload_cast(&ark::Model::copy), + py::arg("input"), py::arg("output"), py::arg("config"), + py::arg("name")) .def("div", py::overload_cast(&ark::Model::div), + const std::string &, const std::string &>( + &ark::Model::div), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("div", py::overload_cast(&ark::Model::div), + const std::string &, const std::string &>( + &ark::Model::div), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("embedding", &ark::Model::embedding, py::arg("input"), - py::arg("weight"), py::arg("output"), py::arg("name")) - .def("exp", &ark::Model::exp, py::arg("input"), py::arg("output"), + py::arg("weight"), py::arg("output"), py::arg("config"), py::arg("name")) + .def("exp", &ark::Model::exp, py::arg("input"), py::arg("output"), + py::arg("config"), py::arg("name")) .def("gelu", &ark::Model::gelu, py::arg("input"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("identity", &ark::Model::identity, py::arg("input"), py::arg("deps"), py::arg("name")) .def("matmul", &ark::Model::matmul, py::arg("input"), py::arg("other"), py::arg("output"), py::arg("trans_input"), py::arg("trans_other"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("mul", py::overload_cast(&ark::Model::mul), + const std::string &, const std::string &>( + &ark::Model::mul), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("mul", py::overload_cast(&ark::Model::mul), + const std::string &, const std::string &>( + &ark::Model::mul), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("noop", &ark::Model::noop, py::arg("input"), py::arg("name")) .def("reduce_max", &ark::Model::reduce_max, py::arg("input"), py::arg("axis"), py::arg("keepdims"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("reduce_mean", &ark::Model::reduce_mean, py::arg("input"), py::arg("axis"), py::arg("keepdims"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("reduce_sum", &ark::Model::reduce_sum, py::arg("input"), py::arg("axis"), py::arg("keepdims"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("relu", &ark::Model::relu, py::arg("input"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("reshape", &ark::Model::reshape, py::arg("input"), py::arg("shape"), py::arg("allowzero"), py::arg("name")) .def("rope", &ark::Model::rope, py::arg("input"), py::arg("other"), - py::arg("output"), py::arg("name")) + py::arg("output"), py::arg("config"), py::arg("name")) .def("rsqrt", &ark::Model::rsqrt, py::arg("input"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("sharding", &ark::Model::sharding, py::arg("input"), py::arg("axis"), py::arg("dim_per_shard"), py::arg("name")) .def("sigmoid", &ark::Model::sigmoid, py::arg("input"), - py::arg("output"), py::arg("name")) + py::arg("output"), py::arg("config"), py::arg("name")) .def("sqrt", &ark::Model::sqrt, py::arg("input"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("sub", py::overload_cast(&ark::Model::sub), + const std::string &, const std::string &>( + &ark::Model::sub), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("sub", py::overload_cast(&ark::Model::sub), + const std::string &, const std::string &>( + &ark::Model::sub), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("name")) + py::arg("config"), py::arg("name")) .def("tensor", &ark::Model::tensor, py::arg("shape"), py::arg("data_type"), py::arg("strides"), py::arg("offsets"), py::arg("padded_shape"), py::arg("name")) .def("transpose", &ark::Model::transpose, py::arg("input"), - py::arg("permutation"), py::arg("output"), py::arg("name")); + py::arg("permutation"), py::arg("output"), py::arg("config"), + py::arg("name")); } From ef3bb84e8ebb3bb86e256767802401e39d617a85 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 29 Jul 2024 20:31:14 +0000 Subject: [PATCH 41/54] plan manager --- ark/api/context_manager_test.cpp | 1 - ark/api/model.cpp | 9 ++ ark/api/plan_manager.cpp | 97 ++++++++++++++++ ark/api/plan_manager_test.cpp | 58 ++++++++++ ark/api/planner.cpp | 125 +++++++++++++++------ ark/include/ark/model.hpp | 9 +- ark/include/ark/model_graph.hpp | 1 + ark/include/ark/plan_manager.hpp | 25 +++++ ark/model/model_graph_impl.cpp | 16 ++- ark/model/model_graph_impl.hpp | 6 +- examples/tutorial/context_tutorial.py | 117 ------------------- examples/tutorial/plan_manager_tutorial.py | 82 ++++++++++++++ python/ark/__init__.py | 2 +- python/ark/context_manager.py | 24 ---- python/ark/plan_manager.py | 34 ++++++ python/ark_py.cpp | 4 +- python/context_manager_py.cpp | 15 --- python/plan_manager_py.cpp | 15 +++ 18 files changed, 440 insertions(+), 200 deletions(-) create mode 100644 ark/api/plan_manager.cpp create mode 100644 ark/api/plan_manager_test.cpp create mode 100644 ark/include/ark/plan_manager.hpp delete mode 100644 examples/tutorial/context_tutorial.py create mode 100644 examples/tutorial/plan_manager_tutorial.py delete mode 100644 python/ark/context_manager.py create mode 100644 python/ark/plan_manager.py delete mode 100644 python/context_manager_py.cpp create mode 100644 python/plan_manager_py.cpp diff --git a/ark/api/context_manager_test.cpp b/ark/api/context_manager_test.cpp index ff60b43bf..5fff94f34 100644 --- a/ark/api/context_manager_test.cpp +++ b/ark/api/context_manager_test.cpp @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -#include "ark/model.hpp" #include "ark/context_manager.hpp" #include "model/model_node.hpp" diff --git a/ark/api/model.cpp b/ark/api/model.cpp index a5a258f71..e9604c341 100644 --- a/ark/api/model.cpp +++ b/ark/api/model.cpp @@ -9,6 +9,15 @@ namespace ark { +Model::Model(int rank, int world_size) : ModelGraph(rank, world_size) { + static size_t next_id = 0; + id_ = next_id++; +} + +Model::Model(const Model &other) : ModelGraph(other), id_(other.id()) {} + +size_t Model::id() const { return id_; } + Model Model::compress(bool merge_nodes) const { Model model(*this); model.compress_nodes(merge_nodes); diff --git a/ark/api/plan_manager.cpp b/ark/api/plan_manager.cpp new file mode 100644 index 000000000..aee8d4f7b --- /dev/null +++ b/ark/api/plan_manager.cpp @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ark/plan_manager.hpp" + +#include "logging.h" +#include "model/model_json.hpp" +#include "model/model_graph_impl.hpp" + +namespace ark { + +class PlanManagerState { + public: + PlanManagerState() : sync(true) {} + bool sync; +}; + +static std::map gPlanManagerStates; + +PlanManager::PlanManager(Model& model, const std::string& plan_context) : model_id_(model.id()), stop_sync_(false) { + auto ctx = Json::parse(plan_context); + if (!ctx.is_object()) { + ERR(ModelError, "plan context must be a JSON object"); + } + if (gPlanManagerStates.find(model_id_) == gPlanManagerStates.end()) { + gPlanManagerStates.emplace(model_id_, PlanManagerState()); + } + auto& state = gPlanManagerStates[model_id_]; + bool async = !state.sync; + std::map context_map; + for (const auto& [key, value] : ctx.items()) { + if (key == "sync") { + if (!value.is_boolean()) { + ERR(ModelError, "sync must be a boolean"); + } + if (state.sync && !value.get()) { + stop_sync_ = true; + state.sync = false; + context_map["AppendTask"] = "true"; + } else if (!state.sync) { + context_map["AppendTask"] = "true"; + } + } else if (key == "processor_range") { + if (!value.is_array()) { + ERR(ModelError, "processor_range must be an array"); + } + if (async) { + LOG(WARN, "Ignoring processor_range under sync=false context"); + continue; + } + context_map["ProcessorRange"] = value.dump(); + } else if (key == "warp_range") { + if (!value.is_array()) { + ERR(ModelError, "warp_range must be an array"); + } + if (async) { + LOG(WARN, "Ignoring warp_range under sync=false context"); + continue; + } + context_map["WarpRange"] = value.dump(); + } else if (key == "sram_range") { + if (!value.is_array()) { + ERR(ModelError, "sram_range must be an array"); + } + if (async) { + LOG(WARN, "Ignoring sram_range under sync=false context"); + continue; + } + context_map["SramRange"] = value.dump(); + } else if (key == "config") { + if (!value.is_object()) { + ERR(ModelError, "config must be an object"); + } + auto cfg = model.impl_->get_context("Config"); + if (cfg.empty()) { + context_map["Config"] = value.dump(); + } else { + auto cfg_obj = Json::parse(cfg); + for (const auto& [k, v] : value.items()) { + cfg_obj[k] = v; + } + context_map["Config"] = cfg_obj.dump(); + } + } else { + LOG(WARN, "Ignoring unknown plan context key: ", key); + } + } + context_manager_ = std::make_shared(model, context_map); +} + +PlanManager::~PlanManager() { + if (stop_sync_) { + gPlanManagerStates[model_id_].sync = true; + } +} + +} // namespace ark diff --git a/ark/api/plan_manager_test.cpp b/ark/api/plan_manager_test.cpp new file mode 100644 index 000000000..78f5d4cb8 --- /dev/null +++ b/ark/api/plan_manager_test.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ark/plan_manager.hpp" +#include "ark/planner.hpp" + +#include "model/model_json.hpp" +#include "unittest/unittest_utils.h" + +ark::unittest::State test_plan_manager() { + ark::Model model; + ark::Tensor t0 = model.tensor({1}, ark::FP32); + ark::Tensor t1 = model.tensor({1}, ark::FP32); + ark::Tensor t2 = model.add(t0, t1); + + ark::Tensor t3; + ark::Tensor t4; + ark::Tensor t5; + ark::Tensor t6; + { + ark::PlanManager pm_0(model, ark::Json({ + {"processor_range", {0, 2}}, + {"warp_range", {0, 4}}, + {"sram_range", {0, 0}}, + {"sync", false} + }).dump()); + t3 = model.relu(t2); + t4 = model.sqrt(t3); + } + { + ark::PlanManager pm_0(model, ark::Json({ + {"processor_range", {2, 4}}, + {"warp_range", {0, 4}}, + {"sram_range", {0, 0}} + }).dump()); + t5 = model.exp(t2); + + ark::PlanManager pm_1(model, ark::Json({ + {"processor_range", {2, 3}} + }).dump()); + t6 = model.rsqrt(t5); + } + + UNITTEST_TRUE(model.verify()); + + ark::DefaultPlanner planner(model, 0); + auto plan_str = planner.plan(); + ark::Json plan = ark::Json::parse(plan_str); + + UNITTEST_LOG(plan_str); + + return ark::unittest::SUCCESS; +} + +int main() { + UNITTEST(test_plan_manager); + return 0; +} diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index dba149a1e..1c40e5301 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -58,19 +58,35 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const { size_t num_sm = gpu_info.num_sm; Json task_infos = Json::array(); Json processor_groups = Json::array(); - size_t max_num_warps = 1; - size_t max_num_processors = 1; - size_t next_node_id = 0; + size_t max_processor_id = 1; + size_t max_warp_id = 1; + size_t next_task_id = 0; + bool prev_append_task = false; + bool first_op = true; + + auto get_context = [&](const ModelNodeRef &node, + const std::string &key) -> Json { + if (node->context.find(key) != node->context.end()) { + return Json::parse(node->context.at(key)); + } + return Json(); + }; + for (const auto &node : model_.nodes()) { + std::string context = ""; + for (const auto &[key, value] : node->context) { + context += key + "=" + value + ","; + } + context += "prev_append_task=" + std::to_string(prev_append_task); + LOG(INFO, context); + for (const auto &op : node->ops) { if (op->is_virtual()) continue; - Json task_info; - task_info["Id"] = next_node_id++; - + auto ctx_config = get_context(node, "Config"); Json config; - if (!op->config().empty()) { - config = op->config(); + if (!ctx_config.empty()) { + config = ctx_config; } else if (!config_rules_.empty()) { const std::string op_str = op->serialize().dump(); for (auto &rule : config_rules_) { @@ -90,31 +106,70 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const { size_t num_warps = config["NumWarps"]; size_t num_tasks = config["NumTasks"]; size_t sram_bytes = config["SramBytes"]; - task_info["NumWarps"] = num_warps; - task_info["SramBytes"] = sram_bytes; - - max_num_warps = std::max(max_num_warps, num_warps); - - task_info["Ops"] = Json::array(); - task_info["Ops"].push_back(op->serialize()); - task_info["Ops"][0]["Config"] = config; - task_infos.push_back(task_info); - - Json resource_group; - size_t num_processors = std::min(num_sm, num_tasks); - max_num_processors = std::max(max_num_processors, num_processors); - resource_group["ProcessorRange"] = {0, num_processors}; - resource_group["WarpRange"] = {0, num_warps}; - resource_group["SramRange"] = {0, sram_bytes}; - resource_group["TaskGroups"] = {{{"TaskId", task_info["Id"]}, - {"TaskRange", {0, num_tasks}}, - {"Granularity", 1}}}; - - Json processor_group; - processor_group["ProcessorRange"] = {0, num_processors}; - processor_group["ResourceGroups"] = Json::array(); - processor_group["ResourceGroups"].push_back(resource_group); - processor_groups.push_back(processor_group); + + auto ctx_append_task = get_context(node, "AppendTask"); + if (!ctx_append_task.empty() && ctx_append_task.get() && + prev_append_task) { + auto &task_info = task_infos.back(); + task_info["NumWarps"] = + std::max(task_info["NumWarps"].get(), num_warps); + task_info["SramBytes"] = + std::max(task_info["SramBytes"].get(), sram_bytes); + task_info["Ops"].push_back(op->serialize()); + task_info["Ops"].back()["Config"] = config; + } else { + Json task_info; + task_info["Id"] = first_op ? next_task_id : ++next_task_id; + task_info["NumWarps"] = num_warps; + task_info["SramBytes"] = sram_bytes; + task_info["Ops"] = Json::array(); + task_info["Ops"].push_back(op->serialize()); + task_info["Ops"][0]["Config"] = config; + task_infos.push_back(task_info); + + auto ctx_processor_range = get_context(node, "ProcessorRange"); + auto ctx_warp_range = get_context(node, "WarpRange"); + auto ctx_sram_range = get_context(node, "SramRange"); + + Json processor_group; + if (!ctx_processor_range.empty()) { + processor_group["ProcessorRange"] = ctx_processor_range; + max_processor_id = std::max( + max_processor_id, ctx_processor_range[1].get()); + } else { + size_t num_processors = std::min(num_sm, num_tasks); + processor_group["ProcessorRange"] = {0, num_processors}; + max_processor_id = + std::max(max_processor_id, num_processors); + } + + Json resource_group; + resource_group["ProcessorRange"] = + processor_group["ProcessorRange"]; + if (!ctx_warp_range.empty()) { + resource_group["WarpRange"] = ctx_warp_range; + max_warp_id = + std::max(max_warp_id, ctx_warp_range[1].get()); + } else { + resource_group["WarpRange"] = {0, num_warps}; + max_warp_id = std::max(max_warp_id, num_warps); + } + if (!ctx_sram_range.empty()) { + resource_group["SramRange"] = ctx_sram_range; + } else { + resource_group["SramRange"] = {0, sram_bytes}; + } + resource_group["TaskGroups"] = {{{"TaskId", task_info["Id"]}, + {"TaskRange", {0, num_tasks}}, + {"Granularity", 1}}}; + + processor_group["ResourceGroups"] = Json::array(); + processor_group["ResourceGroups"].push_back(resource_group); + processor_groups.push_back(processor_group); + } + prev_append_task = + !ctx_append_task.empty() && ctx_append_task.get(); + first_op = false; } } @@ -122,8 +177,8 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const { plan["Rank"] = model_.rank(); plan["WorldSize"] = model_.world_size(); plan["Architecture"] = gpu_info.arch->name(); - plan["NumProcessors"] = max_num_processors; - plan["NumWarpsPerProcessor"] = max_num_warps; + plan["NumProcessors"] = max_processor_id; + plan["NumWarpsPerProcessor"] = max_warp_id; plan["TaskInfos"] = task_infos; plan["ProcessorGroups"] = processor_groups; diff --git a/ark/include/ark/model.hpp b/ark/include/ark/model.hpp index 35efe53d5..e0b17be52 100644 --- a/ark/include/ark/model.hpp +++ b/ark/include/ark/model.hpp @@ -17,15 +17,20 @@ namespace ark { class Model : public ModelGraph { private: + size_t id_; std::set tags_; public: - Model(int rank = 0, int world_size = 1) : ModelGraph(rank, world_size) {} - Model(const Model &other) : ModelGraph(other) {} + Model(int rank = 0, int world_size = 1); + + Model(const Model &other); + ~Model() {} Model &operator=(const Model &other) = default; + size_t id() const; + Model compress(bool merge_nodes = false) const; int unique_tag(); diff --git a/ark/include/ark/model_graph.hpp b/ark/include/ark/model_graph.hpp index f6390a2a9..c53c98c3a 100644 --- a/ark/include/ark/model_graph.hpp +++ b/ark/include/ark/model_graph.hpp @@ -38,6 +38,7 @@ class ModelGraph { protected: friend class Model; + friend class PlanManager; friend class ContextManager; class Impl; diff --git a/ark/include/ark/plan_manager.hpp b/ark/include/ark/plan_manager.hpp new file mode 100644 index 000000000..3952a1c06 --- /dev/null +++ b/ark/include/ark/plan_manager.hpp @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_PLAN_MANAGER_HPP +#define ARK_PLAN_MANAGER_HPP + +#include + +namespace ark { + +class PlanManager { + public: + PlanManager(Model& model, const std::string& plan_context); + + ~PlanManager(); + + private: + size_t model_id_; + bool stop_sync_; + std::shared_ptr context_manager_; +}; + +} // namespace ark + +#endif // ARK_PLAN_MANAGER_HPP diff --git a/ark/model/model_graph_impl.cpp b/ark/model/model_graph_impl.cpp index 53a7fa851..385424e57 100644 --- a/ark/model/model_graph_impl.cpp +++ b/ark/model/model_graph_impl.cpp @@ -37,7 +37,15 @@ void ModelGraphContextStack::pop(const std::string &key) { it->second.pop_back(); } -std::map ModelGraphContextStack::current_context() const { +std::string ModelGraphContextStack::get_context(const std::string &key) const { + if (this->storage_.find(key) == this->storage_.end() || + this->storage_.at(key).empty()) { + return ""; + } + return *this->storage_.at(key).back(); +} + +std::map ModelGraphContextStack::get_context_all() const { std::map cur; for (const auto &pair : this->storage_) { if (!pair.second.empty()) { @@ -167,6 +175,10 @@ bool ModelGraph::Impl::verify() const { return true; } +std::string ModelGraph::Impl::get_context(const std::string &key) const { + return context_stack_->get_context(key); +} + ModelNodeRef ModelGraph::Impl::add_op(ModelOpRef op) { for (auto &tns : op->input_tensors()) { if (tensor_to_producer_op_.find(tns) == tensor_to_producer_op_.end()) { @@ -205,7 +217,7 @@ ModelNodeRef ModelGraph::Impl::add_op(ModelOpRef op) { producer->consumers.push_back(node); } - node->context = context_stack_->current_context(); + node->context = context_stack_->get_context_all(); nodes_.push_back(node); return node; diff --git a/ark/model/model_graph_impl.hpp b/ark/model/model_graph_impl.hpp index fbfc54c7e..ec255423e 100644 --- a/ark/model/model_graph_impl.hpp +++ b/ark/model/model_graph_impl.hpp @@ -34,7 +34,9 @@ class ModelGraphContextStack { void pop(const std::string &key); - std::map current_context() const; + std::string get_context(const std::string &key) const; + + std::map get_context_all() const; }; class ModelGraph::Impl { @@ -80,6 +82,8 @@ class ModelGraph::Impl { bool verify() const; + std::string get_context(const std::string &key) const; + std::string serialize(bool pretty = true) const; std::vector nodes() const; diff --git a/examples/tutorial/context_tutorial.py b/examples/tutorial/context_tutorial.py deleted file mode 100644 index fb01f0a0c..000000000 --- a/examples/tutorial/context_tutorial.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import ark -import time -import torch -import torch.nn.functional as F - - -class VanillaSoftmax(ark.Module): - def __init__(self): - super(Softmax, self).__init__() - - def forward(self, input): - max = ark.reduce_max(input, axis=-1) - output = ark.sub(input, max) - output = ark.exp(output) - sum = ark.reduce_sum(output, axis=-1) - output = ark.div(output, sum) - return output - - -class Softmax(ark.Module): - def __init__(self): - super(Softmax, self).__init__() - - def forward(self, input): - with ark.ContextManager( - processor_range=[0, 304], - warp_range=[0, 8], - sram_range=[0, 0], - task_id=0, - ): - max = ark.reduce_max( - input, - axis=-1, - config={ - "NumWarps": 1, - "ImplType": "WarpWise", - "SramBytes": 0, - "NumTasks": 65536, - }, - ) - output = ark.sub( - input, - max, - config={ - "NumWarps": 1, - "SramBytes": 0, - "Tile": [1, 2048], - "NumTasks": 65536, - }, - ) - output = ark.exp( - output, - config={ - "NumWarps": 1, - "SramBytes": 0, - "Tile": [1, 2048], - "NumTasks": 65536, - }, - ) - sum = ark.reduce_sum( - output, - axis=-1, - config={ - "NumWarps": 1, - "ImplType": "WarpWise", - "SramBytes": 0, - "NumTasks": 65536, - }, - ) - output = ark.div( - output, - sum, - config={ - "NumWarps": 1, - "SramBytes": 0, - "Tile": [1, 2048], - "NumTasks": 65536, - }, - ) - return output - - -def eval(tensor: ark.Tensor): - with ark.Runtime() as rt: - rt.launch() - rt.run() - return tensor.to_torch() - - -def perf(): - with ark.Runtime() as rt: - rt.launch() - - start = time.time() - rt.run(iter=1000) - end = time.time() - return (end - start) / 1000 - - -if __name__ == "__main__": - ark.init() - - shape = (32, 2048, 2048) - - input = torch.randn(*shape).to("cuda:0") - - output = Softmax()(ark.Tensor.from_torch(input)) - - if torch.allclose(eval(output), F.softmax(input, dim=-1), atol=1e-5): - print("Correct result") - else: - print("Incorrect result") - - print(f"Performance: {(perf() * 1e3):.3f} ms/iter") diff --git a/examples/tutorial/plan_manager_tutorial.py b/examples/tutorial/plan_manager_tutorial.py new file mode 100644 index 000000000..25aca7af6 --- /dev/null +++ b/examples/tutorial/plan_manager_tutorial.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import ark +import time +import torch +import torch.nn.functional as F + + +class VanillaSoftmax(ark.Module): + def __init__(self): + super(Softmax, self).__init__() + + def forward(self, input): + max = ark.reduce_max(input, axis=-1) + output = ark.sub(input, max) + output = ark.exp(output) + sum = ark.reduce_sum(output, axis=-1) + output = ark.div(output, sum) + return output + + +class Softmax(ark.Module): + def __init__(self): + super(Softmax, self).__init__() + + def forward(self, input): + with ark.PlanManager( + processor_range=[0, 304], + warp_range=[0, 8], + sram_range=[0, 0], + sync=False, + config={ + "NumWarps": 1, + "SramBytes": 0, + "NumTasks": 65536, + } + ): + with ark.PlanManager(config={"ImplType": "WarpWise"}): + max = ark.reduce_max(input, axis=-1) + with ark.PlanManager(config={"Tile": [1, 2048]}): + output = ark.sub(input, max) + output = ark.exp(output) + with ark.PlanManager(config={"ImplType": "WarpWise"}): + sum = ark.reduce_sum(output, axis=-1) + with ark.PlanManager(config={"Tile": [1, 2048]}): + output = ark.div(output, sum) + return output + + +def eval(tensor: ark.Tensor): + with ark.Runtime() as rt: + rt.launch() + rt.run() + return tensor.to_torch() + + +def perf(): + with ark.Runtime() as rt: + rt.launch() + + start = time.time() + rt.run(iter=1000) + end = time.time() + return (end - start) / 1000 + + +if __name__ == "__main__": + ark.init() + + shape = (32, 2048, 2048) + + input = torch.randn(*shape).to("cuda:0") + + output = Softmax()(ark.Tensor.from_torch(input)) + + if torch.allclose(eval(output), F.softmax(input, dim=-1), atol=1e-5): + print("Correct result") + else: + print("Incorrect result") + + print(f"Performance: {(perf() * 1e3):.3f} ms/iter") diff --git a/python/ark/__init__.py b/python/ark/__init__.py index 00370e683..db19b59d4 100644 --- a/python/ark/__init__.py +++ b/python/ark/__init__.py @@ -101,4 +101,4 @@ def set_world_size(world_size): ) from .planner import DefaultPlanner, Plan from .profiler import Profiler -from .context_manager import ContextManager +from .plan_manager import PlanManager diff --git a/python/ark/context_manager.py b/python/ark/context_manager.py deleted file mode 100644 index 443e1ca5d..000000000 --- a/python/ark/context_manager.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import json -from .model import Model -from ._ark_core import _ContextManager - - -class ContextManager(_ContextManager): - def __init__(self, **kwargs): - context_map = {key: json.dumps(value) for key, value in kwargs.items()} - super().__init__(Model.get_model(), context_map) - - def __enter__(self) -> "ContextManager": - """ - Enter the context manager. - """ - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - """ - Exit the context manager. - """ - del self diff --git a/python/ark/plan_manager.py b/python/ark/plan_manager.py new file mode 100644 index 000000000..80e615ab8 --- /dev/null +++ b/python/ark/plan_manager.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +from typing import List, Dict, Any +from .model import Model +from ._ark_core import _PlanManager + + +class PlanManager(_PlanManager): + def __init__(self, **kwargs): + """ + Plan manager for specifying the parallelization and tiling configuration of the operators in the context. + + Args: + processor_range (List[int], optional): The range of processors to be used. Defaults to None. + warp_range (List[int], optional): The range of warps to be used. Defaults to None. + sram_range (List[int], optional): The range of SRAMs to be used. Defaults to None. + sync (bool, optional): Whether to synchronize the execution. Defaults to True. + config (Dict[str, Any], optional): The configuration for the operators. Defaults to None. + """ + super().__init__(Model.get_model(), json.dumps(kwargs)) + + def __enter__(self) -> "PlanManager": + """ + Enter the plan manager. + """ + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + """ + Exit the plan manager. + """ + del self diff --git a/python/ark_py.cpp b/python/ark_py.cpp index 7acd4ad1a..75788ba55 100644 --- a/python/ark_py.cpp +++ b/python/ark_py.cpp @@ -7,7 +7,7 @@ namespace py = pybind11; -extern void register_context_manager(py::module &m); +extern void register_plan_manager(py::module &m); extern void register_data_type(py::module &m); extern void register_dims(py::module &m); extern void register_error(py::module &m); @@ -23,7 +23,7 @@ extern void register_version(py::module &m); PYBIND11_MODULE(_ark_core, m) { m.doc() = "Bind ARK C++ APIs to Python"; - register_context_manager(m); + register_plan_manager(m); register_data_type(m); register_dims(m); register_error(m); diff --git a/python/context_manager_py.cpp b/python/context_manager_py.cpp deleted file mode 100644 index 3d703a4bc..000000000 --- a/python/context_manager_py.cpp +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include -#include -#include - -#include - -namespace py = pybind11; - -void register_context_manager(py::module &m) { - py::class_(m, "_ContextManager") - .def(py::init&>()); -} diff --git a/python/plan_manager_py.cpp b/python/plan_manager_py.cpp new file mode 100644 index 000000000..34aa0b77c --- /dev/null +++ b/python/plan_manager_py.cpp @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include + +#include + +namespace py = pybind11; + +void register_plan_manager(py::module &m) { + py::class_(m, "_PlanManager") + .def(py::init()); +} From 7a7f70e43d3e6e327abf5fe835fad1902c803ca0 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 30 Jul 2024 04:45:27 +0000 Subject: [PATCH 42/54] fix --- ark/api/plan_manager.cpp | 8 ++++---- ark/api/planner.cpp | 22 ++++++++-------------- examples/tutorial/plan_manager_tutorial.py | 3 +-- python/ark/tensor.py | 7 +++++-- 4 files changed, 18 insertions(+), 22 deletions(-) diff --git a/ark/api/plan_manager.cpp b/ark/api/plan_manager.cpp index aee8d4f7b..8cb1940b1 100644 --- a/ark/api/plan_manager.cpp +++ b/ark/api/plan_manager.cpp @@ -17,7 +17,9 @@ class PlanManagerState { static std::map gPlanManagerStates; -PlanManager::PlanManager(Model& model, const std::string& plan_context) : model_id_(model.id()), stop_sync_(false) { +PlanManager::PlanManager(Model& model, const std::string& plan_context) + : model_id_(model.id()), stop_sync_(false) { + static int task_group_id = 0; auto ctx = Json::parse(plan_context); if (!ctx.is_object()) { ERR(ModelError, "plan context must be a JSON object"); @@ -36,9 +38,7 @@ PlanManager::PlanManager(Model& model, const std::string& plan_context) : model_ if (state.sync && !value.get()) { stop_sync_ = true; state.sync = false; - context_map["AppendTask"] = "true"; - } else if (!state.sync) { - context_map["AppendTask"] = "true"; + context_map["TaskGroupId"] = std::to_string(task_group_id++); } } else if (key == "processor_range") { if (!value.is_array()) { diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index 1c40e5301..032be0d6f 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -61,7 +61,7 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const { size_t max_processor_id = 1; size_t max_warp_id = 1; size_t next_task_id = 0; - bool prev_append_task = false; + int prev_task_group_id = -1; bool first_op = true; auto get_context = [&](const ModelNodeRef &node, @@ -73,13 +73,6 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const { }; for (const auto &node : model_.nodes()) { - std::string context = ""; - for (const auto &[key, value] : node->context) { - context += key + "=" + value + ","; - } - context += "prev_append_task=" + std::to_string(prev_append_task); - LOG(INFO, context); - for (const auto &op : node->ops) { if (op->is_virtual()) continue; @@ -106,10 +99,12 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const { size_t num_warps = config["NumWarps"]; size_t num_tasks = config["NumTasks"]; size_t sram_bytes = config["SramBytes"]; + size_t granularity = config.value("Granularity", 1); - auto ctx_append_task = get_context(node, "AppendTask"); - if (!ctx_append_task.empty() && ctx_append_task.get() && - prev_append_task) { + auto ctx_task_group_id = get_context(node, "TaskGroupId"); + int task_group_id = + ctx_task_group_id.empty() ? -1 : ctx_task_group_id.get(); + if (task_group_id != -1 && task_group_id == prev_task_group_id) { auto &task_info = task_infos.back(); task_info["NumWarps"] = std::max(task_info["NumWarps"].get(), num_warps); @@ -161,14 +156,13 @@ std::string DefaultPlanner::Impl::plan(bool pretty) const { } resource_group["TaskGroups"] = {{{"TaskId", task_info["Id"]}, {"TaskRange", {0, num_tasks}}, - {"Granularity", 1}}}; + {"Granularity", granularity}}}; processor_group["ResourceGroups"] = Json::array(); processor_group["ResourceGroups"].push_back(resource_group); processor_groups.push_back(processor_group); } - prev_append_task = - !ctx_append_task.empty() && ctx_append_task.get(); + prev_task_group_id = task_group_id; first_op = false; } } diff --git a/examples/tutorial/plan_manager_tutorial.py b/examples/tutorial/plan_manager_tutorial.py index 25aca7af6..c840ce0c0 100644 --- a/examples/tutorial/plan_manager_tutorial.py +++ b/examples/tutorial/plan_manager_tutorial.py @@ -26,7 +26,6 @@ def __init__(self): def forward(self, input): with ark.PlanManager( - processor_range=[0, 304], warp_range=[0, 8], sram_range=[0, 0], sync=False, @@ -34,7 +33,7 @@ def forward(self, input): "NumWarps": 1, "SramBytes": 0, "NumTasks": 65536, - } + }, ): with ark.PlanManager(config={"ImplType": "WarpWise"}): max = ark.reduce_max(input, axis=-1) diff --git a/python/ark/tensor.py b/python/ark/tensor.py index 657da1065..eed7a4259 100644 --- a/python/ark/tensor.py +++ b/python/ark/tensor.py @@ -193,7 +193,9 @@ def from_torch(tensor: torch.Tensor, runtime_id: int = -1) -> "Tensor": ark_tensor = _Tensor(dl_capsule, ark_dtype.ctype()) return Tensor(ark_tensor, runtime_id=runtime_id) - def copy(self, data: Union[np.ndarray, torch.Tensor]) -> "Tensor": + def copy( + self, data: Union[np.ndarray, torch.Tensor], stream: int = 0 + ) -> "Tensor": """ Copies data into this tensor. The data type may differ, but the size must match. @@ -214,6 +216,7 @@ def copy(self, data: Union[np.ndarray, torch.Tensor]) -> "Tensor": self._tensor, data.data_ptr(), tensor_bytes, + stream, data.device.type == "cuda", ) elif isinstance(data, np.ndarray): @@ -221,7 +224,7 @@ def copy(self, data: Union[np.ndarray, torch.Tensor]) -> "Tensor": data = np.ascontiguousarray(data) if data.nbytes != tensor_bytes: raise ValueError("data size does not match the tensor") - rt.executor.tensor_write(self._tensor, data) + rt.executor.tensor_write(self._tensor, data, stream) else: raise ValueError("data must be a numpy array or a torch tensor") return self From a77a2ea6b864562f4e916dbaaf30f82e080aad93 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 30 Jul 2024 05:48:00 +0000 Subject: [PATCH 43/54] llama example --- examples/llama/model_7b_b1_s2048.py | 704 ++++++++++++++++++++++++++++ examples/llama/model_test.py | 6 +- 2 files changed, 708 insertions(+), 2 deletions(-) create mode 100644 examples/llama/model_7b_b1_s2048.py diff --git a/examples/llama/model_7b_b1_s2048.py b/examples/llama/model_7b_b1_s2048.py new file mode 100644 index 000000000..f41304e85 --- /dev/null +++ b/examples/llama/model_7b_b1_s2048.py @@ -0,0 +1,704 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""LLaMA 2 Transformer model. + Correspond to https://github.com/facebookresearch/llama/blob/main/llama/model.py +""" + +import ark +import math +from dataclasses import dataclass +from typing import Optional +import os + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = ( + 256 # make SwiGLU hidden layer size multiple of large power of 2 + ) + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +@dataclass +class ModelArgs7B(ModelArgs): + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = ( + 256 # make SwiGLU hidden layer size multiple of large power of 2 + ) + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +@dataclass +class ModelArgs13B(ModelArgs): + dim: int = 5120 + n_layers: int = 40 + n_heads: int = 40 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = ( + 256 # make SwiGLU hidden layer size multiple of large power of 2 + ) + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +@dataclass +class ModelArgs70B(ModelArgs): + dim: int = 8192 + n_layers: int = 80 + n_heads: int = 64 + n_kv_heads: Optional[int] = 8 + vocab_size: int = -1 + multiple_of: int = ( + 4096 # make SwiGLU hidden layer size multiple of large power of 2 + ) + ffn_dim_multiplier: Optional[float] = 1.3 + norm_eps: float = 1e-5 + max_batch_size: int = 32 + max_seq_len: int = 4096 + + +class RMSNorm(ark.Module): + """ + Root mean square layer normalization (RMSNorm). + """ + + def __init__( + self, dim: int, eps: float = 1e-6, dtype: ark.DataType = ark.fp16 + ): + super().__init__() + self.eps = eps + self.dtype = dtype + self.weight = ark.parameter([1, 1, dim], ark.fp32) + + def forward(self, x): + with ark.PlanManager( + warp_range=[0, 8], + sync=False, + config={ + "NumWarps": 1, + "SramBytes": 0, + "NumTasks": 2048, + "Granularity": 7, + }, + ): + with ark.PlanManager(config={"Tile": [1, 4096]}): + x = ark.cast(x, ark.fp32) + x2 = ark.mul(x, x) + with ark.PlanManager(config={"ImplType": "WarpWise"}): + mean = ark.reduce_mean(x2, axis=-1) + with ark.PlanManager( + config={ + "NumWarps": 1, + "SramBytes": 0, + "Tile": [64, 1], + "NumTasks": 32, + } + ): + rrms = ark.rsqrt(mean) + with ark.PlanManager( + warp_range=[0, 8], + sync=False, + config={ + "NumWarps": 1, + "SramBytes": 0, + "NumTasks": 2048, + "Tile": [1, 4096], + "Granularity": 7, + }, + ): + x = ark.mul(x, rrms) + x = ark.mul(x, self.weight, x) + return ark.cast(x, self.dtype) + + +class ColumnParallelLinear(ark.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + Here the weight = A^T, so we need to partition the weight matrix along + its first dimension. + + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dtype: ark.DataType = ark.fp16, + gather_output: bool = True, + local_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.dtype = dtype + self.local_rank = local_rank + self.world_size = world_size + self.gather_output = gather_output + + self.weight = ark.parameter([out_dim // world_size, in_dim], dtype) + + def forward(self, x): + if self.world_size == 1 or self.gather_output == False: + return ark.matmul(x, self.weight, transpose_other=True) + # We need to concat the output_tensor_shards along the last dimension + output_tensor = ark.tensor( + [x.shape()[0], x.shape()[1], self.out_dim], self.dtype + ) + output_tensor_shards = ark.sharding( + output_tensor, + axis=2, + dim_per_shard=self.out_dim // self.world_size, + ) + local_result = ark.identity( + output_tensor_shards[self.local_rank], deps=output_tensor_shards + ) + # (batch_size, seq_len, out_dim // world_size) + local_result = ark.matmul( + x, self.weight, local_result, transpose_other=True + ) + gather_input = ark.identity(output_tensor, deps=[local_result]) + # return gather_input + gather_reshape = ark.reshape( + gather_input, [x.shape()[0] * x.shape()[1], self.out_dim] + ) + gather_out = ark.local_all_gather( + gather_reshape, self.local_rank, self.world_size, 1 + ) + return ark.reshape( + gather_out, [x.shape()[0], x.shape()[1], self.out_dim] + ) + + +class RowParallelLinear(ark.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + + Here the weight = A^T, so we need to partition the weight matrix along + its second dimension. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dtype: ark.DataType = ark.fp16, + input_is_parallel: bool = False, + local_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.dtype = dtype + self.local_rank = local_rank + self.world_size = world_size + self.input_is_parallel = input_is_parallel + + self.weight = ark.parameter([out_dim, in_dim // world_size], dtype) + + def forward(self, x): + if self.world_size == 1: + return ark.matmul(x, self.weight, transpose_other=True) + x_ndims = len(x.shape()) + if self.input_is_parallel: + input_parallel = x + else: + x_shards = ark.sharding( + x, x_ndims - 1, self.in_dim // self.world_size + ) + input_parallel = x_shards[self.local_rank] + local_result = ark.matmul( + input_parallel, self.weight, transpose_other=True + ) + reduced_result = ark.local_all_reduce( + local_result, self.local_rank, self.world_size + ) + return reduced_result + + +class ParallelEmbedding(ark.Module): + """Embedding layer.""" + + def __init__( + self, + vocab_size: int, + dim: int, + dtype: ark.DataType, + local_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + self.vocab_size = vocab_size + self.dim = dim + self.weight = ark.parameter([vocab_size, dim // world_size], dtype) + self.out_dim = dim + self.dtype = dtype + self.world_size = world_size + self.local_rank = local_rank + + def forward(self, x): + if self.world_size == 1: + return ark.embedding(x, self.weight) + + output_tensor = ark.tensor( + [x.shape()[0], x.shape()[1], self.out_dim], self.dtype + ) + output_tensor_shards = ark.sharding( + output_tensor, axis=2, dim_per_shard=self.out_dim // self.world_size + ) + local_result = ark.identity( + output_tensor_shards[self.local_rank], deps=output_tensor_shards + ) + local_result = ark.embedding(x, self.weight, local_result) + gather_input = ark.identity(output_tensor, deps=[local_result]) + gather_reshape = ark.reshape( + gather_input, [x.shape()[0] * x.shape()[1], self.out_dim] + ) + gather_out = ark.local_all_gather( + gather_reshape, self.local_rank, self.world_size, 1 + ) + return ark.reshape( + gather_out, [x.shape()[0], x.shape()[1], self.out_dim] + ) + + +class Linear(ark.Module): + """ + Linear layer module with weights and no bias. + """ + + def __init__( + self, in_dim: int, out_dim: int, dtype: ark.DataType = ark.fp16 + ): + super().__init__() + self.dtype = dtype + self.weight = ark.parameter([out_dim, in_dim], dtype) + + def forward(self, x): + return ark.matmul(x, self.weight, transpose_other=True) + + +class Silu(ark.Module): + """ + Silu activation function, silu(x) = x * sigmoid(x) + """ + + def __init__(self): + super().__init__() + + def forward(self, x: ark.Tensor): + # We need to specify output tensor so that the sigmoid op will not be an in-place operator + output = ark.tensor(x.shape(), x.dtype()) + x1 = ark.sigmoid(x, output) + return ark.mul(x, x1) + + +class FeedForward(ark.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + dtype: ark.DataType = ark.fp16, + local_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ( + (hidden_dim + multiple_of - 1) // multiple_of + ) + + self.w1 = ColumnParallelLinear( + dim, hidden_dim, dtype, False, local_rank, world_size + ) + self.w2 = RowParallelLinear( + hidden_dim, dim, dtype, True, local_rank, world_size + ) + self.w3 = ColumnParallelLinear( + dim, hidden_dim, dtype, False, local_rank, world_size + ) + + def forward(self, x): + # self.w2(F.silu(self.w1(x)) * self.w3(x)) + with ark.PlanManager( + warp_range=[0, 8], + sram_range=[0, 49344], + sync=False, + config={ + "NumWarps": 4, + "NumTasks": 688, + }, + ): + with ark.PlanManager( + config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} + ): + x1 = self.w1(x) + with ark.PlanManager(config={"SramBytes": 0, "Tile": [256, 128]}): + x1 = Silu()(x1) + with ark.PlanManager( + warp_range=[0, 8], + sram_range=[0, 49344], + sync=False, + config={ + "NumWarps": 4, + "NumTasks": 688, + }, + ): + with ark.PlanManager( + config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} + ): + x2 = self.w3(x) + with ark.PlanManager(config={"SramBytes": 0, "Tile": [256, 128]}): + x3 = ark.mul(x1, x2) + x4 = self.w2(x3) + return x4 + + +def apply_rotary_emb(xq, xk, freqs_cis): + """ + Apply rotary embeddings to xq and xk. + """ + xq_out = ark.rope(xq, freqs_cis) + xk_out = ark.rope(xk, freqs_cis) + return xq_out, xk_out + + +class Softmax(ark.Module): + def __init__(self): + super(Softmax, self).__init__() + + def forward(self, input): + with ark.PlanManager( + warp_range=[0, 8], + sram_range=[0, 0], + sync=False, + config={ + "NumWarps": 1, + "SramBytes": 0, + "NumTasks": 65536, + }, + ): + with ark.PlanManager(config={"ImplType": "WarpWise"}): + max = ark.reduce_max(input, axis=-1) + with ark.PlanManager(config={"Tile": [1, 2048]}): + output = ark.sub(input, max) + output = ark.exp(output) + with ark.PlanManager(config={"ImplType": "WarpWise"}): + sum = ark.reduce_sum(output, axis=-1) + with ark.PlanManager(config={"Tile": [1, 2048]}): + output = ark.div(output, sum) + return output + + +class Attention(ark.Module): + def __init__( + self, + args: ModelArgs, + dtype: ark.DataType = ark.fp16, + local_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + self.n_kv_heads = ( + args.n_heads if args.n_kv_heads is None else args.n_kv_heads + ) + model_parallel_size = world_size + self.dtype = dtype + self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + self.wq = ColumnParallelLinear( + args.dim, + args.n_heads * self.head_dim, + dtype, + False, + local_rank, + world_size, + ) + self.wk = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + dtype, + False, + local_rank, + world_size, + ) + self.wv = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + dtype, + False, + local_rank, + world_size, + ) + self.wo = RowParallelLinear( + args.n_heads * self.head_dim, + args.dim, + dtype, + True, + local_rank, + world_size, + ) + + def forward( + self, + x: ark.Tensor, + start_pos: int, + freqs_cis: ark.Tensor, + mask: Optional[ark.Tensor], + ): + bsz, seqlen, _ = x.shape() + + with ark.PlanManager( + warp_range=[0, 4], + sram_range=[0, 24672], + sync=False, + config={"NumWarps": 4, "NumTasks": 256}, + ): + with ark.PlanManager( + config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} + ): + xq = self.wq(x) + xq = ark.reshape( + xq, [bsz, seqlen, self.n_local_heads, self.head_dim] + ) + with ark.PlanManager( + config={"SramBytes": 0, "Tile": [256, 1, 128]} + ): + if freqs_cis is not None: + xq = ark.rope(xq, freqs_cis) + with ark.PlanManager(config={"SramBytes": 0, "Tile": [256, 128]}): + xq = ark.transpose(xq, [0, 2, 1, 3]) + + with ark.PlanManager( + warp_range=[0, 4], + sram_range=[0, 24672], + sync=False, + config={"NumWarps": 4, "NumTasks": 256}, + ): + with ark.PlanManager( + config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} + ): + xk = self.wk(x) + xk = ark.reshape( + xk, [bsz, seqlen, self.n_local_kv_heads, self.head_dim] + ) + with ark.PlanManager( + config={"SramBytes": 0, "Tile": [256, 1, 128]} + ): + if freqs_cis is not None: + xk = ark.rope(xk, freqs_cis) + keys = xk + with ark.PlanManager(config={"SramBytes": 0, "Tile": [256, 128]}): + keys = ark.transpose(keys, [0, 2, 1, 3]) + + with ark.PlanManager( + warp_range=[0, 4], + sram_range=[0, 24672], + sync=False, + config={ + "NumWarps": 4, + "NumTasks": 256, + "SramBytes": 24672, + "TileShapeMNK": [256, 128, 32], + }, + ): + with ark.PlanManager( + config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} + ): + xv = self.wv(x) + xv = ark.reshape( + xv, [bsz, seqlen, self.n_local_kv_heads, self.head_dim] + ) + values = xv + with ark.PlanManager( + config={"SramBytes": 0, "Tile": [256, 1, 128]} + ): + values = ark.transpose(values, [0, 2, 1, 3]) + + with ark.PlanManager( + warp_range=[0, 8], + sram_range=[0, 49344], + sync=False, + config={ + "NumWarps": 4, + "NumTasks": 4096, + "Granularity": 2, + }, + ): + with ark.PlanManager( + config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} + ): + scores = ark.matmul(xq, keys, transpose_other=True) + with ark.PlanManager(config={"SramBytes": 0, "Tile": [256, 128]}): + scores = ark.mul(scores, 1.0 / math.sqrt(self.head_dim)) + + if mask is not None: + scores = ark.add(scores, mask) + + scores = Softmax()(scores) + + with ark.PlanManager( + warp_range=[0, 4], + sram_range=[0, 24672], + sync=False, + config={ + "NumWarps": 4, + "NumTasks": 256, + }, + ): + with ark.PlanManager( + config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} + ): + output = ark.matmul(scores, values) + with ark.PlanManager( + config={"SramBytes": 0, "Tile": [256, 1, 128]} + ): + output = ark.transpose(output, [0, 2, 1, 3]) + output = ark.reshape( + output, [bsz, seqlen, self.head_dim * self.n_local_heads] + ) + return self.wo(output) + + +class TransformerBlock(ark.Module): + def __init__( + self, + layer_id: int, + args: ModelArgs, + dtype: ark.DataType = ark.fp16, + local_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args, dtype, local_rank, world_size) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + dtype=dtype, + local_rank=local_rank, + world_size=world_size, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps, dtype=dtype) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps, dtype=dtype) + + def forward( + self, + x: ark.Tensor, + start_pos: int, + freqs_cis: ark.Tensor, + mask: Optional[ark.Tensor], + ): + attention_norm_x = self.attention_norm(x) + h = self.attention.forward(attention_norm_x, start_pos, freqs_cis, mask) + with ark.PlanManager( + warp_range=[0, 4], + config={ + "NumWarps": 4, + "Tile": [256, 128], + "NumTasks": 256, + "SramBytes": 0, + }, + ): + h = ark.add(x, h) + ff = self.feed_forward(self.ffn_norm(h)) + with ark.PlanManager( + warp_range=[0, 4], + config={ + "NumWarps": 4, + "Tile": [256, 128], + "NumTasks": 256, + "SramBytes": 0, + }, + ): + out = ark.add(h, ff) + return out + + +class Transformer(ark.Module): + def __init__( + self, + params: ModelArgs, + dtype: ark.DataType = ark.fp16, + local_rank: int = 0, + world_size: int = 1, + ): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = ParallelEmbedding( + params.vocab_size, params.dim, dtype, local_rank, world_size + ) + + self.layers = [] + for layer_id in range(self.n_layers): + self.layers.append( + TransformerBlock( + layer_id, params, dtype, local_rank, world_size + ) + ) + self.register_module(f"layers.{layer_id}", self.layers[layer_id]) + self.norm = RMSNorm(params.dim, eps=params.norm_eps, dtype=dtype) + self.output = ColumnParallelLinear( + params.dim, params.vocab_size, dtype, True, local_rank, world_size + ) + + def forward( + self, + tokens: ark.Tensor, + start_pos: int, + freqs_cis: ark.Tensor, + mask: Optional[ark.Tensor], + ): + h = self.tok_embeddings(tokens) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h) + return output diff --git a/examples/llama/model_test.py b/examples/llama/model_test.py index 19c680854..f559a826b 100644 --- a/examples/llama/model_test.py +++ b/examples/llama/model_test.py @@ -59,8 +59,10 @@ def run_ark( output = module(*module_inputs) with ark.Runtime() as rt: - plan = ark.Plan.from_file("plan_llama2_7b_b1_s2048.json") - rt.launch(plan) + plan = ark.DefaultPlanner().plan() + with open("plan.json", "w") as f: + f.write(str(plan)) + rt.launch(plan=plan) # Load model parameters if state_dict: From 78ac0dacb70e26ef5dc8704c0bb69c7c47240cbd Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 6 Aug 2024 08:06:32 +0000 Subject: [PATCH 44/54] fix merge --- ark/include/ark/executor.hpp | 2 +- ark/ops/ops_test_common.cpp | 2 +- ark/ops/ops_test_common.hpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index f0a108a1f..3744c33db 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -95,7 +95,7 @@ class DefaultExecutor : public Executor { public: DefaultExecutor( const Model &model, int device_id = -1, Stream stream = nullptr, - const std::vector &config_rules = {}, + const std::vector &config_rules = {}, const std::string &name = "DefaultExecutor", bool loop_mode = true); }; diff --git a/ark/ops/ops_test_common.cpp b/ark/ops/ops_test_common.cpp index 2bd9ce2e7..4e94d06a7 100644 --- a/ark/ops/ops_test_common.cpp +++ b/ark/ops/ops_test_common.cpp @@ -35,7 +35,7 @@ OpsTestResult op_test( const std::string &test_name_prefix, const Model &model, const std::vector &inputs, const std::vector &outputs, OpsTestBaseline baseline, const std::vector &inputs_data, - const std::vector &config_rules, + const std::vector &config_rules, bool print_on_error) { DefaultExecutor exe(model, -1, nullptr, config_rules); exe.compile(); diff --git a/ark/ops/ops_test_common.hpp b/ark/ops/ops_test_common.hpp index c5d640f3b..3848773e6 100644 --- a/ark/ops/ops_test_common.hpp +++ b/ark/ops/ops_test_common.hpp @@ -171,7 +171,7 @@ OpsTestResult op_test( const std::string &test_name_prefix, const Model &model, const std::vector &inputs, const std::vector &outputs, OpsTestBaseline baseline, const std::vector &inputs_data = {}, - const std::vector &config_rules = {}, + const std::vector &config_rules = {}, bool print_on_error = false); OpsTestGpuMem to_gpu(void *host_ptr, size_t size); From afb518a7622363b000e9fc1d21c4cf8178c3461d Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 6 Aug 2024 08:09:48 +0000 Subject: [PATCH 45/54] fix merge --- ark/api/executor.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 58d058d25..42ed45128 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -233,8 +233,6 @@ void Executor::Impl::init(const PlanJson &plan_json) { } auto gpu_manager = GpuManager::get_instance(device_id_); - - auto gpu_manager = GpuManager::get_instance(gpu_id_); if (!gpu_manager->info().arch->belongs_to( Arch::from_name(plan_json.at("Architecture")))) { LOG(WARN, "Architecture name of the plan `", @@ -779,7 +777,7 @@ void Executor::Impl::barrier() { uintptr_t Executor::Impl::tensor_address(const Tensor tensor) const { size_t buffer_id = tensor.ref()->buffer()->id(); if (buffer_id_to_offset_.find(buffer_id) == buffer_id_to_offset_.end()) { - ERR(NotFoundError, "Invalid buffer ID: ", buffer_id); + ERR(InternalError, "Invalid buffer ID: ", buffer_id); } size_t offset = buffer_id_to_offset_.at(buffer_id); return reinterpret_cast(buffer_->ref(offset)); From 762bf4aa439510dbc04e4f9ee83da84c7a32a03a Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 6 Aug 2024 16:30:57 +0000 Subject: [PATCH 46/54] fix merge --- ark/ops/ops_all_reduce_test.cpp | 15 +++++++-------- ark/ops/ops_communication_test.cpp | 2 +- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/ark/ops/ops_all_reduce_test.cpp b/ark/ops/ops_all_reduce_test.cpp index 90814d036..8cf68b085 100644 --- a/ark/ops/ops_all_reduce_test.cpp +++ b/ark/ops/ops_all_reduce_test.cpp @@ -125,10 +125,9 @@ void test_all_reduce_packet_internal(ark::DimType nelem) { std::vector ones_vec(ones.shape().nelems(), ark::half_t(1.0f)); - auto result = - ark::op_test("all_reduce_packet", m, {ones}, {output}, - baseline_all_reduce, - {ones_vec.data()}, false, gpu_id, NumGpus); + auto result = ark::op_test( + "all_reduce_packet", m, {ones}, {output}, + baseline_all_reduce, {ones_vec.data()}); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); return ark::unittest::SUCCESS; @@ -232,10 +231,10 @@ void test_all_reduce_sm_internal(ark::DimType nelem) { std::vector ones_vec(ones.shape().nelems(), ark::half_t(1.0f)); - auto result = ark::op_test( - "all_reduce_sm", m, {ones}, {output}, - baseline_all_reduce, {ones_vec.data()}, - false, gpu_id, NumGpus, config_rule); + auto result = + ark::op_test("all_reduce_sm", m, {ones}, {output}, + baseline_all_reduce, + {ones_vec.data()}, {config_rule}); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); return ark::unittest::SUCCESS; diff --git a/ark/ops/ops_communication_test.cpp b/ark/ops/ops_communication_test.cpp index db384c1f4..8cdad41b2 100644 --- a/ark/ops/ops_communication_test.cpp +++ b/ark/ops/ops_communication_test.cpp @@ -433,7 +433,7 @@ ark::unittest::State test_communication_send_recv_reduce() { ark::Planner planner(model, gpu_id); planner.install_config_rule(config_rule); - ark::Executor exe(gpu_id, 2, gpu_id, "Executor", planner.plan()); + ark::Executor exe(gpu_id, nullptr, "Executor", planner.plan()); exe.compile(); std::vector data(1024); From f654f0b08d48931acd5645c16300c1a6f3ebe88e Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 6 Aug 2024 16:34:21 +0000 Subject: [PATCH 47/54] add a python method --- python/executor_py.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/python/executor_py.cpp b/python/executor_py.cpp index e782a99fe..a3f2a078b 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -63,6 +63,7 @@ void register_executor(py::module &m) { .def("barrier", &ark::Executor::barrier) .def("destroy", &ark::Executor::destroy) .def("destroyed", &ark::Executor::destroyed) + .def("tensor_address", &ark::Executor::tensor_address) .def("tensor_read", py::overload_cast(&tensor_read), From 498926c6242a35a38ffd6a8c406b4f3cf1ff84c6 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 6 Aug 2024 16:35:28 +0000 Subject: [PATCH 48/54] submodule update --- third_party/mscclpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/mscclpp b/third_party/mscclpp index cddffbc8b..40cb19655 160000 --- a/third_party/mscclpp +++ b/third_party/mscclpp @@ -1 +1 @@ -Subproject commit cddffbc8b6dfa6facf7c64c1b7d73acf30e600b3 +Subproject commit 40cb1965538ab98fea3cc9fe004f730e23e84829 From 3e331a2e2f5487502daccc32890ef49c5d86eb12 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 6 Aug 2024 17:12:15 +0000 Subject: [PATCH 49/54] fix --- ark/model/model_json.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ark/model/model_json.cpp b/ark/model/model_json.cpp index b82f9e484..c2099e2c9 100644 --- a/ark/model/model_json.cpp +++ b/ark/model/model_json.cpp @@ -287,6 +287,7 @@ PlanJson::PlanJson(const Json &json) : Json((json != nullptr) ? json : Json{{"Rank", 0}, {"WorldSize", 1}, + {"Architecture", "ANY"}, {"NumProcessors", 1}, {"NumWarpsPerProcessor", 1}, {"TaskInfos", Json::array()}, From 10bfa75dbd40a96ffca69fb22e89127e1839b940 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 6 Aug 2024 17:14:47 +0000 Subject: [PATCH 50/54] Rename CMake environments --- .github/workflows/codeql.yml | 4 ++-- .github/workflows/ut-cuda.yml | 2 +- CMakeLists.txt | 32 ++++++++++++++++---------------- ark/CMakeLists.txt | 10 +++++----- pyproject.toml | 2 +- third_party/CMakeLists.txt | 9 +++++++-- third_party/mscclpp | 2 +- 7 files changed, 33 insertions(+), 28 deletions(-) diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 0d7094c36..272cb8ebe 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -49,7 +49,7 @@ jobs: - name: Build run: | mkdir build && cd build - cmake -DCMAKE_BUILD_TYPE=Debug -DBUILD_PYTHON=ON -DBYPASS_GPU_CHECK=ON -DUSE_CUDA=ON -DBUILD_TESTS=OFF .. + cmake -DCMAKE_BUILD_TYPE=Debug -DARK_BUILD_PYTHON=ON -DARK_BYPASS_GPU_CHECK=ON -DARK_USE_CUDA=ON -DARK_BUILD_TESTS=OFF .. make -j build ark_py - name: Perform CodeQL Analysis @@ -95,7 +95,7 @@ jobs: - name: Build run: | mkdir build && cd build - CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_BUILD_TYPE=Debug -DBUILD_PYTHON=ON -DBYPASS_GPU_CHECK=ON -DUSE_ROCM=ON -DBUILD_TESTS=OFF .. + CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_BUILD_TYPE=Debug -DARK_BUILD_PYTHON=ON -DARK_BYPASS_GPU_CHECK=ON -DARK_USE_ROCM=ON -DARK_BUILD_TESTS=OFF .. make -j build ark_py - name: Perform CodeQL Analysis diff --git a/.github/workflows/ut-cuda.yml b/.github/workflows/ut-cuda.yml index 4e573adfb..c2e8e7c50 100644 --- a/.github/workflows/ut-cuda.yml +++ b/.github/workflows/ut-cuda.yml @@ -44,7 +44,7 @@ jobs: - name: Build run: | mkdir build && cd build - cmake -DCMAKE_BUILD_TYPE=Debug -DBUILD_PYTHON=ON .. + cmake -DCMAKE_BUILD_TYPE=Debug -DARK_BUILD_PYTHON=ON .. make -j ut ark_py - name: Run C++ UT diff --git a/CMakeLists.txt b/CMakeLists.txt index ee1e3566e..2e80ea1e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,16 +13,16 @@ enable_language(CXX) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") -option(USE_CUDA "Use NVIDIA/CUDA." OFF) -option(USE_ROCM "Use AMD/ROCm." OFF) -option(BYPASS_GPU_CHECK "Bypass GPU check." OFF) -option(BUILD_TESTS "Build unit tests." ON) +option(ARK_USE_CUDA "Use NVIDIA/CUDA." OFF) +option(ARK_USE_ROCM "Use AMD/ROCm." OFF) +option(ARK_BYPASS_GPU_CHECK "Bypass GPU check." OFF) +option(ARK_BUILD_TESTS "Build unit tests." ON) -if(BYPASS_GPU_CHECK) - if(USE_CUDA) +if(ARK_BYPASS_GPU_CHECK) + if(ARK_USE_CUDA) message("Bypassing GPU check: using NVIDIA/CUDA.") find_package(CUDAToolkit REQUIRED) - elseif(USE_ROCM) + elseif(ARK_USE_ROCM) message("Bypassing GPU check: using AMD/ROCm.") set(CMAKE_PREFIX_PATH "/opt/rocm;${CMAKE_PREFIX_PATH}") find_package(hip REQUIRED) @@ -35,16 +35,16 @@ else() include(CheckAmdGpu) if(NVIDIA_FOUND AND AMD_FOUND) message("Detected NVIDIA/CUDA and AMD/ROCm: prioritizing NVIDIA/CUDA.") - set(USE_CUDA ON) - set(USE_ROCM OFF) + set(ARK_USE_CUDA ON) + set(ARK_USE_ROCM OFF) elseif(NVIDIA_FOUND) message("Detected NVIDIA/CUDA.") - set(USE_CUDA ON) - set(USE_ROCM OFF) + set(ARK_USE_CUDA ON) + set(ARK_USE_ROCM OFF) elseif(AMD_FOUND) message("Detected AMD/ROCm.") - set(USE_CUDA OFF) - set(USE_ROCM ON) + set(ARK_USE_CUDA OFF) + set(ARK_USE_ROCM ON) else() message(FATAL_ERROR "Neither NVIDIA/CUDA nor AMD/ROCm is found.") endif() @@ -53,7 +53,7 @@ endif() # Declare project set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-deprecated-declarations") -if(USE_CUDA) +if(ARK_USE_CUDA) set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall,-Wextra") project(ark LANGUAGES CXX CUDA) @@ -72,7 +72,7 @@ if(USE_CUDA) if(CUDAToolkit_VERSION_MAJOR GREATER_EQUAL 12) set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES} 90) endif() -else() # USE_ROCM +else() # ARK_USE_ROCM set(CMAKE_HIP_STANDARD 17) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Wall -Wextra") project(ark LANGUAGES CXX HIP) @@ -145,7 +145,7 @@ add_custom_target(ut) # Details add_subdirectory(ark) -if(BUILD_PYTHON) +if(ARK_BUILD_PYTHON) # Install Python module add_subdirectory(python) add_dependencies(ark_py build) diff --git a/ark/CMakeLists.txt b/ark/CMakeLists.txt index 4457d3c0b..208d9f9cb 100644 --- a/ark/CMakeLists.txt +++ b/ark/CMakeLists.txt @@ -6,7 +6,7 @@ file(GLOB_RECURSE UT_SOURCES CONFIGURE_DEPENDS *_test.cpp) file(GLOB_RECURSE UT_COMMON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/unittest/*.cpp) list(REMOVE_ITEM SOURCES ${UT_SOURCES} ${UT_COMMON_SOURCES}) -if(USE_ROCM) +if(ARK_USE_ROCM) file(GLOB_RECURSE CU_SOURCES CONFIGURE_DEPENDS *.cu) set_source_files_properties(${CU_SOURCES} PROPERTIES LANGUAGE CXX) endif() @@ -23,7 +23,7 @@ target_include_directories(ark_obj SYSTEM PRIVATE ${NUMA_INCLUDE_DIRS} ) -if(USE_CUDA) +if(ARK_USE_CUDA) list(APPEND COMMON_LIBS CUDA::cuda_driver) target_include_directories(ark_obj SYSTEM PRIVATE ${PROJECT_SOURCE_DIR}/third_party/cutlass/include @@ -32,7 +32,7 @@ if(USE_CUDA) target_compile_definitions(ark_obj PUBLIC ARK_CUDA) endif() -if(USE_ROCM) +if(ARK_USE_ROCM) list(APPEND COMMON_LIBS hip::host) target_include_directories(ark_obj SYSTEM PRIVATE ${PROJECT_SOURCE_DIR}/third_party/cutlass/include @@ -45,7 +45,7 @@ target_sources(ark_obj PRIVATE ${SOURCES}) target_link_libraries(ark_obj PUBLIC mscclpp_static PRIVATE ${COMMON_LIBS}) # ARK unit tests -if(BUILD_TESTS) +if(ARK_BUILD_TESTS) foreach(ut_source IN ITEMS ${UT_SOURCES}) get_filename_component(exe_name ${ut_source} NAME_WE) add_executable(${exe_name} ${ut_source} ${UT_COMMON_SOURCES}) @@ -58,7 +58,7 @@ if(BUILD_TESTS) ${NUMA_INCLUDE_DIRS} ) - if(USE_CUDA) + if(ARK_USE_CUDA) target_link_libraries(${exe_name} PRIVATE ark_obj ${COMMON_LIBS} CUDA::cudart CUDA::cublas) target_include_directories(${exe_name} SYSTEM PRIVATE ${CUDAToolkit_INCLUDE_DIRS} diff --git a/pyproject.toml b/pyproject.toml index 1f9386c73..d9fb4502e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ install.strip = true build-dir = "build/{wheel_tag}" [tool.scikit-build.cmake.define] -BUILD_PYTHON = "ON" +ARK_BUILD_PYTHON = "ON" [tool.black] line-length = 80 diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 12ae74298..96e442289 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -14,7 +14,12 @@ FetchContent_Declare( GIT_TAG v0.5.2 SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/mscclpp ) +set(BUILD_TESTS OFF CACHE BOOL "" FORCE) set(BUILD_PYTHON_BINDINGS OFF CACHE BOOL "" FORCE) +set(BUILD_APPS_NCCL OFF CACHE BOOL "" FORCE) +set(USE_CUDA ${ARK_USE_CUDA} CACHE BOOL "" FORCE) +set(USE_ROCM ${ARK_USE_ROCM} CACHE BOOL "" FORCE) +set(BYPASS_GPU_CHECK ON CACHE BOOL "" FORCE) set(INSTALL_PREFIX "ark") FetchContent_GetProperties(mscclpp) if (NOT mscclpp_POPULATED) @@ -35,7 +40,7 @@ if (NOT json_POPULATED) endif() set(JSON_INCLUDE_DIRS ${json_SOURCE_DIR}/include PARENT_SCOPE) -if(USE_CUDA) +if(ARK_USE_CUDA) # Configure CUTLASS FetchContent_Declare( cutlass @@ -58,7 +63,7 @@ if(USE_CUDA) endif() -if(USE_ROCM) +if(ARK_USE_ROCM) # Configure CK FetchContent_Declare( ck diff --git a/third_party/mscclpp b/third_party/mscclpp index cddffbc8b..40cb19655 160000 --- a/third_party/mscclpp +++ b/third_party/mscclpp @@ -1 +1 @@ -Subproject commit cddffbc8b6dfa6facf7c64c1b7d73acf30e600b3 +Subproject commit 40cb1965538ab98fea3cc9fe004f730e23e84829 From 3dda44a8dc310560333de0cf9090d7da0013e21f Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 6 Aug 2024 18:15:09 +0000 Subject: [PATCH 51/54] A few fixes & improved coverage --- ark/api/executor.cpp | 21 +++-- ark/api/executor_test.cpp | 150 +++++++++++++++++++++++++++++++++++ ark/include/ark/executor.hpp | 2 +- python/executor_py.cpp | 2 +- 4 files changed, 161 insertions(+), 14 deletions(-) create mode 100644 ark/api/executor_test.cpp diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 42ed45128..16d369bc8 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -141,7 +141,7 @@ static size_t tensor_stride_bytes(const Json &tensor) { class Executor::Impl { public: Impl(int device_id, Stream stream, const std::string &name, bool loop_mode); - ~Impl() = default; + ~Impl(); void init(const PlanJson& plan); @@ -152,7 +152,7 @@ class Executor::Impl { std::string plan() const { return plan_json_.dump_pretty(); } void compile(); - void launch(int64_t max_spin_count); + void launch(); void run(int iter); void wait(int64_t max_spin_count); float stop(int64_t max_spin_count); @@ -219,6 +219,10 @@ Executor::Impl::Impl(int device_id, Stream stream, const std::string &name, } } +Executor::Impl::~Impl() { + if (is_launched_) stop(-1); +} + void Executor::Impl::init(const PlanJson &plan_json) { plan_json_ = plan_json; rank_ = plan_json_["Rank"].get(); @@ -620,13 +624,12 @@ void Executor::Impl::init_channels(const std::set &remote_ranks) { void Executor::Impl::compile() { kernel_->compile(); } -void Executor::Impl::launch(int64_t max_spin_count) { +void Executor::Impl::launch() { if (!kernel_->is_compiled()) { ERR(InvalidUsageError, "Need to compile first before initialization."); } if (is_launched_) { - // Wait until previous works finish. - this->wait(max_spin_count); + LOG(WARN, "Ignore launching twice."); return; } auto get_global_rt = [&](const std::string &symbol) { @@ -674,12 +677,6 @@ void Executor::Impl::launch(int64_t max_spin_count) { } elapsed_msec_ = -1; - if (!kernel_->is_compiled()) { - ERR(InvalidUsageError, "Need to compile first before initialization."); - } else if (is_launched_) { - LOG(WARN, "Ignore launching twice."); - return; - } timer_begin_->record(stream_raw_); if (world_size_ > 1) { @@ -911,7 +908,7 @@ std::string Executor::plan() const { return impl_->plan(); } void Executor::compile() { impl_->compile(); } -void Executor::launch(int64_t max_spin_count) { impl_->launch(max_spin_count); } +void Executor::launch() { impl_->launch(); } void Executor::run(int iter) { impl_->run(iter); } diff --git a/ark/api/executor_test.cpp b/ark/api/executor_test.cpp new file mode 100644 index 000000000..b0b398ac9 --- /dev/null +++ b/ark/api/executor_test.cpp @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ark/executor.hpp" + +#include "gpu/gpu.hpp" +#include "model/model_json.hpp" +#include "unittest/unittest_utils.h" + +template +ark::unittest::State test_executor() { + ark::gpuStream stream; + UNITTEST_EQ( + ark::gpuStreamCreateWithFlags(&stream, ark::gpuStreamNonBlocking), + ark::gpuSuccess); + + ark::Model empty; + { + ark::DefaultExecutor executor(empty, 0, stream, {}, "test", LoopMode); + UNITTEST_EQ(executor.device_id(), 0); + UNITTEST_EQ(executor.stream(), stream); + + executor.compile(); + executor.launch(); + executor.run(1); + executor.wait(); + executor.stop(); + executor.destroy(); + } + { + ark::DefaultExecutor executor(empty, 0, stream, {}, "test", LoopMode); + executor.compile(); + executor.launch(); + executor.run(1); + executor.wait(); + executor.stop(); + + executor.launch(); + executor.run(1); + executor.wait(); + executor.stop(); + + executor.destroy(); + } + { + ark::DefaultExecutor executor(empty, 0, stream, {}, "test", LoopMode); + UNITTEST_THROW(executor.launch(), ark::InvalidUsageError); + + executor.compile(); + executor.launch(); + executor.launch(); // Will be ignored with a warning. + executor.run(1); + executor.wait(); + executor.wait(); // nothing to do + + // Stop & destroy automatically. + } + + UNITTEST_EQ(ark::gpuStreamDestroy(stream), ark::gpuSuccess); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_executor_loop() { return test_executor(); } + +ark::unittest::State test_executor_no_loop() { return test_executor(); } + +ark::unittest::State test_executor_tensor_read_write() { + // Alloc CPU array + std::vector host_data(1024); + void *host_ptr = host_data.data(); + for (size_t i = 0; i < host_data.size(); ++i) { + host_data[i] = static_cast(i); + } + + // Alloc GPU array + void *dev_ptr; + UNITTEST_EQ(ark::gpuMalloc(&dev_ptr, 1024 * sizeof(float)), + ark::gpuSuccess); + + // Create an ARK tensor + ark::Model m; + auto tensor = m.tensor({1024}, ark::FP32); + m.noop(tensor); + + ark::DefaultExecutor executor(m, 0); + executor.compile(); + executor.launch(); + + // Copy data from CPU array to ARK tensor + executor.tensor_write(tensor, host_ptr, 1024 * sizeof(float)); + + // Copy data from ARK tensor to GPU array + executor.tensor_read(tensor, dev_ptr, 1024 * sizeof(float), nullptr, true); + + // Check the data + std::vector dev_data(1024); + executor.tensor_read(tensor, dev_data.data(), 1024 * sizeof(float)); + for (size_t i = 0; i < dev_data.size(); ++i) { + UNITTEST_EQ(dev_data[i], static_cast(i)); + dev_data[i] = -1; + } + + UNITTEST_EQ(ark::gpuMemcpy(dev_data.data(), dev_ptr, 1024 * sizeof(float), + ark::gpuMemcpyDeviceToHost), + ark::gpuSuccess); + for (size_t i = 0; i < dev_data.size(); ++i) { + UNITTEST_EQ(dev_data[i], static_cast(i)); + dev_data[i] = -1; + } + + // Copy -1s back to GPU array + UNITTEST_EQ(ark::gpuMemcpy(dev_ptr, dev_data.data(), 1024 * sizeof(float), + ark::gpuMemcpyHostToDevice), + ark::gpuSuccess); + + // Copy data from GPU array to ARK tensor + executor.tensor_write(tensor, dev_ptr, 1024 * sizeof(float), nullptr, true); + + // Copy data from ARK tensor to CPU array + executor.tensor_read(tensor, host_ptr, 1024 * sizeof(float)); + + // Check the data + for (size_t i = 0; i < host_data.size(); ++i) { + UNITTEST_EQ(host_data[i], -1); + } + + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_executor_invalid() { + // Invalid device ID. + UNITTEST_THROW(ark::Executor(-1, nullptr, "test", ""), + ark::InvalidUsageError); + + // Invalid rank. + ark::PlanJson plan; + plan["Rank"] = 1; + UNITTEST_THROW(ark::Executor(0, nullptr, "test", plan.dump(), true), + ark::InvalidUsageError); + + return ark::unittest::SUCCESS; +} + +int main() { + UNITTEST(test_executor_loop); + UNITTEST(test_executor_no_loop); + UNITTEST(test_executor_tensor_read_write); + UNITTEST(test_executor_invalid); + return 0; +} diff --git a/ark/include/ark/executor.hpp b/ark/include/ark/executor.hpp index 3744c33db..7f30f39ed 100644 --- a/ark/include/ark/executor.hpp +++ b/ark/include/ark/executor.hpp @@ -39,7 +39,7 @@ class Executor { /// Launch the model (not running yet). This must be called after /// `compile()`. - void launch(int64_t max_spin_count = -1); + void launch(); /// Run the model for `iter` iterations. void run(int iter); diff --git a/python/executor_py.cpp b/python/executor_py.cpp index a3f2a078b..36e1c435e 100644 --- a/python/executor_py.cpp +++ b/python/executor_py.cpp @@ -56,7 +56,7 @@ void register_executor(py::module &m) { }) .def("plan", &ark::Executor::plan) .def("compile", &ark::Executor::compile) - .def("launch", &ark::Executor::launch, py::arg("max_spin_count") = -1) + .def("launch", &ark::Executor::launch) .def("run", &ark::Executor::run, py::arg("iter")) .def("wait", &ark::Executor::wait, py::arg("max_spin_count") = -1) .def("stop", &ark::Executor::stop, py::arg("max_spin_count") = -1) From 4971601b09880e29adc85ab305a739edf55ccbb0 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 6 Aug 2024 19:03:08 +0000 Subject: [PATCH 52/54] fix merge --- ark/api/context_manager.cpp | 42 ---------- ark/api/context_manager_test.cpp | 53 ------------ ark/api/executor.cpp | 8 -- ark/api/model.cpp | 2 +- ark/api/model_graph.cpp | 4 +- ark/api/plan_manager.cpp | 97 ---------------------- ark/api/plan_manager_test.cpp | 58 ------------- ark/codegen.cpp | 1 - ark/include/ark/context_manager.hpp | 24 ------ ark/include/ark/error.hpp | 15 +++- ark/include/ark/model.hpp | 57 +++++-------- ark/include/ark/model_graph.hpp | 2 +- ark/include/ark/plan_manager.hpp | 25 ------ ark/model/model_graph_impl.cpp | 6 +- ark/model/model_graph_impl.hpp | 8 +- ark/model/model_op.cpp | 11 --- ark/model/model_op.hpp | 9 +- ark/ops/ops_arithmetic.cpp | 20 ++--- ark/ops/ops_arithmetic_test.cpp | 48 ++++------- ark/ops/ops_cast.cpp | 10 +-- ark/ops/ops_communication.cpp | 14 ++-- ark/ops/ops_copy.cpp | 5 +- ark/ops/ops_embedding.cpp | 4 +- ark/ops/ops_identity.cpp | 2 +- ark/ops/ops_math.cpp | 31 +++---- ark/ops/ops_matmul.cpp | 6 +- ark/ops/ops_noop.cpp | 2 +- ark/ops/ops_reduce.cpp | 12 +-- ark/ops/ops_refer.cpp | 2 +- ark/ops/ops_reshape.cpp | 4 +- ark/ops/ops_rope.cpp | 5 +- ark/ops/ops_scalar.cpp | 31 +++---- ark/ops/ops_transpose.cpp | 5 +- examples/llama/model_7b_b1_s2048.py | 70 ++++++++-------- examples/tutorial/plan_manager_tutorial.py | 81 ------------------ python/ark/plan_manager.py | 34 -------- python/ark/runtime.py | 1 + python/model_py.cpp | 79 ++++++++---------- python/plan_manager_py.cpp | 15 ---- 39 files changed, 195 insertions(+), 708 deletions(-) delete mode 100644 ark/api/context_manager.cpp delete mode 100644 ark/api/context_manager_test.cpp delete mode 100644 ark/api/plan_manager.cpp delete mode 100644 ark/api/plan_manager_test.cpp delete mode 100644 ark/include/ark/context_manager.hpp delete mode 100644 ark/include/ark/plan_manager.hpp delete mode 100644 examples/tutorial/plan_manager_tutorial.py delete mode 100644 python/ark/plan_manager.py delete mode 100644 python/plan_manager_py.cpp diff --git a/ark/api/context_manager.cpp b/ark/api/context_manager.cpp deleted file mode 100644 index 6d16d9e79..000000000 --- a/ark/api/context_manager.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include "ark/context_manager.hpp" - -#include "model/model_graph_impl.hpp" - -namespace ark { - -class ContextManager::Impl { - public: - Impl(std::shared_ptr context_stack, - const std::map& context_map); - - ~Impl(); - - private: - std::shared_ptr context_stack_; - std::vector keys_; -}; - -ContextManager::Impl::Impl( - std::shared_ptr context_stack, - const std::map& context_map) - : context_stack_(context_stack) { - for (const auto& [key, value] : context_map) { - context_stack_->push(key, value); - keys_.push_back(key); - } -} - -ContextManager::Impl::~Impl() { - for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { - context_stack_->pop(*it); - } -} - -ContextManager::ContextManager( - Model& model, const std::map& context_map) - : impl_(std::make_shared(model.impl_->context_stack_, context_map)) {} - -} // namespace ark diff --git a/ark/api/context_manager_test.cpp b/ark/api/context_manager_test.cpp deleted file mode 100644 index 5fff94f34..000000000 --- a/ark/api/context_manager_test.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include "ark/context_manager.hpp" - -#include "model/model_node.hpp" -#include "unittest/unittest_utils.h" - -ark::unittest::State test_context_manager() { - ark::Model model; - ark::Tensor t0 = model.tensor({1}, ark::FP32); - ark::Tensor t1 = model.tensor({1}, ark::FP32); - ark::Tensor t2 = model.add(t0, t1); - - ark::Tensor t3; - ark::Tensor t4; - ark::Tensor t5; - { - ark::ContextManager cm0_1(model, {{"key0", "val1"}}); - t3 = model.relu(t2); - - ark::ContextManager cm1_1(model, {{"key1", "val2"}}); - t4 = model.sqrt(t3); - } - { - ark::ContextManager cm0_2(model, {{"key0", "val3"}}); - t5 = model.exp(t2); - } - - UNITTEST_TRUE(model.verify()); - - auto compressed = model.compress(false); - UNITTEST_TRUE(compressed.verify()); - - auto nodes = compressed.nodes(); - UNITTEST_EQ(nodes.size(), 4); - - UNITTEST_EQ(nodes[0]->context.size(), 0); - UNITTEST_EQ(nodes[1]->context.size(), 1); - UNITTEST_EQ(nodes[1]->context.at("key0"), "val1"); - UNITTEST_EQ(nodes[2]->context.size(), 2); - UNITTEST_EQ(nodes[2]->context.at("key0"), "val1"); - UNITTEST_EQ(nodes[2]->context.at("key1"), "val2"); - UNITTEST_EQ(nodes[3]->context.size(), 1); - UNITTEST_EQ(nodes[3]->context.at("key0"), "val3"); - - return ark::unittest::SUCCESS; -} - -int main() { - UNITTEST(test_context_manager); - return 0; -} diff --git a/ark/api/executor.cpp b/ark/api/executor.cpp index 6fb2b5f2e..17d579763 100644 --- a/ark/api/executor.cpp +++ b/ark/api/executor.cpp @@ -250,14 +250,6 @@ void Executor::Impl::init(const PlanJson &plan_json) { gpu_manager->info().arch->name(), "`."); } - if (!gpu_manager->info().arch->belongs_to( - Arch::from_name(plan_json_.at("Architecture")))) { - LOG(WARN, "Architecture name of the plan `", - plan_json_.at("Architecture").get(), - "` is not compatible with the GPU architecture `", - gpu_manager->info().arch->name(), "`."); - } - buffer_id_to_offset_ = init_buffers(plan_json_); std::string buffer_id_to_offset_str; diff --git a/ark/api/model.cpp b/ark/api/model.cpp index 8227ea848..dcbd4940e 100644 --- a/ark/api/model.cpp +++ b/ark/api/model.cpp @@ -20,7 +20,7 @@ size_t Model::id() const { return id_; } Model Model::compress() const { Model model(*this); - model.compress_nodes(merge_nodes); + model.compress_nodes(); return model; } diff --git a/ark/api/model_graph.cpp b/ark/api/model_graph.cpp index a4477b8e6..e07565141 100644 --- a/ark/api/model_graph.cpp +++ b/ark/api/model_graph.cpp @@ -33,9 +33,7 @@ int ModelGraph::rank() const { return impl_->rank(); } int ModelGraph::world_size() const { return impl_->world_size(); } -void ModelGraph::compress_nodes(bool merge_nodes) { - impl_->compress_nodes(merge_nodes); -} +void ModelGraph::compress_nodes() { impl_->compress_nodes(); } bool ModelGraph::compressed() const { return impl_->compressed(); } diff --git a/ark/api/plan_manager.cpp b/ark/api/plan_manager.cpp deleted file mode 100644 index 8cb1940b1..000000000 --- a/ark/api/plan_manager.cpp +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include "ark/plan_manager.hpp" - -#include "logging.h" -#include "model/model_json.hpp" -#include "model/model_graph_impl.hpp" - -namespace ark { - -class PlanManagerState { - public: - PlanManagerState() : sync(true) {} - bool sync; -}; - -static std::map gPlanManagerStates; - -PlanManager::PlanManager(Model& model, const std::string& plan_context) - : model_id_(model.id()), stop_sync_(false) { - static int task_group_id = 0; - auto ctx = Json::parse(plan_context); - if (!ctx.is_object()) { - ERR(ModelError, "plan context must be a JSON object"); - } - if (gPlanManagerStates.find(model_id_) == gPlanManagerStates.end()) { - gPlanManagerStates.emplace(model_id_, PlanManagerState()); - } - auto& state = gPlanManagerStates[model_id_]; - bool async = !state.sync; - std::map context_map; - for (const auto& [key, value] : ctx.items()) { - if (key == "sync") { - if (!value.is_boolean()) { - ERR(ModelError, "sync must be a boolean"); - } - if (state.sync && !value.get()) { - stop_sync_ = true; - state.sync = false; - context_map["TaskGroupId"] = std::to_string(task_group_id++); - } - } else if (key == "processor_range") { - if (!value.is_array()) { - ERR(ModelError, "processor_range must be an array"); - } - if (async) { - LOG(WARN, "Ignoring processor_range under sync=false context"); - continue; - } - context_map["ProcessorRange"] = value.dump(); - } else if (key == "warp_range") { - if (!value.is_array()) { - ERR(ModelError, "warp_range must be an array"); - } - if (async) { - LOG(WARN, "Ignoring warp_range under sync=false context"); - continue; - } - context_map["WarpRange"] = value.dump(); - } else if (key == "sram_range") { - if (!value.is_array()) { - ERR(ModelError, "sram_range must be an array"); - } - if (async) { - LOG(WARN, "Ignoring sram_range under sync=false context"); - continue; - } - context_map["SramRange"] = value.dump(); - } else if (key == "config") { - if (!value.is_object()) { - ERR(ModelError, "config must be an object"); - } - auto cfg = model.impl_->get_context("Config"); - if (cfg.empty()) { - context_map["Config"] = value.dump(); - } else { - auto cfg_obj = Json::parse(cfg); - for (const auto& [k, v] : value.items()) { - cfg_obj[k] = v; - } - context_map["Config"] = cfg_obj.dump(); - } - } else { - LOG(WARN, "Ignoring unknown plan context key: ", key); - } - } - context_manager_ = std::make_shared(model, context_map); -} - -PlanManager::~PlanManager() { - if (stop_sync_) { - gPlanManagerStates[model_id_].sync = true; - } -} - -} // namespace ark diff --git a/ark/api/plan_manager_test.cpp b/ark/api/plan_manager_test.cpp deleted file mode 100644 index 78f5d4cb8..000000000 --- a/ark/api/plan_manager_test.cpp +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include "ark/plan_manager.hpp" -#include "ark/planner.hpp" - -#include "model/model_json.hpp" -#include "unittest/unittest_utils.h" - -ark::unittest::State test_plan_manager() { - ark::Model model; - ark::Tensor t0 = model.tensor({1}, ark::FP32); - ark::Tensor t1 = model.tensor({1}, ark::FP32); - ark::Tensor t2 = model.add(t0, t1); - - ark::Tensor t3; - ark::Tensor t4; - ark::Tensor t5; - ark::Tensor t6; - { - ark::PlanManager pm_0(model, ark::Json({ - {"processor_range", {0, 2}}, - {"warp_range", {0, 4}}, - {"sram_range", {0, 0}}, - {"sync", false} - }).dump()); - t3 = model.relu(t2); - t4 = model.sqrt(t3); - } - { - ark::PlanManager pm_0(model, ark::Json({ - {"processor_range", {2, 4}}, - {"warp_range", {0, 4}}, - {"sram_range", {0, 0}} - }).dump()); - t5 = model.exp(t2); - - ark::PlanManager pm_1(model, ark::Json({ - {"processor_range", {2, 3}} - }).dump()); - t6 = model.rsqrt(t5); - } - - UNITTEST_TRUE(model.verify()); - - ark::DefaultPlanner planner(model, 0); - auto plan_str = planner.plan(); - ark::Json plan = ark::Json::parse(plan_str); - - UNITTEST_LOG(plan_str); - - return ark::unittest::SUCCESS; -} - -int main() { - UNITTEST(test_plan_manager); - return 0; -} diff --git a/ark/codegen.cpp b/ark/codegen.cpp index bc43584cb..1619b863f 100644 --- a/ark/codegen.cpp +++ b/ark/codegen.cpp @@ -87,7 +87,6 @@ CodeGenerator::Impl::Impl(const PlanJson &plan, num_warps_per_proc_ = plan.at("NumWarpsPerProcessor"); std::stringstream definitions_ss; - for (auto &task_json : plan.at("TaskInfos")) { definitions_ss << this->def_task(task_json); } diff --git a/ark/include/ark/context_manager.hpp b/ark/include/ark/context_manager.hpp deleted file mode 100644 index 58271ea8c..000000000 --- a/ark/include/ark/context_manager.hpp +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#ifndef ARK_CONTEXT_MANAGER_HPP -#define ARK_CONTEXT_MANAGER_HPP - -#include -#include - -namespace ark { - -class ContextManager { - public: - ContextManager(Model& model, - const std::map& context_map); - - private: - class Impl; - std::shared_ptr impl_; -}; - -} // namespace ark - -#endif // ARK_CONTEXT_MANAGER_HPP diff --git a/ark/include/ark/error.hpp b/ark/include/ark/error.hpp index 78d02cab3..965b1c0bc 100644 --- a/ark/include/ark/error.hpp +++ b/ark/include/ark/error.hpp @@ -9,6 +9,7 @@ namespace ark { +/// Base class for all ARK errors. class BaseError : public std::exception { private: std::string msg_; @@ -24,15 +25,21 @@ class BaseError : public std::exception { _name(const std::string &msg) : BaseError(msg) {} \ }; +/// Internal error in ARK, likely a bug. REGISTER_ERROR_TYPE(InternalError) +/// Invalid usage of ARK API. REGISTER_ERROR_TYPE(InvalidUsageError) -REGISTER_ERROR_TYPE(NotFoundError) +/// Invalid ARK model definition or usage. REGISTER_ERROR_TYPE(ModelError) -REGISTER_ERROR_TYPE(SchedulerError) -REGISTER_ERROR_TYPE(ExecutorError) +/// Invalid ARK plan definition or usage. +REGISTER_ERROR_TYPE(PlanError) +/// Unsupported feature triggered. +REGISTER_ERROR_TYPE(UnsupportedError) +/// Error from invalid system state such as a system call failure. REGISTER_ERROR_TYPE(SystemError) +/// Error from a CUDA/HIP API call. REGISTER_ERROR_TYPE(GpuError) -REGISTER_ERROR_TYPE(RuntimeError) +/// Error from a unit test. REGISTER_ERROR_TYPE(UnitTestError) } // namespace ark diff --git a/ark/include/ark/model.hpp b/ark/include/ark/model.hpp index cbbff7f95..3c4f22e22 100644 --- a/ark/include/ark/model.hpp +++ b/ark/include/ark/model.hpp @@ -103,29 +103,23 @@ class Model : public ModelGraph { // result in `output`. // Currently, only reduction along the last dimension is supported. Tensor reduce_sum(Tensor input, int axis, bool keepdims = true, - Tensor output = NullTensor, - const std::string &config = "", - const std::string &name = ""); + Tensor output = NullTensor, const std::string &name = ""); Tensor reduce_mean(Tensor input, int axis, bool keepdims = true, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); Tensor reduce_max(Tensor input, int axis, bool keepdims = true, - Tensor output = NullTensor, - const std::string &config = "", - const std::string &name = ""); + Tensor output = NullTensor, const std::string &name = ""); // Transposes the `input` tensor according to the given `permutation`. // For example, transpose(input, {0, 1 ,3, 2}) will swap the last two // dimensions of the input tensor. Currently, only 4D tensors are supported. Tensor transpose(Tensor input, const std::vector &permutation, - Tensor output = NullTensor, const std::string &config = "", - const std::string &name = ""); + Tensor output = NullTensor, const std::string &name = ""); // Performs matrix multiplication between the `input` tensor and another // `other` tensor, storing the result in `output`. Tensor matmul(Tensor input, Tensor other, Tensor output = NullTensor, bool trans_input = false, bool trans_other = false, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); // Implements the 'im2col' method for 2D convolution layers, which takes an // `input` tensor and reshapes it to a 2D matrix by extracting image patches // from the input tensor based on the provided parameters. @@ -142,66 +136,63 @@ class Model : public ModelGraph { Tensor output = NullTensor, const std::string &name = ""); // Calculates the exponential of the `input` tensor, element-wise. Tensor exp(Tensor input, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); // Calculates the square root of the `input` tensor, element-wise. Tensor sqrt(Tensor input, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); // Calculates the reverse square root of the `input` tensor, element-wise. Tensor rsqrt(Tensor input, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); // ReLU activation Tensor relu(Tensor input, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); // Copy the `input` tensor to `output` tensor Tensor copy(Tensor input, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); Tensor copy(float val, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); // Applies the Gaussian Error Linear Unit (GELU) activation function to the // `input` tensor, element-wise. GELU is a smooth approximation of the // rectifier function and is widely used in deep learning models. Tensor gelu(Tensor input, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); // Sigmoid activation Tensor sigmoid(Tensor input, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); // Performs rotary position embedding (RoPE) on the `input` tensor Tensor rope(Tensor input, Tensor other, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); // Performs an element-wise addition operator between the `input` tensor // and the `other` tensor Tensor add(Tensor input, Tensor other, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); Tensor add(Tensor input, float value, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); // Performs an element-wise subtraction operator between the `input` tensor // and the `other` tensor Tensor sub(Tensor input, Tensor other, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); Tensor sub(Tensor input, float value, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); // Performs an element-wise multiplication operator between the `input` // tensor and the `other` tensor, Tensor mul(Tensor input, Tensor other, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); Tensor mul(Tensor input, float value, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); // Performs an element-wise division operator between the `input` // tensor and the `other` tensor, Tensor div(Tensor input, Tensor other, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); Tensor div(Tensor input, float value, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); + const std::string &name = ""); Tensor send(Tensor input, int remote_rank, int tag, - Tensor output = NullTensor, const std::string &config = "", - const std::string &name = ""); + Tensor output = NullTensor, const std::string &name = ""); // Blocks the execution until the corresponding 'send' operator with the // specified `id` is completed. - Tensor send_done(Tensor input, const std::string &config = "", - const std::string &name = ""); + Tensor send_done(Tensor input, const std::string &name = ""); // Receives a tensor from a source rank (@p src_rank), identified by the // `id` parameter. Blocks the execution until the corresponding 'recv' // operator is completed. @@ -238,12 +229,10 @@ class Model : public ModelGraph { const std::string &name = ""); /// Embedding layer. Tensor embedding(Tensor input, Tensor weight, Tensor output = NullTensor, - const std::string &config = "", const std::string &name = ""); /// Tensor type casting. Tensor cast(Tensor input, const DataType &data_type, - Tensor output = NullTensor, const std::string &config = "", - const std::string &name = ""); + Tensor output = NullTensor, const std::string &name = ""); // sync across multi devices Tensor device_sync(Tensor input, int rank, int rank_num, diff --git a/ark/include/ark/model_graph.hpp b/ark/include/ark/model_graph.hpp index 598bf343a..29074630c 100644 --- a/ark/include/ark/model_graph.hpp +++ b/ark/include/ark/model_graph.hpp @@ -25,7 +25,7 @@ class ModelGraph { int world_size() const; - void compress_nodes(bool merge_nodes = false); + void compress_nodes(); bool compressed() const; diff --git a/ark/include/ark/plan_manager.hpp b/ark/include/ark/plan_manager.hpp deleted file mode 100644 index 3952a1c06..000000000 --- a/ark/include/ark/plan_manager.hpp +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#ifndef ARK_PLAN_MANAGER_HPP -#define ARK_PLAN_MANAGER_HPP - -#include - -namespace ark { - -class PlanManager { - public: - PlanManager(Model& model, const std::string& plan_context); - - ~PlanManager(); - - private: - size_t model_id_; - bool stop_sync_; - std::shared_ptr context_manager_; -}; - -} // namespace ark - -#endif // ARK_PLAN_MANAGER_HPP diff --git a/ark/model/model_graph_impl.cpp b/ark/model/model_graph_impl.cpp index 81359439a..7c1ea3fb5 100644 --- a/ark/model/model_graph_impl.cpp +++ b/ark/model/model_graph_impl.cpp @@ -112,7 +112,7 @@ ModelGraph::Impl &ModelGraph::Impl::operator=(const ModelGraph::Impl &other) { return *this; } -void ModelGraph::Impl::compress_nodes(bool merge_nodes) { +void ModelGraph::Impl::compress_nodes() { if (!compressed_) { this->recursive_remove_virtual_nodes(); compressed_ = true; @@ -178,10 +178,6 @@ bool ModelGraph::Impl::verify() const { return true; } -std::string ModelGraph::Impl::get_context(const std::string &key) const { - return context_stack_->get_context(key); -} - ModelNodeRef ModelGraph::Impl::add_op(ModelOpRef op) { for (auto &tns : op->input_tensors()) { if (tensor_to_producer_op_.find(tns) == tensor_to_producer_op_.end()) { diff --git a/ark/model/model_graph_impl.hpp b/ark/model/model_graph_impl.hpp index c7080ab73..62944f999 100644 --- a/ark/model/model_graph_impl.hpp +++ b/ark/model/model_graph_impl.hpp @@ -54,8 +54,7 @@ class ModelGraph::Impl { Impl &operator=(const Impl &other); template - ModelOpRef create_op(const std::string &config, const std::string &name, - Args &&...args) { + ModelOpRef create_op(const std::string &name, Args &&... args) { ModelOpRef op = std::make_shared(std::forward(args)...); std::string name_copy; if (name.empty()) { @@ -68,7 +67,6 @@ class ModelGraph::Impl { if (count > 0) { name_copy += "_" + std::to_string(count); } - op->set_config(config); op->set_name(name_copy); add_op(op); return op; @@ -78,14 +76,12 @@ class ModelGraph::Impl { int world_size() const { return world_size_; } - void compress_nodes(bool merge_nodes = false); + void compress_nodes(); bool compressed() const { return compressed_; } bool verify() const; - std::string get_context(const std::string &key) const; - std::string serialize(bool pretty = true) const; std::vector nodes() const; diff --git a/ark/model/model_op.cpp b/ark/model/model_op.cpp index dc4906235..5db8576e8 100644 --- a/ark/model/model_op.cpp +++ b/ark/model/model_op.cpp @@ -92,14 +92,6 @@ const ModelOpType ModelOpT::from_name(const std::string &type_name) { return it->second; } -void ModelOp::set_config(const std::string &config) { - if (!config.empty()) { - config_ = Json::parse(config); - } else { - config_.clear(); - } -} - std::vector ModelOp::input_tensors() const { // input_tensors = read_tensors || write_tensors std::set input_tensors; @@ -192,9 +184,6 @@ Json ModelOp::serialize() const { for (auto &arg : args_) { j["Args"][arg.first] = arg.second.serialize(); } - if (!config_.empty()) { - j["Config"] = config_; - } return j; } diff --git a/ark/model/model_op.hpp b/ark/model/model_op.hpp index d048375c2..f7323d6c0 100644 --- a/ark/model/model_op.hpp +++ b/ark/model/model_op.hpp @@ -50,8 +50,8 @@ class ModelOp { return ""; } - virtual std::vector impl_args( - [[maybe_unused]] const Json &config) const { + virtual std::vector impl_args([ + [maybe_unused]] const Json &config) const { return {}; } @@ -60,14 +60,10 @@ class ModelOp { return {{"NumTasks", 0}, {"NumWarps", 0}, {"SramBytes", 0}}; } - void set_config(const std::string &config); - void set_name(const std::string &name) { name_ = name; } ModelOpType type() const { return type_; } - const Json &config() const { return config_; } - const std::string &name() const { return name_; } bool is_virtual() const { return is_virtual_; } @@ -104,7 +100,6 @@ class ModelOp { const std::vector &template_args = {}); ModelOpType type_; - Json config_; std::string name_; bool is_virtual_; std::vector read_tensors_; diff --git a/ark/ops/ops_arithmetic.cpp b/ark/ops/ops_arithmetic.cpp index ef85b5d22..aeece0d77 100644 --- a/ark/ops/ops_arithmetic.cpp +++ b/ark/ops/ops_arithmetic.cpp @@ -12,10 +12,9 @@ ModelOpAdd::ModelOpAdd(ModelTensorRef input, ModelTensorRef other, : ModelOpBroadcast2("Add", input, other, output) {} Tensor Model::add(Tensor input, Tensor other, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, other.ref_, - output.ref_) + ->create_op(name, input.ref_, other.ref_, output.ref_) ->result_tensors()[0]; } @@ -24,10 +23,9 @@ ModelOpMul::ModelOpMul(ModelTensorRef input, ModelTensorRef other, : ModelOpBroadcast2("Mul", input, other, output) {} Tensor Model::mul(Tensor input, Tensor other, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, other.ref_, - output.ref_) + ->create_op(name, input.ref_, other.ref_, output.ref_) ->result_tensors()[0]; } @@ -36,10 +34,9 @@ ModelOpSub::ModelOpSub(ModelTensorRef input, ModelTensorRef other, : ModelOpBroadcast2("Sub", input, other, output) {} Tensor Model::sub(Tensor input, Tensor other, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, other.ref_, - output.ref_) + ->create_op(name, input.ref_, other.ref_, output.ref_) ->result_tensors()[0]; } @@ -48,10 +45,9 @@ ModelOpDiv::ModelOpDiv(ModelTensorRef input, ModelTensorRef other, : ModelOpBroadcast2("Div", input, other, output) {} Tensor Model::div(Tensor input, Tensor other, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, other.ref_, - output.ref_) + ->create_op(name, input.ref_, other.ref_, output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_arithmetic_test.cpp b/ark/ops/ops_arithmetic_test.cpp index fd6a05b1a..772da3276 100644 --- a/ark/ops/ops_arithmetic_test.cpp +++ b/ark/ops/ops_arithmetic_test.cpp @@ -2,7 +2,6 @@ // Licensed under the MIT license. #include "ops_test_common.hpp" -#include "model/model_json.hpp" template void baseline_add(std::vector &outputs, @@ -143,25 +142,12 @@ ark::unittest::State test_add_fp32() { ark::unittest::State test_add_fp16() { ark::Model m; - ark::Tensor t0 = m.tensor({32, 2048, 2048}, ark::FP16); - ark::Tensor t1 = m.tensor({32, 2048, 2048}, ark::FP16); + ark::Tensor t0 = m.tensor({8192}, ark::FP16); + ark::Tensor t1 = m.tensor({8192}, ark::FP16); ark::Tensor out = m.add(t0, t1); auto result = - ark::op_test("add_fp16", m, {t0, t1}, {out}, baseline_add, {}, - { - ark::DefaultPlanner::ConfigRule([](const std::string op_str, const std::string) { - auto op = ark::Json::parse(op_str); - ark::Json config; - if (op.at("Type") == "Add") { - config["NumWarps"] = 4; - config["SramBytes"] = 0; - config["Tile"] = {128, 256}; - config["NumTasks"] = 4096; - } - return config.dump(); - }) - }); + ark::op_test("add_fp16", m, {t0, t1}, {out}, baseline_add); UNITTEST_LOG(result); UNITTEST_EQ(result.max_diff[0], 0.0f); return ark::unittest::SUCCESS; @@ -430,20 +416,20 @@ ark::unittest::State test_div_invalid() { int main() { ark::init(); - // UNITTEST(test_add_fp32); + UNITTEST(test_add_fp32); UNITTEST(test_add_fp16); - // UNITTEST(test_add_bf16); - // UNITTEST(test_add_overwrite); - // UNITTEST(test_add_broadcast); - // UNITTEST(test_add_invalid); - // UNITTEST(test_sub_fp32); - // UNITTEST(test_sub_invalid); - // UNITTEST(test_mul_fp32); - // UNITTEST(test_mul_fp16); - // UNITTEST(test_mul_overwrite); - // UNITTEST(test_mul_broadcast); - // UNITTEST(test_mul_invalid); - // UNITTEST(test_div_fp32); - // UNITTEST(test_div_invalid); + UNITTEST(test_add_bf16); + UNITTEST(test_add_overwrite); + UNITTEST(test_add_broadcast); + UNITTEST(test_add_invalid); + UNITTEST(test_sub_fp32); + UNITTEST(test_sub_invalid); + UNITTEST(test_mul_fp32); + UNITTEST(test_mul_fp16); + UNITTEST(test_mul_overwrite); + UNITTEST(test_mul_broadcast); + UNITTEST(test_mul_invalid); + UNITTEST(test_div_fp32); + UNITTEST(test_div_invalid); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_cast.cpp b/ark/ops/ops_cast.cpp index 96146217e..e94fec989 100644 --- a/ark/ops/ops_cast.cpp +++ b/ark/ops/ops_cast.cpp @@ -105,7 +105,7 @@ ModelOpByteCast::ModelOpByteCast(ModelTensorRef input, ModelDataType data_type, } Tensor Model::cast(Tensor input, const DataType &data_type, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { check_null(input.ref()); if (output.is_null()) { if (input.data_type() == data_type) { @@ -119,14 +119,14 @@ Tensor Model::cast(Tensor input, const DataType &data_type, Tensor output, byte_cast_helper(input.ref(), data_type.ref(), new_shape, new_strides, new_offsets, new_padded_shape); return impl_ - ->create_op( - config, name, input.ref(), data_type.ref(), new_shape, - new_strides, new_offsets, new_padded_shape) + ->create_op(name, input.ref(), data_type.ref(), + new_shape, new_strides, + new_offsets, new_padded_shape) ->result_tensors()[0]; } } return impl_ - ->create_op(config, name, input.ref(), data_type.ref(), + ->create_op(name, input.ref(), data_type.ref(), output.ref()) ->result_tensors()[0]; } diff --git a/ark/ops/ops_communication.cpp b/ark/ops/ops_communication.cpp index e42c96d9c..baf7aafa2 100644 --- a/ark/ops/ops_communication.cpp +++ b/ark/ops/ops_communication.cpp @@ -589,25 +589,23 @@ Json ModelOpDeviceSync::default_config([[maybe_unused]] const ArchRef arch) cons } Tensor Model::send(Tensor input, int remote_rank, int tag, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { tags_.insert(tag); return impl_ - ->create_op(config, name, input.ref(), remote_rank, tag, + ->create_op(name, input.ref(), remote_rank, tag, output.ref()) ->result_tensors()[0]; } -Tensor Model::send_done(Tensor input, const std::string &config, - const std::string &name) { - return impl_->create_op(config, name, input.ref()) +Tensor Model::send_done(Tensor input, const std::string &name) { + return impl_->create_op(name, input.ref()) ->result_tensors()[0]; } Tensor Model::recv(Tensor output, int remote_rank, int tag, - const std::string &config, const std::string &name) { + const std::string &name) { tags_.insert(tag); - return impl_ - ->create_op(config, name, output.ref(), remote_rank, tag) + return impl_->create_op(name, output.ref(), remote_rank, tag) ->result_tensors()[0]; } diff --git a/ark/ops/ops_copy.cpp b/ark/ops/ops_copy.cpp index 4914c34a4..4f32966b8 100644 --- a/ark/ops/ops_copy.cpp +++ b/ark/ops/ops_copy.cpp @@ -20,9 +20,8 @@ ModelOpCopy::ModelOpCopy(ModelTensorRef input, ModelTensorRef output) verify(); } -Tensor Model::copy(Tensor input, Tensor output, const std::string &config, - const std::string &name) { - return impl_->create_op(config, name, input.ref_, output.ref_) +Tensor Model::copy(Tensor input, Tensor output, const std::string &name) { + return impl_->create_op(name, input.ref_, output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_embedding.cpp b/ark/ops/ops_embedding.cpp index 1169c47c3..2e2626d4c 100644 --- a/ark/ops/ops_embedding.cpp +++ b/ark/ops/ops_embedding.cpp @@ -70,9 +70,9 @@ Json ModelOpEmbedding::default_config([ } Tensor Model::embedding(Tensor input, Tensor weight, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, weight.ref_, + ->create_op(name, input.ref_, weight.ref_, output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_identity.cpp b/ark/ops/ops_identity.cpp index dd398d8a5..065cd9a52 100644 --- a/ark/ops/ops_identity.cpp +++ b/ark/ops/ops_identity.cpp @@ -31,7 +31,7 @@ Tensor Model::identity(Tensor input, const std::vector &deps, for (auto &dep : deps) { deps_ref.emplace_back(dep.ref_); } - return impl_->create_op("", name, input.ref_, deps_ref) + return impl_->create_op(name, input.ref_, deps_ref) ->result_tensors()[0]; } diff --git a/ark/ops/ops_math.cpp b/ark/ops/ops_math.cpp index b2833dcca..1067c561a 100644 --- a/ark/ops/ops_math.cpp +++ b/ark/ops/ops_math.cpp @@ -24,55 +24,48 @@ ModelOpMath::ModelOpMath(const std::string &type_name, ModelTensorRef input, ModelOpExp::ModelOpExp(ModelTensorRef input, ModelTensorRef output) : ModelOpMath("Exp", input, output) {} -Tensor Model::exp(Tensor input, Tensor output, const std::string &config, - const std::string &name) { - return impl_->create_op(config, name, input.ref_, output.ref_) +Tensor Model::exp(Tensor input, Tensor output, const std::string &name) { + return impl_->create_op(name, input.ref_, output.ref_) ->result_tensors()[0]; } ModelOpGelu::ModelOpGelu(ModelTensorRef input, ModelTensorRef output) : ModelOpMath("Gelu", input, output) {} -Tensor Model::gelu(Tensor input, Tensor output, const std::string &config, - const std::string &name) { - return impl_->create_op(config, name, input.ref_, output.ref_) +Tensor Model::gelu(Tensor input, Tensor output, const std::string &name) { + return impl_->create_op(name, input.ref_, output.ref_) ->result_tensors()[0]; } ModelOpRelu::ModelOpRelu(ModelTensorRef input, ModelTensorRef output) : ModelOpMath("Relu", input, output) {} -Tensor Model::relu(Tensor input, Tensor output, const std::string &config, - const std::string &name) { - return impl_->create_op(config, name, input.ref_, output.ref_) +Tensor Model::relu(Tensor input, Tensor output, const std::string &name) { + return impl_->create_op(name, input.ref_, output.ref_) ->result_tensors()[0]; } ModelOpRsqrt::ModelOpRsqrt(ModelTensorRef input, ModelTensorRef output) : ModelOpMath("Rsqrt", input, output) {} -Tensor Model::rsqrt(Tensor input, Tensor output, const std::string &config, - const std::string &name) { - return impl_->create_op(config, name, input.ref_, output.ref_) +Tensor Model::rsqrt(Tensor input, Tensor output, const std::string &name) { + return impl_->create_op(name, input.ref_, output.ref_) ->result_tensors()[0]; } ModelOpSigmoid::ModelOpSigmoid(ModelTensorRef input, ModelTensorRef output) : ModelOpMath("Sigmoid", input, output) {} -Tensor Model::sigmoid(Tensor input, Tensor output, const std::string &config, - const std::string &name) { - return impl_ - ->create_op(config, name, input.ref_, output.ref_) +Tensor Model::sigmoid(Tensor input, Tensor output, const std::string &name) { + return impl_->create_op(name, input.ref_, output.ref_) ->result_tensors()[0]; } ModelOpSqrt::ModelOpSqrt(ModelTensorRef input, ModelTensorRef output) : ModelOpMath("Sqrt", input, output) {} -Tensor Model::sqrt(Tensor input, Tensor output, const std::string &config, - const std::string &name) { - return impl_->create_op(config, name, input.ref_, output.ref_) +Tensor Model::sqrt(Tensor input, Tensor output, const std::string &name) { + return impl_->create_op(name, input.ref_, output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_matmul.cpp b/ark/ops/ops_matmul.cpp index bc94922fc..dca349f44 100644 --- a/ark/ops/ops_matmul.cpp +++ b/ark/ops/ops_matmul.cpp @@ -244,10 +244,10 @@ Json ModelOpMatmul::default_config(const ArchRef arch) const { Tensor Model::matmul(Tensor input, Tensor other, Tensor output, bool trans_input, bool trans_other, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref(), other.ref(), - output.ref(), trans_input, trans_other) + ->create_op(name, input.ref(), other.ref(), output.ref(), + trans_input, trans_other) ->result_tensors()[0]; } diff --git a/ark/ops/ops_noop.cpp b/ark/ops/ops_noop.cpp index 42fe5fdf5..894ab29be 100644 --- a/ark/ops/ops_noop.cpp +++ b/ark/ops/ops_noop.cpp @@ -30,7 +30,7 @@ Json ModelOpNoop::default_config([[maybe_unused]] const ArchRef arch) const { } void Model::noop(Tensor input, const std::string &name) { - impl_->create_op("", name, input.ref_); + impl_->create_op(name, input.ref_); } } // namespace ark diff --git a/ark/ops/ops_reduce.cpp b/ark/ops/ops_reduce.cpp index 19f70385b..78dd9d7e6 100644 --- a/ark/ops/ops_reduce.cpp +++ b/ark/ops/ops_reduce.cpp @@ -127,25 +127,25 @@ Json ModelOpReduce::default_config([[maybe_unused]] const ArchRef arch) const { } Tensor Model::reduce_max(Tensor input, int axis, bool keepdims, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, axis, keepdims, + ->create_op(name, input.ref_, axis, keepdims, output.ref_) ->result_tensors()[0]; } Tensor Model::reduce_mean(Tensor input, int axis, bool keepdims, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, axis, keepdims, + ->create_op(name, input.ref_, axis, keepdims, output.ref_) ->result_tensors()[0]; } Tensor Model::reduce_sum(Tensor input, int axis, bool keepdims, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, axis, keepdims, + ->create_op(name, input.ref_, axis, keepdims, output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_refer.cpp b/ark/ops/ops_refer.cpp index 68c61b30f..782d6708c 100644 --- a/ark/ops/ops_refer.cpp +++ b/ark/ops/ops_refer.cpp @@ -20,7 +20,7 @@ Tensor Model::refer(Tensor input, const Dims &shape, const Dims &strides, const Dims &offsets, const Dims &padded_shape, const std::string &name) { return impl_ - ->create_op("", name, input.ref_, shape, strides, offsets, + ->create_op(name, input.ref_, shape, strides, offsets, padded_shape) ->result_tensors()[0]; } diff --git a/ark/ops/ops_reshape.cpp b/ark/ops/ops_reshape.cpp index 8ed3ac247..aac22b71a 100644 --- a/ark/ops/ops_reshape.cpp +++ b/ark/ops/ops_reshape.cpp @@ -199,8 +199,8 @@ Tensor Model::reshape(Tensor input, const Dims &shape, bool allowzero, reshape_helper(input.ref_, Dims{inferred_shape}, allowzero, new_shape, new_strides, new_offs); return impl_ - ->create_op("", name, input.ref_, new_shape, - new_strides, new_offs) + ->create_op(name, input.ref_, new_shape, new_strides, + new_offs) ->result_tensors()[0]; } diff --git a/ark/ops/ops_rope.cpp b/ark/ops/ops_rope.cpp index 36015aae5..06c1c915e 100644 --- a/ark/ops/ops_rope.cpp +++ b/ark/ops/ops_rope.cpp @@ -12,10 +12,9 @@ ModelOpRope::ModelOpRope(ModelTensorRef input, ModelTensorRef other, : ModelOpBroadcast2("Rope", input, other, output) {} Tensor Model::rope(Tensor input, Tensor other, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, other.ref_, - output.ref_) + ->create_op(name, input.ref_, other.ref_, output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_scalar.cpp b/ark/ops/ops_scalar.cpp index b5c10f1c3..944a7247c 100644 --- a/ark/ops/ops_scalar.cpp +++ b/ark/ops/ops_scalar.cpp @@ -115,21 +115,20 @@ std::vector ModelOpScalarMul::impl_args([ Tensor Model::constant(float val, const Dims &shape, DataType data_type, const std::string &name) { return impl_ - ->create_op("", name, val, shape, data_type.ref(), + ->create_op(name, val, shape, data_type.ref(), nullptr) ->result_tensors()[0]; } -Tensor Model::copy(float val, Tensor output, const std::string &config, - const std::string &name) { +Tensor Model::copy(float val, Tensor output, const std::string &name) { if (output == NullTensor) { return impl_ - ->create_op(config, name, val, Dims{1}, - FP32.ref(), output.ref()) + ->create_op(name, val, Dims{1}, FP32.ref(), + output.ref()) ->result_tensors()[0]; } else { return impl_ - ->create_op(config, name, val, output.shape(), + ->create_op(name, val, output.shape(), output.data_type().ref(), output.ref()) ->result_tensors()[0]; @@ -137,34 +136,30 @@ Tensor Model::copy(float val, Tensor output, const std::string &config, } Tensor Model::add(Tensor input, float value, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, value, - output.ref_) + ->create_op(name, input.ref_, value, output.ref_) ->result_tensors()[0]; } Tensor Model::sub(Tensor input, float value, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, -value, - output.ref_) + ->create_op(name, input.ref_, -value, output.ref_) ->result_tensors()[0]; } Tensor Model::mul(Tensor input, float value, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, value, - output.ref_) + ->create_op(name, input.ref_, value, output.ref_) ->result_tensors()[0]; } Tensor Model::div(Tensor input, float value, Tensor output, - const std::string &config, const std::string &name) { + const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, 1 / value, - output.ref_) + ->create_op(name, input.ref_, 1 / value, output.ref_) ->result_tensors()[0]; } diff --git a/ark/ops/ops_transpose.cpp b/ark/ops/ops_transpose.cpp index c659761d9..d0f7581cc 100644 --- a/ark/ops/ops_transpose.cpp +++ b/ark/ops/ops_transpose.cpp @@ -124,10 +124,9 @@ Json ModelOpTranspose::default_config([ } Tensor Model::transpose(Tensor input, const std::vector &permutation, - Tensor output, const std::string &config, - const std::string &name) { + Tensor output, const std::string &name) { return impl_ - ->create_op(config, name, input.ref_, permutation, + ->create_op(name, input.ref_, permutation, output.ref_) ->result_tensors()[0]; } diff --git a/examples/llama/model_7b_b1_s2048.py b/examples/llama/model_7b_b1_s2048.py index f41304e85..d4a080c84 100644 --- a/examples/llama/model_7b_b1_s2048.py +++ b/examples/llama/model_7b_b1_s2048.py @@ -90,7 +90,7 @@ def __init__( self.weight = ark.parameter([1, 1, dim], ark.fp32) def forward(self, x): - with ark.PlanManager( + with ark.PlannerContext( warp_range=[0, 8], sync=False, config={ @@ -100,12 +100,12 @@ def forward(self, x): "Granularity": 7, }, ): - with ark.PlanManager(config={"Tile": [1, 4096]}): + with ark.PlannerContext(config={"Tile": [1, 4096]}): x = ark.cast(x, ark.fp32) x2 = ark.mul(x, x) - with ark.PlanManager(config={"ImplType": "WarpWise"}): + with ark.PlannerContext(config={"ImplType": "WarpWise"}): mean = ark.reduce_mean(x2, axis=-1) - with ark.PlanManager( + with ark.PlannerContext( config={ "NumWarps": 1, "SramBytes": 0, @@ -114,7 +114,7 @@ def forward(self, x): } ): rrms = ark.rsqrt(mean) - with ark.PlanManager( + with ark.PlannerContext( warp_range=[0, 8], sync=False, config={ @@ -356,7 +356,7 @@ def __init__( def forward(self, x): # self.w2(F.silu(self.w1(x)) * self.w3(x)) - with ark.PlanManager( + with ark.PlannerContext( warp_range=[0, 8], sram_range=[0, 49344], sync=False, @@ -365,13 +365,13 @@ def forward(self, x): "NumTasks": 688, }, ): - with ark.PlanManager( + with ark.PlannerContext( config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} ): x1 = self.w1(x) - with ark.PlanManager(config={"SramBytes": 0, "Tile": [256, 128]}): + with ark.PlannerContext(config={"SramBytes": 0, "Tile": [256, 128]}): x1 = Silu()(x1) - with ark.PlanManager( + with ark.PlannerContext( warp_range=[0, 8], sram_range=[0, 49344], sync=False, @@ -380,11 +380,11 @@ def forward(self, x): "NumTasks": 688, }, ): - with ark.PlanManager( + with ark.PlannerContext( config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} ): x2 = self.w3(x) - with ark.PlanManager(config={"SramBytes": 0, "Tile": [256, 128]}): + with ark.PlannerContext(config={"SramBytes": 0, "Tile": [256, 128]}): x3 = ark.mul(x1, x2) x4 = self.w2(x3) return x4 @@ -404,7 +404,7 @@ def __init__(self): super(Softmax, self).__init__() def forward(self, input): - with ark.PlanManager( + with ark.PlannerContext( warp_range=[0, 8], sram_range=[0, 0], sync=False, @@ -414,14 +414,14 @@ def forward(self, input): "NumTasks": 65536, }, ): - with ark.PlanManager(config={"ImplType": "WarpWise"}): + with ark.PlannerContext(config={"ImplType": "WarpWise"}): max = ark.reduce_max(input, axis=-1) - with ark.PlanManager(config={"Tile": [1, 2048]}): + with ark.PlannerContext(config={"Tile": [1, 2048]}): output = ark.sub(input, max) output = ark.exp(output) - with ark.PlanManager(config={"ImplType": "WarpWise"}): + with ark.PlannerContext(config={"ImplType": "WarpWise"}): sum = ark.reduce_sum(output, axis=-1) - with ark.PlanManager(config={"Tile": [1, 2048]}): + with ark.PlannerContext(config={"Tile": [1, 2048]}): output = ark.div(output, sum) return output @@ -486,50 +486,50 @@ def forward( ): bsz, seqlen, _ = x.shape() - with ark.PlanManager( + with ark.PlannerContext( warp_range=[0, 4], sram_range=[0, 24672], sync=False, config={"NumWarps": 4, "NumTasks": 256}, ): - with ark.PlanManager( + with ark.PlannerContext( config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} ): xq = self.wq(x) xq = ark.reshape( xq, [bsz, seqlen, self.n_local_heads, self.head_dim] ) - with ark.PlanManager( + with ark.PlannerContext( config={"SramBytes": 0, "Tile": [256, 1, 128]} ): if freqs_cis is not None: xq = ark.rope(xq, freqs_cis) - with ark.PlanManager(config={"SramBytes": 0, "Tile": [256, 128]}): + with ark.PlannerContext(config={"SramBytes": 0, "Tile": [256, 128]}): xq = ark.transpose(xq, [0, 2, 1, 3]) - with ark.PlanManager( + with ark.PlannerContext( warp_range=[0, 4], sram_range=[0, 24672], sync=False, config={"NumWarps": 4, "NumTasks": 256}, ): - with ark.PlanManager( + with ark.PlannerContext( config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} ): xk = self.wk(x) xk = ark.reshape( xk, [bsz, seqlen, self.n_local_kv_heads, self.head_dim] ) - with ark.PlanManager( + with ark.PlannerContext( config={"SramBytes": 0, "Tile": [256, 1, 128]} ): if freqs_cis is not None: xk = ark.rope(xk, freqs_cis) keys = xk - with ark.PlanManager(config={"SramBytes": 0, "Tile": [256, 128]}): + with ark.PlannerContext(config={"SramBytes": 0, "Tile": [256, 128]}): keys = ark.transpose(keys, [0, 2, 1, 3]) - with ark.PlanManager( + with ark.PlannerContext( warp_range=[0, 4], sram_range=[0, 24672], sync=False, @@ -540,7 +540,7 @@ def forward( "TileShapeMNK": [256, 128, 32], }, ): - with ark.PlanManager( + with ark.PlannerContext( config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} ): xv = self.wv(x) @@ -548,12 +548,12 @@ def forward( xv, [bsz, seqlen, self.n_local_kv_heads, self.head_dim] ) values = xv - with ark.PlanManager( + with ark.PlannerContext( config={"SramBytes": 0, "Tile": [256, 1, 128]} ): values = ark.transpose(values, [0, 2, 1, 3]) - with ark.PlanManager( + with ark.PlannerContext( warp_range=[0, 8], sram_range=[0, 49344], sync=False, @@ -563,11 +563,11 @@ def forward( "Granularity": 2, }, ): - with ark.PlanManager( + with ark.PlannerContext( config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} ): scores = ark.matmul(xq, keys, transpose_other=True) - with ark.PlanManager(config={"SramBytes": 0, "Tile": [256, 128]}): + with ark.PlannerContext(config={"SramBytes": 0, "Tile": [256, 128]}): scores = ark.mul(scores, 1.0 / math.sqrt(self.head_dim)) if mask is not None: @@ -575,7 +575,7 @@ def forward( scores = Softmax()(scores) - with ark.PlanManager( + with ark.PlannerContext( warp_range=[0, 4], sram_range=[0, 24672], sync=False, @@ -584,11 +584,11 @@ def forward( "NumTasks": 256, }, ): - with ark.PlanManager( + with ark.PlannerContext( config={"SramBytes": 24672, "TileShapeMNK": [256, 128, 32]} ): output = ark.matmul(scores, values) - with ark.PlanManager( + with ark.PlannerContext( config={"SramBytes": 0, "Tile": [256, 1, 128]} ): output = ark.transpose(output, [0, 2, 1, 3]) @@ -634,7 +634,7 @@ def forward( ): attention_norm_x = self.attention_norm(x) h = self.attention.forward(attention_norm_x, start_pos, freqs_cis, mask) - with ark.PlanManager( + with ark.PlannerContext( warp_range=[0, 4], config={ "NumWarps": 4, @@ -645,7 +645,7 @@ def forward( ): h = ark.add(x, h) ff = self.feed_forward(self.ffn_norm(h)) - with ark.PlanManager( + with ark.PlannerContext( warp_range=[0, 4], config={ "NumWarps": 4, diff --git a/examples/tutorial/plan_manager_tutorial.py b/examples/tutorial/plan_manager_tutorial.py deleted file mode 100644 index c840ce0c0..000000000 --- a/examples/tutorial/plan_manager_tutorial.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import ark -import time -import torch -import torch.nn.functional as F - - -class VanillaSoftmax(ark.Module): - def __init__(self): - super(Softmax, self).__init__() - - def forward(self, input): - max = ark.reduce_max(input, axis=-1) - output = ark.sub(input, max) - output = ark.exp(output) - sum = ark.reduce_sum(output, axis=-1) - output = ark.div(output, sum) - return output - - -class Softmax(ark.Module): - def __init__(self): - super(Softmax, self).__init__() - - def forward(self, input): - with ark.PlanManager( - warp_range=[0, 8], - sram_range=[0, 0], - sync=False, - config={ - "NumWarps": 1, - "SramBytes": 0, - "NumTasks": 65536, - }, - ): - with ark.PlanManager(config={"ImplType": "WarpWise"}): - max = ark.reduce_max(input, axis=-1) - with ark.PlanManager(config={"Tile": [1, 2048]}): - output = ark.sub(input, max) - output = ark.exp(output) - with ark.PlanManager(config={"ImplType": "WarpWise"}): - sum = ark.reduce_sum(output, axis=-1) - with ark.PlanManager(config={"Tile": [1, 2048]}): - output = ark.div(output, sum) - return output - - -def eval(tensor: ark.Tensor): - with ark.Runtime() as rt: - rt.launch() - rt.run() - return tensor.to_torch() - - -def perf(): - with ark.Runtime() as rt: - rt.launch() - - start = time.time() - rt.run(iter=1000) - end = time.time() - return (end - start) / 1000 - - -if __name__ == "__main__": - ark.init() - - shape = (32, 2048, 2048) - - input = torch.randn(*shape).to("cuda:0") - - output = Softmax()(ark.Tensor.from_torch(input)) - - if torch.allclose(eval(output), F.softmax(input, dim=-1), atol=1e-5): - print("Correct result") - else: - print("Incorrect result") - - print(f"Performance: {(perf() * 1e3):.3f} ms/iter") diff --git a/python/ark/plan_manager.py b/python/ark/plan_manager.py deleted file mode 100644 index 80e615ab8..000000000 --- a/python/ark/plan_manager.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import json -from typing import List, Dict, Any -from .model import Model -from ._ark_core import _PlanManager - - -class PlanManager(_PlanManager): - def __init__(self, **kwargs): - """ - Plan manager for specifying the parallelization and tiling configuration of the operators in the context. - - Args: - processor_range (List[int], optional): The range of processors to be used. Defaults to None. - warp_range (List[int], optional): The range of warps to be used. Defaults to None. - sram_range (List[int], optional): The range of SRAMs to be used. Defaults to None. - sync (bool, optional): Whether to synchronize the execution. Defaults to True. - config (Dict[str, Any], optional): The configuration for the operators. Defaults to None. - """ - super().__init__(Model.get_model(), json.dumps(kwargs)) - - def __enter__(self) -> "PlanManager": - """ - Enter the plan manager. - """ - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - """ - Exit the plan manager. - """ - del self diff --git a/python/ark/runtime.py b/python/ark/runtime.py index f064a5988..960223c64 100644 --- a/python/ark/runtime.py +++ b/python/ark/runtime.py @@ -3,6 +3,7 @@ import logging from enum import Enum +from typing import Dict, List from _ark_core import _Executor from .planner import Planner, Plan diff --git a/python/model_py.cpp b/python/model_py.cpp index 5a22d6a18..c224a3d5b 100644 --- a/python/model_py.cpp +++ b/python/model_py.cpp @@ -19,100 +19,89 @@ void register_model(py::module &m) { .def("compress", &ark::Model::compress) .def("add", py::overload_cast( - &ark::Model::add), + const std::string &>(&ark::Model::add), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("add", py::overload_cast( - &ark::Model::add), + const std::string &>(&ark::Model::add), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("cast", &ark::Model::cast, py::arg("input"), py::arg("data_type"), - py::arg("output"), py::arg("config"), py::arg("name")) + py::arg("output"), py::arg("name")) .def("constant", &ark::Model::constant, py::arg("value"), py::arg("shape"), py::arg("data_type"), py::arg("name")) .def("copy", - py::overload_cast(&ark::Model::copy), - py::arg("input"), py::arg("output"), py::arg("config"), - py::arg("name")) + py::overload_cast( + &ark::Model::copy), + py::arg("input"), py::arg("output"), py::arg("name")) .def("copy", - py::overload_cast(&ark::Model::copy), - py::arg("input"), py::arg("output"), py::arg("config"), - py::arg("name")) + py::overload_cast( + &ark::Model::copy), + py::arg("input"), py::arg("output"), py::arg("name")) .def("div", py::overload_cast( - &ark::Model::div), + const std::string &>(&ark::Model::div), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("div", py::overload_cast( - &ark::Model::div), + const std::string &>(&ark::Model::div), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("config"), py::arg("name")) - .def("embedding", &ark::Model::embedding, py::arg("input"), - py::arg("weight"), py::arg("output"), py::arg("config"), py::arg("name")) + .def("embedding", &ark::Model::embedding, py::arg("input"), + py::arg("weight"), py::arg("output"), py::arg("name")) .def("exp", &ark::Model::exp, py::arg("input"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("gelu", &ark::Model::gelu, py::arg("input"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("identity", &ark::Model::identity, py::arg("input"), py::arg("deps"), py::arg("name")) .def("matmul", &ark::Model::matmul, py::arg("input"), py::arg("other"), py::arg("output"), py::arg("trans_input"), py::arg("trans_other"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("mul", py::overload_cast( - &ark::Model::mul), + const std::string &>(&ark::Model::mul), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("mul", py::overload_cast( - &ark::Model::mul), + const std::string &>(&ark::Model::mul), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("noop", &ark::Model::noop, py::arg("input"), py::arg("name")) .def("reduce_max", &ark::Model::reduce_max, py::arg("input"), py::arg("axis"), py::arg("keepdims"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("reduce_mean", &ark::Model::reduce_mean, py::arg("input"), py::arg("axis"), py::arg("keepdims"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("reduce_sum", &ark::Model::reduce_sum, py::arg("input"), py::arg("axis"), py::arg("keepdims"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("relu", &ark::Model::relu, py::arg("input"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("reshape", &ark::Model::reshape, py::arg("input"), py::arg("shape"), py::arg("allowzero"), py::arg("name")) .def("rope", &ark::Model::rope, py::arg("input"), py::arg("other"), - py::arg("output"), py::arg("config"), py::arg("name")) + py::arg("output"), py::arg("name")) .def("rsqrt", &ark::Model::rsqrt, py::arg("input"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("sharding", &ark::Model::sharding, py::arg("input"), py::arg("axis"), py::arg("dim_per_shard"), py::arg("name")) .def("sigmoid", &ark::Model::sigmoid, py::arg("input"), - py::arg("output"), py::arg("config"), py::arg("name")) + py::arg("output"), py::arg("name")) .def("sqrt", &ark::Model::sqrt, py::arg("input"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("sub", py::overload_cast( - &ark::Model::sub), + const std::string &>(&ark::Model::sub), py::arg("input"), py::arg("other"), py::arg("output"), - py::arg("config"), py::arg("name")) + py::arg("name")) .def("sub", py::overload_cast( - &ark::Model::sub), + const std::string &>(&ark::Model::sub), py::arg("input"), py::arg("other"), py::arg("output"), py::arg("name")) .def("tensor", diff --git a/python/plan_manager_py.cpp b/python/plan_manager_py.cpp deleted file mode 100644 index 34aa0b77c..000000000 --- a/python/plan_manager_py.cpp +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -#include -#include -#include - -#include - -namespace py = pybind11; - -void register_plan_manager(py::module &m) { - py::class_(m, "_PlanManager") - .def(py::init()); -} From 28b83953ae26b8554fc8b822df8e96dd8bf04091 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 6 Aug 2024 14:33:23 -0700 Subject: [PATCH 53/54] Update runtime.py --- python/ark/runtime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ark/runtime.py b/python/ark/runtime.py index 96c6f470a..e40750260 100644 --- a/python/ark/runtime.py +++ b/python/ark/runtime.py @@ -98,7 +98,7 @@ def launch( _RuntimeState.executor.destroy() _RuntimeState.executor = Executor( - gpu_id, + device_id, stream, "ArkRuntime", plan, From 11901c4a3f49469ede51e992b8b1d2fc1f2c1e3b Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 7 Aug 2024 09:36:45 +0000 Subject: [PATCH 54/54] fix --- python/ark/runtime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ark/runtime.py b/python/ark/runtime.py index e40750260..495fc1c24 100644 --- a/python/ark/runtime.py +++ b/python/ark/runtime.py @@ -101,7 +101,7 @@ def launch( device_id, stream, "ArkRuntime", - plan, + str(plan), loop_mode, ) self.executor = _RuntimeState.executor