From abc26e7c522ff2e3dc56cbef8b474a64133224ec Mon Sep 17 00:00:00 2001 From: corlfj Date: Thu, 11 Sep 2025 19:27:04 +0800 Subject: [PATCH 01/12] [symbol_structure]: add mir op_class, add mir symbol_pass --- python/mrt/mir/opclass.py | 188 ++++++++++++++++++++++++++++++++++ python/mrt/mir/opns.py | 31 ++++++ python/mrt/mir/symbol.py | 31 +----- python/mrt/mir/symbolpass.py | 84 +++++++++++++++ tests/mir/test.op_create.py | 60 +++++++++++ tests/mir/test.symbol_pass.py | 88 ++++++++++++++++ 6 files changed, 452 insertions(+), 30 deletions(-) create mode 100644 python/mrt/mir/opclass.py create mode 100644 python/mrt/mir/symbolpass.py create mode 100644 tests/mir/test.op_create.py create mode 100644 tests/mir/test.symbol_pass.py diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py new file mode 100644 index 0000000..3d5433a --- /dev/null +++ b/python/mrt/mir/opclass.py @@ -0,0 +1,188 @@ +import typing +from dataclasses import dataclass +from . import opns +from . import symbol + +MRT_OP_MAP: typing.Dict[str, typing.Any] = {} + +#def _register_op_map_(op_name: str, clss:typing.Any=None): +# if len(op_name)>0 and clss!=None: +# if op_name not in MRT_OP_MAP: +# MRT_OP_MAP[op_name] = clss +# return MRT_OP_MAP + +def _register_op_map(op_name: str): #, clss:typing.Any=None): + def _wrapper(clss: typing.Any=None): + if len(op_name)>0 and clss!=None: + if op_name not in MRT_OP_MAP: + MRT_OP_MAP[op_name] = clss + return clss + return _wrapper + +@_register_op_map(opns.CONV2D) +@dataclass(init=False) +class Conv2D(symbol.Symbol): + + op_name = opns.CONV2D + + @property + def strides(self) -> typing.Tuple[int, int]: + default_val = (1,1) + return self._strides if self._strides else self.attrs['strides'] if 'strides' in self.attrs else default_val + + @property + def padding(self) -> typing.Tuple[int, int, int, int]: + default_val = (0,0,0,0) + return self._padding if self._padding else self.attrs['padding'] if 'padding' in self.attrs else default_val + + @property + def dilation(self) -> typing.Tuple[int, int]: + default_val = (1,1) + return self._ if self._ else self.attrs[''] if '' in self.attrs else default_val + + @property + def kernel_size(self) -> typing.Tuple[int, int]: + default_val = (3,3) + return self._kernel_size if self._kernel_size else self.attrs['kernel_size'] if 'kernel_size' in self.attrs else default_val + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + + # TODO: what if strides not in attrs? + self._strides = self.attrs['strides'] + if 'padding' in self.attrs: + self._padding = self.attrs['padding'] + if 'dilation' in self.attrs: + self._dilation = self.attrs['dilation'] + if 'kernel_size' in self.attrs: + self._kernel_size = self.attrs['kernel_size'] + + +@_register_op_map(opns.DROP_OUT) +@dataclass(init=False) +class Dropout(symbol.Symbol): + + op_name = opns.DROP_OUT + + @property + def rate(self) -> float: + default_val = 0.0 + return self._rate if self._rate else self.attrs['rate'] if 'rate' in self.attrs else default_val + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + + self._rate = self.attrs['rate'] + +@_register_op_map(opns.CLIP) +@dataclass(init=False) +class Clip(symbol.Symbol): + + op_name = opns.CLIP + + @property + def min(self) -> float: + default_val = None + return self._min if self._min else self.attrs['min'] if 'min' in self.attrs else default_val + + @property + def max(self) -> float: + default_val = None + return self._max if self._max else self.attrs['max'] if 'max' in self.attrs else default_val + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + + self._min = self.attrs['min'] + self._max = self.attrs['max'] + + +@_register_op_map(opns.BATCH_NORM) +@dataclass(init=False) +class BatchNorm(symbol.Symbol): + + op_name = opns.BATCH_NORM + + @property + def axis(self) -> float: + default_val = 1 + return self._axis if self._axis else self.attrs['axis'] if 'axis' in self.attrs else default_val + + @property + def epsilon(self) -> float: + default_val = 1e-5 + return self._epsilon if self._epsilon else self.attrs['epsilon'] if 'epsilon' in self.attrs else default_val + + @property + def center(self) -> float: + default_val = True + return self._center if self._center else self.attrs['center'] if 'center' in self.attrs else default_val + + @property + def scale(self) -> float: + default_val = True + return self._scale if self._scale else self.attrs['scale'] if 'scale' in self.attrs else default_val + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + + self._axis = self.attrs['axis'] + self._epsilon = self.attrs['epsilon'] + self._center = self.attrs['center'] + self._scale = self.attrs['scale'] + +@_register_op_map(opns.DENSE) +@dataclass(init=False) +class Dense(symbol.Symbol): + + op_name = opns.DENSE + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + +@_register_op_map(opns.TUPLE_GET_ITEM) +@dataclass(init=False) +class TupleGetItem(symbol.Symbol): + + op_name = opns.TUPLE_GET_ITEM + + @property + def index(self) -> float: + default_val = 0 + return self._index if self._index else self.attrs['index'] if 'index' in self.attrs else default_val + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + + self._index = self.attrs['index'] + +@_register_op_map(opns.MUL) +@dataclass(init=False) +class Multiply(symbol.Symbol): + + op_name = opns.MUL + + def __init__(self, name:str, **kwargs): + self.name = name + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + diff --git a/python/mrt/mir/opns.py b/python/mrt/mir/opns.py index 5b92822..be2f823 100644 --- a/python/mrt/mir/opns.py +++ b/python/mrt/mir/opns.py @@ -1,6 +1,14 @@ """ MRT operator names """ +import typing + +MRT_OP_SET = set() +def _register_op_list(*op_names: typing.List[str]): + for op_name in op_names: + if op_name not in MRT_OP_SET: + MRT_OP_SET.add(op_name) VAR = "var" +_register_op_list(VAR) DROP_OUT = "nn.dropout" CONV2D = "nn.conv2d" @@ -14,22 +22,29 @@ ADAPTIVE_AVG_POOL2D = "nn.adaptive_avg_pool2d" AVG_POOL2D = "nn.avg_pool2d" MAX_POOL2D = "nn.max_pool2d" +_register_op_list(DROP_OUT, CONV2D, DENSE, BATCH_NORM, RELU, + HARDTANH, SILU, LEAKY_RELU, ADAPTIVE_AVG_POOL2D, + AVG_POOL2D, MAX_POOL2D) SOFTMAX = "nn.softmax" LOG_SOFTMAX = "nn.log_softmax" +_register_op_list(SOFTMAX, LOG_SOFTMAX) EXP = "exp" SIGMOID = "sigmoid" +_register_op_list(EXP, SIGMOID) SUM = "sum" MEAN = "mean" MAX_AXIS = "max" MAXIMUM = "maximum" MINIMUM = "minimum" +_register_op_list(SUM, MEAN, MAX_AXIS, MAXIMUM, MINIMUM) # =========== NON-CALC ops =============== TUPLE = "Tuple" TUPLE_GET_ITEM = "TupleGetItem" +_register_op_list(TUPLE, TUPLE_GET_ITEM) REPEAT = "repeat" SQUEEZE = "squeeze" @@ -40,9 +55,12 @@ SPLIT = "split" TRANSPOSE = "transpose" BROADCAST_TO = "broadcast_to" +_register_op_list(REPEAT, SQUEEZE, FLATTEN, BATCH_FLATTEN, RESHAPE, + CONCAT, SPLIT, TRANSPOSE, BROADCAST_TO, ) EXPAND_DIMS = "expand_dims" TILE = "tile" +_register_op_list(EXPAND_DIMS, TILE) WHERE = "where" GREATER = "greater" @@ -50,6 +68,7 @@ SLICE_LIKE = "slice_like" GET_VALID_COUNT = "vision.get_valid_counts" NON_MAX_SUPRESSION = "vision.non_max_suppression" +_register_op_list(WHERE, GREATER, STRIDED_SLICE, SLICE_LIKE, GET_VALID_COUNT, NON_MAX_SUPRESSION) # relax clip attrs from a_min/a_max to min/max CLIP = "clip" @@ -58,11 +77,14 @@ # relax support astype instead of cast AS_TYPE = "astype" # CAST = "cast" +_register_op_list(CLIP, CEIL, RIGHT_SHIFT, AS_TYPE) ADV_INDEX = "adv_index" +_register_op_list(ADV_INDEX) CALL_TIR = "call_tir" CALL_DPS_PACKED = "call_dps_packed" +_register_op_list(CALL_TIR, CALL_DPS_PACKED) # ======= binary ops ============= @@ -71,6 +93,7 @@ MUL = "multiply" MATMUL = "matmul" DIV = "divide" +_register_op_list(ADD, SUB, MUL, MATMUL, DIV) # ======= unary ops ============== @@ -81,14 +104,17 @@ POW = "pow" PASS = "pass" +_register_op_list(NEGATIVE, ABS, LOG, SQRT, POW, PASS) # ======= auto generate op ========= ARANGE = "arange" ZEROS_LIKE = "zeros_like" ONES_LIKE = "ones_like" +_register_op_list(ARANGE, ZEROS_LIKE, ONES_LIKE) # ======= control flow op =========== IF = "if" ARGWHERE = "argwhere" +_register_op_list(IF, ARGWHERE) # ======= mrt requant op ========== REQUANT = "mrt.requant" @@ -98,4 +124,9 @@ """ right shift precision clip """ LUT = "mrt.lut" """ look up table, equals adv_index in tvm """ +_register_op_list(REQUANT, PCLIP, RS_PCLIP, LUT) + +def Opname2Funcname(op_name: str): + return op_name.replace('.', '_') +#print('MRT_OP_SET:', MRT_OP_SET) diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index 5c97cee..e27adcd 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -11,8 +11,7 @@ # from . import config # from .utils import * -# from .types import * -from .opns import * +from . import opns __ALL__ = [ "Symbol", @@ -277,34 +276,6 @@ def __hash__(self) -> int: def hash(self) -> int: return hash(str(self)) -# class Convolution2D(Symbol): -# strides: typing.Tuple[int, int] - -# class Dropout(Symbol): -# eps: float = 1e-5 - -# class Pass: -# symbol: Symbol - -# def visit(self, op: Symbol): -# env: typing.Dict[Symbol, Symbol] = {} -# for sym in sym2list(self.symbol): -# out = getattr(self, f"visit_{op.op_name}")(op) or op -# assert isinstance(sym, Symbol) -# env[sym] = out -# return env[op] - -# def _default_visit_op(op): -# return op - -# for op in op_list: -# setattr(Pass, f"visit_{op.op_name}", _default_visit_op) - -# class FuseDropoutPass(Pass): -# def visit_dropout(self, op: Dropout): -# op.eps -# return op.args[0] - def _topo_sort(symbol: Symbol, sym_list: typing.List[Symbol]): assert isinstance(symbol, Symbol), \ f"({type(symbol).__name__}){str(symbol)}" diff --git a/python/mrt/mir/symbolpass.py b/python/mrt/mir/symbolpass.py new file mode 100644 index 0000000..883af38 --- /dev/null +++ b/python/mrt/mir/symbolpass.py @@ -0,0 +1,84 @@ +from __future__ import annotations +import typing + +from functools import wraps +from dataclasses import dataclass, fields + +import mrt +from mrt.common import config +from mrt.common.utils import * +from mrt.common.types import * + +from . import opns, opclass +from . import symbol as _symbol + + +# mrt op visits +class SymbolPass: + symbol: _symbol.Symbol + params: ParametersT + + def __init__(self, symbol: _symbol.Symbol, params: ParametersT): + self.symbol = symbol + self.params = params + + def is_param(self, symbol: _symbol.Symbol) -> bool: + return symbol.op_name == opns.VAR and symbol.name in self.params + + def visit(self) -> _symbol.Symbol: + env: typing.Dict[str, _symbol.Symbol] = {} + for sym in _symbol.sym2list(self.symbol): + assert sym.name not in env, f'{sym.name} NotIn env!' + + # Updating args as passed symbol in env_dict + sym = sym.copy(args = [env[arg_sym.name] for arg_sym in sym.args]) + assert isinstance(sym, _symbol.Symbol), sym + + if sym.op_name == opns.DROP_OUT: + #print('ddrroopped_out', getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) or sym) + pass + out = getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) or sym + assert isinstance(out, _symbol.Symbol), out + env[sym.name] = out + return env[self.symbol.name] + + def _default_visit_op(self, op: _symbol.Symbol) -> _symbol.Symbol: + return op + + +# register mrt op default_visit +for op_name in opns.MRT_OP_SET: + funcSuffix = opns.Opname2Funcname(op_name) + setattr(SymbolPass, f"visit_{funcSuffix}", SymbolPass._default_visit_op) + #print(f"visit_, {op_name} => {funcSuffix}", getattr(SymbolPass, f"visit_{funcSuffix}")) + + +# mrt symbol pass +class FuseDropoutPass(SymbolPass): + def visit_nn_dropout(self, sym: _symbol.Symbol) -> _symbol.Symbol: + # make sure op fit again + if sym.op_name == opns.DROP_OUT: + return sym.args[0] + return sym + + +class FuseDividePass(SymbolPass): + def visit_divide(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.DIVIDE: + argA = self.args[0] + argB = self.args[1] + assert self.is_param(argB), f'NotParam: {argB}' + # TODO: fixit + #argB = argB.from_np_data(1. / argB.numpy()) + return opclass.Multiply(sym.name, {'args':[argA, argB]}) + return sym + + +class FuseTupleGetItemPass(SymbolPass): + def visit_TupleGetItem(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.TUPLE_GET_ITEM: + sym_ : opclass.TupleGetItem = sym + assert sym_.index == 0 + return sym_.args[0] + return sym + diff --git a/tests/mir/test.op_create.py b/tests/mir/test.op_create.py new file mode 100644 index 0000000..a36e16c --- /dev/null +++ b/tests/mir/test.op_create.py @@ -0,0 +1,60 @@ +""" +Test script for Alexnet PyTorch to MRT conversion. +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass + + +def test_create_conv2d_op(): + #class CONV2D(Symbol): + # strides: typing.Tuple[int, int] = (1,1) + # padding: typing.Optional[typing.Tuple[int, int, int, int]] = (0,0,0,0) + # create mrt op symbol, def func + print('mrt Conv2D Op Class:', opclass.Conv2D) + conv2d_b = opclass.MRT_OP_MAP[opns.CONV2D]('conv2d_b',args=[[],[],[]], attrs={'strides':(1,1), 'padding':None}) + assert isinstance(conv2d_b, sx.Symbol), 'not!con2d_b symbol' + assert isinstance(conv2d_b, opclass.Conv2D), 'not!2 -con2d_b' + + # attrs hint + assert conv2d_b.args != None + assert conv2d_b.attrs != None + assert conv2d_b.strides != None + + print(f'Got {conv2d_b.name} strides: {conv2d_b.strides}') + print(f'Got {conv2d_b.name} padding: {conv2d_b.padding}') + print(f'Show {conv2d_b.name} {conv2d_b}') + return True + + +# TODO: +#def test_create_symbol_graph(): + +if __name__ == "__main__": + print('MRT_OP_SET as:', opns.MRT_OP_SET) + assert len(opns.MRT_OP_SET) > 0 + + print('MRT_OP_MAP Class as:', opclass.MRT_OP_MAP) + assert len(opclass.MRT_OP_MAP) > 0 + assert opns.CONV2D in opclass.MRT_OP_MAP + + rltflag = test_create_conv2d_op() + print("\n" + "="*60 + "\n") + print('Passed Test!' if rltflag else 'Test Failed!') + print("\n" + "="*60 + "\n") + diff --git a/tests/mir/test.symbol_pass.py b/tests/mir/test.symbol_pass.py new file mode 100644 index 0000000..7d109ef --- /dev/null +++ b/tests/mir/test.symbol_pass.py @@ -0,0 +1,88 @@ +""" +Test script for MRT Alexnet FuseDropoutPass. +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import symbolpass + +def _get_alexnet_model(): + """Get Alexnet MRT Model""" + + # Load pre-trained Alexnet + model = models.alexnet(pretrained=True) + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Alexnet to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + +def test_SymbolPass_FuseDropout(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseDropout Pass ===') + symlist = sx.sym2list(symbol) + dropout_op_cnt = 0 + for sym in symlist: + # print(sym) + dropout_op_cnt += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt>0, f'original model dropout op cnt {dropout_op_cnt} == zero!' + + # init FuseDropout Passer and execute visit + tfs : symbolpass.FuseDropoutPass = symbolpass.FuseDropoutPass(symbol, {}) + #print(getattr(tfs, f"visit_{opns.Opname2Funcname(opns.DROP_OUT)}")) + symbol_passed = tfs.visit() + + print('\n=== After FuseDropout Pass ===') + rlts = sx.sym2list(symbol_passed) + dropout_op_cnt_af = 0 + for sym in rlts: + # print(sym) + dropout_op_cnt_af += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt_af==0, f'passed model dropout op cnt {dropout_op_cnt_af} != zero!' + + #for sym in symdict: + # print(sym, symdict[sym]) + + #print('\n=== Back To SymList ===') + #rltlist = sx.sym2list(symdict[symbol.name]) + + return True + +if __name__ == "__main__": + + print("=== Testing SymbolPass ===") + mrt_graph, mrt_params = _get_alexnet_model() + + print("Testing FuseDropoutPass for Model AlexNet") + rltflag = test_SymbolPass_FuseDropout(mrt_graph, mrt_params) + + print("\n" + "="*60 + "\n") + print('Passed Test!' if rltflag else 'Test Failed!') + print("\n" + "="*60 + "\n") + From 47d606288d6a991479dfa0aee8df5a011eba48f6 Mon Sep 17 00:00:00 2001 From: corlfj Date: Fri, 12 Sep 2025 10:24:56 +0800 Subject: [PATCH 02/12] [mir]: fix last commit --- python/mrt/mir/opclass.py | 50 +++++++++++++++++++++--------------- python/mrt/mir/opns.py | 6 ++--- python/mrt/mir/symbol.py | 2 +- python/mrt/mir/symbolpass.py | 6 ++--- tests/mir/test.op_create.py | 25 +++++++++++------- 5 files changed, 53 insertions(+), 36 deletions(-) diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py index 3d5433a..a8fb932 100644 --- a/python/mrt/mir/opclass.py +++ b/python/mrt/mir/opclass.py @@ -1,4 +1,5 @@ import typing +import numpy as np from dataclasses import dataclass from . import opns from . import symbol @@ -19,7 +20,6 @@ def _wrapper(clss: typing.Any=None): return clss return _wrapper -@_register_op_map(opns.CONV2D) @dataclass(init=False) class Conv2D(symbol.Symbol): @@ -28,31 +28,40 @@ class Conv2D(symbol.Symbol): @property def strides(self) -> typing.Tuple[int, int]: default_val = (1,1) - return self._strides if self._strides else self.attrs['strides'] if 'strides' in self.attrs else default_val + return self._strides if self._strides else self.attrs['strides'] if 'strides' in self.attrs else default_val @property def padding(self) -> typing.Tuple[int, int, int, int]: default_val = (0,0,0,0) - return self._padding if self._padding else self.attrs['padding'] if 'padding' in self.attrs else default_val + return self._padding if self._padding else self.attrs['padding'] if 'padding' in self.attrs else default_val @property def dilation(self) -> typing.Tuple[int, int]: default_val = (1,1) - return self._ if self._ else self.attrs[''] if '' in self.attrs else default_val + return self._dilation if self._dilation else self.attrs['dilation'] if 'dilation' in self.attrs else default_val @property def kernel_size(self) -> typing.Tuple[int, int]: default_val = (3,3) - return self._kernel_size if self._kernel_size else self.attrs['kernel_size'] if 'kernel_size' in self.attrs else default_val - - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} + return self._kernel_size if self._kernel_size else self.attrs['kernel_size'] if 'kernel_size' in self.attrs else default_val + + def __init__(self, name_or_inst: typing.Union[str, symbol.Symbol], **kwargs): + assert isinstance(name_or_inst, str) or isinstance(name_or_inst, symbol.Symbol) + if isinstance(name_or_inst, str): + self.name = name_or_inst + self.args = kwargs.pop('args', []) + self.attrs = kwargs.pop('attrs', {}) + self.extra_attrs = {} + else: + # clone mode + self.name = name_or_inst.name + self.args = [a for a in name_or_inst.args] + self.attrs = {k: v for k, v in name_or_inst.attrs.items()} + self.extra_attrs = {k: v for k, v in name_or_inst.extra_attrs.items()} # TODO: what if strides not in attrs? - self._strides = self.attrs['strides'] + if 'strides' in self.attrs: + self._strides = self.attrs['strides'] if 'padding' in self.attrs: self._padding = self.attrs['padding'] if 'dilation' in self.attrs: @@ -61,7 +70,6 @@ def __init__(self, name:str, **kwargs): self._kernel_size = self.attrs['kernel_size'] -@_register_op_map(opns.DROP_OUT) @dataclass(init=False) class Dropout(symbol.Symbol): @@ -80,7 +88,6 @@ def __init__(self, name:str, **kwargs): self._rate = self.attrs['rate'] -@_register_op_map(opns.CLIP) @dataclass(init=False) class Clip(symbol.Symbol): @@ -88,12 +95,12 @@ class Clip(symbol.Symbol): @property def min(self) -> float: - default_val = None + default_val = np.nan return self._min if self._min else self.attrs['min'] if 'min' in self.attrs else default_val @property def max(self) -> float: - default_val = None + default_val = np.nan return self._max if self._max else self.attrs['max'] if 'max' in self.attrs else default_val def __init__(self, name:str, **kwargs): @@ -106,7 +113,6 @@ def __init__(self, name:str, **kwargs): self._max = self.attrs['max'] -@_register_op_map(opns.BATCH_NORM) @dataclass(init=False) class BatchNorm(symbol.Symbol): @@ -143,7 +149,6 @@ def __init__(self, name:str, **kwargs): self._center = self.attrs['center'] self._scale = self.attrs['scale'] -@_register_op_map(opns.DENSE) @dataclass(init=False) class Dense(symbol.Symbol): @@ -155,7 +160,6 @@ def __init__(self, name:str, **kwargs): self.attrs = kwargs.pop('attrs', {}) self.extra_attrs = {} -@_register_op_map(opns.TUPLE_GET_ITEM) @dataclass(init=False) class TupleGetItem(symbol.Symbol): @@ -174,7 +178,6 @@ def __init__(self, name:str, **kwargs): self._index = self.attrs['index'] -@_register_op_map(opns.MUL) @dataclass(init=False) class Multiply(symbol.Symbol): @@ -186,3 +189,10 @@ def __init__(self, name:str, **kwargs): self.attrs = kwargs.pop('attrs', {}) self.extra_attrs = {} +_register_op_map(opns.CONV2D)(Conv2D) +_register_op_map(opns.DROP_OUT)(Dropout) +_register_op_map(opns.CLIP)(Clip) +_register_op_map(opns.BATCH_NORM)(BatchNorm) +_register_op_map(opns.DENSE)(Dense) +_register_op_map(opns.TUPLE_GET_ITEM)(TupleGetItem) +_register_op_map(opns.MUL)(Multiply) diff --git a/python/mrt/mir/opns.py b/python/mrt/mir/opns.py index be2f823..ed9ac2a 100644 --- a/python/mrt/mir/opns.py +++ b/python/mrt/mir/opns.py @@ -1,8 +1,8 @@ """ MRT operator names """ import typing -MRT_OP_SET = set() -def _register_op_list(*op_names: typing.List[str]): +MRT_OP_SET: typing.Set[str] = set() +def _register_op_list(*op_names: str): for op_name in op_names: if op_name not in MRT_OP_SET: MRT_OP_SET.add(op_name) @@ -127,6 +127,6 @@ def _register_op_list(*op_names: typing.List[str]): _register_op_list(REQUANT, PCLIP, RS_PCLIP, LUT) -def Opname2Funcname(op_name: str): +def Opname2Funcname(op_name: str) -> str: return op_name.replace('.', '_') #print('MRT_OP_SET:', MRT_OP_SET) diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index e27adcd..1832d87 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -462,7 +462,7 @@ def as_tuple(self) -> typing.Tuple[typing.List[str], Symbol]: @classmethod def from_tuple(cls, tuple_names, symbol): - assert symbol.is_op(TUPLE), symbol + assert symbol.is_op(opns.TUPLE), symbol mhs = cls(zip(tuple_names, symbol.args)) mhs.origin = symbol return mhs diff --git a/python/mrt/mir/symbolpass.py b/python/mrt/mir/symbolpass.py index 883af38..e9d65ee 100644 --- a/python/mrt/mir/symbolpass.py +++ b/python/mrt/mir/symbolpass.py @@ -64,9 +64,9 @@ def visit_nn_dropout(self, sym: _symbol.Symbol) -> _symbol.Symbol: class FuseDividePass(SymbolPass): def visit_divide(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.DIVIDE: - argA = self.args[0] - argB = self.args[1] + if sym.op_name == opns.DIV: + argA = sym.args[0] + argB = sym.args[1] assert self.is_param(argB), f'NotParam: {argB}' # TODO: fixit #argB = argB.from_np_data(1. / argB.numpy()) diff --git a/tests/mir/test.op_create.py b/tests/mir/test.op_create.py index a36e16c..cf7e61c 100644 --- a/tests/mir/test.op_create.py +++ b/tests/mir/test.op_create.py @@ -27,18 +27,25 @@ def test_create_conv2d_op(): # padding: typing.Optional[typing.Tuple[int, int, int, int]] = (0,0,0,0) # create mrt op symbol, def func print('mrt Conv2D Op Class:', opclass.Conv2D) - conv2d_b = opclass.MRT_OP_MAP[opns.CONV2D]('conv2d_b',args=[[],[],[]], attrs={'strides':(1,1), 'padding':None}) - assert isinstance(conv2d_b, sx.Symbol), 'not!con2d_b symbol' - assert isinstance(conv2d_b, opclass.Conv2D), 'not!2 -con2d_b' + conv2d_a = opclass.MRT_OP_MAP[opns.CONV2D]('conv2d_a', args=[[],[],[]], attrs={'strides':(1,1), 'padding':None}) + assert isinstance(conv2d_a, sx.Symbol), 'conv2d_a isnot a symbol' + assert isinstance(conv2d_a, opclass.Conv2D), 'conv2d_a isnot a Conv2D' # attrs hint - assert conv2d_b.args != None - assert conv2d_b.attrs != None - assert conv2d_b.strides != None + assert conv2d_a.args != None + assert conv2d_a.attrs != None + assert conv2d_a.strides != None - print(f'Got {conv2d_b.name} strides: {conv2d_b.strides}') - print(f'Got {conv2d_b.name} padding: {conv2d_b.padding}') - print(f'Show {conv2d_b.name} {conv2d_b}') + print(f'Got {conv2d_a.name} strides: {conv2d_a.strides}') + print(f'Got {conv2d_a.name} padding: {conv2d_a.padding}') + print(f'Show {conv2d_a.name} {conv2d_a}') + + # test Conv2D clone mode + conv2d_b = opclass.MRT_OP_MAP[opns.CONV2D](conv2d_a) + assert isinstance(conv2d_b, sx.Symbol), 'conv2d_b isnot a symbol' + assert isinstance(conv2d_b, opclass.Conv2D), 'conv2d_b isnot a Conv2D' + + assert conv2d_b.attrs == conv2d_a.attrs return True From 2e2cfde6916e0e44bc21d90c80e82091bfa17330 Mon Sep 17 00:00:00 2001 From: corlfj Date: Wed, 17 Sep 2025 16:53:04 +0800 Subject: [PATCH 03/12] [mir]: opclass compatible --- python/mrt/mir/opclass.py | 338 ++++++++++++------ python/mrt/mir/opns.py | 28 -- .../mrt/mir/{symbolpass.py => simple_pass.py} | 102 ++++-- python/mrt/mir/symbol.py | 2 + tests/mir/test.op_create.py | 100 +++++- ...est.symbol_pass.py => test.simple_pass.py} | 71 +++- 6 files changed, 443 insertions(+), 198 deletions(-) rename python/mrt/mir/{symbolpass.py => simple_pass.py} (50%) rename tests/mir/{test.symbol_pass.py => test.simple_pass.py} (50%) diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py index a8fb932..0ded6ba 100644 --- a/python/mrt/mir/opclass.py +++ b/python/mrt/mir/opclass.py @@ -1,198 +1,306 @@ import typing import numpy as np -from dataclasses import dataclass +from dataclasses import dataclass, fields + +from mrt.common.utils import N from . import opns from . import symbol +from .symbol import SelfSymbol -MRT_OP_MAP: typing.Dict[str, typing.Any] = {} - -#def _register_op_map_(op_name: str, clss:typing.Any=None): -# if len(op_name)>0 and clss!=None: -# if op_name not in MRT_OP_MAP: -# MRT_OP_MAP[op_name] = clss -# return MRT_OP_MAP +#SelfSymbol = typing.TypeVar("SelfSymbol", bound="Symbol") +MRT_OP_MAP: typing.Dict[str, SelfSymbol] = {} -def _register_op_map(op_name: str): #, clss:typing.Any=None): - def _wrapper(clss: typing.Any=None): - if len(op_name)>0 and clss!=None: +def _register_op_map(op_name: str): + def _wrapper(clss: SelfSymbol = None) -> SelfSymbol: + if len(op_name) > 0 and clss != None: if op_name not in MRT_OP_MAP: MRT_OP_MAP[op_name] = clss + else: + print(f'Warning: "{op_name}" Alreary Registered In MRT_OP_MAP, IsBeing Overrided!') + MRT_OP_MAP[op_name] = clss return clss return _wrapper + @dataclass(init=False) -class Conv2D(symbol.Symbol): +class Variable(symbol.Symbol): + op_name = opns.VAR + + def __init__(self, name=None, op_name=None, shape:typing.Tuple = (), dtype=None, extra_attrs=None): + op_name = op_name or opns.VAR + assert op_name == opns.VAR + super().__init__(name=name or N.n(), op_name=op_name, args=[], attrs={}, extra_attrs=extra_attrs or {}) + self.shape = shape # will also update extra_attrs + self.dtype = dtype # will also update extra_attrs + + @classmethod + def from_dict(cls, d: dict, **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {k: data['extra_attrs'][k] for k in data['extra_attrs'] if k in ['shape', 'dtype']} + try: + out = cls(**attrsdata, **basedata) + except Exception as e: + raise e + return out + +@dataclass(init=False) +class Conv2D(symbol.Symbol): op_name = opns.CONV2D @property def strides(self) -> typing.Tuple[int, int]: default_val = (1,1) - return self._strides if self._strides else self.attrs['strides'] if 'strides' in self.attrs else default_val + return self.attrs['strides'] if 'strides' in self.attrs else default_val @property def padding(self) -> typing.Tuple[int, int, int, int]: default_val = (0,0,0,0) - return self._padding if self._padding else self.attrs['padding'] if 'padding' in self.attrs else default_val + return self.attrs['padding'] if 'padding' in self.attrs else default_val + + @property + def groups(self) -> int: + default_val = 1 + return self.attrs['groups'] if 'groups' in self.attrs else default_val @property def dilation(self) -> typing.Tuple[int, int]: default_val = (1,1) - return self._dilation if self._dilation else self.attrs['dilation'] if 'dilation' in self.attrs else default_val + return self.attrs['dilation'] if 'dilation' in self.attrs else default_val @property def kernel_size(self) -> typing.Tuple[int, int]: default_val = (3,3) - return self._kernel_size if self._kernel_size else self.attrs['kernel_size'] if 'kernel_size' in self.attrs else default_val - - def __init__(self, name_or_inst: typing.Union[str, symbol.Symbol], **kwargs): - assert isinstance(name_or_inst, str) or isinstance(name_or_inst, symbol.Symbol) - if isinstance(name_or_inst, str): - self.name = name_or_inst - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} - else: - # clone mode - self.name = name_or_inst.name - self.args = [a for a in name_or_inst.args] - self.attrs = {k: v for k, v in name_or_inst.attrs.items()} - self.extra_attrs = {k: v for k, v in name_or_inst.extra_attrs.items()} - - # TODO: what if strides not in attrs? - if 'strides' in self.attrs: - self._strides = self.attrs['strides'] - if 'padding' in self.attrs: - self._padding = self.attrs['padding'] - if 'dilation' in self.attrs: - self._dilation = self.attrs['dilation'] - if 'kernel_size' in self.attrs: - self._kernel_size = self.attrs['kernel_size'] - + return self.attrs['kernel_size'] if 'kernel_size' in self.attrs else default_val + + + # Follows (*args, name, **attrs) + def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_size=(3,3), extra_attrs=None): + op_name = op_name or opns.CONV2D + assert op_name == opns.CONV2D + super().__init__(name=name or N.n(), op_name=op_name, args=[X,W], attrs={'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size}, extra_attrs=extra_attrs or {}) + + + # Copy from other instance of same opclass, must have specific attrs (or with default value) + @classmethod + def from_dict(cls, d: dict, **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {k: data['attrs'][k] for k in data['attrs'] if k in ['strides', 'padding', 'groups', 'dilation', 'kernel_size']} + try: + out = cls(data['args'][0], data['args'][1], **attrsdata, **basedata) + except Exception as e: + raise e + return out @dataclass(init=False) class Dropout(symbol.Symbol): - op_name = opns.DROP_OUT @property def rate(self) -> float: default_val = 0.0 - return self._rate if self._rate else self.attrs['rate'] if 'rate' in self.attrs else default_val + return self.attrs['rate'] if 'rate' in self.attrs else default_val - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} - - self._rate = self.attrs['rate'] + def __init__(self, X, name=None, op_name=None, rate:float = 0, extra_attrs=None): + op_name = op_name or opns.DROP_OUT + assert op_name == opns.DROP_OUT + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'rate': rate}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {'rate': data['attrs']['rate']} + try: + out = cls(data['args'][0], **attrsdata, **basedata) + except Exception as e: + raise e + return out @dataclass(init=False) class Clip(symbol.Symbol): - op_name = opns.CLIP @property def min(self) -> float: default_val = np.nan - return self._min if self._min else self.attrs['min'] if 'min' in self.attrs else default_val + return self.attrs['min'] if 'min' in self.attrs else default_val @property def max(self) -> float: default_val = np.nan - return self._max if self._max else self.attrs['max'] if 'max' in self.attrs else default_val - - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} - - self._min = self.attrs['min'] - self._max = self.attrs['max'] + return self.attrs['max'] if 'max' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, min_:float = np.nan, max_:float = np.nan, extra_attrs=None): + op_name = op_name or opns.CLIP + assert op_name == opns.CLIP + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'min': min_, 'max': max_}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {'min': data['attrs']['min'], 'max': data['attrs']['max']} + try: + out = cls(data['args'][0], **attrsdata, **basedata) + except Exception as e: + raise e + return out @dataclass(init=False) class BatchNorm(symbol.Symbol): - op_name = opns.BATCH_NORM @property - def axis(self) -> float: + def axis(self) -> int: default_val = 1 - return self._axis if self._axis else self.attrs['axis'] if 'axis' in self.attrs else default_val + return self.attrs['axis'] if 'axis' in self.attrs else default_val @property def epsilon(self) -> float: default_val = 1e-5 - return self._epsilon if self._epsilon else self.attrs['epsilon'] if 'epsilon' in self.attrs else default_val + return self.attrs['epsilon'] if 'epsilon' in self.attrs else default_val @property - def center(self) -> float: + def center(self) -> bool: default_val = True - return self._center if self._center else self.attrs['center'] if 'center' in self.attrs else default_val + return self.attrs['center'] if 'center' in self.attrs else default_val @property - def scale(self) -> float: + def scale(self) -> bool: default_val = True - return self._scale if self._scale else self.attrs['scale'] if 'scale' in self.attrs else default_val - - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} - - self._axis = self.attrs['axis'] - self._epsilon = self.attrs['epsilon'] - self._center = self.attrs['center'] - self._scale = self.attrs['scale'] + return self.attrs['scale'] if 'scale' in self.attrs else default_val + + def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, center:bool = True, scale:bool = True, extra_attrs=None): + op_name = op_name or opns.BATCH_NORM + assert op_name == opns.BATCH_NORM + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Gamma, Beta, Mean, Var], attrs={'axis': axis, 'epsilon': epsilon, 'center': center, 'scale': scale}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {k: data['attrs'][k] for k in data['attrs'] if k in ['axis', 'epsilon', 'center', 'scale']} + try: + out = cls(*data['args'], **attrsdata, **basedata) + except Exception as e: + raise e + return out -@dataclass(init=False) -class Dense(symbol.Symbol): - - op_name = opns.DENSE - - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} @dataclass(init=False) class TupleGetItem(symbol.Symbol): - op_name = opns.TUPLE_GET_ITEM @property def index(self) -> float: default_val = 0 - return self._index if self._index else self.attrs['index'] if 'index' in self.attrs else default_val - - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} - - self._index = self.attrs['index'] - -@dataclass(init=False) -class Multiply(symbol.Symbol): - - op_name = opns.MUL - - def __init__(self, name:str, **kwargs): - self.name = name - self.args = kwargs.pop('args', []) - self.attrs = kwargs.pop('attrs', {}) - self.extra_attrs = {} - + return self.attrs['index'] if 'index' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, index:int = 0, extra_attrs=None): + op_name = op_name or opns.TUPLE_GET_ITEM + assert op_name == opns.TUPLE_GET_ITEM + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'index': index}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {'index': data['attrs']['index']} + try: + out = cls(data['args'][0], **attrsdata, **basedata) + except Exception as e: + raise e + return out + + +_register_op_map(opns.VAR)(Variable) _register_op_map(opns.CONV2D)(Conv2D) _register_op_map(opns.DROP_OUT)(Dropout) _register_op_map(opns.CLIP)(Clip) _register_op_map(opns.BATCH_NORM)(BatchNorm) -_register_op_map(opns.DENSE)(Dense) _register_op_map(opns.TUPLE_GET_ITEM)(TupleGetItem) -_register_op_map(opns.MUL)(Multiply) + +# Add default register Class for MRT OP Not Implemented! +_register_op_map(opns.MUL)(symbol.Symbol) +_register_op_map(opns.DENSE)(symbol.Symbol) +_register_op_map(opns.RELU)(symbol.Symbol) +_register_op_map(opns.HARDTANH)(symbol.Symbol) +_register_op_map(opns.SILU)(symbol.Symbol) +_register_op_map(opns.LEAKY_RELU)(symbol.Symbol) +_register_op_map(opns.ADAPTIVE_AVG_POOL2D)(symbol.Symbol) +_register_op_map(opns.AVG_POOL2D)(symbol.Symbol) +_register_op_map(opns.MAX_POOL2D)(symbol.Symbol) +_register_op_map(opns.SOFTMAX)(symbol.Symbol) +_register_op_map(opns.LOG_SOFTMAX)(symbol.Symbol) +_register_op_map(opns.EXP)(symbol.Symbol) +_register_op_map(opns.SIGMOID)(symbol.Symbol) +_register_op_map(opns.SUM)(symbol.Symbol) +_register_op_map(opns.MEAN)(symbol.Symbol) +_register_op_map(opns.MAX_AXIS)(symbol.Symbol) +_register_op_map(opns.MAXIMUM)(symbol.Symbol) +_register_op_map(opns.MINIMUM)(symbol.Symbol) +_register_op_map(opns.TUPLE)(symbol.Symbol) +_register_op_map(opns.REPEAT)(symbol.Symbol) +_register_op_map(opns.SQUEEZE)(symbol.Symbol) +_register_op_map(opns.FLATTEN)(symbol.Symbol) +_register_op_map(opns.BATCH_FLATTEN)(symbol.Symbol) +_register_op_map(opns.RESHAPE)(symbol.Symbol) +_register_op_map(opns.CONCAT)(symbol.Symbol) +_register_op_map(opns.SPLIT)(symbol.Symbol) +_register_op_map(opns.TRANSPOSE)(symbol.Symbol) +_register_op_map(opns.BROADCAST_TO)(symbol.Symbol) +_register_op_map(opns.EXPAND_DIMS)(symbol.Symbol) +_register_op_map(opns.TILE)(symbol.Symbol) +_register_op_map(opns.WHERE)(symbol.Symbol) +_register_op_map(opns.GREATER)(symbol.Symbol) +_register_op_map(opns.STRIDED_SLICE)(symbol.Symbol) +_register_op_map(opns.SLICE_LIKE)(symbol.Symbol) +_register_op_map(opns.GET_VALID_COUNT)(symbol.Symbol) +_register_op_map(opns.NON_MAX_SUPRESSION)(symbol.Symbol) +_register_op_map(opns.CEIL)(symbol.Symbol) +_register_op_map(opns.RIGHT_SHIFT)(symbol.Symbol) +_register_op_map(opns.AS_TYPE)(symbol.Symbol) +_register_op_map(opns.ADV_INDEX)(symbol.Symbol) +_register_op_map(opns.CALL_TIR)(symbol.Symbol) +_register_op_map(opns.CALL_DPS_PACKED)(symbol.Symbol) +_register_op_map(opns.ADD)(symbol.Symbol) +_register_op_map(opns.SUB)(symbol.Symbol) +_register_op_map(opns.MATMUL)(symbol.Symbol) +_register_op_map(opns.DIV)(symbol.Symbol) +_register_op_map(opns.NEGATIVE)(symbol.Symbol) +_register_op_map(opns.ABS)(symbol.Symbol) +_register_op_map(opns.LOG)(symbol.Symbol) +_register_op_map(opns.SQRT)(symbol.Symbol) +_register_op_map(opns.POW)(symbol.Symbol) +_register_op_map(opns.PASS)(symbol.Symbol) +_register_op_map(opns.ARANGE)(symbol.Symbol) +_register_op_map(opns.ZEROS_LIKE)(symbol.Symbol) +_register_op_map(opns.ONES_LIKE)(symbol.Symbol) +_register_op_map(opns.IF)(symbol.Symbol) +_register_op_map(opns.ARGWHERE)(symbol.Symbol) +_register_op_map(opns.REQUANT)(symbol.Symbol) +_register_op_map(opns.PCLIP)(symbol.Symbol) +_register_op_map(opns.RS_PCLIP)(symbol.Symbol) +_register_op_map(opns.LUT)(symbol.Symbol) diff --git a/python/mrt/mir/opns.py b/python/mrt/mir/opns.py index ed9ac2a..31da253 100644 --- a/python/mrt/mir/opns.py +++ b/python/mrt/mir/opns.py @@ -1,14 +1,6 @@ """ MRT operator names """ -import typing - -MRT_OP_SET: typing.Set[str] = set() -def _register_op_list(*op_names: str): - for op_name in op_names: - if op_name not in MRT_OP_SET: - MRT_OP_SET.add(op_name) VAR = "var" -_register_op_list(VAR) DROP_OUT = "nn.dropout" CONV2D = "nn.conv2d" @@ -22,29 +14,22 @@ def _register_op_list(*op_names: str): ADAPTIVE_AVG_POOL2D = "nn.adaptive_avg_pool2d" AVG_POOL2D = "nn.avg_pool2d" MAX_POOL2D = "nn.max_pool2d" -_register_op_list(DROP_OUT, CONV2D, DENSE, BATCH_NORM, RELU, - HARDTANH, SILU, LEAKY_RELU, ADAPTIVE_AVG_POOL2D, - AVG_POOL2D, MAX_POOL2D) SOFTMAX = "nn.softmax" LOG_SOFTMAX = "nn.log_softmax" -_register_op_list(SOFTMAX, LOG_SOFTMAX) EXP = "exp" SIGMOID = "sigmoid" -_register_op_list(EXP, SIGMOID) SUM = "sum" MEAN = "mean" MAX_AXIS = "max" MAXIMUM = "maximum" MINIMUM = "minimum" -_register_op_list(SUM, MEAN, MAX_AXIS, MAXIMUM, MINIMUM) # =========== NON-CALC ops =============== TUPLE = "Tuple" TUPLE_GET_ITEM = "TupleGetItem" -_register_op_list(TUPLE, TUPLE_GET_ITEM) REPEAT = "repeat" SQUEEZE = "squeeze" @@ -55,12 +40,9 @@ def _register_op_list(*op_names: str): SPLIT = "split" TRANSPOSE = "transpose" BROADCAST_TO = "broadcast_to" -_register_op_list(REPEAT, SQUEEZE, FLATTEN, BATCH_FLATTEN, RESHAPE, - CONCAT, SPLIT, TRANSPOSE, BROADCAST_TO, ) EXPAND_DIMS = "expand_dims" TILE = "tile" -_register_op_list(EXPAND_DIMS, TILE) WHERE = "where" GREATER = "greater" @@ -68,7 +50,6 @@ def _register_op_list(*op_names: str): SLICE_LIKE = "slice_like" GET_VALID_COUNT = "vision.get_valid_counts" NON_MAX_SUPRESSION = "vision.non_max_suppression" -_register_op_list(WHERE, GREATER, STRIDED_SLICE, SLICE_LIKE, GET_VALID_COUNT, NON_MAX_SUPRESSION) # relax clip attrs from a_min/a_max to min/max CLIP = "clip" @@ -77,14 +58,11 @@ def _register_op_list(*op_names: str): # relax support astype instead of cast AS_TYPE = "astype" # CAST = "cast" -_register_op_list(CLIP, CEIL, RIGHT_SHIFT, AS_TYPE) ADV_INDEX = "adv_index" -_register_op_list(ADV_INDEX) CALL_TIR = "call_tir" CALL_DPS_PACKED = "call_dps_packed" -_register_op_list(CALL_TIR, CALL_DPS_PACKED) # ======= binary ops ============= @@ -93,7 +71,6 @@ def _register_op_list(*op_names: str): MUL = "multiply" MATMUL = "matmul" DIV = "divide" -_register_op_list(ADD, SUB, MUL, MATMUL, DIV) # ======= unary ops ============== @@ -104,17 +81,14 @@ def _register_op_list(*op_names: str): POW = "pow" PASS = "pass" -_register_op_list(NEGATIVE, ABS, LOG, SQRT, POW, PASS) # ======= auto generate op ========= ARANGE = "arange" ZEROS_LIKE = "zeros_like" ONES_LIKE = "ones_like" -_register_op_list(ARANGE, ZEROS_LIKE, ONES_LIKE) # ======= control flow op =========== IF = "if" ARGWHERE = "argwhere" -_register_op_list(IF, ARGWHERE) # ======= mrt requant op ========== REQUANT = "mrt.requant" @@ -124,9 +98,7 @@ def _register_op_list(*op_names: str): """ right shift precision clip """ LUT = "mrt.lut" """ look up table, equals adv_index in tvm """ -_register_op_list(REQUANT, PCLIP, RS_PCLIP, LUT) def Opname2Funcname(op_name: str) -> str: return op_name.replace('.', '_') -#print('MRT_OP_SET:', MRT_OP_SET) diff --git a/python/mrt/mir/symbolpass.py b/python/mrt/mir/simple_pass.py similarity index 50% rename from python/mrt/mir/symbolpass.py rename to python/mrt/mir/simple_pass.py index e9d65ee..2fc362e 100644 --- a/python/mrt/mir/symbolpass.py +++ b/python/mrt/mir/simple_pass.py @@ -14,18 +14,13 @@ # mrt op visits -class SymbolPass: +class SimplePass: symbol: _symbol.Symbol - params: ParametersT - - def __init__(self, symbol: _symbol.Symbol, params: ParametersT): - self.symbol = symbol - self.params = params - def is_param(self, symbol: _symbol.Symbol) -> bool: - return symbol.op_name == opns.VAR and symbol.name in self.params + def __init__(self, symbol: _symbol.Symbol): + self.symbol = symbol - def visit(self) -> _symbol.Symbol: + def visit(self, custom_func: typing.Callable[[Symbol], typing.Optional[Symbol]] = None) -> _symbol.Symbol: env: typing.Dict[str, _symbol.Symbol] = {} for sym in _symbol.sym2list(self.symbol): assert sym.name not in env, f'{sym.name} NotIn env!' @@ -33,11 +28,8 @@ def visit(self) -> _symbol.Symbol: # Updating args as passed symbol in env_dict sym = sym.copy(args = [env[arg_sym.name] for arg_sym in sym.args]) assert isinstance(sym, _symbol.Symbol), sym - - if sym.op_name == opns.DROP_OUT: - #print('ddrroopped_out', getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) or sym) - pass - out = getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) or sym + out = custom_func(sym) if custom_func else getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) + out = out or sym assert isinstance(out, _symbol.Symbol), out env[sym.name] = out return env[self.symbol.name] @@ -46,15 +38,31 @@ def _default_visit_op(self, op: _symbol.Symbol) -> _symbol.Symbol: return op -# register mrt op default_visit -for op_name in opns.MRT_OP_SET: +# mrt op visits with params, variables +class InferPass(SimplePass): + params: ParametersT + + def is_param(self, symbol: _symbol.Symbol) -> bool: + return symbol.op_name == opns.VAR and symbol.name in self.params + + def get_param(self, symbol: _symbol.Symbol) -> OpNumpyT: + assert self.is_param(symbol) + return self.params[symbol.name] if self.is_param(symbol) else [] + + def __init__(self, symbol: _symbol.Symbol, params: ParametersT): + self.symbol = symbol + self.params = params + + +# Register MRT all op's default_visit_op function +for op_name in opclass.MRT_OP_MAP.keys(): funcSuffix = opns.Opname2Funcname(op_name) - setattr(SymbolPass, f"visit_{funcSuffix}", SymbolPass._default_visit_op) - #print(f"visit_, {op_name} => {funcSuffix}", getattr(SymbolPass, f"visit_{funcSuffix}")) + setattr(SimplePass, f"visit_{funcSuffix}", SimplePass._default_visit_op) + #print(f"visit_, {op_name} => {funcSuffix}", getattr(SimplePass, f"visit_{funcSuffix}")) -# mrt symbol pass -class FuseDropoutPass(SymbolPass): +# mrt symbol simple pass +class FuseDropoutPass(SimplePass): def visit_nn_dropout(self, sym: _symbol.Symbol) -> _symbol.Symbol: # make sure op fit again if sym.op_name == opns.DROP_OUT: @@ -62,7 +70,48 @@ def visit_nn_dropout(self, sym: _symbol.Symbol) -> _symbol.Symbol: return sym -class FuseDividePass(SymbolPass): +class FuseTupleGetItemPass(SimplePass): + def visit_TupleGetItem(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.TUPLE_GET_ITEM: + return sym + sym_ : opclass.TupleGetItem = sym + assert sym_.index == 0 + return sym_.args[0] + return sym + + +class FuseBatchNormPass(InferPass): + def visit_nn_batch_norm(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.BATCH_NORM: + X, Gamma, Beta, Mean, Var = sym.args + Gamma = self.get_param(Gamma) + Beta = self.get_param(Beta) + Mean = self.get_param(Mean) + Var = self.get_param(Var) + return sym + return sym + + +class FuseSoftmaxPass(SimplePass): + def visit_nn_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.SOFTMAX: + return self.args[0] + return sym + + def visit_nn_log_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.LOG_SOFTMAX: + return self.args[0] + return sym + + +class FuseMeanPass(SimplePass): + def visit_mean(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.MEAN: + return sym + return sym + + +class FuseDividePass(InferPass): def visit_divide(self, sym: _symbol.Symbol) -> _symbol.Symbol: if sym.op_name == opns.DIV: argA = sym.args[0] @@ -70,15 +119,6 @@ def visit_divide(self, sym: _symbol.Symbol) -> _symbol.Symbol: assert self.is_param(argB), f'NotParam: {argB}' # TODO: fixit #argB = argB.from_np_data(1. / argB.numpy()) - return opclass.Multiply(sym.name, {'args':[argA, argB]}) - return sym - - -class FuseTupleGetItemPass(SymbolPass): - def visit_TupleGetItem(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.TUPLE_GET_ITEM: - sym_ : opclass.TupleGetItem = sym - assert sym_.index == 0 - return sym_.args[0] + return opclass.MRT_OP_MAP[opns.MUL](argA, argB) return sym diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index 1832d87..c1e6bd7 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -19,6 +19,8 @@ "filter_operators", ] +SelfSymbol = typing.TypeVar("SelfSymbol", bound="Symbol") + def _format_printer(data): if isinstance(data, dict): data = ["{}={}".format(k, _format_printer(v)) \ diff --git a/tests/mir/test.op_create.py b/tests/mir/test.op_create.py index cf7e61c..eab2d90 100644 --- a/tests/mir/test.op_create.py +++ b/tests/mir/test.op_create.py @@ -22,12 +22,14 @@ def test_create_conv2d_op(): - #class CONV2D(Symbol): - # strides: typing.Tuple[int, int] = (1,1) - # padding: typing.Optional[typing.Tuple[int, int, int, int]] = (0,0,0,0) - # create mrt op symbol, def func - print('mrt Conv2D Op Class:', opclass.Conv2D) - conv2d_a = opclass.MRT_OP_MAP[opns.CONV2D]('conv2d_a', args=[[],[],[]], attrs={'strides':(1,1), 'padding':None}) + + X = opclass.Variable(name="x", shape=(1, 3, 224, 224,), dtype="float") + W = opclass.Variable(name="w", shape=(32, 3, 10, 10,), dtype="float") + assert [shp for shp in X.shape] == [shp for shp in (1, 3, 224, 224)], f'Wrong X shape {X.shape}' + assert X.dtype == "float", f'Wrong X dtype {X.dtype}' + + # Symbol Init using opclass OP + conv2d_a = opclass.Conv2D(X, W, name='conv2d_a', strides=(2,2)) assert isinstance(conv2d_a, sx.Symbol), 'conv2d_a isnot a symbol' assert isinstance(conv2d_a, opclass.Conv2D), 'conv2d_a isnot a Conv2D' @@ -41,27 +43,87 @@ def test_create_conv2d_op(): print(f'Show {conv2d_a.name} {conv2d_a}') # test Conv2D clone mode - conv2d_b = opclass.MRT_OP_MAP[opns.CONV2D](conv2d_a) + conv2d_b = conv2d_a.copy() assert isinstance(conv2d_b, sx.Symbol), 'conv2d_b isnot a symbol' assert isinstance(conv2d_b, opclass.Conv2D), 'conv2d_b isnot a Conv2D' - assert conv2d_b.attrs == conv2d_a.attrs + assert conv2d_b.attrs == conv2d_a.attrs, f'a: {conv2d_b.attrs} != b: {conv2d_a.attrs}' + + # test Dict to Find Class and Init + conv2d_c = opclass.MRT_OP_MAP[opns.CONV2D](X, W, strides=(2,2)) + assert isinstance(conv2d_c, opclass.Conv2D), 'conv2d_c isnot a Conv2D' + + # test Variable clone mode + X1 = X.copy() + assert X1.shape == X.shape + assert X1.dtype == X.dtype + + # test: Symbol Compatible Mode + args = [X1, W] + attrs = {'strides':(3,3)} + + # Symbol Compatible Init + conv2d_d = opclass.Conv2D(*args, name='conv2d_d', **attrs) + conv2d_e = opclass.Conv2D(*args, **attrs) + assert isinstance(conv2d_d, opclass.Conv2D), 'conv2d_d isnot a Conv2D' + assert isinstance(conv2d_e, opclass.Conv2D), 'conv2d_e isnot a Conv2D' + + return True + + +def test_create_symbol_graph(): + X0 = opclass.Variable(name="x", shape=(1, 3, 224, 224,), dtype="float") + W0 = opclass.Variable(name="w", shape=(32, 3, 10, 10,), dtype="float") + conv2d_a = opclass.Conv2D(X0, W0, name='conv2d_a', strides=(1,1)) + + W1 = opclass.Variable(shape=(16, 3, 12, 12,), dtype="float") + conv2d_b = opclass.Conv2D(conv2d_a, W1, name='conv2d_b', strides=(1,1)) + symlist = sx.sym2list(conv2d_b) + + assert symlist[0] == X0 + assert symlist[1] == W0 + + for id_ in range(len(symlist)): + print(id_, symlist[id_]) + return True -# TODO: -#def test_create_symbol_graph(): +def test_create_batch_norm_op(): + X = opclass.Variable(name="x", shape=(1, 32, 128, 128,), dtype="float") + Gamma = opclass.Variable(name="gamma", shape=(32,), dtype="float") + Beta = opclass.Variable(name="beta", shape=(32,), dtype="float") + Mean = opclass.Variable(name="mean", shape=(32,), dtype="float") + Var = opclass.Variable(name="var", shape=(32,), dtype="float") + batch_norm_a = opclass.BatchNorm(X, Gamma, Beta, Mean, Var, axis=1, epsilon=1e-4) + + # attrs hint + assert batch_norm_a.args != None + assert batch_norm_a.attrs != None + assert batch_norm_a.axis != 0 + + # test clone mode + batch_norm_b = batch_norm_a.copy() + assert isinstance(batch_norm_b , opclass.BatchNorm) + + assert batch_norm_a.attrs == batch_norm_b.attrs, f'a: {batch_norm_a.attrs} != b: {batch_norm_b.attrs}' + assert len(batch_norm_a.args) == len(batch_norm_b.args), f'a: {len(batch_norm_a.args)} != b: {len(batch_norm_b.args)}' + + return True + if __name__ == "__main__": - print('MRT_OP_SET as:', opns.MRT_OP_SET) - assert len(opns.MRT_OP_SET) > 0 + print('MRT_OP_SET as:', opclass.MRT_OP_MAP.keys()) + assert len(opclass.MRT_OP_MAP.keys()) > 0 - print('MRT_OP_MAP Class as:', opclass.MRT_OP_MAP) - assert len(opclass.MRT_OP_MAP) > 0 assert opns.CONV2D in opclass.MRT_OP_MAP - - rltflag = test_create_conv2d_op() - print("\n" + "="*60 + "\n") - print('Passed Test!' if rltflag else 'Test Failed!') - print("\n" + "="*60 + "\n") + print('MRT_OP_MAP Conv2D Class as:', opclass.MRT_OP_MAP[opns.CONV2D]) + + test_id = 0 + for func_ in [test_create_conv2d_op, test_create_symbol_graph, test_create_batch_norm_op]: + rltflag = func_() + test_id += 1 + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id}!' if rltflag else f'Test{test_id} Failed!') + print("\n" + "="*60 + "\n") diff --git a/tests/mir/test.symbol_pass.py b/tests/mir/test.simple_pass.py similarity index 50% rename from tests/mir/test.symbol_pass.py rename to tests/mir/test.simple_pass.py index 7d109ef..9d20ee0 100644 --- a/tests/mir/test.symbol_pass.py +++ b/tests/mir/test.simple_pass.py @@ -19,7 +19,7 @@ from mrt.mir import helper, symbol as sx from mrt.mir import opns from mrt.mir import opclass -from mrt.mir import symbolpass +from mrt.mir import simple_pass def _get_alexnet_model(): """Get Alexnet MRT Model""" @@ -41,7 +41,7 @@ def _get_alexnet_model(): mrt_graph, mrt_params = pytorch_to_mrt(ep) return mrt_graph, mrt_params -def test_SymbolPass_FuseDropout(mrt_graph, mrt_params): +def test_SimplePass_FuseDropout(mrt_graph, mrt_params): symbol = mrt_graph['main'] #print(symbol) @@ -54,7 +54,7 @@ def test_SymbolPass_FuseDropout(mrt_graph, mrt_params): assert dropout_op_cnt>0, f'original model dropout op cnt {dropout_op_cnt} == zero!' # init FuseDropout Passer and execute visit - tfs : symbolpass.FuseDropoutPass = symbolpass.FuseDropoutPass(symbol, {}) + tfs : simple_pass.FuseDropoutPass = simple_pass.FuseDropoutPass(symbol) #print(getattr(tfs, f"visit_{opns.Opname2Funcname(opns.DROP_OUT)}")) symbol_passed = tfs.visit() @@ -74,15 +74,76 @@ def test_SymbolPass_FuseDropout(mrt_graph, mrt_params): return True + +def test_SimplePass_CustomFunc(mrt_graph): + symbol = mrt_graph['main'] + + print('\n=== Before CustomFunc Pass ===') + symlist = sx.sym2list(symbol) + + tfs : simple_pass.SimplePass = simple_pass.SimplePass(symbol) + conv2d_name_list = [] + def _filter_op(sym: sx.Symbol) -> sx.Symbol: + if sym.op_name == opns.CONV2D: + conv2d_name_list.append(sym.name) + return sym + + symbol_passed = tfs.visit(_filter_op) + + print('\n=== After CustomFunc Pass ===') + assert len(conv2d_name_list) > 0 + print(conv2d_name_list) + rlts = sx.sym2list(symbol_passed) + + return True + + +def test_SimplePass_FuseDropout_CustomFunc(mrt_graph): + symbol = mrt_graph['main'] + + print('\n=== Before FuseDropout CustomFunc Pass ===') + symlist = sx.sym2list(symbol) + dropout_op_cnt = 0 + for sym in symlist: + dropout_op_cnt += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt > 0, f'ori model dropout op cnt {dropout_op_cnt} == zero!' + + tfs : simple_pass.SimplePass = simple_pass.SimplePass(symbol) + def _nn_dropout(sym: sx.Symbol) -> sx.Symbol: + if sym.op_name == opns.DROP_OUT: + return sym.args[0] + return sym + symbol_passed = tfs.visit(_nn_dropout) + + print('\n=== After FuseDropout CustomFunc Pass ===') + rlts = sx.sym2list(symbol_passed) + dropout_op_cnt_af = 0 + for sym in rlts: + dropout_op_cnt_af += 1 if sym.op_name == opns.DROP_OUT else 0 + assert dropout_op_cnt_af == 0, f'passed model dropout op cnt {dropout_op_cnt_af} != zero!' + + return True + + if __name__ == "__main__": print("=== Testing SymbolPass ===") mrt_graph, mrt_params = _get_alexnet_model() print("Testing FuseDropoutPass for Model AlexNet") - rltflag = test_SymbolPass_FuseDropout(mrt_graph, mrt_params) + rltflag = test_SimplePass_FuseDropout(mrt_graph, mrt_params) + print("\n" + "="*60 + "\n") + print('Passed Test1!' if rltflag else 'Test1 Failed!') + print("\n" + "="*60 + "\n") + + rltflag = test_SimplePass_CustomFunc(mrt_graph) + print("\n" + "="*60 + "\n") + print('Passed Test2!' if rltflag else 'Test2 Failed!') + print("\n" + "="*60 + "\n") + print("Testing FuseDropout CustomFunc for Model AlexNet") + rltflag = test_SimplePass_FuseDropout_CustomFunc(mrt_graph) print("\n" + "="*60 + "\n") - print('Passed Test!' if rltflag else 'Test Failed!') + print('Passed Test3!' if rltflag else 'Test3 Failed!') print("\n" + "="*60 + "\n") From ab36f5fb39d55b1a52e8c90cc898ac7963436a93 Mon Sep 17 00:00:00 2001 From: corlfj Date: Mon, 22 Sep 2025 10:26:14 +0800 Subject: [PATCH 04/12] [mir]: opclass opfunc, more op --- python/mrt/mir/opclass.py | 904 +++++++++++++++++++++++++++++------- tests/mir/test.op_create.py | 20 +- 2 files changed, 758 insertions(+), 166 deletions(-) diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py index 0ded6ba..2e4cdf9 100644 --- a/python/mrt/mir/opclass.py +++ b/python/mrt/mir/opclass.py @@ -1,6 +1,6 @@ import typing import numpy as np -from dataclasses import dataclass, fields +from dataclasses import dataclass from mrt.common.utils import N from . import opns @@ -8,10 +8,14 @@ from .symbol import SelfSymbol #SelfSymbol = typing.TypeVar("SelfSymbol", bound="Symbol") -MRT_OP_MAP: typing.Dict[str, SelfSymbol] = {} + +SymbolCreator = typing.Union[typing.Callable[[typing.Any, typing.Any], typing.Type[symbol.Symbol]], SelfSymbol] +#SymbolCreator = typing.Union[typing.Callable[[...], symbol.Symbol], SelfSymbol] + +MRT_OP_MAP: typing.Dict[str, SymbolCreator] = {} def _register_op_map(op_name: str): - def _wrapper(clss: SelfSymbol = None) -> SelfSymbol: + def _wrapper(clss: SymbolCreator = None) -> SymbolCreator: if len(op_name) > 0 and clss != None: if op_name not in MRT_OP_MAP: MRT_OP_MAP[op_name] = clss @@ -22,30 +26,43 @@ def _wrapper(clss: SelfSymbol = None) -> SelfSymbol: return _wrapper -@dataclass(init=False) -class Variable(symbol.Symbol): - op_name = opns.VAR - - def __init__(self, name=None, op_name=None, shape:typing.Tuple = (), dtype=None, extra_attrs=None): - op_name = op_name or opns.VAR - assert op_name == opns.VAR - super().__init__(name=name or N.n(), op_name=op_name, args=[], attrs={}, extra_attrs=extra_attrs or {}) - self.shape = shape # will also update extra_attrs - self.dtype = dtype # will also update extra_attrs - - @classmethod - def from_dict(cls, d: dict, **kwargs): - data = cls.default_dict() - data.update(d) - data.update(kwargs) - data = cls.update_dict(data) - basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} - attrsdata = {k: data['extra_attrs'][k] for k in data['extra_attrs'] if k in ['shape', 'dtype']} - try: - out = cls(**attrsdata, **basedata) - except Exception as e: - raise e - return out +# OPs from external (not in MRT op), using custom op_name with default op_func +#y = extern_opfunc("tanh")(X) +def extern_opfunc(op_name: str): + def op_func(*args, **attrs): + return symbol.Symbol(*args, op_name=op_name, **attrs) + return op_func + + +def _from_dict_attrs(cls, d: dict, attr_keys:typing.List[str]=[], **kwargs): + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + attrsdata = {k: data['attrs'][k] for k in data['attrs'] if k in attr_keys} + try: + out = cls(*data['args'], **attrsdata, **basedata) + except Exception as e: + raise e + return out + +# OPs without attrs, just register function (funcName should be lower case) +def var(name=None, op_name=None, shape=(), dtype=float) -> symbol.Symbol: + op_name = op_name or opns.VAR + assert op_name == opns.VAR + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[], attrs={}, extra_attrs={'shape': shape or (), 'dtype': dtype or float}) + +#def _return_func_single_arg(op_name: op_name): +def relu(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.RELU + assert op_name == opns.RELU + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def silu(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.SILU + assert op_name == opns.SILU + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) @dataclass(init=False) @@ -84,49 +101,27 @@ def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0, assert op_name == opns.CONV2D super().__init__(name=name or N.n(), op_name=op_name, args=[X,W], attrs={'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size}, extra_attrs=extra_attrs or {}) - - # Copy from other instance of same opclass, must have specific attrs (or with default value) @classmethod def from_dict(cls, d: dict, **kwargs): - data = cls.default_dict() - data.update(d) - data.update(kwargs) - data = cls.update_dict(data) - basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} - attrsdata = {k: data['attrs'][k] for k in data['attrs'] if k in ['strides', 'padding', 'groups', 'dilation', 'kernel_size']} - try: - out = cls(data['args'][0], data['args'][1], **attrsdata, **basedata) - except Exception as e: - raise e - return out + return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation', 'kernel_size'], **kwargs) @dataclass(init=False) class Dropout(symbol.Symbol): op_name = opns.DROP_OUT @property - def rate(self) -> float: - default_val = 0.0 - return self.attrs['rate'] if 'rate' in self.attrs else default_val + def p(self) -> float: + default_val = 0.5 + return self.attrs['p'] if 'p' in self.attrs else default_val - def __init__(self, X, name=None, op_name=None, rate:float = 0, extra_attrs=None): + def __init__(self, X, name=None, op_name=None, p:float = 0.5, extra_attrs=None): op_name = op_name or opns.DROP_OUT assert op_name == opns.DROP_OUT - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'rate': rate}, extra_attrs=extra_attrs or {}) + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'p': p}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): - data = cls.default_dict() - data.update(d) - data.update(kwargs) - data = cls.update_dict(data) - basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} - attrsdata = {'rate': data['attrs']['rate']} - try: - out = cls(data['args'][0], **attrsdata, **basedata) - except Exception as e: - raise e - return out + return _from_dict_attrs(cls, d, ['p'], **kwargs) @dataclass(init=False) class Clip(symbol.Symbol): @@ -149,17 +144,7 @@ def __init__(self, X, name=None, op_name=None, min_:float = np.nan, max_:float = @classmethod def from_dict(cls, d: dict, **kwargs): - data = cls.default_dict() - data.update(d) - data.update(kwargs) - data = cls.update_dict(data) - basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} - attrsdata = {'min': data['attrs']['min'], 'max': data['attrs']['max']} - try: - out = cls(data['args'][0], **attrsdata, **basedata) - except Exception as e: - raise e - return out + return _from_dict_attrs(cls, d, ['min', 'max'], **kwargs) @dataclass(init=False) @@ -177,33 +162,18 @@ def epsilon(self) -> float: return self.attrs['epsilon'] if 'epsilon' in self.attrs else default_val @property - def center(self) -> bool: - default_val = True + def momentum(self) -> float: + default_val = 0.1 return self.attrs['center'] if 'center' in self.attrs else default_val - @property - def scale(self) -> bool: - default_val = True - return self.attrs['scale'] if 'scale' in self.attrs else default_val - - def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, center:bool = True, scale:bool = True, extra_attrs=None): + def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, extra_attrs=None): op_name = op_name or opns.BATCH_NORM assert op_name == opns.BATCH_NORM - super().__init__(name=name or N.n(), op_name=op_name, args=[X, Gamma, Beta, Mean, Var], attrs={'axis': axis, 'epsilon': epsilon, 'center': center, 'scale': scale}, extra_attrs=extra_attrs or {}) + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Gamma, Beta, Mean, Var], attrs={'axis': axis, 'epsilon': epsilon, 'momentum': momentum}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): - data = cls.default_dict() - data.update(d) - data.update(kwargs) - data = cls.update_dict(data) - basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} - attrsdata = {k: data['attrs'][k] for k in data['attrs'] if k in ['axis', 'epsilon', 'center', 'scale']} - try: - out = cls(*data['args'], **attrsdata, **basedata) - except Exception as e: - raise e - return out + return _from_dict_attrs(cls, d, ['axis', 'epsilon', 'momentum'], **kwargs) @dataclass(init=False) @@ -222,85 +192,707 @@ def __init__(self, X, name=None, op_name=None, index:int = 0, extra_attrs=None): @classmethod def from_dict(cls, d: dict, **kwargs): - data = cls.default_dict() - data.update(d) - data.update(kwargs) - data = cls.update_dict(data) - basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} - attrsdata = {'index': data['attrs']['index']} - try: - out = cls(data['args'][0], **attrsdata, **basedata) - except Exception as e: - raise e - return out - - -_register_op_map(opns.VAR)(Variable) + return _from_dict_attrs(cls, d, ['index'], **kwargs) + + +@dataclass(init=False) +class LeakyRelu(symbol.Symbol): + op_name = opns.LEAKY_RELU + + @property + def negative_slope(self) -> float: + default_val = 1e-2 + return self.attrs['negative_slope'] if 'negative_slope' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, negative_slope:float = 1e-2, extra_attrs=None): + op_name = op_name or opns.LEAKY_RELU + assert op_name == opns.LEAKY_RELU + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'negative_slope': negative_slope}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['negative_slope'], **kwargs) + + +def dense(X, W, B, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.DENSE + assert op_name == opns.DENSE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, W, B], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Hardtanh(symbol.Symbol): + op_name = opns.HARDTANH + + @property + def min_val(self) -> float: + default_val = -1.0 + return self.attrs['min_val'] if 'min_val' in self.attrs else default_val + + @property + def max_val(self) -> float: + default_val = 1.0 + return self.attrs['max_val'] if 'max_val' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, min_val:float = -1.0, max_val:float = 1.0, extra_attrs=None): + op_name = op_name or opns.HARDTANH + assert op_name == opns.HARDTANH + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'min_val': min_val, 'max_val':max_val}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['min_val', 'max_val'], **kwargs) + + +@dataclass(init=False) +class AdaptiveAvgPool2D(symbol.Symbol): + op_name = opns.ADAPTIVE_AVG_POOL2D + + @property + def output_size(self) -> typing.Union[int, typing.Tuple[int, int]]: + default_val = 0 + return self.attrs['output_size'] if 'output_size' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=0, extra_attrs=None): + op_name = op_name or opns.ADAPTIVE_AVG_POOL2D + assert op_name == opns.ADAPTIVE_AVG_POOL2D + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'output_size': output_size}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['output_size'], **kwargs) + +@dataclass(init=False) +class AvgPool2D(symbol.Symbol): + op_name = opns.AVG_POOL2D + + @property + def pool_size(self) -> typing.Tuple[int, int]: + default_val = (2, 2) + return self.attrs['pool_size'] if 'pool_size' in self.attrs else default_val + @property + def strides(self): + default_val = None + return self.attrs['strides'] if 'strides' in self.attrs else default_val + @property + def padding(self) -> int: + default_val = 0 + return self.attrs['padding'] if 'padding' in self.attrs else default_val + @property + def ceil_mode(self) -> bool: + default_val = False + return self.attrs['ceil_mode'] if 'ceil_mode' in self.attrs else default_val + @property + def layout(self) -> str: + default_val = 'NCHW' + return self.attrs['layout'] if 'layout' in self.attrs else default_val + @property + def count_include_pad(self) -> bool: + default_val = True + return self.attrs['count_include_pad'] if 'count_include_pad' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, pool_size=(2,2), strides=None, padding=0, ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): + op_name = op_name or opns.AVG_POOL2D + assert op_name == opns.AVG_POOL2D + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout, 'count_include_pad':count_include_pad}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['pool_size', 'strides', 'padding', 'ceil_mode', 'layout', 'count_include_pad'], **kwargs) + +@dataclass(init=False) +class MaxPool2D(symbol.Symbol): + op_name = opns.MAX_POOL2D + + @property + def pool_size(self) -> typing.Tuple[int, int]: + default_val = (2, 2) + return self.attrs['pool_size'] if 'pool_size' in self.attrs else default_val + @property + def layout(self) -> str: + default_val = 'NCHW' + return self.attrs['layout'] if 'layout' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, pool_size=(2,2), layout='NCHW', extra_attrs=None): + op_name = op_name or opns.MAX_POOL2D + assert op_name == opns.MAX_POOL2D + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'layout':layout}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['pool_size', 'layout'], **kwargs) + + +@dataclass(init=False) +class Softmax(symbol.Symbol): + op_name = opns.SOFTMAX + + @property + def axis(self) -> typing.Optional[int]: + default_val = None + return self.attrs['axis'] if 'axis' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): + op_name = op_name or opns.SOFTMAX + assert op_name == opns.SOFTMAX + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'axis':axis}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis'], **kwargs) + + +@dataclass(init=False) +class LogSoftmax(symbol.Symbol): + op_name = opns.LOG_SOFTMAX + + @property + def axis(self) -> typing.Optional[int]: + default_val = None + return self.attrs['axis'] if 'axis' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): + op_name = op_name or opns.LOG_SOFTMAX + assert op_name == opns.LOG_SOFTMAX + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'axis':axis}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis'], **kwargs) + + +def exp(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.EXP + assert op_name == opns.EXP + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def sigmoid(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.SIGMOID + assert op_name == opns.SIGMOID + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Sum(symbol.Symbol): + op_name = opns.SUM + + @property + def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: + default_val = None + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + @property + def keepdim(self) -> typing.Optional[bool]: + default_val = None + return self.attrs['keepdim'] if 'keepdim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + op_name = op_name or opns.SUM + assert op_name == opns.SUM + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim, 'keepdim': keepdim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) + + +@dataclass(init=False) +class Mean(symbol.Symbol): + op_name = opns.MEAN + + @property + def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: + default_val = None + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + @property + def keepdim(self) -> typing.Optional[bool]: + default_val = None + return self.attrs['keepdim'] if 'keepdim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + op_name = op_name or opns.MEAN + assert op_name == opns.MEAN + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim, 'keepdim': keepdim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) + + +@dataclass(init=False) +class MaxAxis(symbol.Symbol): + op_name = opns.MAX_AXIS + + @property + def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: + default_val = None + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + @property + def keepdim(self) -> typing.Optional[bool]: + default_val = None + return self.attrs['keepdim'] if 'keepdim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + op_name = op_name or opns.MAX_AXIS + assert op_name == opns.MAX_AXIS + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim, 'keepdim': keepdim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) + +def maximum(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.MAXIMUM + assert op_name == opns.MAXIMUM + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def minimum(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.MINIMUM + assert op_name == opns.MINIMUM + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def repeat(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.REPEAT + assert op_name == opns.REPEAT + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Squeeze(symbol.Symbol): + op_name = opns.SQUEEZE + + @property + def dim(self) -> typing.Optional[int]: + default_val = None + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim=None, extra_attrs=None): + op_name = op_name or opns.SQUEEZE + assert op_name == opns.SQUEEZE + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim'], **kwargs) + +@dataclass(init=False) +class Flatten(symbol.Symbol): + op_name = opns.FLATTEN + + @property + def start_dim(self) -> int: + default_val = 0 + return self.attrs['start_dim'] if 'start_dim' in self.attrs else default_val + + @property + def end_dim(self) -> int: + default_val = -1 + return self.attrs['end_dim'] if 'end_dim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, start_dim=0, end_dim=-1, extra_attrs=None): + op_name = op_name or opns.FLATTEN + assert op_name == opns.FLATTEN + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'start_dim': start_dim, 'end_dim':end_dim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['start_dim', 'end_dim'], **kwargs) + + +@dataclass(init=False) +class Reshape(symbol.Symbol): + op_name = opns.RESHAPE + + @property + def newshape(self) -> typing.Tuple[int,...]: + default_val = None + return self.attrs['newshape'] if 'newshape' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): + op_name = op_name or opns.RESHAPE + assert op_name == opns.RESHAPE + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['newshape'], **kwargs) + + +@dataclass(init=False) +class Concat(symbol.Symbol): + op_name = opns.CONCAT + + @property + def axis(self) -> int: + default_val = 0 + return self.attrs['axis'] if 'axis' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): + op_name = op_name or opns.CONCAT + assert op_name == opns.CONCAT + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'axis': axis}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['axis'], **kwargs) + + +@dataclass(init=False) +class Split(symbol.Symbol): + op_name = opns.SPLIT + + @property + def split_size(self) -> typing.List[int]: + default_val = [] + return self.attrs['split_size'] if 'split_size' in self.attrs else default_val + + @property + def dim(self) -> int: + default_val = 0 + return self.attrs['dim'] if 'dim' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, split_size=[], dim=0, extra_attrs=None): + op_name = op_name or opns.SPLIT + assert op_name == opns.SPLIT + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'split_size': split_size, 'dim': dim}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['split_size', 'dim'], **kwargs) + + +@dataclass(init=False) +class Transpose(symbol.Symbol): + op_name = opns.TRANSPOSE + + @property + def dim0(self) -> int: + default_val = 0 + return self.attrs['dim0'] if 'dim0' in self.attrs else default_val + + @property + def dim1(self) -> int: + default_val = 0 + return self.attrs['dim1'] if 'dim1' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dim0=0, dim1=0, extra_attrs=None): + op_name = op_name or opns.TRANSPOSE + assert op_name == opns.TRANSPOSE + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim0': dim0, 'dim1': dim1}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dim0', 'dim1'], **kwargs) + + +@dataclass(init=False) +class BroadcastTo(symbol.Symbol): + op_name = opns.BROADCAST_TO + + @property + def newshape(self) -> typing.Tuple[int,...]: + default_val = None + return self.attrs['newshape'] if 'newshape' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): + op_name = op_name or opns.BROADCAST_TO + assert op_name == opns.BROADCAST_TO + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['newshape'], **kwargs) + + +@dataclass(init=False) +class ExpandDims(symbol.Symbol): + op_name = opns.EXPAND_DIMS + + @property + def newshape(self) -> typing.Tuple[int,...]: + default_val = None + return self.attrs['newshape'] if 'newshape' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): + op_name = op_name or opns.EXPAND_DIMS + assert op_name == opns.EXPAND_DIMS + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['newshape'], **kwargs) + + +@dataclass(init=False) +class Tile(symbol.Symbol): + op_name = opns.TILE + + @property + def dims(self) -> typing.Tuple[int,...]: + default_val = None + return self.attrs['dims'] if 'dims' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, dims=None, extra_attrs=None): + op_name = op_name or opns.TILE + assert op_name == opns.TILE + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dims': dims}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dims'], **kwargs) + +def where(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.WHERE + assert op_name == opns.WHERE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def greater(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.GREATER + assert op_name == opns.GREATER + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class NonMaxSuppression(symbol.Symbol): + op_name = opns.NON_MAX_SUPRESSION + + @property + def iou_threshold(self) -> float: + default_val = 0.5 + return self.attrs['iou_threshold'] if 'iou_threshold' in self.attrs else default_val + @property + def score_threshold(self) -> typing.Optional[float]: + default_val = None + return self.attrs['score_threshold'] if 'score_threshold' in self.attrs else default_val + + def __init__(self, X, name=None, op_name=None, iou_threshold=0.5, score_threshold=None, extra_attrs=None): + op_name = op_name or opns.NON_MAX_SUPRESSION + assert op_name == opns.NON_MAX_SUPRESSION + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'iou_threshold': iou_threshold,'score_threshold':score_threshold}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['dims'], **kwargs) + + +def ceil(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.CEIL + assert op_name == opns.CEIL + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def rightShift(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.RIGHT_SHIFT + assert op_name == opns.RIGHT_SHIFT + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Add(symbol.Symbol): + op_name = opns.ADD + + @property + def alpha(self) -> int: + default_val = 1 + return self.attrs['alpha'] if 'alpha' in self.attrs else default_val + + def __init__(self, X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): + op_name = op_name or opns.ADD + assert op_name == opns.ADD + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={'alpha': alpha}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['alpha'], **kwargs) + +@dataclass(init=False) +class Sub(symbol.Symbol): + op_name = opns.SUB + + @property + def alpha(self) -> int: + default_val = 1 + return self.attrs['alpha'] if 'alpha' in self.attrs else default_val + + def __init__(self, X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): + op_name = op_name or opns.SUB + assert op_name == opns.SUB + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={'alpha': alpha}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['alpha'], **kwargs) + +def mul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.MUL + assert op_name == opns.MUL + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +def matMul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.MATMUL + assert op_name == opns.MATMUL + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Div(symbol.Symbol): + op_name = opns.DIV + + @property + def rounding_mode(self) -> typing.Optional[str]: + default_val = None + return self.attrs['rounding_mode'] if 'rounding_mode' in self.attrs else default_val + + def __init__(self, X, Y, name=None, op_name=None, rounding_mode=None, extra_attrs=None): + op_name = op_name or opns.DIV + assert op_name == opns.DIV + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={'rounding_mode': rounding_mode}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['rounding_mode'], **kwargs) + +def negative(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.NEGATIVE + assert op_name == opns.NEGATIVE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def abs(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.ABS + assert op_name == opns.ABS + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def log(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.LOG + assert op_name == opns.LOG + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def sqrt(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.SQRT + assert op_name == opns.SQRT + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def pow(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.POW + assert op_name == opns.POW + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) + +def pass_(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.PASS + assert op_name == opns.PASS + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +@dataclass(init=False) +class Arange(symbol.Symbol): + op_name = opns.ARANGE + + @property + def end(self) -> int: + default_val = 0 + return self.attrs['end'] if 'end' in self.attrs else default_val + + @property + def start(self) -> int: + default_val = 0 + return self.attrs['start'] if 'start' in self.attrs else default_val + + @property + def step(self) -> int: + default_val = 1 + return self.attrs['step'] if 'step' in self.attrs else default_val + + def __init__(self, name=None, op_name=None, end=0, start=0, step=1, extra_attrs=None): + op_name = op_name or opns.ARANGE + assert op_name == opns.ARANGE + super().__init__(name=name or N.n(), op_name=op_name, args=[], attrs={'end': end, 'start': start, 'step': step}, extra_attrs=extra_attrs or {}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['end', 'start', 'step'], **kwargs) + +def zerosLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.ZEROS_LIKE + assert op_name == opns.ZEROS_LIKE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + +def onesLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: + op_name = op_name or opns.ONES_LIKE + assert op_name == opns.ONES_LIKE + return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) + + +_register_op_map(opns.VAR)(var) +_register_op_map(opns.RELU)(relu) + _register_op_map(opns.CONV2D)(Conv2D) _register_op_map(opns.DROP_OUT)(Dropout) _register_op_map(opns.CLIP)(Clip) _register_op_map(opns.BATCH_NORM)(BatchNorm) _register_op_map(opns.TUPLE_GET_ITEM)(TupleGetItem) +_register_op_map(opns.LEAKY_RELU)(LeakyRelu) + +_register_op_map(opns.MUL)(mul) +_register_op_map(opns.DENSE)(dense) +_register_op_map(opns.HARDTANH)(Hardtanh) +_register_op_map(opns.SILU)(silu) +_register_op_map(opns.ADAPTIVE_AVG_POOL2D)(AdaptiveAvgPool2D) +_register_op_map(opns.AVG_POOL2D)(AvgPool2D) +_register_op_map(opns.MAX_POOL2D)(MaxPool2D) +_register_op_map(opns.SOFTMAX)(Softmax) +_register_op_map(opns.LOG_SOFTMAX)(LogSoftmax) +_register_op_map(opns.EXP)(exp) +_register_op_map(opns.SIGMOID)(sigmoid) +_register_op_map(opns.SUM)(Sum) +_register_op_map(opns.MEAN)(Mean) +_register_op_map(opns.MAX_AXIS)(MaxAxis) +_register_op_map(opns.MAXIMUM)(maximum) +_register_op_map(opns.MINIMUM)(minimum) + + +_register_op_map(opns.REPEAT)(repeat) +_register_op_map(opns.SQUEEZE)(Squeeze) +_register_op_map(opns.FLATTEN)(Flatten) +_register_op_map(opns.RESHAPE)(Reshape) +_register_op_map(opns.CONCAT)(Concat) +_register_op_map(opns.SPLIT)(Split) +_register_op_map(opns.TRANSPOSE)(Transpose) +_register_op_map(opns.BROADCAST_TO)(BroadcastTo) +_register_op_map(opns.EXPAND_DIMS)(ExpandDims) +_register_op_map(opns.TILE)(Tile) +_register_op_map(opns.WHERE)(where) +_register_op_map(opns.GREATER)(greater) +_register_op_map(opns.NON_MAX_SUPRESSION)(NonMaxSuppression) + +_register_op_map(opns.CEIL)(ceil) +_register_op_map(opns.RIGHT_SHIFT)(rightShift) + +_register_op_map(opns.ADD)(Add) +_register_op_map(opns.SUB)(Sub) +_register_op_map(opns.MATMUL)(matMul) +_register_op_map(opns.DIV)(Div) +_register_op_map(opns.NEGATIVE)(negative) +_register_op_map(opns.ABS)(abs) +_register_op_map(opns.LOG)(log) +_register_op_map(opns.SQRT)(sqrt) +_register_op_map(opns.POW)(pow) +_register_op_map(opns.PASS)(pass_) +_register_op_map(opns.ARANGE)(Arange) +_register_op_map(opns.ZEROS_LIKE)(zerosLike) +_register_op_map(opns.ONES_LIKE)(onesLike) + + # Add default register Class for MRT OP Not Implemented! -_register_op_map(opns.MUL)(symbol.Symbol) -_register_op_map(opns.DENSE)(symbol.Symbol) -_register_op_map(opns.RELU)(symbol.Symbol) -_register_op_map(opns.HARDTANH)(symbol.Symbol) -_register_op_map(opns.SILU)(symbol.Symbol) -_register_op_map(opns.LEAKY_RELU)(symbol.Symbol) -_register_op_map(opns.ADAPTIVE_AVG_POOL2D)(symbol.Symbol) -_register_op_map(opns.AVG_POOL2D)(symbol.Symbol) -_register_op_map(opns.MAX_POOL2D)(symbol.Symbol) -_register_op_map(opns.SOFTMAX)(symbol.Symbol) -_register_op_map(opns.LOG_SOFTMAX)(symbol.Symbol) -_register_op_map(opns.EXP)(symbol.Symbol) -_register_op_map(opns.SIGMOID)(symbol.Symbol) -_register_op_map(opns.SUM)(symbol.Symbol) -_register_op_map(opns.MEAN)(symbol.Symbol) -_register_op_map(opns.MAX_AXIS)(symbol.Symbol) -_register_op_map(opns.MAXIMUM)(symbol.Symbol) -_register_op_map(opns.MINIMUM)(symbol.Symbol) -_register_op_map(opns.TUPLE)(symbol.Symbol) -_register_op_map(opns.REPEAT)(symbol.Symbol) -_register_op_map(opns.SQUEEZE)(symbol.Symbol) -_register_op_map(opns.FLATTEN)(symbol.Symbol) -_register_op_map(opns.BATCH_FLATTEN)(symbol.Symbol) -_register_op_map(opns.RESHAPE)(symbol.Symbol) -_register_op_map(opns.CONCAT)(symbol.Symbol) -_register_op_map(opns.SPLIT)(symbol.Symbol) -_register_op_map(opns.TRANSPOSE)(symbol.Symbol) -_register_op_map(opns.BROADCAST_TO)(symbol.Symbol) -_register_op_map(opns.EXPAND_DIMS)(symbol.Symbol) -_register_op_map(opns.TILE)(symbol.Symbol) -_register_op_map(opns.WHERE)(symbol.Symbol) -_register_op_map(opns.GREATER)(symbol.Symbol) -_register_op_map(opns.STRIDED_SLICE)(symbol.Symbol) -_register_op_map(opns.SLICE_LIKE)(symbol.Symbol) -_register_op_map(opns.GET_VALID_COUNT)(symbol.Symbol) -_register_op_map(opns.NON_MAX_SUPRESSION)(symbol.Symbol) -_register_op_map(opns.CEIL)(symbol.Symbol) -_register_op_map(opns.RIGHT_SHIFT)(symbol.Symbol) -_register_op_map(opns.AS_TYPE)(symbol.Symbol) -_register_op_map(opns.ADV_INDEX)(symbol.Symbol) -_register_op_map(opns.CALL_TIR)(symbol.Symbol) -_register_op_map(opns.CALL_DPS_PACKED)(symbol.Symbol) -_register_op_map(opns.ADD)(symbol.Symbol) -_register_op_map(opns.SUB)(symbol.Symbol) -_register_op_map(opns.MATMUL)(symbol.Symbol) -_register_op_map(opns.DIV)(symbol.Symbol) -_register_op_map(opns.NEGATIVE)(symbol.Symbol) -_register_op_map(opns.ABS)(symbol.Symbol) -_register_op_map(opns.LOG)(symbol.Symbol) -_register_op_map(opns.SQRT)(symbol.Symbol) -_register_op_map(opns.POW)(symbol.Symbol) -_register_op_map(opns.PASS)(symbol.Symbol) -_register_op_map(opns.ARANGE)(symbol.Symbol) -_register_op_map(opns.ZEROS_LIKE)(symbol.Symbol) -_register_op_map(opns.ONES_LIKE)(symbol.Symbol) +_register_op_map(opns.TUPLE)(extern_opfunc(opns.TUPLE)) +_register_op_map(opns.AS_TYPE)(extern_opfunc(opns.AS_TYPE)) +_register_op_map(opns.ADV_INDEX)(extern_opfunc(opns.ADV_INDEX)) +_register_op_map(opns.CALL_TIR)(extern_opfunc(opns.CALL_TIR)) +_register_op_map(opns.CALL_DPS_PACKED)(extern_opfunc(opns.CALL_DPS_PACKED)) + _register_op_map(opns.IF)(symbol.Symbol) _register_op_map(opns.ARGWHERE)(symbol.Symbol) _register_op_map(opns.REQUANT)(symbol.Symbol) _register_op_map(opns.PCLIP)(symbol.Symbol) _register_op_map(opns.RS_PCLIP)(symbol.Symbol) _register_op_map(opns.LUT)(symbol.Symbol) + +_register_op_map(opns.BATCH_FLATTEN)(symbol.Symbol) +_register_op_map(opns.STRIDED_SLICE)(symbol.Symbol) +_register_op_map(opns.SLICE_LIKE)(symbol.Symbol) +_register_op_map(opns.GET_VALID_COUNT)(symbol.Symbol) + diff --git a/tests/mir/test.op_create.py b/tests/mir/test.op_create.py index eab2d90..2e1136e 100644 --- a/tests/mir/test.op_create.py +++ b/tests/mir/test.op_create.py @@ -23,8 +23,8 @@ def test_create_conv2d_op(): - X = opclass.Variable(name="x", shape=(1, 3, 224, 224,), dtype="float") - W = opclass.Variable(name="w", shape=(32, 3, 10, 10,), dtype="float") + X = opclass.var(name="x", shape=(1, 3, 224, 224,), dtype="float") + W = opclass.var(name="w", shape=(32, 3, 10, 10,), dtype="float") assert [shp for shp in X.shape] == [shp for shp in (1, 3, 224, 224)], f'Wrong X shape {X.shape}' assert X.dtype == "float", f'Wrong X dtype {X.dtype}' @@ -72,11 +72,11 @@ def test_create_conv2d_op(): def test_create_symbol_graph(): - X0 = opclass.Variable(name="x", shape=(1, 3, 224, 224,), dtype="float") - W0 = opclass.Variable(name="w", shape=(32, 3, 10, 10,), dtype="float") + X0 = opclass.var(name="x", shape=(1, 3, 224, 224,), dtype="float") + W0 = opclass.var(name="w", shape=(32, 3, 10, 10,), dtype="float") conv2d_a = opclass.Conv2D(X0, W0, name='conv2d_a', strides=(1,1)) - W1 = opclass.Variable(shape=(16, 3, 12, 12,), dtype="float") + W1 = opclass.var(shape=(16, 3, 12, 12,), dtype="float") conv2d_b = opclass.Conv2D(conv2d_a, W1, name='conv2d_b', strides=(1,1)) symlist = sx.sym2list(conv2d_b) @@ -90,11 +90,11 @@ def test_create_symbol_graph(): def test_create_batch_norm_op(): - X = opclass.Variable(name="x", shape=(1, 32, 128, 128,), dtype="float") - Gamma = opclass.Variable(name="gamma", shape=(32,), dtype="float") - Beta = opclass.Variable(name="beta", shape=(32,), dtype="float") - Mean = opclass.Variable(name="mean", shape=(32,), dtype="float") - Var = opclass.Variable(name="var", shape=(32,), dtype="float") + X = opclass.var(name="x", shape=(1, 32, 128, 128,), dtype="float") + Gamma = opclass.var(name="gamma", shape=(32,), dtype="float") + Beta = opclass.var(name="beta", shape=(32,), dtype="float") + Mean = opclass.var(name="mean", shape=(32,), dtype="float") + Var = opclass.var(name="var", shape=(32,), dtype="float") batch_norm_a = opclass.BatchNorm(X, Gamma, Beta, Mean, Var, axis=1, epsilon=1e-4) # attrs hint From d036f17888db66c51880e9098683823c0fef9fc8 Mon Sep 17 00:00:00 2001 From: corlfj Date: Tue, 23 Sep 2025 14:36:44 +0800 Subject: [PATCH 05/12] [mir]: add inferpass(with params) --- python/mrt/mir/simple_pass.py | 177 ++++++++++++++++++++++++++-------- python/mrt/mir/symbol.py | 5 +- tests/mir/test.simple_pass.py | 8 +- 3 files changed, 142 insertions(+), 48 deletions(-) diff --git a/python/mrt/mir/simple_pass.py b/python/mrt/mir/simple_pass.py index 2fc362e..a876b95 100644 --- a/python/mrt/mir/simple_pass.py +++ b/python/mrt/mir/simple_pass.py @@ -2,25 +2,27 @@ import typing from functools import wraps -from dataclasses import dataclass, fields +from dataclasses import dataclass -import mrt from mrt.common import config +#from mrt.runtime import inference from mrt.common.utils import * from mrt.common.types import * -from . import opns, opclass +from . import op, opns, opclass from . import symbol as _symbol # mrt op visits +@dataclass class SimplePass: symbol: _symbol.Symbol - def __init__(self, symbol: _symbol.Symbol): - self.symbol = symbol - - def visit(self, custom_func: typing.Callable[[Symbol], typing.Optional[Symbol]] = None) -> _symbol.Symbol: + """op-level visit of graph + infer different visit function with different op_name + return: head symbol processed + """ + def graph_visits(self) -> _symbol.Symbol: env: typing.Dict[str, _symbol.Symbol] = {} for sym in _symbol.sym2list(self.symbol): assert sym.name not in env, f'{sym.name} NotIn env!' @@ -28,7 +30,7 @@ def visit(self, custom_func: typing.Callable[[Symbol], typing.Optional[Symbol]] # Updating args as passed symbol in env_dict sym = sym.copy(args = [env[arg_sym.name] for arg_sym in sym.args]) assert isinstance(sym, _symbol.Symbol), sym - out = custom_func(sym) if custom_func else getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) + out = getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) out = out or sym assert isinstance(out, _symbol.Symbol), out env[sym.name] = out @@ -37,21 +39,72 @@ def visit(self, custom_func: typing.Callable[[Symbol], typing.Optional[Symbol]] def _default_visit_op(self, op: _symbol.Symbol) -> _symbol.Symbol: return op + """custom visit of graph + calling custom_func for all op_name + return: head symbol processed + """ + def custom_visits(self, custom_run: _symbol._TransformerParamT, name: str = "", once: bool = False) -> _symbol.Symbol: + with N(name): + if once: + return custom_run(self.symbol) + return _symbol.transform(self.symbol, custom_run) + # mrt op visits with params, variables +@dataclass class InferPass(SimplePass): params: ParametersT - def is_param(self, symbol: _symbol.Symbol) -> bool: - return symbol.op_name == opns.VAR and symbol.name in self.params - - def get_param(self, symbol: _symbol.Symbol) -> OpNumpyT: - assert self.is_param(symbol) - return self.params[symbol.name] if self.is_param(symbol) else [] - - def __init__(self, symbol: _symbol.Symbol, params: ParametersT): - self.symbol = symbol - self.params = params + def is_input(self, op_: _symbol.Symbol) -> bool: + return op.is_input(op_, self.params) + def is_variable(self, op_: _symbol.Symbol) -> bool: + return op.is_variable(op_, self.params) + def is_operator(self, op_: _symbol.Symbol) -> bool: + return op.is_operator(op_, self.params) + def is_param(self, op_: _symbol.Symbol) -> bool: + return op_.op_name == opns.VAR and op_.name in self.params + + def get_param(self, op_: _symbol.Symbol) -> OpNumpyT: + return self.params[op_.name] if self.is_param(op_) else [] + def get_as_numpy(self, op_: _symbol.Symbol) -> OpNumpyT: + assert self.is_param(op_), f"{op_.name} is not parameter." + data = self.params[op_.name] + assert isinstance(data, (tuple, list, np.ndarray)), \ + f"param:{op_.name} not OpNumpyT, get {type(data)}" + return data + + """custom visit of graph + calling custom_func for all op_name + return: head symbol processed + """ + def custom_visits_with_params(self, custom_run: _symbol._TransformerParamT, name: str = "", once: bool = False) -> _symbol.Symbol: + with N(name): + if once: + return custom_run(self.symbol, self.params) + return _symbol.transform(self.symbol, custom_run, params=self.params) + + # From original quantization.Transformer + def as_parameter(self, data: OpNumpyT, name:str, dtype): + def _f(data, dtype): + if isinstance(data, list): + assert len(data) == len(dtype) + return [_f(d, t) for d, t in zip(data, dtype)] + assert isinstance(data, np.ndarray), type(data) + return data.astype(dtype) + array = _f(data, dtype) + shape = np.array(array).shape + self.params[name] = array + return opclass.var(array, shape=shape, dtype=dtype) + + def from_np_data(self, data: np.ndarray, dtype, prefix=None) -> _symbol.Symbol: + name = N.n(prefix=prefix) + # some data is np.float/int type, use np.array to wrap it. + data = np.array(data) + self.params[name] = data.astype(dtype) + return opclass.var(name, shape=data.shape, dtype=dtype)#.like(self) + + def from_const_data(self, data: typing.Union[int, float], dtype) -> _symbol.Symbol: + return self.from_np_data(data, dtype) # Register MRT all op's default_visit_op function @@ -71,42 +124,83 @@ def visit_nn_dropout(self, sym: _symbol.Symbol) -> _symbol.Symbol: class FuseTupleGetItemPass(SimplePass): - def visit_TupleGetItem(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.TUPLE_GET_ITEM: - return sym - sym_ : opclass.TupleGetItem = sym - assert sym_.index == 0 - return sym_.args[0] - return sym - - -class FuseBatchNormPass(InferPass): - def visit_nn_batch_norm(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.BATCH_NORM: - X, Gamma, Beta, Mean, Var = sym.args - Gamma = self.get_param(Gamma) - Beta = self.get_param(Beta) - Mean = self.get_param(Mean) - Var = self.get_param(Var) - return sym + def visit_TupleGetItem(self, sym: opclass.TupleGetItem) -> _symbol.Symbol: + #if sym.op_name == opns.TUPLE_GET_ITEM: + # assert sym.index == 0 + # return sym.args[0] return sym -class FuseSoftmaxPass(SimplePass): +class FuseNaiveSoftmaxPass(SimplePass): def visit_nn_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: if sym.op_name == opns.SOFTMAX: - return self.args[0] + return sym.args[0] return sym def visit_nn_log_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: if sym.op_name == opns.LOG_SOFTMAX: - return self.args[0] + return sym.args[0] return sym -class FuseMeanPass(SimplePass): +class FuseMeanPass(InferPass): def visit_mean(self, sym: _symbol.Symbol) -> _symbol.Symbol: if sym.op_name == opns.MEAN: + X = sym.args[0] + out = opclass.Sum(X, **sym.attrs) + scale = self.from_np_data(np.array( + 1. * product(out.shape) / product(X.shape)), dtype=out.dtype) + out = opclass.mul(out, scale) + return out + return sym + + +class FuseConstantPass(InferPass): + threshold: typing.ClassVar[float] = 1e-5 + + def np_is_zero(self, data) -> float: + return np.abs(data).max() < self.threshold + + + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol:#: _symbol._TransformerParamT + if self.is_operator(sym) and all([self.is_param(arg) for arg in sym.args]): + data = inference.run_single_params( + sym, [self.get_as_numpy(a) for a in sym.args]) + return self.as_parameter(data, name=sym.name, dtype=sym.dtype) + elif sym.is_op(opns.ADD, opns.SUB): # , BIAS_ADD): + strips = [] + for arg in sym.args: + if self.is_param(arg) and self.np_is_zero(self.get_as_numpy(arg)): + strips.append(arg) + args = [a for a in sym.args if a not in strips] + if len(args) == 1: + return args[0] + elif sym.is_op(opns.SLICE_LIKE): + if not self.is_param(sym.args[0]): + return None + a, b = sym.args + data = inference.run_single_params( + sym, [self.get_as_numpy(a), np.zeros(b.shape, b.dtype)]) + return self.as_parameter(data, name=sym.name, dtype=sym.dtype) + elif sym.is_op(opns.REQUANT): + if sym.rescale == 1: + return sym.args[0] + elif sym.is_op(opns.ZEROS_LIKE, opns.ONES_LIKE): + data = inference.run_single_params(sym, []) + return self.as_parameter(data, name=sym.name, dtype=sym.dtype) + return sym + return custom_run + + +class FuseBatchNormPass(InferPass): + def visit_nn_batch_norm(self, sym: _symbol.Symbol) -> _symbol.Symbol: + if sym.op_name == opns.BATCH_NORM: + X, Gamma, Beta, Mean, Var = sym.args + Gamma = self.get_param(Gamma) + Beta = self.get_param(Beta) + Mean = self.get_param(Mean) + Var = self.get_param(Var) return sym return sym @@ -117,8 +211,7 @@ def visit_divide(self, sym: _symbol.Symbol) -> _symbol.Symbol: argA = sym.args[0] argB = sym.args[1] assert self.is_param(argB), f'NotParam: {argB}' - # TODO: fixit - #argB = argB.from_np_data(1. / argB.numpy()) + argB = self.from_np_data(1. / self.get_as_numpy(argB), dtype=argB.dtype) return opclass.MRT_OP_MAP[opns.MUL](argA, argB) return sym diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index c1e6bd7..315b70d 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -322,6 +322,7 @@ def load_json(data: _SymbolJsonT, **extra_attrs) -> Symbol: _VisitorT = typing.Callable[[Symbol], None] _TransformerT = typing.Callable[[Symbol], typing.Optional[Symbol]] +_TransformerParamT = typing.Callable[[Symbol, typing.Optional[ParametersT]], Symbol] """ Symbol Transformer Return new symbol to transform old symbol into updated one, @@ -338,7 +339,7 @@ def visit(symbol: Symbol, callback: _VisitorT): if callback.__name__ in C.log_vot_cbs: config.log(callback.__name__, f">> {sym}") -def transform(symbol: Symbol, callback: _TransformerT) -> Symbol: +def transform(symbol: Symbol, callback: _TransformerParamT, params:typing.Optional[ParametersT] = None) -> Symbol: """ Transform symbol from old to new, with inputs updated. Only the return value indicates mutation, while changing @@ -355,7 +356,7 @@ def transform(symbol: Symbol, callback: _TransformerT) -> Symbol: if callback.__name__ in C.log_vot_cbs: config.log(callback.__name__, f"<< {sym}") - out = callback(sym) or sym + out = (callback(sym, params) if params else callback(sym)) or sym assert isinstance(out, Symbol), out # default const_ prefix symbol means parameters assert sym.name not in sym_map, sym.name diff --git a/tests/mir/test.simple_pass.py b/tests/mir/test.simple_pass.py index 9d20ee0..33139d4 100644 --- a/tests/mir/test.simple_pass.py +++ b/tests/mir/test.simple_pass.py @@ -56,7 +56,7 @@ def test_SimplePass_FuseDropout(mrt_graph, mrt_params): # init FuseDropout Passer and execute visit tfs : simple_pass.FuseDropoutPass = simple_pass.FuseDropoutPass(symbol) #print(getattr(tfs, f"visit_{opns.Opname2Funcname(opns.DROP_OUT)}")) - symbol_passed = tfs.visit() + symbol_passed = tfs.graph_visits() print('\n=== After FuseDropout Pass ===') rlts = sx.sym2list(symbol_passed) @@ -83,12 +83,12 @@ def test_SimplePass_CustomFunc(mrt_graph): tfs : simple_pass.SimplePass = simple_pass.SimplePass(symbol) conv2d_name_list = [] - def _filter_op(sym: sx.Symbol) -> sx.Symbol: + def _filter_op(sym: sx.Symbol, params=None) -> sx.Symbol: if sym.op_name == opns.CONV2D: conv2d_name_list.append(sym.name) return sym - symbol_passed = tfs.visit(_filter_op) + symbol_passed = tfs.custom_visits(_filter_op) print('\n=== After CustomFunc Pass ===') assert len(conv2d_name_list) > 0 @@ -113,7 +113,7 @@ def _nn_dropout(sym: sx.Symbol) -> sx.Symbol: if sym.op_name == opns.DROP_OUT: return sym.args[0] return sym - symbol_passed = tfs.visit(_nn_dropout) + symbol_passed = tfs.custom_visits(_nn_dropout) print('\n=== After FuseDropout CustomFunc Pass ===') rlts = sx.sym2list(symbol_passed) From 4cde5903ed541c8d689d9b1a5bb0ff2a9ef3b137 Mon Sep 17 00:00:00 2001 From: corlfj Date: Mon, 29 Sep 2025 17:46:02 +0800 Subject: [PATCH 06/12] [mir]: opclass redefine names, op compulsory attrs check --- python/mrt/mir/opclass.py | 237 +++++++++++++++++++++++++++--------- tests/mir/test.op_create.py | 54 +++++++- 2 files changed, 232 insertions(+), 59 deletions(-) diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py index 2e4cdf9..e3b40f0 100644 --- a/python/mrt/mir/opclass.py +++ b/python/mrt/mir/opclass.py @@ -9,7 +9,7 @@ #SelfSymbol = typing.TypeVar("SelfSymbol", bound="Symbol") -SymbolCreator = typing.Union[typing.Callable[[typing.Any, typing.Any], typing.Type[symbol.Symbol]], SelfSymbol] +SymbolCreator = typing.Union[typing.Callable[[typing.Any, ...], typing.Type[symbol.Symbol]], SelfSymbol] #SymbolCreator = typing.Union[typing.Callable[[...], symbol.Symbol], SelfSymbol] MRT_OP_MAP: typing.Dict[str, SymbolCreator] = {} @@ -29,8 +29,9 @@ def _wrapper(clss: SymbolCreator = None) -> SymbolCreator: # OPs from external (not in MRT op), using custom op_name with default op_func #y = extern_opfunc("tanh")(X) def extern_opfunc(op_name: str): - def op_func(*args, **attrs): - return symbol.Symbol(*args, op_name=op_name, **attrs) + def op_func(name, args, attrs, extra_attrs): + #return symbol.Symbol(op_name=op_name, *args, **attrs) + return symbol.Symbol(name, op_name, args, attrs, extra_attrs) return op_func @@ -91,19 +92,26 @@ def dilation(self) -> typing.Tuple[int, int]: @property def kernel_size(self) -> typing.Tuple[int, int]: - default_val = (3,3) - return self.attrs['kernel_size'] if 'kernel_size' in self.attrs else default_val + assert 'kernel_size' in self.attrs + return self.attrs['kernel_size'] # Follows (*args, name, **attrs) - def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_size=(3,3), extra_attrs=None): + def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), extra_attrs=None): op_name = op_name or opns.CONV2D assert op_name == opns.CONV2D + assert len(W.shape) == 4, f'Wrong Weight Shape for Conv2D: {W.shape}' + kernel_size = (W.shape[2], W.shape[3]) super().__init__(name=name or N.n(), op_name=op_name, args=[X,W], attrs={'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): - return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation', 'kernel_size'], **kwargs) + # Auto inferred 'kernel_size' + return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation'], **kwargs) + +def conv2d(X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), extra_attrs=None): + return Conv2D(X, W, name, op_name, strides, padding, groups, dilation, extra_attrs) + @dataclass(init=False) class Dropout(symbol.Symbol): @@ -123,29 +131,38 @@ def __init__(self, X, name=None, op_name=None, p:float = 0.5, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['p'], **kwargs) +def dropout(X, name=None, op_name=None, p:float = 0.5, extra_attrs=None): + return Dropout(X, name, op_name, p, extra_attrs) + + @dataclass(init=False) class Clip(symbol.Symbol): op_name = opns.CLIP @property def min(self) -> float: - default_val = np.nan - return self.attrs['min'] if 'min' in self.attrs else default_val + assert 'min' in self.attrs + return self.attrs['min'] @property def max(self) -> float: - default_val = np.nan - return self.attrs['max'] if 'max' in self.attrs else default_val + assert 'max' in self.attrs + return self.attrs['max'] def __init__(self, X, name=None, op_name=None, min_:float = np.nan, max_:float = np.nan, extra_attrs=None): op_name = op_name or opns.CLIP assert op_name == opns.CLIP + assert min_ != np.nan + assert max_ != np.nan super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'min': min_, 'max': max_}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['min', 'max'], **kwargs) +def clip(X, name=None, op_name=None, min_:float = np.nan, max_:float = np.nan, extra_attrs=None): + return Clip(X, name, op_name, min_, max_, extra_attrs) + @dataclass(init=False) class BatchNorm(symbol.Symbol): @@ -175,6 +192,9 @@ def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['axis', 'epsilon', 'momentum'], **kwargs) +def batch_norm(X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, extra_attrs=None): + return BatchNorm(X, Gamma, Beta, Mean, Var, name, op_name, axis, epsilon, momentum, extra_attrs) + @dataclass(init=False) class TupleGetItem(symbol.Symbol): @@ -194,6 +214,9 @@ def __init__(self, X, name=None, op_name=None, index:int = 0, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['index'], **kwargs) +def tuple_get_item(X, name=None, op_name=None, index:int = 0, extra_attrs=None): + return TupleGetItem(X, name, op_name, index, extra_attrs) + @dataclass(init=False) class LeakyRelu(symbol.Symbol): @@ -213,6 +236,9 @@ def __init__(self, X, name=None, op_name=None, negative_slope:float = 1e-2, extr def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['negative_slope'], **kwargs) +def leaky_relu(X, name=None, op_name=None, negative_slope:float = 1e-2, extra_attrs=None): + return LeakyRelu(X, name, op_name, negative_slope, extra_attrs) + def dense(X, W, B, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.DENSE @@ -242,6 +268,8 @@ def __init__(self, X, name=None, op_name=None, min_val:float = -1.0, max_val:flo def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['min_val', 'max_val'], **kwargs) +def hard_tanh(X, name=None, op_name=None, min_val:float = -1.0, max_val:float = 1.0, extra_attrs=None): + return Hardtanh(X, name, op_name, min_val, max_val, extra_attrs) @dataclass(init=False) class AdaptiveAvgPool2D(symbol.Symbol): @@ -249,33 +277,41 @@ class AdaptiveAvgPool2D(symbol.Symbol): @property def output_size(self) -> typing.Union[int, typing.Tuple[int, int]]: - default_val = 0 - return self.attrs['output_size'] if 'output_size' in self.attrs else default_val + assert 'output_size' in self.attrs + return self.attrs['output_size'] - def __init__(self, X, name=None, op_name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=0, extra_attrs=None): + def __init__(self, X, name=None, op_name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=None, extra_attrs=None): op_name = op_name or opns.ADAPTIVE_AVG_POOL2D assert op_name == opns.ADAPTIVE_AVG_POOL2D + assert output_size != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'output_size': output_size}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['output_size'], **kwargs) +def adaptive_avg_pool2d(X, name=None, op_name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=0, extra_attrs=None): + return AdaptiveAvgPool2D(X, name, op_name, output_size, extra_attrs) + @dataclass(init=False) class AvgPool2D(symbol.Symbol): op_name = opns.AVG_POOL2D @property def pool_size(self) -> typing.Tuple[int, int]: - default_val = (2, 2) - return self.attrs['pool_size'] if 'pool_size' in self.attrs else default_val + assert 'pool_size' in self.attrs + return self.attrs['pool_size'] @property - def strides(self): - default_val = None + def strides(self) -> typing.Tuple[int, int]: + default_val = (0, 0) return self.attrs['strides'] if 'strides' in self.attrs else default_val @property - def padding(self) -> int: - default_val = 0 + def dilation(self) -> typing.Tuple[int, int]: + default_val = (1, 1) + return self.attrs['dilation'] if 'dilation' in self.attrs else default_val + @property + def padding(self) -> typing.Tuple[int, int, int, int]: + default_val = (0, 0, 0, 0) return self.attrs['padding'] if 'padding' in self.attrs else default_val @property def ceil_mode(self) -> bool: @@ -290,14 +326,19 @@ def count_include_pad(self) -> bool: default_val = True return self.attrs['count_include_pad'] if 'count_include_pad' in self.attrs else default_val - def __init__(self, X, name=None, op_name=None, pool_size=(2,2), strides=None, padding=0, ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): + def __init__(self, X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): op_name = op_name or opns.AVG_POOL2D assert op_name == opns.AVG_POOL2D - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout, 'count_include_pad':count_include_pad}, extra_attrs=extra_attrs or {}) + assert pool_size != None + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'dilation':dilation, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout, 'count_include_pad':count_include_pad}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): - return _from_dict_attrs(cls, d, ['pool_size', 'strides', 'padding', 'ceil_mode', 'layout', 'count_include_pad'], **kwargs) + return _from_dict_attrs(cls, d, ['pool_size', 'dilation', 'strides', 'padding', 'ceil_mode', 'layout', 'count_include_pad'], **kwargs) + +def avg_pool2d(X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): + return AvgPool2D(X, name, op_name, pool_size, dilation, strides, padding, ceil_mode, layout, count_include_pad, extra_attrs) + @dataclass(init=False) class MaxPool2D(symbol.Symbol): @@ -305,21 +346,41 @@ class MaxPool2D(symbol.Symbol): @property def pool_size(self) -> typing.Tuple[int, int]: - default_val = (2, 2) - return self.attrs['pool_size'] if 'pool_size' in self.attrs else default_val + assert 'pool_size' in self.attrs + return self.attrs['pool_size'] + @property + def strides(self) -> typing.Tuple[int, int]: + default_val = (0, 0) + return self.attrs['strides'] if 'strides' in self.attrs else default_val + @property + def dilation(self) -> typing.Tuple[int, int]: + default_val = (1, 1) + return self.attrs['dilation'] if 'dilation' in self.attrs else default_val + @property + def padding(self) -> typing.Tuple[int, int, int, int]: + default_val = (0, 0, 0, 0) + return self.attrs['padding'] if 'padding' in self.attrs else default_val + @property + def ceil_mode(self) -> bool: + default_val = False + return self.attrs['ceil_mode'] if 'ceil_mode' in self.attrs else default_val @property def layout(self) -> str: default_val = 'NCHW' return self.attrs['layout'] if 'layout' in self.attrs else default_val - def __init__(self, X, name=None, op_name=None, pool_size=(2,2), layout='NCHW', extra_attrs=None): + def __init__(self, X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', extra_attrs=None): op_name = op_name or opns.MAX_POOL2D assert op_name == opns.MAX_POOL2D - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'layout':layout}, extra_attrs=extra_attrs or {}) + assert pool_size != None + super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'dilation':dilation, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): - return _from_dict_attrs(cls, d, ['pool_size', 'layout'], **kwargs) + return _from_dict_attrs(cls, d, ['pool_size', 'dilation', 'strides', 'padding', 'ceil_mode', 'layout'], **kwargs) + +def max_pool2d(X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', extra_attrs=None): + return MaxPool2D(X, name, op_name, pool_size, dilation, strides, padding, ceil_mode, layout, extra_attrs) @dataclass(init=False) @@ -340,6 +401,8 @@ def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['axis'], **kwargs) +def softmax(X, name=None, op_name=None, axis=None, extra_attrs=None): + return Softmax(X, name, op_name, axis, extra_attrs) @dataclass(init=False) class LogSoftmax(symbol.Symbol): @@ -359,6 +422,9 @@ def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['axis'], **kwargs) +def log_softmax(X, name=None, op_name=None, axis=None, extra_attrs=None): + return LogSoftmax(X, name, op_name, axis, extra_attrs) + def exp(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.EXP @@ -393,6 +459,9 @@ def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_att def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) +def sum(X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + return Sum(X, name, op_name, dim, keepdim, extra_attrs) + @dataclass(init=False) class Mean(symbol.Symbol): @@ -417,6 +486,9 @@ def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_att def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) +def mean(X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + return Mean(X, name, op_name, dim, keepdim, extra_attrs) + @dataclass(init=False) class MaxAxis(symbol.Symbol): @@ -441,6 +513,10 @@ def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_att def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) +def max_axis(X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): + return MaxAxis(X, name, op_name, dim, keepdim, extra_attrs) + + def maximum(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.MAXIMUM assert op_name == opns.MAXIMUM @@ -474,6 +550,9 @@ def __init__(self, X, name=None, op_name=None, dim=None, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim'], **kwargs) +def squeeze(X, name=None, op_name=None, dim=None, extra_attrs=None): + return Squeeze(X, name, op_name, dim, extra_attrs) + @dataclass(init=False) class Flatten(symbol.Symbol): op_name = opns.FLATTEN @@ -497,6 +576,9 @@ def __init__(self, X, name=None, op_name=None, start_dim=0, end_dim=-1, extra_at def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['start_dim', 'end_dim'], **kwargs) +def flatten(X, name=None, op_name=None, start_dim=0, end_dim=-1, extra_attrs=None): + return Flatten(X, name, op_name, start_dim, end_dim, extra_attrs) + @dataclass(init=False) class Reshape(symbol.Symbol): @@ -504,18 +586,21 @@ class Reshape(symbol.Symbol): @property def newshape(self) -> typing.Tuple[int,...]: - default_val = None - return self.attrs['newshape'] if 'newshape' in self.attrs else default_val + assert 'newshape' in self.attrs + return self.attrs['newshape'] def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): op_name = op_name or opns.RESHAPE assert op_name == opns.RESHAPE + assert newshape != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['newshape'], **kwargs) +def reshape(X, name=None, op_name=None, newshape=None, extra_attrs=None): + return Reshape(X, name, op_name, newshape, extra_attrs) @dataclass(init=False) class Concat(symbol.Symbol): @@ -535,6 +620,8 @@ def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['axis'], **kwargs) +def concat(X, name=None, op_name=None, axis=None, extra_attrs=None): + return Concat(X, name, op_name, axis, extra_attrs) @dataclass(init=False) class Split(symbol.Symbol): @@ -542,23 +629,27 @@ class Split(symbol.Symbol): @property def split_size(self) -> typing.List[int]: - default_val = [] - return self.attrs['split_size'] if 'split_size' in self.attrs else default_val + assert 'split_size' in self.attrs + return self.attrs['split_size'] @property def dim(self) -> int: default_val = 0 return self.attrs['dim'] if 'dim' in self.attrs else default_val - def __init__(self, X, name=None, op_name=None, split_size=[], dim=0, extra_attrs=None): + def __init__(self, X, name=None, op_name=None, split_size=None, dim=0, extra_attrs=None): op_name = op_name or opns.SPLIT assert op_name == opns.SPLIT + assert split_size != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'split_size': split_size, 'dim': dim}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['split_size', 'dim'], **kwargs) +def split(X, name=None, op_name=None, split_size=[], dim=0, extra_attrs=None): + return Split(X, name, op_name, split_size, dim, extra_attrs) + @dataclass(init=False) class Transpose(symbol.Symbol): @@ -566,23 +657,28 @@ class Transpose(symbol.Symbol): @property def dim0(self) -> int: - default_val = 0 - return self.attrs['dim0'] if 'dim0' in self.attrs else default_val + assert 'dim0' in self.attrs + return self.attrs['dim0'] @property def dim1(self) -> int: - default_val = 0 - return self.attrs['dim1'] if 'dim1' in self.attrs else default_val + assert 'dim1' in self.attrs + return self.attrs['dim1'] - def __init__(self, X, name=None, op_name=None, dim0=0, dim1=0, extra_attrs=None): + def __init__(self, X, name=None, op_name=None, dim0=None, dim1=None, extra_attrs=None): op_name = op_name or opns.TRANSPOSE assert op_name == opns.TRANSPOSE + assert dim0 != None + assert dim1 != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim0': dim0, 'dim1': dim1}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim0', 'dim1'], **kwargs) +def transpose(X, name=None, op_name=None, dim0=None, dim1=None, extra_attrs=None): + return Transpose(X, name, op_name, dim0, dim1, extra_attrs) + @dataclass(init=False) class BroadcastTo(symbol.Symbol): @@ -590,18 +686,21 @@ class BroadcastTo(symbol.Symbol): @property def newshape(self) -> typing.Tuple[int,...]: - default_val = None - return self.attrs['newshape'] if 'newshape' in self.attrs else default_val + assert 'newshape' in self.attrs + return self.attrs['newshape'] def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): op_name = op_name or opns.BROADCAST_TO assert op_name == opns.BROADCAST_TO + assert newshape != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['newshape'], **kwargs) +def broadcast_to(X, name=None, op_name=None, newshape=None, extra_attrs=None): + return BroadcastTo(X, name, op_name, newshape, extra_attrs) @dataclass(init=False) class ExpandDims(symbol.Symbol): @@ -609,18 +708,21 @@ class ExpandDims(symbol.Symbol): @property def newshape(self) -> typing.Tuple[int,...]: - default_val = None - return self.attrs['newshape'] if 'newshape' in self.attrs else default_val + assert 'newshape' in self.attrs + return self.attrs['newshape'] def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): op_name = op_name or opns.EXPAND_DIMS assert op_name == opns.EXPAND_DIMS + assert newshape != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['newshape'], **kwargs) +def expand_dims(X, name=None, op_name=None, newshape=None, extra_attrs=None): + return ExpandDims(X, name, op_name, newshape, extra_attrs) @dataclass(init=False) class Tile(symbol.Symbol): @@ -628,18 +730,23 @@ class Tile(symbol.Symbol): @property def dims(self) -> typing.Tuple[int,...]: - default_val = None - return self.attrs['dims'] if 'dims' in self.attrs else default_val + assert 'dims' in self.attrs + return self.attrs['dims'] def __init__(self, X, name=None, op_name=None, dims=None, extra_attrs=None): op_name = op_name or opns.TILE assert op_name == opns.TILE + assert dims != None super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dims': dims}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dims'], **kwargs) +def tile(X, name=None, op_name=None, dims=None, extra_attrs=None): + return Tile(X, name, op_name, dims, extra_attrs) + + def where(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.WHERE assert op_name == opns.WHERE @@ -672,13 +779,16 @@ def __init__(self, X, name=None, op_name=None, iou_threshold=0.5, score_threshol def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dims'], **kwargs) +def non_max_suppression(X, name=None, op_name=None, iou_threshold=0.5, score_threshold=None, extra_attrs=None): + return NonMaxSuppression(X, name, op_name, iou_threshold, score_threshold, extra_attrs) + def ceil(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.CEIL assert op_name == opns.CEIL return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) -def rightShift(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: +def right_shift(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.RIGHT_SHIFT assert op_name == opns.RIGHT_SHIFT return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) @@ -701,6 +811,9 @@ def __init__(self, X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['alpha'], **kwargs) +def add(X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): + return Add(X, Y, name, op_name, alpha, extra_attrs) + @dataclass(init=False) class Sub(symbol.Symbol): op_name = opns.SUB @@ -719,12 +832,16 @@ def __init__(self, X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['alpha'], **kwargs) +def sub(X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): + return Sub(X, Y, name, op_name, alpha, extra_attrs) + + def mul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.MUL assert op_name == opns.MUL return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) -def matMul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: +def mat_mul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.MATMUL assert op_name == opns.MATMUL return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) @@ -747,6 +864,10 @@ def __init__(self, X, Y, name=None, op_name=None, rounding_mode=None, extra_attr def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['rounding_mode'], **kwargs) +def div(X, Y, name=None, op_name=None, rounding_mode=None, extra_attrs=None): + return Div(X, Y, name, op_name, rounding_mode, extra_attrs) + + def negative(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.NEGATIVE assert op_name == opns.NEGATIVE @@ -783,8 +904,8 @@ class Arange(symbol.Symbol): @property def end(self) -> int: - default_val = 0 - return self.attrs['end'] if 'end' in self.attrs else default_val + assert 'end' in self.attrs + return self.attrs['end'] @property def start(self) -> int: @@ -796,21 +917,26 @@ def step(self) -> int: default_val = 1 return self.attrs['step'] if 'step' in self.attrs else default_val - def __init__(self, name=None, op_name=None, end=0, start=0, step=1, extra_attrs=None): + def __init__(self, name=None, op_name=None, end=None, start=0, step=1, extra_attrs=None): op_name = op_name or opns.ARANGE assert op_name == opns.ARANGE + assert end != None super().__init__(name=name or N.n(), op_name=op_name, args=[], attrs={'end': end, 'start': start, 'step': step}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['end', 'start', 'step'], **kwargs) -def zerosLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: +def arange(name=None, op_name=None, end=None, start=0, step=1, extra_attrs=None): + return Arange(name, op_name, end, start, step, extra_attrs) + + +def zeros_like(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.ZEROS_LIKE assert op_name == opns.ZEROS_LIKE return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) -def onesLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: +def ones_like(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: op_name = op_name or opns.ONES_LIKE assert op_name == opns.ONES_LIKE return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) @@ -860,11 +986,11 @@ def onesLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: _register_op_map(opns.NON_MAX_SUPRESSION)(NonMaxSuppression) _register_op_map(opns.CEIL)(ceil) -_register_op_map(opns.RIGHT_SHIFT)(rightShift) +_register_op_map(opns.RIGHT_SHIFT)(right_shift) _register_op_map(opns.ADD)(Add) _register_op_map(opns.SUB)(Sub) -_register_op_map(opns.MATMUL)(matMul) +_register_op_map(opns.MATMUL)(mat_mul) _register_op_map(opns.DIV)(Div) _register_op_map(opns.NEGATIVE)(negative) _register_op_map(opns.ABS)(abs) @@ -873,8 +999,8 @@ def onesLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: _register_op_map(opns.POW)(pow) _register_op_map(opns.PASS)(pass_) _register_op_map(opns.ARANGE)(Arange) -_register_op_map(opns.ZEROS_LIKE)(zerosLike) -_register_op_map(opns.ONES_LIKE)(onesLike) +_register_op_map(opns.ZEROS_LIKE)(zeros_like) +_register_op_map(opns.ONES_LIKE)(ones_like) # Add default register Class for MRT OP Not Implemented! @@ -895,4 +1021,3 @@ def onesLike(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: _register_op_map(opns.STRIDED_SLICE)(symbol.Symbol) _register_op_map(opns.SLICE_LIKE)(symbol.Symbol) _register_op_map(opns.GET_VALID_COUNT)(symbol.Symbol) - diff --git a/tests/mir/test.op_create.py b/tests/mir/test.op_create.py index 2e1136e..7707afb 100644 --- a/tests/mir/test.op_create.py +++ b/tests/mir/test.op_create.py @@ -21,6 +21,20 @@ from mrt.mir import opclass +def test_op_func(): + X = opclass.var(name="var2", shape=(16, 128, 128), dtype="float") + ceil0 = opclass.ceil(X) + assert isinstance(ceil0, sx.Symbol), 'ceil0 isnot a symbol' + assert ceil0.op_name == opns.CEIL + assert len(ceil0.name) > 0 + + ceil1 = opclass.ceil(X, 'ceil_1') + assert ceil1.op_name == opns.CEIL + assert ceil1.name == 'ceil_1' + + return True + + def test_create_conv2d_op(): X = opclass.var(name="x", shape=(1, 3, 224, 224,), dtype="float") @@ -68,6 +82,10 @@ def test_create_conv2d_op(): assert isinstance(conv2d_d, opclass.Conv2D), 'conv2d_d isnot a Conv2D' assert isinstance(conv2d_e, opclass.Conv2D), 'conv2d_e isnot a Conv2D' + # alias function Init + conv2d_f = opclass.conv2d(*args, **attrs) + assert isinstance(conv2d_f, opclass.Conv2D), 'conv2d_f isnot a Conv2D' + return True @@ -104,7 +122,7 @@ def test_create_batch_norm_op(): # test clone mode batch_norm_b = batch_norm_a.copy() - assert isinstance(batch_norm_b , opclass.BatchNorm) + assert isinstance(batch_norm_b, opclass.BatchNorm) assert batch_norm_a.attrs == batch_norm_b.attrs, f'a: {batch_norm_a.attrs} != b: {batch_norm_b.attrs}' assert len(batch_norm_a.args) == len(batch_norm_b.args), f'a: {len(batch_norm_a.args)} != b: {len(batch_norm_b.args)}' @@ -112,6 +130,32 @@ def test_create_batch_norm_op(): return True +def test_create_reshape_op(): + X = opclass.var(name="x", shape=(16, 32, 64, 64,), dtype="float") + try: + reshape0 = opclass.Reshape(X, name="reshape_0") + assert False, "Reshape Must have attr 'newshape', Should already Fail!" + except: + pass + + reshape1 = opclass.Reshape(X, name="reshape_1", newshape=(16, 8, 128, 128)) + assert isinstance(reshape1, opclass.Reshape) + + return True + + +def test_op_extern_func(): + + # extern_func Do not need to fill 'op_name' + args = [opclass.var(name="var2", shape=(16, 128, 128), dtype="float")] + attrs = {} + extra_attrs = {} + call_dps_packed = opclass.MRT_OP_MAP[opns.CALL_DPS_PACKED]('packed_0', args, attrs, extra_attrs) + assert isinstance(call_dps_packed, sx.Symbol), 'call_dps_packed isnot a symbol' + assert call_dps_packed.op_name == opns.CALL_DPS_PACKED + return True + + if __name__ == "__main__": print('MRT_OP_SET as:', opclass.MRT_OP_MAP.keys()) assert len(opclass.MRT_OP_MAP.keys()) > 0 @@ -120,10 +164,14 @@ def test_create_batch_norm_op(): print('MRT_OP_MAP Conv2D Class as:', opclass.MRT_OP_MAP[opns.CONV2D]) test_id = 0 - for func_ in [test_create_conv2d_op, test_create_symbol_graph, test_create_batch_norm_op]: + passed_cnt = 0 + test_funcs = [test_op_func, test_create_conv2d_op, test_create_symbol_graph, test_create_batch_norm_op, test_create_reshape_op, test_op_extern_func] + for func_ in test_funcs: rltflag = func_() test_id += 1 + passed_cnt += rltflag print("\n" + "="*60 + "\n") - print(f'Passed Test{test_id}!' if rltflag else f'Test{test_id} Failed!') + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') From 7fcde4e03b7f53cfdd5ebdd714aaf6b4f7279ca2 Mon Sep 17 00:00:00 2001 From: corlfj Date: Tue, 30 Sep 2025 14:57:07 +0800 Subject: [PATCH 07/12] [mir]: testing infer_pass, fix some opclass issue --- python/mrt/mir/opclass.py | 34 ++++-- python/mrt/mir/simple_pass.py | 192 +++++++++++++++++++++++++----- python/mrt/mir/symbol.py | 5 +- tests/mir/test.infer_pass.py | 103 ++++++++++++++++ tests/mir/test.infer_pass_div.py | 88 ++++++++++++++ tests/mir/test.infer_pass_mean.py | 89 ++++++++++++++ 6 files changed, 468 insertions(+), 43 deletions(-) create mode 100644 tests/mir/test.infer_pass.py create mode 100644 tests/mir/test.infer_pass_div.py create mode 100644 tests/mir/test.infer_pass_mean.py diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py index e3b40f0..02fb929 100644 --- a/python/mrt/mir/opclass.py +++ b/python/mrt/mir/opclass.py @@ -95,22 +95,26 @@ def kernel_size(self) -> typing.Tuple[int, int]: assert 'kernel_size' in self.attrs return self.attrs['kernel_size'] + @property + def kernel_layout(self) -> str: + default_val = 'OIHW' + return self.attrs['kernel_layout'] if 'kernel_layout' in self.attrs else default_val # Follows (*args, name, **attrs) - def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), extra_attrs=None): + def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_layout='OIHW', extra_attrs=None): op_name = op_name or opns.CONV2D assert op_name == opns.CONV2D assert len(W.shape) == 4, f'Wrong Weight Shape for Conv2D: {W.shape}' kernel_size = (W.shape[2], W.shape[3]) - super().__init__(name=name or N.n(), op_name=op_name, args=[X,W], attrs={'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size}, extra_attrs=extra_attrs or {}) + super().__init__(name=name or N.n(), op_name=op_name, args=[X,W], attrs={'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size, 'kernel_layout': kernel_layout}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): # Auto inferred 'kernel_size' - return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation'], **kwargs) + return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation', 'kernel_layout'], **kwargs) -def conv2d(X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), extra_attrs=None): - return Conv2D(X, W, name, op_name, strides, padding, groups, dilation, extra_attrs) +def conv2d(X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_layout='OIHW', extra_attrs=None): + return Conv2D(X, W, name, op_name, strides, padding, groups, dilation, kernel_layout, extra_attrs) @dataclass(init=False) @@ -181,19 +185,29 @@ def epsilon(self) -> float: @property def momentum(self) -> float: default_val = 0.1 + return self.attrs['momentum'] if 'momentum' in self.attrs else default_val + + @property + def center(self) -> bool: + default_val = True return self.attrs['center'] if 'center' in self.attrs else default_val - def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, extra_attrs=None): + @property + def scale(self) -> bool: + default_val = True + return self.attrs['scale'] if 'scale' in self.attrs else default_val + + def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, center=True, scale=True, extra_attrs=None): op_name = op_name or opns.BATCH_NORM assert op_name == opns.BATCH_NORM - super().__init__(name=name or N.n(), op_name=op_name, args=[X, Gamma, Beta, Mean, Var], attrs={'axis': axis, 'epsilon': epsilon, 'momentum': momentum}, extra_attrs=extra_attrs or {}) + super().__init__(name=name or N.n(), op_name=op_name, args=[X, Gamma, Beta, Mean, Var], attrs={'axis': axis, 'epsilon': epsilon, 'momentum': momentum, 'center': center, 'scale': scale}, extra_attrs=extra_attrs or {}) @classmethod def from_dict(cls, d: dict, **kwargs): - return _from_dict_attrs(cls, d, ['axis', 'epsilon', 'momentum'], **kwargs) + return _from_dict_attrs(cls, d, ['axis', 'epsilon', 'momentum', 'center', 'scale'], **kwargs) -def batch_norm(X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, extra_attrs=None): - return BatchNorm(X, Gamma, Beta, Mean, Var, name, op_name, axis, epsilon, momentum, extra_attrs) +def batch_norm(X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, center=True, scale=True, extra_attrs=None): + return BatchNorm(X, Gamma, Beta, Mean, Var, name, op_name, axis, epsilon, momentum, center, scale, extra_attrs) @dataclass(init=False) diff --git a/python/mrt/mir/simple_pass.py b/python/mrt/mir/simple_pass.py index a876b95..302da1b 100644 --- a/python/mrt/mir/simple_pass.py +++ b/python/mrt/mir/simple_pass.py @@ -75,6 +75,7 @@ def get_as_numpy(self, op_: _symbol.Symbol) -> OpNumpyT: """custom visit of graph calling custom_func for all op_name + according to how custom_run implemented, params is from argument or class_property return: head symbol processed """ def custom_visits_with_params(self, custom_run: _symbol._TransformerParamT, name: str = "", once: bool = False) -> _symbol.Symbol: @@ -96,15 +97,15 @@ def _f(data, dtype): self.params[name] = array return opclass.var(array, shape=shape, dtype=dtype) - def from_np_data(self, data: np.ndarray, dtype, prefix=None) -> _symbol.Symbol: + def from_np_data(self, sym:_symbol.Symbol, data: np.ndarray, dtype, prefix=None) -> _symbol.Symbol: name = N.n(prefix=prefix) # some data is np.float/int type, use np.array to wrap it. data = np.array(data) self.params[name] = data.astype(dtype) - return opclass.var(name, shape=data.shape, dtype=dtype)#.like(self) + return opclass.var(name, shape=data.shape, dtype=dtype).like(sym) - def from_const_data(self, data: typing.Union[int, float], dtype) -> _symbol.Symbol: - return self.from_np_data(data, dtype) + def from_const_data(self, sym:_symbol.Symbol, data: typing.Union[int, float], dtype) -> _symbol.Symbol: + return self.from_np_data(sym, data, dtype) # Register MRT all op's default_visit_op function @@ -144,15 +145,17 @@ def visit_nn_log_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: class FuseMeanPass(InferPass): - def visit_mean(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.MEAN: - X = sym.args[0] - out = opclass.Sum(X, **sym.attrs) - scale = self.from_np_data(np.array( - 1. * product(out.shape) / product(X.shape)), dtype=out.dtype) - out = opclass.mul(out, scale) - return out - return sym + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.MEAN: + X = sym.args[0] + out = opclass.Sum(X, **sym.attrs).like(sym) + scale = self.from_np_data(sym, np.array( + 1. * product(out.shape) / product(X.shape)), dtype=out.dtype) + out = opclass.mul(out, scale) + return out + return sym + return custom_run class FuseConstantPass(InferPass): @@ -161,9 +164,8 @@ class FuseConstantPass(InferPass): def np_is_zero(self, data) -> float: return np.abs(data).max() < self.threshold - def get_run(self) -> _symbol._TransformerParamT: - def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol:#: _symbol._TransformerParamT + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: if self.is_operator(sym) and all([self.is_param(arg) for arg in sym.args]): data = inference.run_single_params( sym, [self.get_as_numpy(a) for a in sym.args]) @@ -178,7 +180,7 @@ def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) return args[0] elif sym.is_op(opns.SLICE_LIKE): if not self.is_param(sym.args[0]): - return None + return sym a, b = sym.args data = inference.run_single_params( sym, [self.get_as_numpy(a), np.zeros(b.shape, b.dtype)]) @@ -194,24 +196,150 @@ def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) class FuseBatchNormPass(InferPass): - def visit_nn_batch_norm(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.BATCH_NORM: - X, Gamma, Beta, Mean, Var = sym.args - Gamma = self.get_param(Gamma) - Beta = self.get_param(Beta) - Mean = self.get_param(Mean) - Var = self.get_param(Var) + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: opclass.BatchNorm, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.BATCH_NORM: + X, Gamma, Beta, Mean, Var = sym.args + Gamma = self.get_param(Gamma) + Beta = self.get_param(Beta) + Mean = self.get_param(Mean) + Var = self.get_param(Var) + + assert sym.axis == 1 + Beta = Beta if sym.center else 0 + Gamma = Gamma if sym.scale else 1 + + # (x - mean) / sqrt(var + epsilon) * gamma + beta + Gamma = Gamma / np.sqrt(Var + sym.epsilon) + # (x - mean) * gamma + beta + # x * gamma + (beta - mean * gamma) + bias: np.ndarray = (Beta - Mean * Gamma) + K = Gamma.shape[0] + + if X.is_op(opns.CONV2D): + A, W = X.args + assert X.kernel_layout == "OIHW" + assert W.shape[0] == K + # (A * W) * gamma + bias + # A * (W * gamma) + bias + W_data = self.get_as_numpy(W) * Gamma.reshape(K, 1, 1, 1) + W_sym = self.from_np_data(W, W_data, W.dtype) + out = op.nn_conv2d(A, W_sym, **X.attrs) + elif X.is_op(opns.DENSE): + A, W = X.args + # (A * W) * gamma + bias + # A * (W * gamma) + bias + W_data = self.get_as_numpy(W) * Gamma.reshape(K, 1) + W_sym = self.from_np_data(W, W_data, W.dtype) + out = op.nn_dense(A, W_sym, **X.attrs) + else: + reshp = [s if i == sym.axis else 1 \ + for i, s in enumerate(X.shape)] + W = self.from_np_data(X, Gamma.reshape(reshp), X.dtype) + out = opclass.mul(X, W) + + bias = bias.reshape([s if i == sym.axis else 1 \ + for i, s in enumerate(out.shape)]) + B = out.like(sym) + B = self.from_np_data(B, bias, dtype=B.dtype) + return opclass.add(out, B).like(sym) + return sym - return sym + return custom_run class FuseDividePass(InferPass): - def visit_divide(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.DIV: - argA = sym.args[0] - argB = sym.args[1] - assert self.is_param(argB), f'NotParam: {argB}' - argB = self.from_np_data(1. / self.get_as_numpy(argB), dtype=argB.dtype) - return opclass.MRT_OP_MAP[opns.MUL](argA, argB) - return sym + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.DIV: + argA = sym.args[0] + argB = sym.args[1] + assert self.is_param(argB), f'NotParam: {argB}' + argB = self.from_np_data(sym, 1. / self.get_as_numpy(argB), dtype=argB.dtype) + out = opclass.mul(argA, argB) + return out.like(sym) + return sym + return custom_run + +class FuseLeakyReLU(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.LEAKY_RELU: + alpha = self.from_const_data(sym, sym.alpha, dtype=float) + X = sym.args[0] + out = opclass.relu(opclass.negative(X)) + out = opclass.mul(alpha, out) + return opclass.sub(opclass.relu(X), out) + return sym + return custom_run + +class FuseAdaptiveAvgPool2D(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + if sym.op_name == opns.ADAPTIVE_AVG_POOL2D: + X = sym.args[0] + assert sym.layout == "NCHW" + inp_shap = X.shape[2:] + out_size = sym.output_size or inp_shap + if not isinstance(out_size, (list, tuple)): + out_size = (out_size, out_size) + sym.output_size = out_size + + assert len(X.shape) == 4 + if all([s == 1 for s in sym.output_size]): + scale = np.array(1 / np.prod(X.shape[-2:])) + out = opclass.Sum(X, dim=list(range(4))[-2:], keepdims=True) + scale = self.from_np_data(sym, scale.astype(X.dtype)) + return opclass.mul(out, scale).like(self) + elif out_size[0] > inp_shap[0] or out_size[1] > inp_shap[1]: + assert all([s == 1 for s in inp_shap]) + # TODO: fix opclass repeat + out = opclass.repeat(X, repeats=out_size[0], axis=-2) + out = opclass.repeat(out, repeats=out_size[1], axis=-1) + return out.like(self) + + # calculate the attributes refers to: + # https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work + strides = [i // o for i, o in zip(inp_shap, out_size)] + kernel = [i-(o-1)*s for i, o, s in zip(inp_shap, out_size, strides)] + attrs = { + "kernel_size": kernel, + "strides": strides, + "padding": (0, 0), + "dilation": (1, 1), + "data_layout": sym.layout, + "groups": X.shape[1], + "channels": X.shape[1], + } + W_shape = (X.shape[1], 1, *kernel) + W = self.from_np_data(X, np.full(W_shape, 1 / product(kernel)), dtype=X.dtype) + out = opclass.Conv2D(X, W, **attrs) + return out.like(sym) + return sym + return custom_run + + +class FuseAvgPool2D(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + return sym + return custom_run + +class Spliter(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + return sym + return custom_run + +class Merger(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + return sym + return custom_run + +class Calibrator(InferPass): + def get_run(self) -> _symbol._TransformerParamT: + def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: + return sym + return custom_run diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index 315b70d..07bd8af 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -113,7 +113,10 @@ def like(self, other: Symbol, **kwargs) -> Symbol: # assert self.shape == other.shape, "%s vs.\n %s" % (self, other) # assert self.dtype == other.dtype , "%s vs.\n %s" % (self, other) data = other.to_dict() - data.update(self.to_dict()) + data_new = self.to_dict() + data.update(data_new) + + data["extra_attrs"] = other.extra_attrs if self.extra_attrs == {} else data["extra_attrs"] # copy extra attrs by default. # data["extra_attrs"] = other.extra_attrs return type(other).from_dict(data, **kwargs) diff --git a/tests/mir/test.infer_pass.py b/tests/mir/test.infer_pass.py new file mode 100644 index 0000000..3d94e93 --- /dev/null +++ b/tests/mir/test.infer_pass.py @@ -0,0 +1,103 @@ +""" +Test script for MRT InferPass +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_resnet18_model(): + """Get Resnet18 MRT Model""" + + # Load pre-trained ResNet18 + model = models.resnet18(weights='IMAGENET1K_V1') + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Model to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + + +def test_InferPass_FuseBatchNorm(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + print('\n=== Before FuseBatchNorm Pass ===') + symlist = sx.sym2list(symbol) + return True + + +def test_InferPass_FuseAdaptiveAvgPool2D(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + print('\n=== Before FuseAdaptiveAvgPool2D Pass ===') + symlist = sx.sym2list(symbol) + return True + + +def test_InferPass_FuseTupleGetItem(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseTuple Pass ===') + symlist = sx.sym2list(symbol) + #for x in symlist: + #print(x) + + op_cnt = 0 + for sym in symlist: + op_cnt += 1 if sym.op_name == opns.TUPLE_GET_ITEM else 0 + assert op_cnt > 0, f'ori model TupleGetItem op cnt {op_cnt} == zero!' + + # init Passer and execute visit + tfs : simple_pass.FuseTupleGetItemPass = simple_pass.FuseTupleGetItemPass(symbol, mrt_params) + symbol_passed = tfs.custom_visits_with_params(tfs.get_run()) + + print('\n=== After FuseTuple Pass ===') + rlts = sx.sym2list(symbol_passed) + op_cnt_af = 0 + for sym in rlts: + # print(sym) + op_cnt_af += 1 if sym.op_name == opns.TUPLE_GET_ITEM else 0 + assert op_cnt_af==0, f'passed model op cnt {op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + print("=== Testing InferPass ===") + mrt_graph, mrt_params = _get_resnet18_model() + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_InferPass_FuseBatchNorm, test_InferPass_FuseAdaptiveAvgPool2D, test_InferPass_FuseTupleGetItem] + for func_ in test_funcs: + rltflag = func_(mrt_graph, mrt_params) + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') + diff --git a/tests/mir/test.infer_pass_div.py b/tests/mir/test.infer_pass_div.py new file mode 100644 index 0000000..547363c --- /dev/null +++ b/tests/mir/test.infer_pass_div.py @@ -0,0 +1,88 @@ +""" +Test script for MRT InferPass +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_fasterrcnn_resnet50_fpn_model(): + """Get Fasterrcnn_resnet50_fpn MRT Model""" + + # Load pre-trained model + model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Model to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + + +def test_InferPass_FuseDivide(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseDivide Pass ===') + symlist = sx.sym2list(symbol) + + divide_op_cnt = 0 + for sym in symlist: + divide_op_cnt += 1 if sym.op_name == opns.DIV else 0 + assert divide_op_cnt > 0, f'ori model divide op cnt {divide_op_cnt} == zero!' + + # init FuseDivide Passer and execute visit + tfs : simple_pass.FuseDividePass = simple_pass.FuseDividePass(symbol, mrt_params) + symbol_passed = tfs.custom_visits_with_params(tfs.get_run()) + + print('\n=== After FuseDivide Pass ===') + rlts = sx.sym2list(symbol_passed) + divide_op_cnt_af = 0 + for sym in rlts: + # print(sym) + divide_op_cnt_af += 1 if sym.op_name == opns.DIV else 0 + assert divide_op_cnt_af==0, f'passed model divide op cnt {divide_op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + print("=== Testing InferPass Divide ===") + mrt_graph, mrt_params = _get_fasterrcnn_resnet50_fpn_model() + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_InferPass_FuseDivide] + for func_ in test_funcs: + rltflag = func_(mrt_graph, mrt_params) + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') + diff --git a/tests/mir/test.infer_pass_mean.py b/tests/mir/test.infer_pass_mean.py new file mode 100644 index 0000000..d8586ec --- /dev/null +++ b/tests/mir/test.infer_pass_mean.py @@ -0,0 +1,89 @@ +""" +Test script for MRT InferPass +""" + +from os import path +import sys, os + +ROOT = path.dirname(path.dirname(path.dirname( + path.realpath(__file__)))) +sys.path.insert(0, path.join(ROOT, "python")) + +import torch +import torchvision.models as models +import numpy as np +from collections import namedtuple + +from mrt.frontend.pytorch import pytorch_to_mrt, mrt_to_pytorch, type_infer +from mrt.frontend.pytorch import vm +from mrt.mir import helper, symbol as sx +from mrt.mir import opns +from mrt.mir import opclass +from mrt.mir import simple_pass + +def _get_shufflenet_model(): + """Get Shufflenet MRT Model""" + + # Load pre-trained + model = models.shufflenet_v2_x1_0(pretrained=True) + model.eval() + + # Create example input + example_inputs = torch.randn(1, 3, 224, 224) + + # Test inference with original model + with torch.no_grad(): + original_output = model(example_inputs) + + # Convert to MRT + print("\nConverting Model to MRT...") + ep = torch.export.export(model, (example_inputs,)) + mrt_graph, mrt_params = pytorch_to_mrt(ep) + return mrt_graph, mrt_params + + +def test_InferPass_FuseMean(mrt_graph, mrt_params): + symbol = mrt_graph['main'] + #print(symbol) + + print('\n=== Before FuseMean Pass ===') + symlist = sx.sym2list(symbol) + #for x in symlist: + #print(x) + + op_cnt = 0 + for sym in symlist: + op_cnt += 1 if sym.op_name == opns.MEAN else 0 + assert op_cnt > 0, f'ori model mean op cnt {op_cnt} == zero!' + + # init Passer and execute visit + tfs : simple_pass.FuseMeanPass = simple_pass.FuseMeanPass(symbol, mrt_params) + symbol_passed = tfs.custom_visits_with_params(tfs.get_run()) + + print('\n=== After FuseMean Pass ===') + rlts = sx.sym2list(symbol_passed) + op_cnt_af = 0 + for sym in rlts: + # print(sym) + op_cnt_af += 1 if sym.op_name == opns.MEAN else 0 + assert op_cnt_af==0, f'passed model op cnt {op_cnt_af} != zero!' + + return True + + +if __name__ == "__main__": + print("=== Testing InferPass Mean ===") + mrt_graph, mrt_params = _get_shufflenet_model() + + test_id = 0 + passed_cnt = 0 + test_funcs = [test_InferPass_FuseMean] + for func_ in test_funcs: + rltflag = func_(mrt_graph, mrt_params) + test_id += 1 + passed_cnt += rltflag + print("\n" + "="*60 + "\n") + print(f'Passed Test{test_id} Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!' if rltflag else f'Test{test_id} Failed! Processed({passed_cnt}/{len(test_funcs)}), Passed({passed_cnt}/{test_id})!') + print("\n" + "="*60 + "\n") + print(f'Summary_Passed {passed_cnt}/{len(test_funcs)}') + From 2843c2fef9b7ec8839d95351115c9236e25913a1 Mon Sep 17 00:00:00 2001 From: corlfj Date: Tue, 4 Nov 2025 14:28:42 +0800 Subject: [PATCH 08/12] [mir]: WithParameters not inherit from Symbol --- python/mrt/api.py | 17 ++- python/mrt/mir/symbol.py | 5 +- python/mrt/quantization/calibrate.py | 19 ++-- python/mrt/quantization/discrete.py | 57 +++++----- python/mrt/quantization/fixed_point.py | 30 ++--- python/mrt/quantization/fuse.py | 73 ++++++------ python/mrt/quantization/precision.py | 22 ++-- python/mrt/quantization/scaler.py | 11 +- python/mrt/quantization/segement.py | 32 ++++-- python/mrt/quantization/transform.py | 151 +++++++++++++++++++------ python/mrt/runtime/inference.py | 4 +- 11 files changed, 269 insertions(+), 152 deletions(-) diff --git a/python/mrt/api.py b/python/mrt/api.py index fe511cb..34d0355 100644 --- a/python/mrt/api.py +++ b/python/mrt/api.py @@ -223,7 +223,12 @@ def checkpoint_run(self, def discrete(self) -> Trace: fuse_tr = self.fuse() - seg_tr = fuse_tr.checkpoint_run(seg.Spliter.get_transformer()) + + """Must pass params inside a dict, + Cause it will be unfolded separately + """ + kwargs_seg = {"pointer": {"head": {}, "head_params": {}, "seg_names": []}} + seg_tr = fuse_tr.checkpoint_run(seg.Spliter.get_transformer(), **kwargs_seg) C = TraceConfig.G() calib_tr = seg_tr.calibrate( @@ -232,7 +237,8 @@ def discrete(self) -> Trace: quant_tr = calib_tr.quantize() quant_tr = quant_tr.checkpoint_run( seg.Merger.get_transformer(), - spliter=seg_tr.symbol) + spliter=seg_tr.symbol, + **kwargs_seg) return quant_tr def fuse(self, **kwargs) -> Trace: @@ -254,8 +260,13 @@ def fuse(self, **kwargs) -> Trace: def calibrate(self, repeats: int = 1, **kwargs) -> Trace: assert self._dataset is not None tr_name = kwargs.pop("tr_name", "calibrate") + raw_data: typing.Dict[str, OpOutputT] = {} + out_data: typing.List[OpNumpyT] = [] + out = self for i in range(repeats): + kwargs["raw_data"] = raw_data + kwargs["out_data"] = out_data data, _ = self._dataset.next() out = out.checkpoint_run( calib.Calibrator.get_transformer(), @@ -265,7 +276,7 @@ def calibrate(self, repeats: int = 1, **kwargs) -> Trace: **kwargs) out = out.checkpoint_run( calib.SymmetricMinMaxSampling.get_transformer(), - tr_name = "%s_sampling" % tr_name) + tr_name = "%s_sampling" % tr_name, **{'origin_data': raw_data}) return out def quantize(self, **kwargs): diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index 07bd8af..973e92b 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -112,6 +112,7 @@ def like(self, other: Symbol, **kwargs) -> Symbol: """ cast current symbol to child class. """ # assert self.shape == other.shape, "%s vs.\n %s" % (self, other) # assert self.dtype == other.dtype , "%s vs.\n %s" % (self, other) + assert isinstance(other, Symbol) data = other.to_dict() data_new = self.to_dict() data.update(data_new) @@ -342,7 +343,7 @@ def visit(symbol: Symbol, callback: _VisitorT): if callback.__name__ in C.log_vot_cbs: config.log(callback.__name__, f">> {sym}") -def transform(symbol: Symbol, callback: _TransformerParamT, params:typing.Optional[ParametersT] = None) -> Symbol: +def transform(symbol: Symbol, callback: _TransformerT) -> Symbol: """ Transform symbol from old to new, with inputs updated. Only the return value indicates mutation, while changing @@ -359,7 +360,7 @@ def transform(symbol: Symbol, callback: _TransformerParamT, params:typing.Option if callback.__name__ in C.log_vot_cbs: config.log(callback.__name__, f"<< {sym}") - out = (callback(sym, params) if params else callback(sym)) or sym + out = callback(sym) or sym assert isinstance(out, Symbol), out # default const_ prefix symbol means parameters assert sym.name not in sym_map, sym.name diff --git a/python/mrt/quantization/calibrate.py b/python/mrt/quantization/calibrate.py index 0cf8a0e..7961ce7 100644 --- a/python/mrt/quantization/calibrate.py +++ b/python/mrt/quantization/calibrate.py @@ -18,7 +18,7 @@ @dataclass(repr=False) class Calibrator(Transformer): """ skip dump, and restore from np_data. """ - raw_data: OpOutputT | None = field(repr=False, default=None) + raw_data: typing.Dict[str, OpOutputT] = field(repr=False, default_factory=dict) """ calibrate may be processed multi-times """ data: typing.List[OpNumpyT] = field(default_factory=list) @@ -43,6 +43,8 @@ def __call__(self, sampling_func: SamplingFuncT = None, **kwargs): kwargs.pop("origin", None) + self.raw_data = kwargs.pop("raw_data", {}) + self.data = kwargs.pop("out_data", []) if self.is_input(): out = data_dict.get(self.name, data) @@ -51,10 +53,10 @@ def __call__(self, elif self.is_param(): out = self.params[self.name] else: - single_op = op.retrieve_operator(self) + single_op = op.retrieve_operator(self.graph) out = inference.run_single( single_op, - [a.raw_data for a in self.args], + [self.raw_data[a.name] for a in self.args], **kwargs) assert isinstance(out, (np.ndarray, list)), type(out) @@ -65,7 +67,7 @@ def __call__(self, self._assert([o.dtype.name for o in out], self.dtype) self._assert([o.shape for o in out], self.shape) - self.raw_data = out + self.raw_data[self.name] = out if sampling_func is not None: out = sampling_func(out) self.data.append(out) @@ -103,15 +105,18 @@ def data(self, val): def sampling(cls, np_data: np.ndarray) -> typing.Any: raise NotImplementedError() - def __call__(self, origin: Calibrator, **kw): + def __call__(self, origin: Symbol, **kw): print(type(origin), origin) + origin_data = kw.pop('origin_data', []) + origin_data = origin_data[self.name] + if self.is_op(opns.CLIP): # TODO: remove clip if threshold is less than a_max a_min, a_max = self.parsed.a_min, self.parsed.a_max self.data = max(abs(a_min), abs(a_max)) else: - self.data = self.sampling(origin.data) - return self + self.data = self.sampling(origin_data) + return self.graph @dataclass(repr=False) class SymmetricMinMaxSampling(Sampling): diff --git a/python/mrt/quantization/discrete.py b/python/mrt/quantization/discrete.py index bcca5a2..5ca5574 100644 --- a/python/mrt/quantization/discrete.py +++ b/python/mrt/quantization/discrete.py @@ -35,7 +35,10 @@ def undefined(self) -> bool: @dataclass(repr=False) class QuantInfo(WithScale, WithPrecision, Sampling): - requant_ops: typing.Dict[DiscreteInfo, Symbol] = field(repr=False) + requant_ops: typing.Dict[DiscreteInfo, Symbol] = field(repr=False, default_factory=dict) + + def from_symbol(self, sym: Symbol) -> typing.Self: + return type(self)(sym, self.params) @classmethod def default_dict(cls, **kwargs) -> dict: @@ -62,7 +65,7 @@ def rescale(self, info: DiscreteInfo): """ scale, precision = info.scale, info.precision if info.undefined: - return self + return self.graph elif scale is not None: precision = self.scale_to_precision(scale) elif precision is not None: @@ -72,11 +75,10 @@ def rescale(self, info: DiscreteInfo): curr_scale = self.scale if self.scale_defined else 1 #TODO: add pass to check rescale=1 and duplicate requant out = op.requant( - self, + self.graph, rescale=scale/curr_scale, precision=precision, - ) - out = out.like(self) + ).like(self.graph) out.set_extra_attrs( data=self.data, scale=scale, precision=precision) self.requant_ops[info] = out @@ -137,12 +139,13 @@ def _rule(s: QuantInfo): register_rules_with_default(SUM, requant_rule=args_max_prec(10)) def uniform_args_scale(args: typing.List[QuantInfo], + params: ParametersT = {}, std_prec: int =15): # standard max precision for add/sub children. assert len(args) > 0 # raw_print(s) - assert any([c.is_operator() for c in args]), \ + assert any([op.is_operator(c.graph, params) for c in args]), \ "Need fuse constant for uniform_args_scale" scales = [] for arg in args: @@ -173,27 +176,28 @@ def uniform_args_scale(args: typing.List[QuantInfo], # scale = min(scaleA, scaleB) # return [DiscreteInfo(scale=scale) for c in s.args] def scale_like_index(s: WithScale, index: int = 0): - return s.args[index].scale + return s.args[index].extra_attrs.get("scale", -1) + register_rules_with_default( ADD, SUB, # BIAS_ADD, MAXIMUM, MINIMUM, - requant_rule=lambda s: uniform_args_scale(s.args), + requant_rule=lambda s: uniform_args_scale([s.from_symbol(a) for a in s.args], s.params), scale_rule=scale_like_index) def scale_concat(s: WithScale): - fscale = s.args[0].scale - if all([a.scale == fscale for a in s.args]): + fscale = s.args[0].extra_attrs.get("scale", -1) + if all([a.extra_attrs.get("scale", -1) == fscale for a in s.args]): return fscale - return [a.scale for a in s.args] + return [a.extra_attrs.get("scale", -1) for a in s.args] register_rules_with_default( CONCAT, TUPLE, - requant_rule=lambda s: uniform_args_scale(s.args), + requant_rule=lambda s: uniform_args_scale([s.from_symbol(a) for a in s.args], s.params), scale_rule=scale_concat) def uniform_first_scale(s: QuantInfo): - target_scale = s.args[0].scale + target_scale = s.args[0].extra_attrs.get("scale", -1) return [DiscreteInfo(scale=target_scale) for c in s.args] register_rules_with_default( @@ -202,7 +206,7 @@ def uniform_first_scale(s: QuantInfo): # register_rules_with_default( # WHERE, -# requant_rule=lambda s: uniform_args_scale(s.args[1:]), +# requant_rule=lambda s: uniform_args_scale([s.from_symbol(a) for a in s.args[1:]], s.params), # scale_rule=scale_like_index(s, -1), # ) @@ -217,7 +221,7 @@ def uniform_first_scale(s: QuantInfo): register_rules_with_default(NEGATIVE) def scale_tuple_get_item(s: WithScale): - ascale = s.args[0].scale + ascale = s.args[0].extra_attrs.get("scale", -1) if isinstance(ascale, (list, tuple)): return ascale[s.parsed.index] return ascale @@ -226,7 +230,7 @@ def scale_tuple_get_item(s: WithScale): scale_rule=scale_tuple_get_item) def op_clip_rules(s: QuantInfo): - scale = s.args[0].scale + scale = s.args[0].extra_attrs.get("scale", -1) s.set_extra_attrs( a_min=s.parsed.a_min * scale, a_max=s.parsed.a_max * scale) @@ -257,7 +261,7 @@ def op_lut_rules(s: QuantInfo): # if s.is_op(EXP): # arg_max = min(math.log(s.data), arg_max) - op_inp = np.arange(-alpha, alpha+1) / s.args[0].scale + op_inp = np.arange(-alpha, alpha+1) / s.args[0].extra_attrs.get("scale", -1) table = inference.run(s, [ tvm.nd.array(op_inp), ]) table = np.clip(table.numpy(), a_min=-s.data, a_max=s.data) # table = np.reshape(table, (-1, 1)) @@ -280,8 +284,8 @@ def softmax_scale_rules(s: QuantInfo): def op_softmax_rules(s: QuantInfo): lambd = 10 X = s.args[0] # get requant rule op - Xp = X.attrs["precision"] - Xs = X.scale #X.attrs["precision"] + Xp = X.extra_attrs["precision"] + Xs = X.extra_attrs["scale"] #X.attrs["precision"] axis = s.attrs["axis"] alpha = int(lambd * Xs) var = s.from_np_data(np.array(alpha, "int")) @@ -359,9 +363,9 @@ class Discretor(QuantInfo): """ def __call__(self, **kw): if self.is_variable(): - return + return self.graph elif self.is_op(TUPLE): - return + return self.graph orig_names = [a.name for a in self.args] @@ -378,14 +382,14 @@ def __call__(self, **kw): # requant input to specific precision arg_dts = _DISCRETE_REQUANT_RULES[self.op_name](self) for i, arg in enumerate(self.args): - self.args[i] = arg.rescale(arg_dts[i]) + self.args[i] = self.from_symbol(arg).rescale(arg_dts[i]) # calculate the F function - out = _DISCRETE_OP_RULES[self.op_name](self).like( - self, extra_attrs=self.extra_attrs) + out = _DISCRETE_OP_RULES[self.op_name](self.graph).like( + self.graph, extra_attrs=self.extra_attrs) # calculate the output data's scale - out.scale = INFER_SCALE_RULES[self.op_name](out) + out.set_extra_attrs(scale = INFER_SCALE_RULES[self.op_name](out)) new = op.subgraph(out, inames=[a.name for a in self.args]) # self.is_op(EXP) and raw_print(new) # out.scale = infer_scale(new) @@ -397,7 +401,8 @@ def __call__(self, **kw): # out = op.pclip(out, precision=target_precision).like( # out, extra_attrs=out.extra_attrs) # out.precision = target_precision - out.precision = self.scale_to_precision(out.scale) + out.set_extra_attrs(precision = self.scale_to_precision(out.extra_attrs.get("scale", -1))) + # TODO: add skip for some operators # same_scale = all([a.scale == out.scale for a in self.args]) diff --git a/python/mrt/quantization/fixed_point.py b/python/mrt/quantization/fixed_point.py index 087798a..20291f4 100644 --- a/python/mrt/quantization/fixed_point.py +++ b/python/mrt/quantization/fixed_point.py @@ -65,7 +65,7 @@ def map_int_requant(self): precision, which follows precision max bit limit. """ - X: FixPoint = self.args[0] + X: FixPoint = self.from_symbol(self.args[0]) rescale = self.parsed.rescale anno_bit = WithPrecision.MAX_BIT // 2 @@ -82,29 +82,31 @@ def map_int_requant(self): if X.precision > anno_bit: # recalculate exp + exp = exp + (X.precision - anno_bit) rs_bit = X.from_const_data(X.precision - anno_bit) - X = op.right_shift(X, rs_bit).like(self) + X_op = op.right_shift(X.graph, rs_bit).like(self.graph) + X = self.from_symbol(X_op) X.precision = anno_bit assert frac >= 1 assert exp <= 0 frac_sym = X.from_const_data(frac) - out = op.mul(X, frac_sym).like(self) + out = op.mul(X.graph, frac_sym).like(self.graph) - exp_sym = out.from_const_data(-exp) + exp_sym = self.from_symbol(out).from_const_data(-exp) if ExporterConfig.G().use_clip: if ExporterConfig.G().use_pclip: out = op.rs_pclip(out, exp_sym, precision=self.precision) else: pos = self.int_max() - out = op.right_shift(out, exp_sym).like(self) - out = op.clip(out, min=-pos, max=pos).like(self) + out = op.right_shift(out, exp_sym).like(self.graph) + out = op.clip(out, min=-pos, max=pos).like(self.graph) else: - out = op.right_shift(out, exp_sym).like(self) + out = op.right_shift(out, exp_sym).like(self.graph) return out def process(self): @@ -114,7 +116,7 @@ def process(self): if G.use_int_dtype: G.use_round = True - out = self + out = self.graph if self.is_param() and G.use_round: data = np.round(self.numpy()) if G.use_int_dtype: @@ -123,7 +125,7 @@ def process(self): pos = self.int_max() if self.is_op(REQUANT): - if G.use_int_requant and (not self.args[0].is_input()): + if G.use_int_requant and (not self.from_symbol(self.args[0]).is_input()): out = self.map_int_requant() else: # use float multipy to map requant rescale = self.parsed.rescale @@ -154,10 +156,10 @@ def process(self): def __call__(self, **kw): if not self.precision_defined: logger.warning(f"symbol: {self.name} is ignored without precision defined.") - return self + return self.graph self.validate_precision() - out = self.process().like(self, extra_attrs=self.extra_attrs) + out = self.process().like(self.graph, extra_attrs=self.extra_attrs) # TODO: add precision int max validate # if self.is_param(): # absmax = np.abs(out.numpy()).max() @@ -166,7 +168,7 @@ def __call__(self, **kw): @dataclass(repr=False) class Simulator(QuantInfo): - def round(self, out: Transformer): + def round(self, out: Symbol): # data_0_5 = self.from_const_data(0.5) # out = op.add(out, data_0_5) # out = op.ceil(out) @@ -176,7 +178,7 @@ def round(self, out: Transformer): return out def __call__(self, with_clip=False, with_round=False, **kw): - out: Transformer = self + out: Symbol = self.graph if self.is_input(): """ input is the original float data, skip. """ return out @@ -198,7 +200,7 @@ def __call__(self, with_clip=False, with_round=False, **kw): out = op.clip(out, min=-pos, max=pos) # print(out) # sys.exit() - return out.like(self) + return out.like(self.graph) @dataclass(repr=False) diff --git a/python/mrt/quantization/fuse.py b/python/mrt/quantization/fuse.py index 31360a5..364f7a9 100644 --- a/python/mrt/quantization/fuse.py +++ b/python/mrt/quantization/fuse.py @@ -29,32 +29,32 @@ def np_is_zero(self, data) -> float: return np.abs(data).max() < self.threshold def __call__(self: Transformer, **kw): - if self.is_operator() and all([c.is_param() for c in self.args]): + if self.is_operator() and all([self.from_symbol(c).is_param() for c in self.args]): data = inference.run_single( - self, [a.numpy() for a in self.args]) + self.graph, [self.from_symbol(a).numpy() for a in self.args]) return self.as_parameter(data) elif self.is_op(ADD, SUB): # , BIAS_ADD): strips = [] for arg in self.args: - if arg.is_param() and self.np_is_zero(arg.numpy()): + if self.from_symbol(arg).is_param() and self.np_is_zero(self.from_symbol(arg).numpy()): # if arg.is_param() and np.abs(arg.numpy()).max() == 0: strips.append(arg) args = [a for a in self.args if a not in strips] if len(args) == 1: return args[0] elif self.is_op(SLICE_LIKE): - if not self.args[0].is_param(): + if not self.from_symbol(self.args[0]).is_param(): return a, b = self.args arg1 = np.zeros(b.shape, b.dtype) data = inference.run_single( - self, [a.numpy(), np.zeros(b.shape, b.dtype)]) + self.graph, [self.from_symbol(a).numpy(), np.zeros(b.shape, b.dtype)]) return self.as_parameter(data) elif self.is_op(REQUANT): if self.parsed.rescale == 1: return self.args[0] elif self.is_op(ZEROS_LIKE, ONES_LIKE): - data = inference.run_single(self, []) + data = inference.run_single(self.graph, []) return self.as_parameter(data) @@ -62,10 +62,11 @@ class FuseBatchNorm(Transformer): @filter_operators(BATCH_NORM) def __call__(self, **kw): X, gamma, beta, mean, var = self.args + X = self.from_symbol(X) parsed: BatchNormAttrs = self.parsed - gamma, beta = gamma.numpy(), beta.numpy() - mean, var = mean.numpy(), var.numpy() + gamma, beta = self.from_symbol(gamma).numpy(), self.from_symbol(beta).numpy() + mean, var = self.from_symbol(mean).numpy(), self.from_symbol(var).numpy() # print(gamma.shape, beta.shape, mean.shape, var.shape) assert parsed.axis == 1 @@ -90,8 +91,8 @@ def __call__(self, **kw): # (A * W) * gamma + bias # A * (W * gamma) + bias - W_data = W.numpy() * gamma.reshape(K, 1, 1, 1) - W_sym = W.from_np_data(W_data) + W_data = self.from_symbol(W).numpy() * gamma.reshape(K, 1, 1, 1) + W_sym = self.from_symbol(W).from_np_data(W_data) out = op.nn_conv2d(A, W_sym, **X.attrs) elif X.is_op(DENSE): A, W = X.args @@ -99,27 +100,27 @@ def __call__(self, **kw): # (A * W) * gamma + bias # A * (W * gamma) + bias - W_data = W.numpy() * gamma.reshape(K, 1) - W_sym = W.from_np_data(W_data) + W_data = self.from_symbol(W).numpy() * gamma.reshape(K, 1) + W_sym = self.from_symbol(W).from_np_data(W_data) out = op.nn_dense(A, W_sym, **X.attrs) else: reshp = [s if i == parsed.axis else 1 \ for i, s in enumerate(X.shape)] W = X.from_np_data(gamma.reshape(reshp)) - out = op.mul(X, W) + out = op.mul(X.graph, W) bias = bias.reshape([s if i == parsed.axis else 1 \ for i, s in enumerate(out.shape)]) - B = out.like(self).from_np_data(bias) + B = self.from_symbol(out.like(self.graph)).from_np_data(bias) out = op.add(out, B) # out = op.bias_add(out, B, axis=parsed.axis) - return out.like(self) + return out.like(self.graph) class FuseTupleGetItem(Transformer): @filter_operators(TUPLE_GET_ITEM) def __call__(self, **kw): X: Symbol = self.args[0] - if X.is_op(BATCH_NORM, DROP_OUT): + if self.from_symbol(X).is_op(BATCH_NORM, DROP_OUT): return X # assert X.is_op(BATCH_NORM, DROP_OUT), X.name # assert self.parsed.index == 0 @@ -133,7 +134,7 @@ def __call__(self, **kw): @filter_operators(AVG_POOL2D) def _fuse_avg_pool2d(self): - X: Transformer = self.args[0] + X: Symbol = self.args[0] parsed: AvgPool2DAttrs = self.parsed assert parsed.layout == "NCHW" # TODO: ignore for unstrict mode @@ -148,15 +149,15 @@ def _fuse_avg_pool2d(self): "channels": X.shape[1], } W_shape = (X.shape[1], 1, *parsed.pool_size) - W = X.from_np_data(np.full( + W = self.from_symbol(X).from_np_data(np.full( W_shape, 1 / product(parsed.pool_size))) out = op.nn_conv2d(X, W, **attrs) - return out.like(self) + return out.like(self.graph) @filter_operators(ADAPTIVE_AVG_POOL2D) def _fuse_adaptive_avg_pool2d(self): - X: Transformer = self.args[0] + X: Symbol = self.args[0] parsed: AdaptiveAvgPool2DAttrs = self.parsed assert parsed.layout == "NCHW" ins = X.shape[2:] @@ -170,12 +171,12 @@ def _fuse_adaptive_avg_pool2d(self): scale = np.array(1 / np.prod(X.shape[-2:])) out = op.sum(X, axis=list(range(4))[-2:], keepdims=True) scale = self.from_np_data(scale.astype(X.dtype)) - return op.mul(out, scale).like(self) + return op.mul(out, scale).like(self.graph) elif ous[0] > ins[0] or ous[1] > ins[1]: assert all([s == 1 for s in ins]) out = op.repeat(X, repeats=ous[0], axis=-2) out = op.repeat(out, repeats=ous[1], axis=-1) - return out.like(self) + return out.like(self.graph) # calculate the attributes refers to: # https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work @@ -191,23 +192,23 @@ def _fuse_adaptive_avg_pool2d(self): "channels": X.shape[1], } W_shape = (X.shape[1], 1, *kernel) - W = X.from_np_data(np.full(W_shape, 1 / product(kernel))) + W = self.from_symbol(X).from_np_data(np.full(W_shape, 1 / product(kernel))) out = op.nn_conv2d(X, W, **attrs) - return out.like(self) + return out.like(self.graph) class FuseNaiveSoftmax(Transformer): def __call__(self, **kw): - return self # not fuse pass + return self.graph # not fuse pass if self.is_op(SOFTMAX, LOG_SOFTMAX): return self.args[0] - assert self.is_variable() or not self.args[0].is_op(SOFTMAX, LOG_SOFTMAX) - return self + assert self.is_variable() or not self.from_symbol(self.args[0]).is_op(SOFTMAX, LOG_SOFTMAX) + return self.graph class FuseMean(Transformer): @filter_operators(MEAN) def __call__(self, **kw): - X: Transformer = self.args[0] + X: Symbol = self.args[0] # max_axis = len(X.shape) # axis = X.attrs.get("axis", None) # axis = axis or [i for i in range(max_axis)] @@ -221,7 +222,7 @@ def __call__(self, **kw): scale = self.from_np_data(np.array( 1. * product(out.shape) / product(X.shape))) out = op.mul(out, scale) - return out.like(self) + return out.like(self.graph) class FuseLeakyReLU(Transformer): @filter_operators(LEAKY_RELU) @@ -234,22 +235,22 @@ def __call__(self, **kw): LeakyReLU(X) = relu(X) - slope*relu(-X) """ alpha = self.from_const_data(self.parsed.alpha) - X: Transformer = self.args[0] + X: Symbol = self.args[0] out = op.nn_relu(op.negative(X)) out = op.mul(alpha, out) out = op.sub(op.nn_relu(X), out) - return out.like(self) + return out.like(self.graph) class FuseDivide(Transformer): @filter_operators(DIV) def __call__(self, **kw): """ Transform div to mul if possible. """ - A: Transformer = self.args[0] - B: Transformer = self.args[1] - assert B.is_param(), B - B = B.from_np_data(1. / B.numpy()) - return op.mul(A, B).like(self) + A: Symbol = self.args[0] + B: Symbol = self.args[1] + assert self.from_symbol(B).is_param(), B + B = self.from_symbol(B).from_np_data(1. / self.from_symbol(B).numpy()) + return op.mul(A, B).like(self.graph) # move to fuse constant # class FuseNaiveMathmatic(Transformer): diff --git a/python/mrt/quantization/precision.py b/python/mrt/quantization/precision.py index b933b6a..faf074d 100644 --- a/python/mrt/quantization/precision.py +++ b/python/mrt/quantization/precision.py @@ -14,14 +14,16 @@ number_to_bits, count_to_bits, bits_to_number from mrt.common.types import ParametersT -from .transform import Transformer +from .transform import SymbolBridge, Transformer __ALL__ = [ "WithPrecision", "InferPrecision", "QuantizedInfo", ] @dataclass(repr=False) -class WithPrecision(Symbol): +#class WithPrecision(Symbol): +class WithPrecision(SymbolBridge): + #class WithPrecision(Transformer): MAX_BIT: typing.ClassVar[int] = 32 @classmethod @@ -106,13 +108,13 @@ def _add_rules(f: RulesFuncT): return f return _add_rules -_infer_mul: RulesFuncT = lambda s: sum([c.precision for c in s.args[:2]]) +_infer_mul: RulesFuncT = lambda s: sum([c.extra_attrs.get("precision", -1) for c in s.args[:2]]) """ conv2d may has 3-args, use prefix-2. """ -_infer_max: RulesFuncT = lambda s: max([c.precision for c in s.args]) +_infer_max: RulesFuncT = lambda s: max([c.extra_attrs.get("precision", -1) for c in s.args]) def _infer_index(s: WithPrecision, index: int): - return s.args[index].precision + return s.args[index].extra_attrs.get("precision", -1) prec_rules(TUPLE)(_infer_max) prec_rules(MAX_AXIS)(_infer_max) @@ -166,7 +168,7 @@ def _infer_right_shift(s: WithPrecision): A, B = s.args[0], s.args[1] assert B.is_param() b_prec = InferPrecision.bind(B) - return A.precision - b_prec + return A.extra_attrs.get("precision", -1) - b_prec @prec_rules(REQUANT, PCLIP, RS_PCLIP) def _infer_attr_prec(s: WithPrecision): @@ -178,7 +180,7 @@ class PrecisionRevisor(WithPrecision, Transformer): def __call__(self, **kw): out = self if out.is_input(): - return + return out.graph elif out.is_op(REQUANT, PCLIP): assert out.precision == out.parsed.precision, f"{out.name} out_prec:{out.precision}, out_parsed_prec:{out.parsed.precision}" elif out.is_op(RS_PCLIP): @@ -202,12 +204,12 @@ def __call__(self, **kw): # print("infered prec:", oprec) if out.precision_defined and oprec > out.precision: out.precision, oprec = oprec, out.precision - out = op.pclip(out, precision=oprec).like( - out, extra_attrs=out.extra_attrs) + out = out.from_symbol(op.pclip(out.graph, precision=oprec).like( + out.graph, extra_attrs=out.extra_attrs)) out.precision = oprec out.validate_precision() - return out + return out.graph # def cvm_infer_single_precision( # symbol: WithPrecision, params: ParametersT) -> int: diff --git a/python/mrt/quantization/scaler.py b/python/mrt/quantization/scaler.py index 8cd058d..0ad2f76 100644 --- a/python/mrt/quantization/scaler.py +++ b/python/mrt/quantization/scaler.py @@ -7,8 +7,11 @@ from mrt.mir.opns import * from mrt.mir.symbol import * +from .transform import SymbolBridge + @dataclass(repr=False) -class WithScale(Symbol): +#class WithScale(Symbol): +class WithScale(SymbolBridge): @classmethod def _validate_scale(cls, scale, msg=None): if isinstance(scale, (list, tuple)): @@ -55,13 +58,13 @@ def _add_rules(f: ScaleRulesT): return _add_rules def scale_index(s: WithScale, index: int): - return s.args[index].scale + return s.args[index].extra_attrs.get("scale", -1) def scale_nn(s: WithScale): - return s.args[0].scale * s.args[1].scale + return s.args[0].extra_attrs.get("scale", -1) * s.args[1].extra_attrs.get("scale", -1) def scale_identity(s: WithScale): - return s.args[0].scale + return s.args[0].extra_attrs.get("scale", -1) def infer_scale(symbol: WithScale): def _infer(sym: Symbol): diff --git a/python/mrt/quantization/segement.py b/python/mrt/quantization/segement.py index 38d3f5b..c49b90d 100644 --- a/python/mrt/quantization/segement.py +++ b/python/mrt/quantization/segement.py @@ -34,16 +34,16 @@ def __call__(self, **kwargs): refs = { self.name: 1 } # add refs for root symbol def _collect_refs(sym: Spliter): refs.setdefault(sym.name, 0) - if sym.is_variable(): + if self.from_symbol(sym).is_variable(): return for a in sym.args: refs.setdefault(a.name, 0) refs[a.name] += 1 - visit(self, _collect_refs) + visit(self.graph, _collect_refs) sym_map = {} sym_status = {} - heads = [self] + heads = [self.graph] """ status code: 1 means current symbol has been scaned and sub childs have been added into scan list. @@ -102,7 +102,7 @@ def _collect_refs(sym: Spliter): def _split(sym: Spliter): return op.as_variable(sym) \ if sym.name in self.seg_names else sym - head = transform(self, _split) + head = transform(self.graph, _split) self.head = dump_json(head) self.head_params = {} @@ -114,21 +114,30 @@ def _update_params(sym: Symbol): # helper.format_print(head, self.head_params) - return op.Tuple(*outs).like(self) + kwargs['pointer']["seg_names"] = self.seg_names + kwargs['pointer']["head"] = self.head + kwargs['pointer']["head_params"] = self.head_params + + return op.Tuple(*outs).like(self.graph) @dataclass(repr=False) class Merger(WithScale, RunOnce): - def __call__(self, spliter: Spliter, **kw): + def __call__(self, spliter: Symbol, **kw): assert self.op_name == opns.TUPLE - tail_outs = dict(zip(spliter.seg_names, self.args)) + + head = kw['pointer']["head"] + head_params = kw['pointer']["head_params"] + seg_names = kw['pointer']["seg_names"] + + tail_outs = dict(zip(seg_names, self.args)) # print(spliter.seg_names) - assert spliter.head is not None + assert head is not None head_params = {k: to_ndarray(v) \ - for k, v in spliter.head_params.items()} + for k, v in head_params.items()} # head_params.update(self.params) - head = load_json(spliter.head, params=head_params) + head = load_json(head, params=head_params) # helper.format_print(head, head_params) @@ -139,6 +148,7 @@ def _merge(sym: Symbol): return sym out = transform(head, _merge) - return out.like(self, params={ **head_params, **self.params }) + self.params = { **head_params, **self.params } + return out.like(self.graph) diff --git a/python/mrt/quantization/transform.py b/python/mrt/quantization/transform.py index 8fa7e47..fb15319 100644 --- a/python/mrt/quantization/transform.py +++ b/python/mrt/quantization/transform.py @@ -14,8 +14,69 @@ from mrt.common.utils import N @dataclass(repr=False) -class WithParameters(Symbol): - parsed: _BaseAttrs = field(repr=False) +class SymbolBridge: # SymbolManipulator / Pass + graph: Symbol + + def __init__(self, symbol: Symbol): + self.graph = symbol + + @classmethod + def base(cls, symbol: Symbol): + return cls(symbol) + + def __repr__(self, **attrs): + return self.graph.__repr__(**attrs) + + def from_symbol(self, sym: Symbol) -> typing.Self: + return type(self)(sym) + + @property + def parsed(self)-> _BaseAttrs: + return parse_attrs(self.graph.op_name, self.graph.attrs) + return self.graph.attrs + + """Member Symbol Start + """ + def is_op(self, *op_names) -> bool: + """ Check current symbol is in the op name list. """ + assert len(op_names) > 0 + return self.graph.op_name in op_names + def is_near(self, *names, check_args: bool = True) -> bool: + return self.graph.is_near(*names, check_args) + def to_dict(self): + return self.graph.to_dict() + @classmethod + def from_dict(cls, d: dict, **kwargs) -> WithParameters: + return cls(Symbol.from_dict(d, **kwargs), {}) + @property + def args(self): + return self.graph.args + @property + def op_name(self): + return self.graph.op_name + @property + def name(self): + return self.graph.name + @property + def shape(self): + return self.graph.shape + @property + def dtype(self): + return self.graph.dtype + @property + def attrs(self): + return self.graph.attrs + @property + def extra_attrs(self): + return self.graph.extra_attrs + def set_extra_attrs(self, **kwargs): + return self.graph.extra_attrs.update(kwargs) + """Member Symbol End + """ + +@dataclass(repr=False) +class WithParameters(SymbolBridge): # SymbolManipulator / Pass + graph: Symbol params: ParametersT = field(repr=False) """ Parameters should not be changed in transformer, use copy mode instead to avoid possible errors. @@ -23,31 +84,34 @@ class WithParameters(Symbol): deep copy params in trace `checkpoint_run` api. """ + def __init__(self, symbol: Symbol, params: ParametersT): + self.graph = symbol + self.params = params + @classmethod - def update_dict(cls, data_dict: dict, **kwargs) -> dict: - data_dict.update(kwargs) - parsed = parse_attrs( - data_dict["op_name"], data_dict["attrs"]) - return super().update_dict(data_dict, parsed=parsed) + def base(cls, symbol: Symbol, params: ParametersT): + return cls(symbol, params) def __repr__(self, **attrs): if self.is_param(): attrs["absmax"] = np.abs(self.numpy()).max(initial=0) return super().__repr__(**attrs) - def ndarray(self) -> OpOutputT: - return to_ndarray(self.numpy()) + @property + def parsed(self)-> _BaseAttrs: + return parse_attrs(self.graph.op_name, self.graph.attrs) + attrs = self.graph.attrs + return attrs + def numpy(self) -> OpNumpyT: - assert self.is_param(), f"{self.name} is not parameter." - data = self.params[self.name] + assert self.is_param(), f"{self.graph.name} is not parameter." + data = self.params[self.graph.name] assert isinstance(data, (tuple, list, np.ndarray)), \ - f"param:{self.name} not OpNumpyT, get {type(data)}" + f"param:{self.graph.name} not OpNumpyT, get {type(data)}" return data - return to_numpy(self.ndarray()) - - def as_parameter(self, data: OpNumpyT): + def as_parameter(self, data: OpNumpyT) -> Symbol: def _f(data, dtype): if isinstance(data, list): assert len(data) == len(dtype) @@ -55,27 +119,44 @@ def _f(data, dtype): assert isinstance(data, np.ndarray), type(data) return data.astype(dtype) - self.params[self.name] = _f(data, self.dtype) - return op.as_variable(self) + self.params[self.graph.name] = _f(data, self.graph.dtype) + return op.as_variable(self.graph) - def from_const_data(self, data: typing.Union[int, float]) -> WithParameters: + def from_const_data(self, data: typing.Union[int, float]) -> Symbol: return self.from_np_data(data) - def from_np_data(self, data: np.ndarray, prefix="%") -> Symbol: + def from_symbol(self, sym: Symbol) -> typing.Type[WithParameters]: #TODO + return type(self)(sym, self.params) + + def from_np_data(self, data: np.ndarray | typing.Union[int, float], prefix="%") -> Symbol: + """ out = Return Symbol + out = op.add(out, B) + self: WithParameter + self.graph: Symbol + self.from_symbol(out).from_np_data() + + out = Return WithParameter + out.from_np_data() + + op.add(out.graph, B) + + graph: Symbol + """ name = N.n(prefix=prefix) # some data is np.float/int type, use np.array to wrap it. data = np.array(data) - self.params[name] = data.astype(self.dtype) - return op.variable(name, data.shape, self.dtype).like(self) + self.params[name] = data.astype(self.graph.dtype) + ## return type(self). # Mark! + return op.variable(name, data.shape, self.graph.dtype).like(self.graph) def is_input(self) -> bool: - return op.is_input(self, self.params) + return op.is_input(self.graph, self.params) def is_param(self) -> bool: - return op.is_param(self, self.params) + return op.is_param(self.graph, self.params) def is_variable(self) -> bool: - return op.is_variable(self, self.params) + return op.is_variable(self.graph, self.params) def is_operator(self) -> bool: - return op.is_operator(self, self.params) + return op.is_operator(self.graph, self.params) TransformerT = typing.Callable[[Graph], Graph] """ Transformer Callback Function Type, @@ -87,16 +168,9 @@ class Transformer(WithParameters): """ Symbol Transformer """ RUN_ONCE: typing.ClassVar[bool] =False - """ whether to run callback once? """ - # def to_dict(self, **kwargs): - # """ override to dict, since transformer may want to - # access the previous tfm. Thus, the next - # update_dict has the `origin` key by default. - # """ - # data = super().to_dict(**kwargs) - # data["extra_attrs"]["origin"] = self - # return data + def __init__(self, *args): + super().__init__(*args) @classmethod def get_transformer(cls, name: typing.Optional[str] = None): @@ -106,9 +180,9 @@ def _run(sym: Symbol): # use current cls to apply transform, this # may loss some information from origin # symbol, so record as `origin` in call. - out = cls.base(sym, params=params) - out = out(origin=sym, **kwargs) or out - assert isinstance(out, cls), ( + out = cls.base(sym, params) # Type as Transformer + out = out(origin=sym, **kwargs) or sym # Type as Symbol + assert isinstance(out, Symbol), ( "transform output type should be {}," " but get {}" ).format(cls, type(out)) @@ -148,3 +222,6 @@ def __call__(self, *args, **kw) -> typing.Optional[Transformer]: class RunOnce(Transformer): RUN_ONCE: typing.ClassVar[bool] = True + def __init__(self, *args): # symbol: Symbol, params: ParametersT):#, parsed: _BaseAttrs=None): + super().__init__(*args) + diff --git a/python/mrt/runtime/inference.py b/python/mrt/runtime/inference.py index f520018..ec62979 100644 --- a/python/mrt/runtime/inference.py +++ b/python/mrt/runtime/inference.py @@ -16,10 +16,10 @@ def run_single( sym = op.retrieve_operator(sym) if sym.is_op(TUPLE_GET_ITEM): - return args_data[0][sym.parsed.index] + return args_data[0][sym.attrs['index']] elif sym.is_op(REQUANT): # it's type is np.float32/64, use np.array to wrap. - return np.array(sym.parsed.rescale * args_data[0]) + return np.array(sym.attrs['rescale'] * args_data[0]) elif sym.is_op(ARANGE): args = [a.numpy().item() for a in args_data] return np.arange(*args, **sym.attrs) From 1b83b76672683b7ce854f952fa37badcf6d902ea Mon Sep 17 00:00:00 2001 From: corlfj Date: Thu, 27 Nov 2025 16:21:02 +0800 Subject: [PATCH 09/12] [mir]: Pass for resnet18 --- python/mrt/api.py | 10 +- python/mrt/frontend/api.py | 1 + python/mrt/frontend/expr.py | 14 +- python/mrt/frontend/pytorch/converter.py | 25 +- python/mrt/frontend/pytorch/vm.py | 1 + python/mrt/mir/mhsymbol.py | 36 ++ python/mrt/mir/op.py | 129 ++--- python/mrt/mir/opclass.py | 657 ++++++++--------------- python/mrt/mir/opns.py | 2 +- python/mrt/mir/simple_pass.py | 345 ------------ python/mrt/mir/symbol.py | 78 +-- python/mrt/quantization/calibrate.py | 28 +- python/mrt/quantization/discrete.py | 49 +- python/mrt/quantization/fixed_point.py | 50 +- python/mrt/quantization/fuse.py | 43 +- python/mrt/quantization/precision.py | 7 +- python/mrt/quantization/segement.py | 4 +- python/mrt/quantization/transform.py | 5 +- python/mrt/runtime/inference.py | 1 + tests/mir/test.op_create.py | 6 + tests/test.pytorch.py | 7 +- 21 files changed, 520 insertions(+), 978 deletions(-) create mode 100644 python/mrt/mir/mhsymbol.py delete mode 100644 python/mrt/mir/simple_pass.py diff --git a/python/mrt/api.py b/python/mrt/api.py index 34d0355..b282f6f 100644 --- a/python/mrt/api.py +++ b/python/mrt/api.py @@ -12,7 +12,7 @@ from .runtime.analysis import * from .mir import op, helper -# from .mir.model import MultiHeadSymbol +from .mir.mhsymbol import MultiHeadSymbol from .mir.symbol import * from .dataset.base import Dataset @@ -253,6 +253,7 @@ def fuse(self, **kwargs) -> Trace: fuse.FuseDropout.get_transformer(), fuse.FuseMean.get_transformer(), fuse.FuseNaiveSoftmax.get_transformer(), + fuse.FuseIdentity.get_transformer(), fuse.FuseConstant.get_transformer(), **kwargs, ) @@ -260,23 +261,18 @@ def fuse(self, **kwargs) -> Trace: def calibrate(self, repeats: int = 1, **kwargs) -> Trace: assert self._dataset is not None tr_name = kwargs.pop("tr_name", "calibrate") - raw_data: typing.Dict[str, OpOutputT] = {} - out_data: typing.List[OpNumpyT] = [] out = self for i in range(repeats): - kwargs["raw_data"] = raw_data - kwargs["out_data"] = out_data data, _ = self._dataset.next() out = out.checkpoint_run( calib.Calibrator.get_transformer(), data = data, - # tr_name = tr_name, tr_name = f"{tr_name}_run_{i}", **kwargs) out = out.checkpoint_run( calib.SymmetricMinMaxSampling.get_transformer(), - tr_name = "%s_sampling" % tr_name, **{'origin_data': raw_data}) + tr_name = "%s_sampling" % tr_name) return out def quantize(self, **kwargs): diff --git a/python/mrt/frontend/api.py b/python/mrt/frontend/api.py index 0b55096..c9c9692 100644 --- a/python/mrt/frontend/api.py +++ b/python/mrt/frontend/api.py @@ -4,6 +4,7 @@ from functools import wraps from mrt.mir.symbol import * +from mrt.mir.mhsymbol import MultiHeadSymbol, Graph from mrt.common.types import * from mrt.common.config import MRTConfig diff --git a/python/mrt/frontend/expr.py b/python/mrt/frontend/expr.py index 9515db4..0341c58 100644 --- a/python/mrt/frontend/expr.py +++ b/python/mrt/frontend/expr.py @@ -14,6 +14,7 @@ from ..symbol import * from ..types import * from .. import op +from .. import opclass __ALL__ = [ "expr2symbol", "symbol2expr", "tvm_type_infer" ] @@ -62,7 +63,7 @@ def _cast_expr(node: RelayExpr): if isinstance(node, relay.expr.Constant): name = N.n("const_") params[name] = node.data - symbol_map[node] = op.variable(name, + symbol_map[node] = opclass.var(name, node.data.shape, node.data.dtype) return @@ -85,11 +86,11 @@ def _cast_expr(node: RelayExpr): if isinstance(node, relay.expr.Var): name = node.name_hint or N.n(prefix="input_") - symbol_map[node] = op.variable(name, shape, dtype) + symbol_map[node] = opclass.var(name, shape, dtype) elif isinstance(node, relay.expr.If): args = [ node.cond, node.true_branch, node.false_branch ] args = [symbol_map[i] for i in args] - symbol_map[node] = op._new_op(IF, *args, **attrs) + symbol_map[node] = opclass.extern_op_func(IF)(*args, **attrs) elif isinstance(node, relay.expr.Call): op_name = node.op.name if op_name in [CONCAT, ADV_INDEX]: @@ -108,15 +109,14 @@ def _cast_expr(node: RelayExpr): attrs.pop("dtype") elif op_name == GET_VALID_COUNT: attrs.pop("score_threshold") - symbol_map[node] = op._new_op(op_name, *args, **attrs) + symbol_map[node] = opclass.extern_op_func(op_name)(*args, **attrs) elif isinstance(node, relay.TupleGetItem): args = [ symbol_map[node.tuple_value], ] attrs['index'] = node.index - symbol_map[node] = op._new_op( - TUPLE_GET_ITEM, *args, **attrs) + symbol_map[node] = opclass.extern_op_func(TUPLE_GET_ITEM)(*args, **attrs) elif isinstance(node, relay.Tuple): args = [ symbol_map[f] for f in node.fields ] - symbol_map[node] = op._new_op(TUPLE, *args, **attrs) + symbol_map[node] = opclass.extern_op_func(TUPLE)(*args, **attrs) else: raise RuntimeError( "MRT not support expr type:{}".format(type(node))) diff --git a/python/mrt/frontend/pytorch/converter.py b/python/mrt/frontend/pytorch/converter.py index 2318fc6..0aa744d 100644 --- a/python/mrt/frontend/pytorch/converter.py +++ b/python/mrt/frontend/pytorch/converter.py @@ -9,8 +9,9 @@ import torch.nn.functional as F import sys -from mrt.mir.symbol import Symbol, MultiHeadSymbol, sym2list, transform -from mrt.mir import op +from mrt.mir.symbol import Symbol, sym2list, transform +from mrt.mir.mhsymbol import MultiHeadSymbol +from mrt.mir import op, opclass from mrt.mir.opns import * from mrt.common.types import ParametersT from mrt.common.utils import N @@ -46,7 +47,7 @@ class _T: "adaptive_avg_pool2d.default": _T(ADAPTIVE_AVG_POOL2D, 1, [ Attr("output_size", (1,1)) ]), "max_pool2d.default": _T(MAX_POOL2D, 1, [ - Attr("kernel_size", (1,1)), Attr("strides", (1,1)), Attr("padding", (0,0)) ]), + Attr("kernel_size", (1,1)), Attr("strides", (1,1)), Attr("padding", (0,0)), Attr("dilation", (1,1)), Attr("ceil_mode", False) ]), "mean.dim": _T(MEAN, 1, [ Attr("dim", None), Attr("keepdim", False) ]), "add.Tensor": _T(ADD, 2), "add_.Tensor": _T(ADD, 2), @@ -60,7 +61,7 @@ class _T: "cat.default": _T(CONCAT, 1, [ Attr("dim", 0) ]), "view.default": _T(RESHAPE, 1, [ Attr("shape", ()) ]), "transpose.int": _T(TRANSPOSE, 1, [ Attr("dim0", 0), Attr("dim1", 0) ]), - "contiguous.default": _T(PASS, 1), + "contiguous.default": _T(IDENTITY, 1), "chunk.default": _T(SPLIT, 1, [ Attr("chunks", 1), Attr("dim", 0) ]), "getitem": _T(TUPLE_GET_ITEM, 1, [ Attr("index", 0) ]), @@ -100,7 +101,7 @@ class _T: ), RESHAPE: torch.reshape, TRANSPOSE: torch.transpose, - PASS: lambda x: x, + IDENTITY: lambda x: x, SPLIT: torch.chunk, ADD: torch.add, @@ -156,7 +157,7 @@ def create_parameters(ep: torch.export.ExportedProgram): dshape = data_to_mrt(torch_shape) dtype = data_to_mrt(torch_dtype) - out = op.variable(name_hint, dshape, dtype) + out = opclass.var(name=name_hint, shape=dshape, dtype=dtype) params[name_hint] = to_bind_parameters[spec.target].detach().numpy().astype(dtype) assert dshape == list(params[name_hint].shape) # print(">> vars: ", out) @@ -207,7 +208,7 @@ def _retrieve_args(node): continue if node.name not in param_vars: # input - env[node] = op.variable(node.name, shape, dtype) + env[node] = opclass.var(name=node.name, shape=shape, dtype=dtype) else: env[node] = param_vars[node.name] elif node.op == "output": # [[ out1, out2, out3 ]] @@ -234,12 +235,16 @@ def _retrieve_args(node): if mapper.op_name == CONCAT: args = args[0] + if mapper.op_name == SPLIT: + shape = data_to_mrt([ t.shape for t in node.meta['val']]) + dtype = data_to_mrt([ t.dtype for t in node.meta['val']]) + if mapper.op_name == TUPLE_GET_ITEM and args[0].op_name == BATCH_NORM: out = args[0] else: - out = op._new_op( - mapper.op_name, *args, - name=node.name, extra_attrs={ "shape": shape, "dtype": dtype }, + out = Symbol(*args, + name=node.name, op_name=mapper.op_name, + extra_attrs={ "shape": shape, "dtype": dtype }, **attrs) env[node] = out else: diff --git a/python/mrt/frontend/pytorch/vm.py b/python/mrt/frontend/pytorch/vm.py index 31c44e8..bf353a2 100644 --- a/python/mrt/frontend/pytorch/vm.py +++ b/python/mrt/frontend/pytorch/vm.py @@ -6,6 +6,7 @@ from .types import * from mrt.mir.symbol import * +from mrt.mir.mhsymbol import MultiHeadSymbol from mrt.common.types import * Executor = namedtuple("Executor", ["vm", "device"]) diff --git a/python/mrt/mir/mhsymbol.py b/python/mrt/mir/mhsymbol.py new file mode 100644 index 0000000..bff35de --- /dev/null +++ b/python/mrt/mir/mhsymbol.py @@ -0,0 +1,36 @@ +import typing + +from mrt.common.utils import * +from mrt.common.types import * + +from . import opns, opclass, optype +from . import symbol + +#from mrt.mir.mhsymbol import MultiHeadSymbol, Graph +class MultiHeadSymbol(dict): + """ { "main": F(X) } """ + origin: typing.Optional[symbol.Symbol] = None + + @classmethod + def from_symbol(cls, symbol: symbol.Symbol, name: str = "main"): + return MultiHeadSymbol({ name: symbol }) + + def as_tuple(self) -> typing.Tuple[typing.List[str], symbol.Symbol]: + from . import op + # args = list(self.values()) + # sym_type = type(args[0]) if args else Symbol + mhs = self.origin or optype.infer_single(opclass.MRT_OP_MAP[opns.TUPLE](*list(self.values()))) + return list(self.keys()), mhs + + @classmethod + def from_tuple(cls, tuple_names, symbol): + assert symbol.is_op(opns.TUPLE), symbol + mhs = cls(zip(tuple_names, symbol.args)) + mhs.origin = symbol + return mhs + +Graph = typing.Union[symbol.Symbol, MultiHeadSymbol] +""" Notice that Symbol and MultiHeadSymbol can both + be regarded as a model Graph. +""" + diff --git a/python/mrt/mir/op.py b/python/mrt/mir/op.py index 1eab62b..84f9498 100644 --- a/python/mrt/mir/op.py +++ b/python/mrt/mir/op.py @@ -32,7 +32,8 @@ def variable(name, shape, dtype) -> Symbol: def as_variable(symbol: Symbol, shape=None, dtype=None) -> Symbol: """ inherit extra attrs """ - out = symbol.copy(op_name=VAR, args=[], attrs={}) + # out = symbol.copy(op_name=VAR, args=[], attrs={}) + out = symbol.as_variable() out.shape = shape or out.shape out.dtype = dtype or out.dtype return out @@ -40,69 +41,69 @@ def as_variable(symbol: Symbol, shape=None, dtype=None) -> Symbol: def retrieve_operator(symbol: Symbol) -> Symbol: return symbol.copy(args=[as_variable(c) for c in symbol.args]) -def _new_op(op_name, *args, extra_attrs=None, **attrs) -> Symbol: - name = attrs.pop("name", N.n()) - return Symbol.from_dict({}, - name=name, op_name=op_name, - args=args or [], attrs=attrs or {}, - extra_attrs=extra_attrs or {}) - -def _register_op(op_name): - def _op(*args, **attrs) -> Symbol: - op = _new_op(op_name, *args, **attrs) - from . import optype - out = optype.infer_single(op) - return out - return _op - -Tuple = _register_op(TUPLE) -TupleGetItem = _register_op(TUPLE_GET_ITEM) - -# class Conv2D(Symbol): -# strides: - -# TODO: define op function -# def conv2d(X, weight, bias, strides=(1,1)...): -# return Symbol(args=[X, weight, bias], -# attrs={ "strides": strides }) -nn_conv2d = _register_op(CONV2D) -nn_dense = _register_op(DENSE) -nn_batch_norm = _register_op(BATCH_NORM) -# bias_add = _register_op(BIAS_ADD) - -nn_relu = _register_op(RELU) - -sum = _register_op(SUM) -# mean = _register_op(MEAN) -clip = _register_op(CLIP) -ceil = _register_op(CEIL) -right_shift = _register_op(RIGHT_SHIFT) -# relax api from cast to astype -# astype = _register_op(AS_TYPE) -cast = _register_op(AS_TYPE) -# flatten = _register_op(FLATTEN) -adv_index = _register_op(ADV_INDEX) -zeros_like = _register_op(ZEROS_LIKE) - -repeat = _register_op(REPEAT) -reshape = _register_op(RESHAPE) - -add = _register_op(ADD) -sub = _register_op(SUB) -max_axis = _register_op(MAX_AXIS) -mul = _register_op(MUL) -div = _register_op(DIV) -matmul = _register_op(MATMUL) -exp = _register_op(EXP) -negative = _register_op(NEGATIVE) - -sigmoid = _register_op(SIGMOID) -softmax = _register_op(SOFTMAX) - -requant = _register_op(REQUANT) -pclip = _register_op(PCLIP) -rs_pclip = _register_op(RS_PCLIP) -lut = _register_op(LUT) +# def _new_op(*args, op_name='', extra_attrs=None, **attrs) -> Symbol: +# name = attrs.pop("name", N.n()) +# return Symbol(*args, +# name=name, op_name=op_name, +# extra_attrs=extra_attrs or {}, +# **attrs) +# +# def _register_op(op_name): +# def _op(*args, **attrs) -> Symbol: +# op = _new_op(*args, op_name=op_name, **attrs) +# from . import optype +# out = optype.infer_single(op) +# return out +# return _op +# +# Tuple = _register_op(TUPLE) +# TupleGetItem = _register_op(TUPLE_GET_ITEM) +# +# # class Conv2D(Symbol): +# # strides: +# +# # TODO: define op function +# # def conv2d(X, weight, bias, strides=(1,1)...): +# # return Symbol(args=[X, weight, bias], +# # attrs={ "strides": strides }) +# nn_conv2d = _register_op(CONV2D) +# nn_dense = _register_op(DENSE) +# nn_batch_norm = _register_op(BATCH_NORM) +# # bias_add = _register_op(BIAS_ADD) +# +# nn_relu = _register_op(RELU) +# +# sum = _register_op(SUM) +# # mean = _register_op(MEAN) +# clip = _register_op(CLIP) +# ceil = _register_op(CEIL) +# right_shift = _register_op(RIGHT_SHIFT) +# # relax api from cast to astype +# # astype = _register_op(AS_TYPE) +# cast = _register_op(AS_TYPE) +# # flatten = _register_op(FLATTEN) +# adv_index = _register_op(ADV_INDEX) +# zeros_like = _register_op(ZEROS_LIKE) +# +# repeat = _register_op(REPEAT) +# reshape = _register_op(RESHAPE) +# +# add = _register_op(ADD) +# sub = _register_op(SUB) +# max_axis = _register_op(MAX_AXIS) +# mul = _register_op(MUL) +# div = _register_op(DIV) +# matmul = _register_op(MATMUL) +# exp = _register_op(EXP) +# negative = _register_op(NEGATIVE) +# +# sigmoid = _register_op(SIGMOID) +# softmax = _register_op(SOFTMAX) +# +# requant = _register_op(REQUANT) +# pclip = _register_op(PCLIP) +# rs_pclip = _register_op(RS_PCLIP) +# lut = _register_op(LUT) def is_operator(symbol: Symbol, params: ParametersT = {}): return symbol.op_name != VAR diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py index 02fb929..eb8972d 100644 --- a/python/mrt/mir/opclass.py +++ b/python/mrt/mir/opclass.py @@ -1,16 +1,11 @@ import typing import numpy as np -from dataclasses import dataclass from mrt.common.utils import N from . import opns from . import symbol -from .symbol import SelfSymbol -#SelfSymbol = typing.TypeVar("SelfSymbol", bound="Symbol") - -SymbolCreator = typing.Union[typing.Callable[[typing.Any, ...], typing.Type[symbol.Symbol]], SelfSymbol] -#SymbolCreator = typing.Union[typing.Callable[[...], symbol.Symbol], SelfSymbol] +SymbolCreator = typing.Union[typing.Callable[[typing.Any, ...], typing.Type[symbol.Symbol]], symbol.SelfSymbol] MRT_OP_MAP: typing.Dict[str, SymbolCreator] = {} @@ -27,119 +22,93 @@ def _wrapper(clss: SymbolCreator = None) -> SymbolCreator: # OPs from external (not in MRT op), using custom op_name with default op_func -#y = extern_opfunc("tanh")(X) +# y = extern_opfunc("tanh")(X) def extern_opfunc(op_name: str): - def op_func(name, args, attrs, extra_attrs): - #return symbol.Symbol(op_name=op_name, *args, **attrs) - return symbol.Symbol(name, op_name, args, attrs, extra_attrs) + def op_func(*args, name=None, extra_attrs=None, **kwargs): + return symbol.Symbol(*args, name=name or N.n(), op_name=op_name, extra_attrs=extra_attrs or {}, **kwargs) return op_func - def _from_dict_attrs(cls, d: dict, attr_keys:typing.List[str]=[], **kwargs): data = cls.default_dict() data.update(d) data.update(kwargs) data = cls.update_dict(data) - basedata = {k: data[k] for k in data if k in ['name', 'op_name', 'extra_attrs']} + basedata = {k: data[k] for k in data if k in ['name', 'extra_attrs']} attrsdata = {k: data['attrs'][k] for k in data['attrs'] if k in attr_keys} try: - out = cls(*data['args'], **attrsdata, **basedata) + out = cls(*data['args'], **basedata, **attrsdata) except Exception as e: raise e return out # OPs without attrs, just register function (funcName should be lower case) -def var(name=None, op_name=None, shape=(), dtype=float) -> symbol.Symbol: - op_name = op_name or opns.VAR - assert op_name == opns.VAR - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[], attrs={}, extra_attrs={'shape': shape or (), 'dtype': dtype or float}) +def var(name=None, shape=(), dtype=float) -> symbol.Symbol: + return symbol.Symbol(name=name or N.n(), op_name=opns.VAR, extra_attrs={'shape': shape or (), 'dtype': dtype or float}) #def _return_func_single_arg(op_name: op_name): -def relu(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.RELU - assert op_name == opns.RELU - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def relu(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.RELU, extra_attrs=extra_attrs or {}) -def silu(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.SILU - assert op_name == opns.SILU - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def silu(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.SILU, extra_attrs=extra_attrs or {}) -@dataclass(init=False) class Conv2D(symbol.Symbol): op_name = opns.CONV2D @property def strides(self) -> typing.Tuple[int, int]: - default_val = (1,1) - return self.attrs['strides'] if 'strides' in self.attrs else default_val + return self.attrs['strides'] @property def padding(self) -> typing.Tuple[int, int, int, int]: - default_val = (0,0,0,0) - return self.attrs['padding'] if 'padding' in self.attrs else default_val + return self.attrs['padding'] @property def groups(self) -> int: - default_val = 1 - return self.attrs['groups'] if 'groups' in self.attrs else default_val + return self.attrs['groups'] @property def dilation(self) -> typing.Tuple[int, int]: - default_val = (1,1) - return self.attrs['dilation'] if 'dilation' in self.attrs else default_val + return self.attrs['dilation'] - @property - def kernel_size(self) -> typing.Tuple[int, int]: - assert 'kernel_size' in self.attrs - return self.attrs['kernel_size'] - - @property - def kernel_layout(self) -> str: - default_val = 'OIHW' - return self.attrs['kernel_layout'] if 'kernel_layout' in self.attrs else default_val # Follows (*args, name, **attrs) - def __init__(self, X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_layout='OIHW', extra_attrs=None): - op_name = op_name or opns.CONV2D - assert op_name == opns.CONV2D + def __init__(self, X, W, name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_layout='OIHW', extra_attrs=None): assert len(W.shape) == 4, f'Wrong Weight Shape for Conv2D: {W.shape}' kernel_size = (W.shape[2], W.shape[3]) - super().__init__(name=name or N.n(), op_name=op_name, args=[X,W], attrs={'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size, 'kernel_layout': kernel_layout}, extra_attrs=extra_attrs or {}) + #attrs = {'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation, 'kernel_size':kernel_size, 'kernel_layout': kernel_layout} + attrs = {'strides':strides, 'padding':padding, 'groups':groups, 'dilation':dilation} + super().__init__(X, W, name=name or N.n(), op_name=opns.CONV2D, extra_attrs=extra_attrs or {}, **attrs) @classmethod def from_dict(cls, d: dict, **kwargs): # Auto inferred 'kernel_size' - return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation', 'kernel_layout'], **kwargs) + return _from_dict_attrs(cls, d, ['strides', 'padding', 'groups', 'dilation'], **kwargs) -def conv2d(X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_layout='OIHW', extra_attrs=None): - return Conv2D(X, W, name, op_name, strides, padding, groups, dilation, kernel_layout, extra_attrs) +def conv2d(*args, **kwargs): + #def conv2d(X, W, name=None, op_name=None, strides=(1,1), padding=(0,0,0,0), groups=1, dilation=(1,1), kernel_layout='OIHW', extra_attrs=None): + return Conv2D(*args, **kwargs) #X, W, name, op_name, strides, padding, groups, dilation, kernel_layout, extra_attrs) -@dataclass(init=False) class Dropout(symbol.Symbol): op_name = opns.DROP_OUT @property def p(self) -> float: - default_val = 0.5 - return self.attrs['p'] if 'p' in self.attrs else default_val + return self.attrs['p'] - def __init__(self, X, name=None, op_name=None, p:float = 0.5, extra_attrs=None): - op_name = op_name or opns.DROP_OUT - assert op_name == opns.DROP_OUT - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'p': p}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, p:float = 0.5, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.DROP_OUT, extra_attrs=extra_attrs or {}, **{'p': p}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['p'], **kwargs) -def dropout(X, name=None, op_name=None, p:float = 0.5, extra_attrs=None): - return Dropout(X, name, op_name, p, extra_attrs) +def dropout(*args, **kwargs): + return dropout(*args, **kwargs) -@dataclass(init=False) class Clip(symbol.Symbol): op_name = opns.CLIP @@ -153,139 +122,112 @@ def max(self) -> float: assert 'max' in self.attrs return self.attrs['max'] - def __init__(self, X, name=None, op_name=None, min_:float = np.nan, max_:float = np.nan, extra_attrs=None): - op_name = op_name or opns.CLIP - assert op_name == opns.CLIP - assert min_ != np.nan - assert max_ != np.nan - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'min': min_, 'max': max_}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, min:float = np.nan, max:float = np.nan, extra_attrs=None): + assert min != np.nan + assert max != np.nan + super().__init__(X, name=name or N.n(), op_name=opns.CLIP, extra_attrs=extra_attrs or {}, **{'min': min, 'max': max}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['min', 'max'], **kwargs) -def clip(X, name=None, op_name=None, min_:float = np.nan, max_:float = np.nan, extra_attrs=None): - return Clip(X, name, op_name, min_, max_, extra_attrs) +def clip(*args, **kwargs): + return Clip(*args, **kwargs) - -@dataclass(init=False) class BatchNorm(symbol.Symbol): op_name = opns.BATCH_NORM @property def axis(self) -> int: - default_val = 1 - return self.attrs['axis'] if 'axis' in self.attrs else default_val + return self.attrs['axis'] @property def epsilon(self) -> float: - default_val = 1e-5 - return self.attrs['epsilon'] if 'epsilon' in self.attrs else default_val + return self.attrs['epsilon'] @property def momentum(self) -> float: - default_val = 0.1 - return self.attrs['momentum'] if 'momentum' in self.attrs else default_val + return self.attrs['momentum'] @property def center(self) -> bool: - default_val = True - return self.attrs['center'] if 'center' in self.attrs else default_val + return self.attrs['center'] @property def scale(self) -> bool: - default_val = True - return self.attrs['scale'] if 'scale' in self.attrs else default_val + return self.attrs['scale'] - def __init__(self, X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, center=True, scale=True, extra_attrs=None): - op_name = op_name or opns.BATCH_NORM - assert op_name == opns.BATCH_NORM - super().__init__(name=name or N.n(), op_name=op_name, args=[X, Gamma, Beta, Mean, Var], attrs={'axis': axis, 'epsilon': epsilon, 'momentum': momentum, 'center': center, 'scale': scale}, extra_attrs=extra_attrs or {}) + def __init__(self, X, Gamma, Beta, Mean, Var, name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, center=True, scale=True, extra_attrs=None): + super().__init__(*[X, Gamma, Beta, Mean, Var], name=name or N.n(), op_name=opns.BATCH_NORM, extra_attrs=extra_attrs or {}, **{'axis': axis, 'epsilon': epsilon, 'momentum': momentum, 'center': center, 'scale': scale}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['axis', 'epsilon', 'momentum', 'center', 'scale'], **kwargs) -def batch_norm(X, Gamma, Beta, Mean, Var, name=None, op_name=None, axis:int = 1, epsilon:float = 1e-5, momentum:float = 0.1, center=True, scale=True, extra_attrs=None): - return BatchNorm(X, Gamma, Beta, Mean, Var, name, op_name, axis, epsilon, momentum, center, scale, extra_attrs) +def batch_norm(*args, **kwargs): + return BatchNorm(*args, **kwargs) -@dataclass(init=False) class TupleGetItem(symbol.Symbol): op_name = opns.TUPLE_GET_ITEM @property def index(self) -> float: - default_val = 0 - return self.attrs['index'] if 'index' in self.attrs else default_val + return self.attrs['index'] - def __init__(self, X, name=None, op_name=None, index:int = 0, extra_attrs=None): - op_name = op_name or opns.TUPLE_GET_ITEM - assert op_name == opns.TUPLE_GET_ITEM - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'index': index}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, index:int = 0, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.TUPLE_GET_ITEM, extra_attrs=extra_attrs or {}, **{'index': index}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['index'], **kwargs) -def tuple_get_item(X, name=None, op_name=None, index:int = 0, extra_attrs=None): - return TupleGetItem(X, name, op_name, index, extra_attrs) +def tuple_get_item(*args, **kwargs): + return TupleGetItem(*args, **kwargs) -@dataclass(init=False) class LeakyRelu(symbol.Symbol): op_name = opns.LEAKY_RELU @property def negative_slope(self) -> float: - default_val = 1e-2 - return self.attrs['negative_slope'] if 'negative_slope' in self.attrs else default_val + return self.attrs['negative_slope'] - def __init__(self, X, name=None, op_name=None, negative_slope:float = 1e-2, extra_attrs=None): - op_name = op_name or opns.LEAKY_RELU - assert op_name == opns.LEAKY_RELU - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'negative_slope': negative_slope}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, negative_slope:float = 1e-2, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.LEAKY_RELU, extra_attrs=extra_attrs or {}, **{'negative_slope': negative_slope}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['negative_slope'], **kwargs) -def leaky_relu(X, name=None, op_name=None, negative_slope:float = 1e-2, extra_attrs=None): - return LeakyRelu(X, name, op_name, negative_slope, extra_attrs) +def leaky_relu(*args, **kwargs): + return LeakyRelu(*args, **kwargs) -def dense(X, W, B, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.DENSE - assert op_name == opns.DENSE - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, W, B], attrs={}, extra_attrs=extra_attrs or {}) +def dense(X, W, B, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(*[X, W, B], name=name or N.n(), op_name=opns.DENSE, extra_attrs=extra_attrs or {}) -@dataclass(init=False) class Hardtanh(symbol.Symbol): op_name = opns.HARDTANH @property def min_val(self) -> float: - default_val = -1.0 - return self.attrs['min_val'] if 'min_val' in self.attrs else default_val + return self.attrs['min_val'] @property def max_val(self) -> float: - default_val = 1.0 - return self.attrs['max_val'] if 'max_val' in self.attrs else default_val + return self.attrs['max_val'] - def __init__(self, X, name=None, op_name=None, min_val:float = -1.0, max_val:float = 1.0, extra_attrs=None): - op_name = op_name or opns.HARDTANH - assert op_name == opns.HARDTANH - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'min_val': min_val, 'max_val':max_val}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, min_val:float = -1.0, max_val:float = 1.0, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.HARDTANH, extra_attrs=extra_attrs or {}, **{'min_val': min_val, 'max_val':max_val}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['min_val', 'max_val'], **kwargs) -def hard_tanh(X, name=None, op_name=None, min_val:float = -1.0, max_val:float = 1.0, extra_attrs=None): - return Hardtanh(X, name, op_name, min_val, max_val, extra_attrs) +def hard_tanh(*args, **kwargs): + return Hardtanh(*args, **kwargs) -@dataclass(init=False) class AdaptiveAvgPool2D(symbol.Symbol): op_name = opns.ADAPTIVE_AVG_POOL2D @@ -294,20 +236,17 @@ def output_size(self) -> typing.Union[int, typing.Tuple[int, int]]: assert 'output_size' in self.attrs return self.attrs['output_size'] - def __init__(self, X, name=None, op_name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=None, extra_attrs=None): - op_name = op_name or opns.ADAPTIVE_AVG_POOL2D - assert op_name == opns.ADAPTIVE_AVG_POOL2D + def __init__(self, X, name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=None, extra_attrs=None): assert output_size != None - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'output_size': output_size}, extra_attrs=extra_attrs or {}) + super().__init__(X, name=name or N.n(), op_name=opns.ADAPTIVE_AVG_POOL2D, extra_attrs=extra_attrs or {}, **{'output_size': output_size}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['output_size'], **kwargs) -def adaptive_avg_pool2d(X, name=None, op_name=None, output_size:typing.Union[int, typing.Tuple[int, int]]=0, extra_attrs=None): - return AdaptiveAvgPool2D(X, name, op_name, output_size, extra_attrs) +def adaptive_avg_pool2d(*args, **kwargs): + return AdaptiveAvgPool2D(*args, **kwargs) -@dataclass(init=False) class AvgPool2D(symbol.Symbol): op_name = opns.AVG_POOL2D @@ -317,44 +256,34 @@ def pool_size(self) -> typing.Tuple[int, int]: return self.attrs['pool_size'] @property def strides(self) -> typing.Tuple[int, int]: - default_val = (0, 0) - return self.attrs['strides'] if 'strides' in self.attrs else default_val + return self.attrs['strides'] @property def dilation(self) -> typing.Tuple[int, int]: - default_val = (1, 1) - return self.attrs['dilation'] if 'dilation' in self.attrs else default_val + return self.attrs['dilation'] @property def padding(self) -> typing.Tuple[int, int, int, int]: - default_val = (0, 0, 0, 0) - return self.attrs['padding'] if 'padding' in self.attrs else default_val + return self.attrs['padding'] @property def ceil_mode(self) -> bool: - default_val = False - return self.attrs['ceil_mode'] if 'ceil_mode' in self.attrs else default_val + return self.attrs['ceil_mode'] @property def layout(self) -> str: - default_val = 'NCHW' - return self.attrs['layout'] if 'layout' in self.attrs else default_val + return self.attrs['layout'] @property def count_include_pad(self) -> bool: - default_val = True - return self.attrs['count_include_pad'] if 'count_include_pad' in self.attrs else default_val + return self.attrs['count_include_pad'] - def __init__(self, X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): - op_name = op_name or opns.AVG_POOL2D - assert op_name == opns.AVG_POOL2D + def __init__(self, X, name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): assert pool_size != None - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'dilation':dilation, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout, 'count_include_pad':count_include_pad}, extra_attrs=extra_attrs or {}) + super().__init__(X, name=name or N.n(), op_name=opns.AVG_POOL2D, extra_attrs=extra_attrs or {}, **{'pool_size':pool_size, 'dilation':dilation, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout, 'count_include_pad':count_include_pad}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['pool_size', 'dilation', 'strides', 'padding', 'ceil_mode', 'layout', 'count_include_pad'], **kwargs) -def avg_pool2d(X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', count_include_pad=True, extra_attrs=None): - return AvgPool2D(X, name, op_name, pool_size, dilation, strides, padding, ceil_mode, layout, count_include_pad, extra_attrs) - +def avg_pool2d(*args, **kwargs): + return AvgPool2D(*args, **kwargs) -@dataclass(init=False) class MaxPool2D(symbol.Symbol): op_name = opns.MAX_POOL2D @@ -364,237 +293,187 @@ def pool_size(self) -> typing.Tuple[int, int]: return self.attrs['pool_size'] @property def strides(self) -> typing.Tuple[int, int]: - default_val = (0, 0) - return self.attrs['strides'] if 'strides' in self.attrs else default_val + return self.attrs['strides'] @property def dilation(self) -> typing.Tuple[int, int]: - default_val = (1, 1) - return self.attrs['dilation'] if 'dilation' in self.attrs else default_val + return self.attrs['dilation'] @property def padding(self) -> typing.Tuple[int, int, int, int]: - default_val = (0, 0, 0, 0) - return self.attrs['padding'] if 'padding' in self.attrs else default_val + return self.attrs['padding'] @property def ceil_mode(self) -> bool: - default_val = False - return self.attrs['ceil_mode'] if 'ceil_mode' in self.attrs else default_val + return self.attrs['ceil_mode'] @property def layout(self) -> str: - default_val = 'NCHW' - return self.attrs['layout'] if 'layout' in self.attrs else default_val + return self.attrs['layout'] - def __init__(self, X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', extra_attrs=None): - op_name = op_name or opns.MAX_POOL2D - assert op_name == opns.MAX_POOL2D + def __init__(self, X, name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', extra_attrs=None): assert pool_size != None - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'pool_size':pool_size, 'dilation':dilation, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout}, extra_attrs=extra_attrs or {}) + super().__init__(X, name=name or N.n(), op_name=opns.MAX_POOL2D, extra_attrs=extra_attrs or {}, **{'pool_size':pool_size, 'dilation':dilation, 'strides':strides, 'padding':padding, 'ceil_mode':ceil_mode, 'layout':layout}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['pool_size', 'dilation', 'strides', 'padding', 'ceil_mode', 'layout'], **kwargs) -def max_pool2d(X, name=None, op_name=None, pool_size=None, dilation=(1,1), strides=(0,0), padding=(0,0,0,0), ceil_mode=False, layout='NCHW', extra_attrs=None): - return MaxPool2D(X, name, op_name, pool_size, dilation, strides, padding, ceil_mode, layout, extra_attrs) +def max_pool2d(*args, **kwargs): + return MaxPool2D(*args, **kwargs) -@dataclass(init=False) class Softmax(symbol.Symbol): op_name = opns.SOFTMAX @property def axis(self) -> typing.Optional[int]: - default_val = None - return self.attrs['axis'] if 'axis' in self.attrs else default_val + return self.attrs['axis'] - def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): - op_name = op_name or opns.SOFTMAX - assert op_name == opns.SOFTMAX - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'axis':axis}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, axis=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.SOFTMAX, extra_attrs=extra_attrs or {}, **{'axis':axis}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['axis'], **kwargs) -def softmax(X, name=None, op_name=None, axis=None, extra_attrs=None): - return Softmax(X, name, op_name, axis, extra_attrs) +def softmax(*args, **kwargs): + return Softmax(*args, **kwargs) -@dataclass(init=False) class LogSoftmax(symbol.Symbol): op_name = opns.LOG_SOFTMAX @property def axis(self) -> typing.Optional[int]: - default_val = None - return self.attrs['axis'] if 'axis' in self.attrs else default_val + return self.attrs['axis'] - def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): - op_name = op_name or opns.LOG_SOFTMAX - assert op_name == opns.LOG_SOFTMAX - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'axis':axis}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, axis=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.LOG_SOFTMAX, extra_attrs=extra_attrs or {}, **{'axis':axis}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['axis'], **kwargs) -def log_softmax(X, name=None, op_name=None, axis=None, extra_attrs=None): - return LogSoftmax(X, name, op_name, axis, extra_attrs) +def log_softmax(*args, **kwargs): + return LogSoftmax(*args, **kwargs) -def exp(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.EXP - assert op_name == opns.EXP - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def exp(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.EXP, extra_attrs=extra_attrs or {}) -def sigmoid(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.SIGMOID - assert op_name == opns.SIGMOID - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def sigmoid(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.SIGMOID, extra_attrs=extra_attrs or {}) -@dataclass(init=False) class Sum(symbol.Symbol): op_name = opns.SUM @property def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: - default_val = None - return self.attrs['dim'] if 'dim' in self.attrs else default_val + return self.attrs['dim'] @property def keepdim(self) -> typing.Optional[bool]: - default_val = None - return self.attrs['keepdim'] if 'keepdim' in self.attrs else default_val + return self.attrs['keepdim'] - def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): - op_name = op_name or opns.SUM - assert op_name == opns.SUM - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim, 'keepdim': keepdim}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, dim=None, keepdim=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.SUM, extra_attrs=extra_attrs or {}, **{'dim': dim, 'keepdim': keepdim}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) -def sum(X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): - return Sum(X, name, op_name, dim, keepdim, extra_attrs) +def sum(*args, **kwargs): + return Sum(*args, **kwargs) -@dataclass(init=False) class Mean(symbol.Symbol): op_name = opns.MEAN @property def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: - default_val = None - return self.attrs['dim'] if 'dim' in self.attrs else default_val + return self.attrs['dim'] @property def keepdim(self) -> typing.Optional[bool]: - default_val = None - return self.attrs['keepdim'] if 'keepdim' in self.attrs else default_val + return self.attrs['keepdim'] - def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): - op_name = op_name or opns.MEAN - assert op_name == opns.MEAN - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim, 'keepdim': keepdim}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, dim=None, keepdim=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.MEAN, extra_attrs=extra_attrs or {}, **{'dim': dim, 'keepdim': keepdim}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) -def mean(X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): - return Mean(X, name, op_name, dim, keepdim, extra_attrs) +def mean(*args, **kwargs): + return Mean(*args, **kwargs) -@dataclass(init=False) class MaxAxis(symbol.Symbol): op_name = opns.MAX_AXIS @property def dim(self) -> typing.Optional[typing.Tuple[int, ...]]: - default_val = None - return self.attrs['dim'] if 'dim' in self.attrs else default_val + return self.attrs['dim'] @property def keepdim(self) -> typing.Optional[bool]: - default_val = None - return self.attrs['keepdim'] if 'keepdim' in self.attrs else default_val + return self.attrs['keepdim'] - def __init__(self, X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): - op_name = op_name or opns.MAX_AXIS - assert op_name == opns.MAX_AXIS - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim, 'keepdim': keepdim}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, dim=None, keepdim=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.MAX_AXIS, extra_attrs=extra_attrs or {}, **{'dim': dim, 'keepdim': keepdim}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim', 'keepdim'], **kwargs) -def max_axis(X, name=None, op_name=None, dim=None, keepdim=None, extra_attrs=None): - return MaxAxis(X, name, op_name, dim, keepdim, extra_attrs) +def max_axis(*args, **kwargs): + return MaxAxis(*args, **kwargs) -def maximum(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.MAXIMUM - assert op_name == opns.MAXIMUM - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def maximum(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.MAXIMUM, extra_attrs=extra_attrs or {}) -def minimum(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.MINIMUM - assert op_name == opns.MINIMUM - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def minimum(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.MINIMUM, extra_attrs=extra_attrs or {}) -def repeat(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.REPEAT - assert op_name == opns.REPEAT - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def repeat(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.REPEAT, extra_attrs=extra_attrs or {}) -@dataclass(init=False) class Squeeze(symbol.Symbol): op_name = opns.SQUEEZE @property def dim(self) -> typing.Optional[int]: - default_val = None - return self.attrs['dim'] if 'dim' in self.attrs else default_val + return self.attrs['dim'] - def __init__(self, X, name=None, op_name=None, dim=None, extra_attrs=None): - op_name = op_name or opns.SQUEEZE - assert op_name == opns.SQUEEZE - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim': dim}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, dim=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.SQUEEZE, extra_attrs=extra_attrs or {}, **{'dim': dim}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim'], **kwargs) -def squeeze(X, name=None, op_name=None, dim=None, extra_attrs=None): - return Squeeze(X, name, op_name, dim, extra_attrs) +def squeeze(*args, **kwargs): + return Squeeze(*args, **kwargs) -@dataclass(init=False) class Flatten(symbol.Symbol): op_name = opns.FLATTEN @property def start_dim(self) -> int: - default_val = 0 - return self.attrs['start_dim'] if 'start_dim' in self.attrs else default_val + return self.attrs['start_dim'] @property def end_dim(self) -> int: - default_val = -1 - return self.attrs['end_dim'] if 'end_dim' in self.attrs else default_val + return self.attrs['end_dim'] - def __init__(self, X, name=None, op_name=None, start_dim=0, end_dim=-1, extra_attrs=None): - op_name = op_name or opns.FLATTEN - assert op_name == opns.FLATTEN - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'start_dim': start_dim, 'end_dim':end_dim}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, start_dim=0, end_dim=-1, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.FLATTEN, extra_attrs=extra_attrs or {}, **{'start_dim': start_dim, 'end_dim':end_dim}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['start_dim', 'end_dim'], **kwargs) -def flatten(X, name=None, op_name=None, start_dim=0, end_dim=-1, extra_attrs=None): - return Flatten(X, name, op_name, start_dim, end_dim, extra_attrs) +def flatten(*args, **kwargs): + return Flatten(*args, **kwargs) -@dataclass(init=False) class Reshape(symbol.Symbol): op_name = opns.RESHAPE @@ -603,41 +482,34 @@ def newshape(self) -> typing.Tuple[int,...]: assert 'newshape' in self.attrs return self.attrs['newshape'] - def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): - op_name = op_name or opns.RESHAPE - assert op_name == opns.RESHAPE + def __init__(self, X, name=None, newshape=None, extra_attrs=None): assert newshape != None - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) + super().__init__(X, name=name or N.n(), op_name=opns.RESHAPE, extra_attrs=extra_attrs or {}, **{'newshape': newshape}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['newshape'], **kwargs) -def reshape(X, name=None, op_name=None, newshape=None, extra_attrs=None): - return Reshape(X, name, op_name, newshape, extra_attrs) +def reshape(*args, **kwargs): + return Reshape(*args, **kwargs) -@dataclass(init=False) class Concat(symbol.Symbol): op_name = opns.CONCAT @property def axis(self) -> int: - default_val = 0 - return self.attrs['axis'] if 'axis' in self.attrs else default_val + return self.attrs['axis'] - def __init__(self, X, name=None, op_name=None, axis=None, extra_attrs=None): - op_name = op_name or opns.CONCAT - assert op_name == opns.CONCAT - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'axis': axis}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, axis=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.CONCAT, extra_attrs=extra_attrs or {}, **{'axis': axis}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['axis'], **kwargs) -def concat(X, name=None, op_name=None, axis=None, extra_attrs=None): - return Concat(X, name, op_name, axis, extra_attrs) +def concat(*args, **kwargs): + return Concat(*args, **kwargs) -@dataclass(init=False) class Split(symbol.Symbol): op_name = opns.SPLIT @@ -648,24 +520,20 @@ def split_size(self) -> typing.List[int]: @property def dim(self) -> int: - default_val = 0 - return self.attrs['dim'] if 'dim' in self.attrs else default_val + return self.attrs['dim'] - def __init__(self, X, name=None, op_name=None, split_size=None, dim=0, extra_attrs=None): - op_name = op_name or opns.SPLIT - assert op_name == opns.SPLIT + def __init__(self, X, name=None, split_size=None, dim=0, extra_attrs=None): assert split_size != None - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'split_size': split_size, 'dim': dim}, extra_attrs=extra_attrs or {}) + super().__init__(X, name=name or N.n(), op_name=opns.SPLIT, extra_attrs=extra_attrs or {}, **{'split_size': split_size, 'dim': dim}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['split_size', 'dim'], **kwargs) -def split(X, name=None, op_name=None, split_size=[], dim=0, extra_attrs=None): - return Split(X, name, op_name, split_size, dim, extra_attrs) +def split(*args, **kwargs): + return Split(*args, **kwargs) -@dataclass(init=False) class Transpose(symbol.Symbol): op_name = opns.TRANSPOSE @@ -679,22 +547,19 @@ def dim1(self) -> int: assert 'dim1' in self.attrs return self.attrs['dim1'] - def __init__(self, X, name=None, op_name=None, dim0=None, dim1=None, extra_attrs=None): - op_name = op_name or opns.TRANSPOSE - assert op_name == opns.TRANSPOSE + def __init__(self, X, name=None, dim0=None, dim1=None, extra_attrs=None): assert dim0 != None assert dim1 != None - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dim0': dim0, 'dim1': dim1}, extra_attrs=extra_attrs or {}) + super().__init__(X, name=name or N.n(), op_name=opns.TRANSPOSE, extra_attrs=extra_attrs or {}, **{'dim0': dim0, 'dim1': dim1}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dim0', 'dim1'], **kwargs) -def transpose(X, name=None, op_name=None, dim0=None, dim1=None, extra_attrs=None): - return Transpose(X, name, op_name, dim0, dim1, extra_attrs) +def transpose(*args, **kwargs): + return Transpose(*args, **kwargs) -@dataclass(init=False) class BroadcastTo(symbol.Symbol): op_name = opns.BROADCAST_TO @@ -703,20 +568,17 @@ def newshape(self) -> typing.Tuple[int,...]: assert 'newshape' in self.attrs return self.attrs['newshape'] - def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): - op_name = op_name or opns.BROADCAST_TO - assert op_name == opns.BROADCAST_TO + def __init__(self, X, name=None, newshape=None, extra_attrs=None): assert newshape != None - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) + super().__init__(X, name=name or N.n(), op_name=opns.BROADCAST_TO, extra_attrs=extra_attrs or {}, **{'newshape': newshape}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['newshape'], **kwargs) -def broadcast_to(X, name=None, op_name=None, newshape=None, extra_attrs=None): - return BroadcastTo(X, name, op_name, newshape, extra_attrs) +def broadcast_to(*args, **kwargs): + return BroadcastTo(*args, **kwargs) -@dataclass(init=False) class ExpandDims(symbol.Symbol): op_name = opns.EXPAND_DIMS @@ -725,20 +587,17 @@ def newshape(self) -> typing.Tuple[int,...]: assert 'newshape' in self.attrs return self.attrs['newshape'] - def __init__(self, X, name=None, op_name=None, newshape=None, extra_attrs=None): - op_name = op_name or opns.EXPAND_DIMS - assert op_name == opns.EXPAND_DIMS + def __init__(self, X, name=None, newshape=None, extra_attrs=None): assert newshape != None - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'newshape': newshape}, extra_attrs=extra_attrs or {}) + super().__init__(X, name=name or N.n(), op_name=opns.EXPAND_DIMS, extra_attrs=extra_attrs or {}, **{'newshape': newshape}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['newshape'], **kwargs) -def expand_dims(X, name=None, op_name=None, newshape=None, extra_attrs=None): - return ExpandDims(X, name, op_name, newshape, extra_attrs) +def expand_dims(*args, **kwargs): + return ExpandDims(*args, **kwargs) -@dataclass(init=False) class Tile(symbol.Symbol): op_name = opns.TILE @@ -747,172 +606,126 @@ def dims(self) -> typing.Tuple[int,...]: assert 'dims' in self.attrs return self.attrs['dims'] - def __init__(self, X, name=None, op_name=None, dims=None, extra_attrs=None): - op_name = op_name or opns.TILE - assert op_name == opns.TILE + def __init__(self, X, name=None, dims=None, extra_attrs=None): assert dims != None - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'dims': dims}, extra_attrs=extra_attrs or {}) + super().__init__(X, name=name or N.n(), op_name=opns.TILE, extra_attrs=extra_attrs or {}, **{'dims': dims}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dims'], **kwargs) -def tile(X, name=None, op_name=None, dims=None, extra_attrs=None): - return Tile(X, name, op_name, dims, extra_attrs) +def tile(*args, **kwargs): + return Tile(*args, **kwargs) -def where(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.WHERE - assert op_name == opns.WHERE - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def where(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.WHERE, extra_attrs=extra_attrs or {}) -def greater(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.GREATER - assert op_name == opns.GREATER - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) +def greater(X, Y, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(*[X,Y], name=name or N.n(), op_name=opns.GREATER, extra_attrs=extra_attrs or {}) -@dataclass(init=False) class NonMaxSuppression(symbol.Symbol): op_name = opns.NON_MAX_SUPRESSION @property def iou_threshold(self) -> float: - default_val = 0.5 - return self.attrs['iou_threshold'] if 'iou_threshold' in self.attrs else default_val + return self.attrs['iou_threshold'] @property def score_threshold(self) -> typing.Optional[float]: - default_val = None - return self.attrs['score_threshold'] if 'score_threshold' in self.attrs else default_val + return self.attrs['score_threshold'] - def __init__(self, X, name=None, op_name=None, iou_threshold=0.5, score_threshold=None, extra_attrs=None): - op_name = op_name or opns.NON_MAX_SUPRESSION - assert op_name == opns.NON_MAX_SUPRESSION - super().__init__(name=name or N.n(), op_name=op_name, args=[X], attrs={'iou_threshold': iou_threshold,'score_threshold':score_threshold}, extra_attrs=extra_attrs or {}) + def __init__(self, X, name=None, iou_threshold=0.5, score_threshold=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.NON_MAX_SUPRESSION, extra_attrs=extra_attrs or {}, **{'iou_threshold': iou_threshold,'score_threshold':score_threshold}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['dims'], **kwargs) -def non_max_suppression(X, name=None, op_name=None, iou_threshold=0.5, score_threshold=None, extra_attrs=None): - return NonMaxSuppression(X, name, op_name, iou_threshold, score_threshold, extra_attrs) +def non_max_suppression(*args, **kwargs): + return NonMaxSuppression(*args, **kwargs) -def ceil(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.CEIL - assert op_name == opns.CEIL - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def ceil(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.CEIL, extra_attrs=extra_attrs or {}) -def right_shift(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.RIGHT_SHIFT - assert op_name == opns.RIGHT_SHIFT - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) +def right_shift(X, Y, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(*[X, Y], name=name or N.n(), op_name=opns.RIGHT_SHIFT, extra_attrs=extra_attrs or {}) -@dataclass(init=False) class Add(symbol.Symbol): op_name = opns.ADD @property def alpha(self) -> int: - default_val = 1 - return self.attrs['alpha'] if 'alpha' in self.attrs else default_val + return self.attrs['alpha'] - def __init__(self, X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): - op_name = op_name or opns.ADD - assert op_name == opns.ADD - super().__init__(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={'alpha': alpha}, extra_attrs=extra_attrs or {}) + def __init__(self, X, Y, name=None, alpha=1, extra_attrs=None): + super().__init__(*[X, Y], name=name or N.n(), op_name=opns.ADD, extra_attrs=extra_attrs or {}, **{'alpha': alpha}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['alpha'], **kwargs) -def add(X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): - return Add(X, Y, name, op_name, alpha, extra_attrs) +def add(*args, **kwargs): + return Add(*args, **kwargs) -@dataclass(init=False) class Sub(symbol.Symbol): op_name = opns.SUB @property def alpha(self) -> int: - default_val = 1 - return self.attrs['alpha'] if 'alpha' in self.attrs else default_val + return self.attrs['alpha'] - def __init__(self, X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): - op_name = op_name or opns.SUB - assert op_name == opns.SUB - super().__init__(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={'alpha': alpha}, extra_attrs=extra_attrs or {}) + def __init__(self, X, Y, name=None, alpha=1, extra_attrs=None): + super().__init__(*[X, Y], name=name or N.n(), op_name=opns.SUB, extra_attrs=extra_attrs or {}, **{'alpha': alpha}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['alpha'], **kwargs) -def sub(X, Y, name=None, op_name=None, alpha=1, extra_attrs=None): - return Sub(X, Y, name, op_name, alpha, extra_attrs) +def sub(*args, **kwargs): + return Sub(*args, **kwargs) +def mul(X, Y, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(*[X, Y], name=name or N.n(), op_name=opns.MUL, extra_attrs=extra_attrs or {}) -def mul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.MUL - assert op_name == opns.MUL - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) +def mat_mul(X, Y, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(*[X, Y], name=name or N.n(), op_name=opns.MATMUL, extra_attrs=extra_attrs or {}) -def mat_mul(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.MATMUL - assert op_name == opns.MATMUL - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) - -@dataclass(init=False) class Div(symbol.Symbol): op_name = opns.DIV @property def rounding_mode(self) -> typing.Optional[str]: - default_val = None - return self.attrs['rounding_mode'] if 'rounding_mode' in self.attrs else default_val + return self.attrs['rounding_mode'] - def __init__(self, X, Y, name=None, op_name=None, rounding_mode=None, extra_attrs=None): - op_name = op_name or opns.DIV - assert op_name == opns.DIV - super().__init__(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={'rounding_mode': rounding_mode}, extra_attrs=extra_attrs or {}) + def __init__(self, X, Y, name=None, rounding_mode=None, extra_attrs=None): + super().__init__(*[X, Y], name=name or N.n(), op_name=opns.DIV, extra_attrs=extra_attrs or {}, **{'rounding_mode': rounding_mode}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['rounding_mode'], **kwargs) -def div(X, Y, name=None, op_name=None, rounding_mode=None, extra_attrs=None): - return Div(X, Y, name, op_name, rounding_mode, extra_attrs) - +def div(*args, **kwargs): + return Div(*args, **kwargs) -def negative(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.NEGATIVE - assert op_name == opns.NEGATIVE - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def negative(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.NEGATIVE, extra_attrs=extra_attrs or {}) -def abs(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.ABS - assert op_name == opns.ABS - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def abs(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.ABS, extra_attrs=extra_attrs or {}) -def log(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.LOG - assert op_name == opns.LOG - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def log(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.LOG, extra_attrs=extra_attrs or {}) -def sqrt(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.SQRT - assert op_name == opns.SQRT - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def sqrt(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.SQRT, extra_attrs=extra_attrs or {}) -def pow(X, Y, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.POW - assert op_name == opns.POW - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X, Y], attrs={}, extra_attrs=extra_attrs or {}) +def pow(X, Y, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(*[X, Y], name=name or N.n(), op_name=opns.POW, extra_attrs=extra_attrs or {}) -def pass_(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.PASS - assert op_name == opns.PASS - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def identity(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.IDENTITY, extra_attrs=extra_attrs or {}) -@dataclass(init=False) class Arange(symbol.Symbol): op_name = opns.ARANGE @@ -923,37 +736,29 @@ def end(self) -> int: @property def start(self) -> int: - default_val = 0 - return self.attrs['start'] if 'start' in self.attrs else default_val + return self.attrs['start'] @property def step(self) -> int: - default_val = 1 - return self.attrs['step'] if 'step' in self.attrs else default_val + return self.attrs['step'] - def __init__(self, name=None, op_name=None, end=None, start=0, step=1, extra_attrs=None): - op_name = op_name or opns.ARANGE - assert op_name == opns.ARANGE + def __init__(self, name=None, end=None, start=0, step=1, extra_attrs=None): assert end != None - super().__init__(name=name or N.n(), op_name=op_name, args=[], attrs={'end': end, 'start': start, 'step': step}, extra_attrs=extra_attrs or {}) + super().__init__(name=name or N.n(), op_name=opns.ARANGE, extra_attrs=extra_attrs or {}, **{'end': end, 'start': start, 'step': step}) @classmethod def from_dict(cls, d: dict, **kwargs): return _from_dict_attrs(cls, d, ['end', 'start', 'step'], **kwargs) -def arange(name=None, op_name=None, end=None, start=0, step=1, extra_attrs=None): - return Arange(name, op_name, end, start, step, extra_attrs) +def arange(*args, **kwargs): + return Arange(*args, **kwargs) -def zeros_like(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.ZEROS_LIKE - assert op_name == opns.ZEROS_LIKE - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def zeros_like(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.ZEROS_LIKE, extra_attrs=extra_attrs or {}) -def ones_like(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: - op_name = op_name or opns.ONES_LIKE - assert op_name == opns.ONES_LIKE - return symbol.Symbol(name=name or N.n(), op_name=op_name, args=[X], attrs={}, extra_attrs=extra_attrs or {}) +def ones_like(X, name=None, extra_attrs=None) -> symbol.Symbol: + return symbol.Symbol(X, name=name or N.n(), op_name=opns.ONES_LIKE, extra_attrs=extra_attrs or {}) _register_op_map(opns.VAR)(var) @@ -1011,7 +816,7 @@ def ones_like(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: _register_op_map(opns.LOG)(log) _register_op_map(opns.SQRT)(sqrt) _register_op_map(opns.POW)(pow) -_register_op_map(opns.PASS)(pass_) +_register_op_map(opns.IDENTITY)(identity) _register_op_map(opns.ARANGE)(Arange) _register_op_map(opns.ZEROS_LIKE)(zeros_like) _register_op_map(opns.ONES_LIKE)(ones_like) @@ -1024,14 +829,14 @@ def ones_like(X, name=None, op_name=None, extra_attrs=None) -> symbol.Symbol: _register_op_map(opns.CALL_TIR)(extern_opfunc(opns.CALL_TIR)) _register_op_map(opns.CALL_DPS_PACKED)(extern_opfunc(opns.CALL_DPS_PACKED)) -_register_op_map(opns.IF)(symbol.Symbol) -_register_op_map(opns.ARGWHERE)(symbol.Symbol) -_register_op_map(opns.REQUANT)(symbol.Symbol) -_register_op_map(opns.PCLIP)(symbol.Symbol) -_register_op_map(opns.RS_PCLIP)(symbol.Symbol) -_register_op_map(opns.LUT)(symbol.Symbol) - -_register_op_map(opns.BATCH_FLATTEN)(symbol.Symbol) -_register_op_map(opns.STRIDED_SLICE)(symbol.Symbol) -_register_op_map(opns.SLICE_LIKE)(symbol.Symbol) -_register_op_map(opns.GET_VALID_COUNT)(symbol.Symbol) +_register_op_map(opns.IF)(extern_opfunc(opns.IF)) +_register_op_map(opns.ARGWHERE)(extern_opfunc(opns.ARGWHERE)) +_register_op_map(opns.REQUANT)(extern_opfunc(opns.REQUANT)) +_register_op_map(opns.PCLIP)(extern_opfunc(opns.PCLIP)) +_register_op_map(opns.RS_PCLIP)(extern_opfunc(opns.RS_PCLIP)) +_register_op_map(opns.LUT)(extern_opfunc(opns.LUT)) + +_register_op_map(opns.BATCH_FLATTEN)(extern_opfunc(opns.BATCH_FLATTEN)) +_register_op_map(opns.STRIDED_SLICE)(extern_opfunc(opns.STRIDED_SLICE)) +_register_op_map(opns.SLICE_LIKE)(extern_opfunc(opns.SLICE_LIKE)) +_register_op_map(opns.GET_VALID_COUNT)(extern_opfunc(opns.GET_VALID_COUNT)) diff --git a/python/mrt/mir/opns.py b/python/mrt/mir/opns.py index 31da253..cec6427 100644 --- a/python/mrt/mir/opns.py +++ b/python/mrt/mir/opns.py @@ -80,7 +80,7 @@ SQRT = "sqrt" POW = "pow" -PASS = "pass" +IDENTITY = "identity" # original PASS # ======= auto generate op ========= ARANGE = "arange" ZEROS_LIKE = "zeros_like" diff --git a/python/mrt/mir/simple_pass.py b/python/mrt/mir/simple_pass.py deleted file mode 100644 index 302da1b..0000000 --- a/python/mrt/mir/simple_pass.py +++ /dev/null @@ -1,345 +0,0 @@ -from __future__ import annotations -import typing - -from functools import wraps -from dataclasses import dataclass - -from mrt.common import config -#from mrt.runtime import inference -from mrt.common.utils import * -from mrt.common.types import * - -from . import op, opns, opclass -from . import symbol as _symbol - - -# mrt op visits -@dataclass -class SimplePass: - symbol: _symbol.Symbol - - """op-level visit of graph - infer different visit function with different op_name - return: head symbol processed - """ - def graph_visits(self) -> _symbol.Symbol: - env: typing.Dict[str, _symbol.Symbol] = {} - for sym in _symbol.sym2list(self.symbol): - assert sym.name not in env, f'{sym.name} NotIn env!' - - # Updating args as passed symbol in env_dict - sym = sym.copy(args = [env[arg_sym.name] for arg_sym in sym.args]) - assert isinstance(sym, _symbol.Symbol), sym - out = getattr(self, f"visit_{opns.Opname2Funcname(sym.op_name)}")(sym) - out = out or sym - assert isinstance(out, _symbol.Symbol), out - env[sym.name] = out - return env[self.symbol.name] - - def _default_visit_op(self, op: _symbol.Symbol) -> _symbol.Symbol: - return op - - """custom visit of graph - calling custom_func for all op_name - return: head symbol processed - """ - def custom_visits(self, custom_run: _symbol._TransformerParamT, name: str = "", once: bool = False) -> _symbol.Symbol: - with N(name): - if once: - return custom_run(self.symbol) - return _symbol.transform(self.symbol, custom_run) - - -# mrt op visits with params, variables -@dataclass -class InferPass(SimplePass): - params: ParametersT - - def is_input(self, op_: _symbol.Symbol) -> bool: - return op.is_input(op_, self.params) - def is_variable(self, op_: _symbol.Symbol) -> bool: - return op.is_variable(op_, self.params) - def is_operator(self, op_: _symbol.Symbol) -> bool: - return op.is_operator(op_, self.params) - def is_param(self, op_: _symbol.Symbol) -> bool: - return op_.op_name == opns.VAR and op_.name in self.params - - def get_param(self, op_: _symbol.Symbol) -> OpNumpyT: - return self.params[op_.name] if self.is_param(op_) else [] - def get_as_numpy(self, op_: _symbol.Symbol) -> OpNumpyT: - assert self.is_param(op_), f"{op_.name} is not parameter." - data = self.params[op_.name] - assert isinstance(data, (tuple, list, np.ndarray)), \ - f"param:{op_.name} not OpNumpyT, get {type(data)}" - return data - - """custom visit of graph - calling custom_func for all op_name - according to how custom_run implemented, params is from argument or class_property - return: head symbol processed - """ - def custom_visits_with_params(self, custom_run: _symbol._TransformerParamT, name: str = "", once: bool = False) -> _symbol.Symbol: - with N(name): - if once: - return custom_run(self.symbol, self.params) - return _symbol.transform(self.symbol, custom_run, params=self.params) - - # From original quantization.Transformer - def as_parameter(self, data: OpNumpyT, name:str, dtype): - def _f(data, dtype): - if isinstance(data, list): - assert len(data) == len(dtype) - return [_f(d, t) for d, t in zip(data, dtype)] - assert isinstance(data, np.ndarray), type(data) - return data.astype(dtype) - array = _f(data, dtype) - shape = np.array(array).shape - self.params[name] = array - return opclass.var(array, shape=shape, dtype=dtype) - - def from_np_data(self, sym:_symbol.Symbol, data: np.ndarray, dtype, prefix=None) -> _symbol.Symbol: - name = N.n(prefix=prefix) - # some data is np.float/int type, use np.array to wrap it. - data = np.array(data) - self.params[name] = data.astype(dtype) - return opclass.var(name, shape=data.shape, dtype=dtype).like(sym) - - def from_const_data(self, sym:_symbol.Symbol, data: typing.Union[int, float], dtype) -> _symbol.Symbol: - return self.from_np_data(sym, data, dtype) - - -# Register MRT all op's default_visit_op function -for op_name in opclass.MRT_OP_MAP.keys(): - funcSuffix = opns.Opname2Funcname(op_name) - setattr(SimplePass, f"visit_{funcSuffix}", SimplePass._default_visit_op) - #print(f"visit_, {op_name} => {funcSuffix}", getattr(SimplePass, f"visit_{funcSuffix}")) - - -# mrt symbol simple pass -class FuseDropoutPass(SimplePass): - def visit_nn_dropout(self, sym: _symbol.Symbol) -> _symbol.Symbol: - # make sure op fit again - if sym.op_name == opns.DROP_OUT: - return sym.args[0] - return sym - - -class FuseTupleGetItemPass(SimplePass): - def visit_TupleGetItem(self, sym: opclass.TupleGetItem) -> _symbol.Symbol: - #if sym.op_name == opns.TUPLE_GET_ITEM: - # assert sym.index == 0 - # return sym.args[0] - return sym - - -class FuseNaiveSoftmaxPass(SimplePass): - def visit_nn_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.SOFTMAX: - return sym.args[0] - return sym - - def visit_nn_log_softmax(self, sym: _symbol.Symbol) -> _symbol.Symbol: - if sym.op_name == opns.LOG_SOFTMAX: - return sym.args[0] - return sym - - -class FuseMeanPass(InferPass): - def get_run(self) -> _symbol._TransformerParamT: - def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: - if sym.op_name == opns.MEAN: - X = sym.args[0] - out = opclass.Sum(X, **sym.attrs).like(sym) - scale = self.from_np_data(sym, np.array( - 1. * product(out.shape) / product(X.shape)), dtype=out.dtype) - out = opclass.mul(out, scale) - return out - return sym - return custom_run - - -class FuseConstantPass(InferPass): - threshold: typing.ClassVar[float] = 1e-5 - - def np_is_zero(self, data) -> float: - return np.abs(data).max() < self.threshold - - def get_run(self) -> _symbol._TransformerParamT: - def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: - if self.is_operator(sym) and all([self.is_param(arg) for arg in sym.args]): - data = inference.run_single_params( - sym, [self.get_as_numpy(a) for a in sym.args]) - return self.as_parameter(data, name=sym.name, dtype=sym.dtype) - elif sym.is_op(opns.ADD, opns.SUB): # , BIAS_ADD): - strips = [] - for arg in sym.args: - if self.is_param(arg) and self.np_is_zero(self.get_as_numpy(arg)): - strips.append(arg) - args = [a for a in sym.args if a not in strips] - if len(args) == 1: - return args[0] - elif sym.is_op(opns.SLICE_LIKE): - if not self.is_param(sym.args[0]): - return sym - a, b = sym.args - data = inference.run_single_params( - sym, [self.get_as_numpy(a), np.zeros(b.shape, b.dtype)]) - return self.as_parameter(data, name=sym.name, dtype=sym.dtype) - elif sym.is_op(opns.REQUANT): - if sym.rescale == 1: - return sym.args[0] - elif sym.is_op(opns.ZEROS_LIKE, opns.ONES_LIKE): - data = inference.run_single_params(sym, []) - return self.as_parameter(data, name=sym.name, dtype=sym.dtype) - return sym - return custom_run - - -class FuseBatchNormPass(InferPass): - def get_run(self) -> _symbol._TransformerParamT: - def custom_run(sym: opclass.BatchNorm, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: - if sym.op_name == opns.BATCH_NORM: - X, Gamma, Beta, Mean, Var = sym.args - Gamma = self.get_param(Gamma) - Beta = self.get_param(Beta) - Mean = self.get_param(Mean) - Var = self.get_param(Var) - - assert sym.axis == 1 - Beta = Beta if sym.center else 0 - Gamma = Gamma if sym.scale else 1 - - # (x - mean) / sqrt(var + epsilon) * gamma + beta - Gamma = Gamma / np.sqrt(Var + sym.epsilon) - # (x - mean) * gamma + beta - # x * gamma + (beta - mean * gamma) - bias: np.ndarray = (Beta - Mean * Gamma) - K = Gamma.shape[0] - - if X.is_op(opns.CONV2D): - A, W = X.args - assert X.kernel_layout == "OIHW" - assert W.shape[0] == K - # (A * W) * gamma + bias - # A * (W * gamma) + bias - W_data = self.get_as_numpy(W) * Gamma.reshape(K, 1, 1, 1) - W_sym = self.from_np_data(W, W_data, W.dtype) - out = op.nn_conv2d(A, W_sym, **X.attrs) - elif X.is_op(opns.DENSE): - A, W = X.args - # (A * W) * gamma + bias - # A * (W * gamma) + bias - W_data = self.get_as_numpy(W) * Gamma.reshape(K, 1) - W_sym = self.from_np_data(W, W_data, W.dtype) - out = op.nn_dense(A, W_sym, **X.attrs) - else: - reshp = [s if i == sym.axis else 1 \ - for i, s in enumerate(X.shape)] - W = self.from_np_data(X, Gamma.reshape(reshp), X.dtype) - out = opclass.mul(X, W) - - bias = bias.reshape([s if i == sym.axis else 1 \ - for i, s in enumerate(out.shape)]) - B = out.like(sym) - B = self.from_np_data(B, bias, dtype=B.dtype) - return opclass.add(out, B).like(sym) - - return sym - return custom_run - - -class FuseDividePass(InferPass): - def get_run(self) -> _symbol._TransformerParamT: - def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: - if sym.op_name == opns.DIV: - argA = sym.args[0] - argB = sym.args[1] - assert self.is_param(argB), f'NotParam: {argB}' - argB = self.from_np_data(sym, 1. / self.get_as_numpy(argB), dtype=argB.dtype) - out = opclass.mul(argA, argB) - return out.like(sym) - return sym - return custom_run - - -class FuseLeakyReLU(InferPass): - def get_run(self) -> _symbol._TransformerParamT: - def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: - if sym.op_name == opns.LEAKY_RELU: - alpha = self.from_const_data(sym, sym.alpha, dtype=float) - X = sym.args[0] - out = opclass.relu(opclass.negative(X)) - out = opclass.mul(alpha, out) - return opclass.sub(opclass.relu(X), out) - return sym - return custom_run - -class FuseAdaptiveAvgPool2D(InferPass): - def get_run(self) -> _symbol._TransformerParamT: - def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: - if sym.op_name == opns.ADAPTIVE_AVG_POOL2D: - X = sym.args[0] - assert sym.layout == "NCHW" - inp_shap = X.shape[2:] - out_size = sym.output_size or inp_shap - if not isinstance(out_size, (list, tuple)): - out_size = (out_size, out_size) - sym.output_size = out_size - - assert len(X.shape) == 4 - if all([s == 1 for s in sym.output_size]): - scale = np.array(1 / np.prod(X.shape[-2:])) - out = opclass.Sum(X, dim=list(range(4))[-2:], keepdims=True) - scale = self.from_np_data(sym, scale.astype(X.dtype)) - return opclass.mul(out, scale).like(self) - elif out_size[0] > inp_shap[0] or out_size[1] > inp_shap[1]: - assert all([s == 1 for s in inp_shap]) - # TODO: fix opclass repeat - out = opclass.repeat(X, repeats=out_size[0], axis=-2) - out = opclass.repeat(out, repeats=out_size[1], axis=-1) - return out.like(self) - - # calculate the attributes refers to: - # https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work - strides = [i // o for i, o in zip(inp_shap, out_size)] - kernel = [i-(o-1)*s for i, o, s in zip(inp_shap, out_size, strides)] - attrs = { - "kernel_size": kernel, - "strides": strides, - "padding": (0, 0), - "dilation": (1, 1), - "data_layout": sym.layout, - "groups": X.shape[1], - "channels": X.shape[1], - } - W_shape = (X.shape[1], 1, *kernel) - W = self.from_np_data(X, np.full(W_shape, 1 / product(kernel)), dtype=X.dtype) - out = opclass.Conv2D(X, W, **attrs) - return out.like(sym) - return sym - return custom_run - - -class FuseAvgPool2D(InferPass): - def get_run(self) -> _symbol._TransformerParamT: - def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: - return sym - return custom_run - -class Spliter(InferPass): - def get_run(self) -> _symbol._TransformerParamT: - def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: - return sym - return custom_run - -class Merger(InferPass): - def get_run(self) -> _symbol._TransformerParamT: - def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: - return sym - return custom_run - -class Calibrator(InferPass): - def get_run(self) -> _symbol._TransformerParamT: - def custom_run(sym: _symbol.Symbol, params: typing.Optional[ParametersT] = None) -> _symbol.Symbol: - return sym - return custom_run diff --git a/python/mrt/mir/symbol.py b/python/mrt/mir/symbol.py index 973e92b..cadb48e 100644 --- a/python/mrt/mir/symbol.py +++ b/python/mrt/mir/symbol.py @@ -192,7 +192,6 @@ def _uniform(n: str, max_size: int) -> str: _format_printer(oattrs)) -@dataclass class Symbol(_BaseSymbol): """ Uniform Symbol Representation for RelayExpr @@ -209,6 +208,15 @@ class Symbol(_BaseSymbol): for the user's config about quantization layers. """ + def __init__(self, *args, name:str=None, op_name:str=None, extra_attrs:dict=None, **attrs): + assert name != None + assert op_name != None + self.name = name + self.op_name = op_name + self.args = [arg for arg in args] + self.attrs = attrs + self.extra_attrs = extra_attrs or {} + # Overridable Methods, inheritted from _BaseSymbol # to support multi-inherit design. @classmethod @@ -220,12 +228,43 @@ def set_extra_attrs(self, **kwargs): def base(cls, symbol: Symbol, **kwargs) -> Symbol: return super().base(symbol, **kwargs) def like(self, other: Symbol, **kwargs) -> Symbol: - return super().like(other, **kwargs) + """ cast current symbol to child class. """ + assert isinstance(other, Symbol) + data = other.to_dict() + data_new = self.to_dict() + data.update(data_new) + data["extra_attrs"] = other.extra_attrs if self.extra_attrs == {} else data["extra_attrs"] + # copy extra attrs by default. + # data["extra_attrs"] = other.extra_attrs + #return type(other).from_dict(data, **kwargs) + return Symbol.from_dict(data, **kwargs) + def as_variable(self, **kwargs) -> Symbol: + sym = Symbol.from_dict(self.to_dict(), **kwargs) # kwargs override self + sym.op_name = opns.VAR + sym.args = [] + sym.attrs = {} + return sym def copy(self, **kwargs) -> Symbol: return super().copy(**kwargs) @classmethod def from_dict(cls, d: dict, **kwargs): - return super().from_dict(d, **kwargs) + data = cls.default_dict() + data.update(d) + data.update(kwargs) + data = cls.update_dict(data) + fnames = [f.name for f in fields(cls)] + data = {k: data[k] for k in data if k in fnames} + args = data['args'] or [] + attrs = data['attrs'] or {} + try: + out = cls(*args, name=data['name'], op_name=data['op_name'], extra_attrs=data['extra_attrs'], **attrs) + except Exception as _: + raise RuntimeError(( + "Error for type:{} create from dict, " + "expected: {}, but get {}" + ).format(get_class_name(cls), + fnames, data.keys())) + return out @classmethod def default_dict(cls, **kwargs) -> dict: kwargs.setdefault("args", []) @@ -349,6 +388,7 @@ def transform(symbol: Symbol, callback: _TransformerT) -> Symbol: Only the return value indicates mutation, while changing attributes in parameter passed in args does nothing. """ + assert isinstance(symbol.args, list), f"Symbol_Args_Wrong_type: {type(symbol.args)}" sym_map: typing.Dict = {} C = config.LogConfig.G() for sym in sym2list(symbol): @@ -360,7 +400,11 @@ def transform(symbol: Symbol, callback: _TransformerT) -> Symbol: if callback.__name__ in C.log_vot_cbs: config.log(callback.__name__, f"<< {sym}") - out = callback(sym) or sym + # Skipping transform output symbol in trace-Exporter + if callback.__name__.find("Exporter")>=0 and sym.name == symbol.name: + out = sym + else: + out = callback(sym) or sym assert isinstance(out, Symbol), out # default const_ prefix symbol means parameters assert sym.name not in sym_map, sym.name @@ -452,27 +496,6 @@ def transform(symbol: Symbol, callback: _TransformerT) -> Symbol: # name: str = "main") -> MultiHeadSymbol: # return MultiHeadSymbol(**{ name: symbol }) -class MultiHeadSymbol(dict): - """ { "main": F(X) } """ - origin: typing.Optional[Symbol] = None - - @classmethod - def from_symbol(cls, symbol: Symbol, name: str = "main"): - return MultiHeadSymbol({ name: symbol }) - - def as_tuple(self) -> typing.Tuple[typing.List[str], Symbol]: - from . import op - # args = list(self.values()) - # sym_type = type(args[0]) if args else Symbol - mhs = self.origin or op.Tuple(*list(self.values())) - return list(self.keys()), mhs - - @classmethod - def from_tuple(cls, tuple_names, symbol): - assert symbol.is_op(opns.TUPLE), symbol - mhs = cls(zip(tuple_names, symbol.args)) - mhs.origin = symbol - return mhs # MultiHeadSymbol = typing.Dict[str, Symbol] @@ -523,11 +546,6 @@ def from_tuple(cls, tuple_names, symbol): # return {k: load_json(v) for k, v in data} # ^^^^^^^^^^^^^^^ MultiHeadSymbol API ^^^^^^^^^^^^^^^^^^ - -Graph = typing.Union[Symbol, MultiHeadSymbol] -""" Notice that Symbol and MultiHeadSymbol can both - be regarded as a model Graph. -""" # def graph_visit(graph: Graph, callback: _VisitorT): # return visit(graph, callback) # # visit_func = visit if isinstance(graph, Symbol) else mhs_visit diff --git a/python/mrt/quantization/calibrate.py b/python/mrt/quantization/calibrate.py index 7961ce7..b1cc081 100644 --- a/python/mrt/quantization/calibrate.py +++ b/python/mrt/quantization/calibrate.py @@ -17,10 +17,19 @@ @dataclass(repr=False) class Calibrator(Transformer): - """ skip dump, and restore from np_data. """ - raw_data: typing.Dict[str, OpOutputT] = field(repr=False, default_factory=dict) - """ calibrate may be processed multi-times """ - data: typing.List[OpNumpyT] = field(default_factory=list) + @property + def raw_data(self) -> OpOutputT | None: + return self.extra_attrs.get("raw_data", None) + @raw_data.setter + def raw_data(self, val): + self.set_extra_attrs(raw_data=val) + + @property + def data(self) -> typing.List[OpNumpyT]: + return self.extra_attrs.get("data", []) + @data.setter + def data(self, val): + self.set_extra_attrs(data=val) def _rand_data(self, enabled: bool = False, @@ -43,8 +52,6 @@ def __call__(self, sampling_func: SamplingFuncT = None, **kwargs): kwargs.pop("origin", None) - self.raw_data = kwargs.pop("raw_data", {}) - self.data = kwargs.pop("out_data", []) if self.is_input(): out = data_dict.get(self.name, data) @@ -56,7 +63,7 @@ def __call__(self, single_op = op.retrieve_operator(self.graph) out = inference.run_single( single_op, - [self.raw_data[a.name] for a in self.args], + [self.from_symbol(a).raw_data for a in self.args], **kwargs) assert isinstance(out, (np.ndarray, list)), type(out) @@ -67,7 +74,7 @@ def __call__(self, self._assert([o.dtype.name for o in out], self.dtype) self._assert([o.shape for o in out], self.shape) - self.raw_data[self.name] = out + self.raw_data = out if sampling_func is not None: out = sampling_func(out) self.data.append(out) @@ -107,15 +114,12 @@ def sampling(cls, np_data: np.ndarray) -> typing.Any: def __call__(self, origin: Symbol, **kw): print(type(origin), origin) - origin_data = kw.pop('origin_data', []) - origin_data = origin_data[self.name] - if self.is_op(opns.CLIP): # TODO: remove clip if threshold is less than a_max a_min, a_max = self.parsed.a_min, self.parsed.a_max self.data = max(abs(a_min), abs(a_max)) else: - self.data = self.sampling(origin_data) + self.data = self.sampling(origin.extra_attrs.get("raw_data")) return self.graph @dataclass(repr=False) diff --git a/python/mrt/quantization/discrete.py b/python/mrt/quantization/discrete.py index 5ca5574..ba9a528 100644 --- a/python/mrt/quantization/discrete.py +++ b/python/mrt/quantization/discrete.py @@ -4,7 +4,10 @@ import math from dataclasses import dataclass, field -from mrt.mir import op +from mrt.mir import op, opclass +from mrt.mir.optype import infer_single +from mrt.mir.opclass import MRT_OP_MAP + from mrt.mir.opns import * from mrt.mir.symbol import * @@ -74,10 +77,10 @@ def rescale(self, info: DiscreteInfo): if info not in self.requant_ops: curr_scale = self.scale if self.scale_defined else 1 #TODO: add pass to check rescale=1 and duplicate requant - out = op.requant( + out = infer_single(MRT_OP_MAP[REQUANT]( self.graph, rescale=scale/curr_scale, - precision=precision, + precision=precision) ).like(self.graph) out.set_extra_attrs( data=self.data, scale=scale, precision=precision) @@ -253,9 +256,9 @@ def op_lut_rules(s: QuantInfo): X = s.args[0] offset = s.from_np_data(np.array(alpha, "int")) - indices = op.add(X, offset).like(X) - indices = op.clip(indices, a_min=0, a_max=2*alpha).like(X) #a_max=alpha+1) - indices = op.cast(indices, dtype="int32") + indices = infer_single(opclass.add(X, offset)).like(X) + indices = infer_single(opclass.clip(indices, a_min=0, a_max=2*alpha)).like(X) #a_max=alpha+1) + indices = infer_single(MRT_OP_MAP[AS_TYPE](indices, dtype="int32")) # arg_min, arg_max = -s.data, s.data # if s.is_op(EXP): @@ -267,7 +270,7 @@ def op_lut_rules(s: QuantInfo): # table = np.reshape(table, (-1, 1)) oscale = s.precision_to_scale(LUT_OUT_PREC) weight = s.from_np_data(table * oscale) - out = op.adv_index(weight, indices).like(s) + out = infer_single(MRT_OP_MAP[ADV_INDEX](weight, indices)).like(s) # out.scale = s.precision_to_scale(LUT_INP_PREC) return out @@ -290,16 +293,16 @@ def op_softmax_rules(s: QuantInfo): alpha = int(lambd * Xs) var = s.from_np_data(np.array(alpha, "int")) - max_axis = op.max_axis(X, axis = axis, keepdims=True) - offset = op.sub(max_axis, var) - offset = op.pclip(offset, precision=Xp) + max_axis = infer_single(opclass.max_axis(X, dim=axis, keepdim=True)) + offset = infer_single(opclass.sub(max_axis, var)) + offset = infer_single(MRT_OP_MAP[PCLIP](offset, precision=Xp)) offset.set_extra_attrs(precision=Xp) - norm = op.sub(X, offset) - norm = op.nn_relu(norm) - norm = op.pclip(norm, precision=Xp) + norm = infer_single(opclass.sub(X, offset)) + norm = infer_single(opclass.relu(norm)) + norm = infer_single(MRT_OP_MAP[PCLIP](norm, precision=Xp)) norm.set_extra_attrs(precision=Xp) # TODO: norm = op.cast(norm, dtype="int32") - norm = op.cast(norm, dtype="int32") + norm = infer_single(MRT_OP_MAP[AS_TYPE](norm, dtype="int32")) op_inp = np.arange(0, alpha+1) / Xs table = np.exp(op_inp) @@ -308,21 +311,21 @@ def op_softmax_rules(s: QuantInfo): weight = np.round(table) # weight = np.transpose(weight, (1, 0)) weight = s.from_np_data(weight) - out_lut = op.adv_index(weight, norm).like(s) - sum_lut = op.sum(out_lut, axis=axis, keepdims=True).like(out_lut) + out_lut = infer_single(MRT_OP_MAP[ADV_INDEX](weight, norm)).like(s) + sum_lut = infer_single(opclass.sum(out_lut, dim=axis, keepdim=True)).like(out_lut) oprec = min(SOFTMAX_PREC, 31 - tprec) oscale = bits_to_number(oprec) nd_oscale = s.from_np_data(np.array(oscale, "int")) - prob = op.mul(out_lut, nd_oscale) + prob = infer_single(opclass.mul(out_lut, nd_oscale)) - half_lut = op.rs_pclip(sum_lut, s.from_const_data(1), precision=31) + half_lut = infer_single(MRT_OP_MAP[RS_PCLIP](sum_lut, s.from_const_data(1), precision=31)) half_lut.set_extra_attrs(precision=31) - prob = op.add(prob, half_lut) - out = op.div(prob, sum_lut) - out = op.cast(out, dtype="int32") - out = op.cast(out, dtype="float32") - out = op.pclip(out, precision=oprec) + prob = infer_single(opclass.add(prob, half_lut)) + out = infer_single(opclass.div(prob, sum_lut)) + out = infer_single(MRT_OP_MAP[AS_TYPE](out, dtype="int32")) + out = infer_single(MRT_OP_MAP[AS_TYPE](out, dtype="float32")) + out = infer_single(MRT_OP_MAP[PCLIP](out, precision=oprec)) out.set_extra_attrs(scale=oscale, precision=oprec) return out diff --git a/python/mrt/quantization/fixed_point.py b/python/mrt/quantization/fixed_point.py index 20291f4..61450a0 100644 --- a/python/mrt/quantization/fixed_point.py +++ b/python/mrt/quantization/fixed_point.py @@ -4,7 +4,9 @@ import numpy as np from dataclasses import dataclass -from mrt.mir import op +from mrt.mir import op, opclass +from mrt.mir.optype import infer_single +from mrt.mir.opclass import MRT_OP_MAP from mrt.mir.opns import * from mrt.mir.symbol import filter_operators from mrt.mir.attrs import PClipAttrs, RequantAttrs @@ -86,27 +88,26 @@ def map_int_requant(self): exp = exp + (X.precision - anno_bit) rs_bit = X.from_const_data(X.precision - anno_bit) - X_op = op.right_shift(X.graph, rs_bit).like(self.graph) - X = self.from_symbol(X_op) + X_op = infer_single(opclass.right_shift(X.graph, rs_bit)).like(self.graph) + X = self.from_symbol(X_op) X.precision = anno_bit assert frac >= 1 assert exp <= 0 frac_sym = X.from_const_data(frac) - out = op.mul(X.graph, frac_sym).like(self.graph) + out = infer_single(opclass.mul(X.graph, frac_sym)).like(self.graph) exp_sym = self.from_symbol(out).from_const_data(-exp) if ExporterConfig.G().use_clip: if ExporterConfig.G().use_pclip: - out = op.rs_pclip(out, exp_sym, - precision=self.precision) + out = infer_single(MRT_OP_MAP[RS_PCLIP](out, exp_sym, precision=self.precision)) else: pos = self.int_max() - out = op.right_shift(out, exp_sym).like(self.graph) - out = op.clip(out, min=-pos, max=pos).like(self.graph) + out = infer_single(opclass.right_shift(out, exp_sym)).like(self.graph) + out = infer_single(opclass.clip(out, min=-pos, max=pos)).like(self.graph) else: - out = op.right_shift(out, exp_sym).like(self.graph) + out = infer_single(opclass.right_shift(out, exp_sym)).like(self.graph) return out def process(self): @@ -130,26 +131,26 @@ def process(self): else: # use float multipy to map requant rescale = self.parsed.rescale rescale = self.from_const_data(rescale) - out = op.mul(self.args[0], rescale) + out = infer_single(opclass.mul(self.args[0], rescale)) if G.use_clip: - out = op.clip(out, min=-pos, max=pos) + out = infer_single(opclass.clip(out, min=-pos, max=pos)) if not G.use_int_dtype and G.use_round: orig_dtype = out.dtype - out = op.cast(out, dtype="int32") - out = op.cast(out, dtype=orig_dtype) + out = infer_single(MRT_OP_MAP[AS_TYPE](out, dtype="int32")) + out = infer_single(MRT_OP_MAP[AS_TYPE](out, dtype=orig_dtype)) if not G.use_clip: if self.is_op(PCLIP): out = self.args[0] elif self.is_op(RS_PCLIP): - out = op.right_shift(*self.args) + out = infer_single(opclass.right_shift(*self.args)) elif not G.use_pclip: if self.is_op(PCLIP): out = self.args[0] elif self.is_op(RS_PCLIP): - out = op.right_shift(*self.args) - out = op.clip(out, min=-pos, max=pos) + out = infer_single(opclass.right_shift(*self.args)) + out = infer_single(opclass.clip(out, min=-pos, max=pos)) return out @@ -173,8 +174,8 @@ def round(self, out: Symbol): # out = op.add(out, data_0_5) # out = op.ceil(out) orig_dtype = out.dtype - out = op.cast(out, dtype="int32") - out = op.cast(out, dtype=orig_dtype) + out = infer_single(MRT_OP_MAP[AS_TYPE](out, dtype="int32")) + out = infer_single(MRT_OP_MAP[AS_TYPE](out, dtype=orig_dtype)) return out def __call__(self, with_clip=False, with_round=False, **kw): @@ -191,13 +192,13 @@ def __call__(self, with_clip=False, with_round=False, **kw): if self.is_op(REQUANT): rescale = self.parsed.rescale rescale = self.from_const_data(rescale) - out = op.mul(out, rescale) + out = infer_single(opclass.mul(out, rescale)) if with_round: out = self.round(out) if with_clip: pos = self.int_max() # relax api from a_min/a_max to min/max - out = op.clip(out, min=-pos, max=pos) + out = infer_single(opclass.clip(out, min=-pos, max=pos)) # print(out) # sys.exit() return out.like(self.graph) @@ -215,18 +216,17 @@ def map_requant(self) -> FixPoint: anno_bit = WithPrecision.MAX_BIT // 2 if X.precision > anno_bit: rs_bit = X.from_const_data(X.precision - anno_bit) - X = op.right_shift(X, rs_bit).like(self) + X = infer_single(opclass.right_shift(X, rs_bit).like(self)) X.precision = anno_bit frac, exp = cvm_float(self.parsed.rescale, anno_bit) assert frac >= 1 assert exp <= 0 frac_sym = X.from_const_data(frac) - out = op.mul(X, frac_sym).like(self) + out = infer_single(opclass.mul(X, frac_sym)).like(self) exp_sym = out.from_const_data(-exp) - out = op.rs_pclip(out, exp_sym, - precision=self.precision) + out = infer_single(MRT_OP_MAP[RS_PCLIP](out, exp_sym, precision=self.precision)) # pos = self.int_max() # out = op.right_shift(out, exp_sym).like(self) # out = op.clip(out, a_min=-pos, a_max=pos).like(self) @@ -237,7 +237,7 @@ def map_pclip(self) -> FixPoint: X: FixPoint = self.args[0] pos = self.int_max() out = X - out = op.pclip(X, precision=self.precision).like(self) + out = infer_single(MRT_OP_MAP[PCLIP](X, precision=self.precision)).like(self) # out = op.clip(X, a_min=-pos, a_max=pos).like(self) return out diff --git a/python/mrt/quantization/fuse.py b/python/mrt/quantization/fuse.py index 364f7a9..d6f110d 100644 --- a/python/mrt/quantization/fuse.py +++ b/python/mrt/quantization/fuse.py @@ -3,7 +3,7 @@ import numpy as np -from mrt.mir import op +from mrt.mir import opclass, optype from mrt.mir.opns import * from mrt.mir.symbol import * from mrt.mir.attrs import * @@ -21,6 +21,10 @@ class FuseDropout(Transformer): @filter_operators(DROP_OUT) def __call__(self, **kwargs): return self.args[0] +class FuseIdentity(Transformer): + @filter_operators(IDENTITY) + def __call__(self, **kwargs): + return self.args[0] class FuseConstant(Transformer): threshold: typing.ClassVar[float] = 1e-5 @@ -93,7 +97,7 @@ def __call__(self, **kw): # A * (W * gamma) + bias W_data = self.from_symbol(W).numpy() * gamma.reshape(K, 1, 1, 1) W_sym = self.from_symbol(W).from_np_data(W_data) - out = op.nn_conv2d(A, W_sym, **X.attrs) + out = optype.infer_single(opclass.conv2d(A, W_sym, **X.attrs)) elif X.is_op(DENSE): A, W = X.args dense_parsed: DenseAttrs = X.parsed @@ -102,18 +106,18 @@ def __call__(self, **kw): # A * (W * gamma) + bias W_data = self.from_symbol(W).numpy() * gamma.reshape(K, 1) W_sym = self.from_symbol(W).from_np_data(W_data) - out = op.nn_dense(A, W_sym, **X.attrs) + out = optype.infer_single(opclass.dense(A, W_sym, **X.attrs)) else: reshp = [s if i == parsed.axis else 1 \ for i, s in enumerate(X.shape)] W = X.from_np_data(gamma.reshape(reshp)) - out = op.mul(X.graph, W) + out = optype.infer_single(opclass.mul(X.graph, W)) bias = bias.reshape([s if i == parsed.axis else 1 \ for i, s in enumerate(out.shape)]) B = self.from_symbol(out.like(self.graph)).from_np_data(bias) - out = op.add(out, B) - # out = op.bias_add(out, B, axis=parsed.axis) + out = opclass.add(out, B) + out = optype.infer_single(out) return out.like(self.graph) class FuseTupleGetItem(Transformer): @@ -151,7 +155,7 @@ def _fuse_avg_pool2d(self): W_shape = (X.shape[1], 1, *parsed.pool_size) W = self.from_symbol(X).from_np_data(np.full( W_shape, 1 / product(parsed.pool_size))) - out = op.nn_conv2d(X, W, **attrs) + out = optype.infer_single(opclass.conv2d(X, W, **attrs)) return out.like(self.graph) @@ -169,13 +173,14 @@ def _fuse_adaptive_avg_pool2d(self): assert len(X.shape) == 4 if all([s == 1 for s in parsed.output_size]): scale = np.array(1 / np.prod(X.shape[-2:])) - out = op.sum(X, axis=list(range(4))[-2:], keepdims=True) + out = optype.infer_single(opclass.sum(X, dim=list(range(4))[-2:], keepdim=True)) scale = self.from_np_data(scale.astype(X.dtype)) - return op.mul(out, scale).like(self.graph) + out = optype.infer_single(opclass.mul(out, scale)) + return out.like(self.graph) elif ous[0] > ins[0] or ous[1] > ins[1]: assert all([s == 1 for s in ins]) - out = op.repeat(X, repeats=ous[0], axis=-2) - out = op.repeat(out, repeats=ous[1], axis=-1) + out = optype.infer_single(opclass.repeat(X, repeats=ous[0], axis=-2)) + out = optype.infer_single(opclass.repeat(out, repeats=ous[1], axis=-1)) return out.like(self.graph) # calculate the attributes refers to: @@ -193,7 +198,7 @@ def _fuse_adaptive_avg_pool2d(self): } W_shape = (X.shape[1], 1, *kernel) W = self.from_symbol(X).from_np_data(np.full(W_shape, 1 / product(kernel))) - out = op.nn_conv2d(X, W, **attrs) + out = optype.infer_single(opclass.conv2d(X, W, **attrs)) return out.like(self.graph) class FuseNaiveSoftmax(Transformer): @@ -218,10 +223,10 @@ def __call__(self, **kw): # axis = [a for a in range(max_axis) if a not in axis] # axis_len = product([X.shape[a] for a in axis]) - out = op.sum(X, **self.attrs) + out = optype.infer_single(opclass.sum(X, **self.attrs)) scale = self.from_np_data(np.array( 1. * product(out.shape) / product(X.shape))) - out = op.mul(out, scale) + out = optype.infer_single(opclass.mul(out, scale)) return out.like(self.graph) class FuseLeakyReLU(Transformer): @@ -236,9 +241,10 @@ def __call__(self, **kw): """ alpha = self.from_const_data(self.parsed.alpha) X: Symbol = self.args[0] - out = op.nn_relu(op.negative(X)) - out = op.mul(alpha, out) - out = op.sub(op.nn_relu(X), out) + out = optype.infer_single(opclass.negative(X)) + out = optype.infer_single(opclass.relu(out)) + out = optype.infer_single(opclass.mul(alpha, out)) + out = optype.infer_single(opclass.sub(optype.infer_single(opclass.relu(X)), out)) return out.like(self.graph) @@ -250,7 +256,8 @@ def __call__(self, **kw): B: Symbol = self.args[1] assert self.from_symbol(B).is_param(), B B = self.from_symbol(B).from_np_data(1. / self.from_symbol(B).numpy()) - return op.mul(A, B).like(self.graph) + out = optype.infer_single(opclass.mul(A, B)) + return out.like(self.graph) # move to fuse constant # class FuseNaiveMathmatic(Transformer): diff --git a/python/mrt/quantization/precision.py b/python/mrt/quantization/precision.py index faf074d..e479223 100644 --- a/python/mrt/quantization/precision.py +++ b/python/mrt/quantization/precision.py @@ -6,7 +6,7 @@ import math import numpy as np -from mrt.mir import op +from mrt.mir import op, optype, opclass, opns from mrt.mir.opns import * from mrt.mir.symbol import Symbol, visit, transform @@ -204,8 +204,9 @@ def __call__(self, **kw): # print("infered prec:", oprec) if out.precision_defined and oprec > out.precision: out.precision, oprec = oprec, out.precision - out = out.from_symbol(op.pclip(out.graph, precision=oprec).like( - out.graph, extra_attrs=out.extra_attrs)) + out = out.from_symbol(optype.infer_single(opclass.MRT_OP_MAP[opns.PCLIP]( + out.graph, precision=oprec)).like( + out.graph, extra_attrs=out.extra_attrs)) out.precision = oprec out.validate_precision() diff --git a/python/mrt/quantization/segement.py b/python/mrt/quantization/segement.py index c49b90d..a4635ca 100644 --- a/python/mrt/quantization/segement.py +++ b/python/mrt/quantization/segement.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from mrt.mir.symbol import * -from mrt.mir import op, opns, helper +from mrt.mir import op, opns, helper, opclass from .scaler import WithScale from .transform import RunOnce @@ -118,7 +118,7 @@ def _update_params(sym: Symbol): kwargs['pointer']["head"] = self.head kwargs['pointer']["head_params"] = self.head_params - return op.Tuple(*outs).like(self.graph) + return opclass.MRT_OP_MAP[opns.TUPLE](*outs).like(self.graph) @dataclass(repr=False) class Merger(WithScale, RunOnce): diff --git a/python/mrt/quantization/transform.py b/python/mrt/quantization/transform.py index fb15319..8cd51b4 100644 --- a/python/mrt/quantization/transform.py +++ b/python/mrt/quantization/transform.py @@ -7,8 +7,9 @@ import numpy as np from mrt.mir.symbol import * +from mrt.mir.mhsymbol import Graph -from mrt.mir import op, opns +from mrt.mir import op, opns, opclass from mrt.mir.attrs import _BaseAttrs, parse_attrs from mrt.common.utils import N @@ -147,7 +148,7 @@ def from_np_data(self, data: np.ndarray | typing.Union[int, float], prefix="%") data = np.array(data) self.params[name] = data.astype(self.graph.dtype) ## return type(self). # Mark! - return op.variable(name, data.shape, self.graph.dtype).like(self.graph) + return opclass.var(name, data.shape, self.graph.dtype).like(self.graph) def is_input(self) -> bool: return op.is_input(self.graph, self.params) diff --git a/python/mrt/runtime/inference.py b/python/mrt/runtime/inference.py index ec62979..3b48afe 100644 --- a/python/mrt/runtime/inference.py +++ b/python/mrt/runtime/inference.py @@ -2,6 +2,7 @@ import numpy as np from mrt.mir.symbol import * +from mrt.mir.mhsymbol import MultiHeadSymbol from mrt.mir.opns import * from mrt import frontend as ft diff --git a/tests/mir/test.op_create.py b/tests/mir/test.op_create.py index 7707afb..004cb5a 100644 --- a/tests/mir/test.op_create.py +++ b/tests/mir/test.op_create.py @@ -52,6 +52,11 @@ def test_create_conv2d_op(): assert conv2d_a.attrs != None assert conv2d_a.strides != None + args = [X, W] + attrs = {'strides':(3,3)} + conv2d_f = opclass.conv2d(*args, **attrs) + assert isinstance(conv2d_f, opclass.Conv2D), 'conv2d_f isnot a Conv2D' + print(f'Got {conv2d_a.name} strides: {conv2d_a.strides}') print(f'Got {conv2d_a.name} padding: {conv2d_a.padding}') print(f'Show {conv2d_a.name} {conv2d_a}') @@ -76,6 +81,7 @@ def test_create_conv2d_op(): args = [X1, W] attrs = {'strides':(3,3)} + # Symbol Compatible Init conv2d_d = opclass.Conv2D(*args, name='conv2d_d', **attrs) conv2d_e = opclass.Conv2D(*args, **attrs) diff --git a/tests/test.pytorch.py b/tests/test.pytorch.py index 499f66b..8af997b 100644 --- a/tests/test.pytorch.py +++ b/tests/test.pytorch.py @@ -40,8 +40,9 @@ # model inference context, like cpu, gpu, etc. config = { - "device": "cuda:0", - "target": "" } + #"device": "cuda:0", + "device": "cpu", + "target": ""} # TODO: load the model from torchvision model_name = "resnet18" # passed @@ -95,7 +96,7 @@ calibrate_repeats=16, force_run_from_trcb="Discretor", log_after_all=True, - # log_before_tr_or_cbs=[ "calibrate_run_0", ], + log_before_tr_or_cbs=[ "PrecisionRevisor", ], ): dis_tr = tr.discrete() From 8e7a44d64533a3f81f3b5d6d773d74a0661f394a Mon Sep 17 00:00:00 2001 From: corlfj Date: Thu, 27 Nov 2025 17:20:33 +0800 Subject: [PATCH 10/12] [mir]: rename transformer --- python/mrt/api.py | 6 ++-- python/mrt/frontend/expr.py | 8 +++--- python/mrt/frontend/pytorch/converter.py | 6 ++-- python/mrt/mir/op.py | 1 - .../transform.py => mir/symbol_pass.py} | 26 ++++++++--------- python/mrt/quantization/calibrate.py | 6 ++-- python/mrt/quantization/discrete.py | 1 - python/mrt/quantization/fixed_point.py | 1 - python/mrt/quantization/fuse.py | 28 +++++++++---------- python/mrt/quantization/precision.py | 6 ++-- python/mrt/quantization/scaler.py | 2 +- python/mrt/quantization/segement.py | 6 ++-- python/mrt/runtime/inference.py | 4 +-- 13 files changed, 47 insertions(+), 54 deletions(-) rename python/mrt/{quantization/transform.py => mir/symbol_pass.py} (91%) diff --git a/python/mrt/api.py b/python/mrt/api.py index b282f6f..7e55898 100644 --- a/python/mrt/api.py +++ b/python/mrt/api.py @@ -26,7 +26,7 @@ from .quantization.discrete import Discretor from .quantization.precision import PrecisionRevisor -from .quantization.transform import TransformerT +from .mir.symbol_pass import SymTransformerT @dataclass class TraceConfig(config._BaseConfig): @@ -174,7 +174,7 @@ def _new(self, tr_name: str, _stat_type = self._stat_type) def checkpoint_run(self, - *callbacks: typing.List[TransformerT], + *callbacks: typing.List[SymTransformerT], tr_name: typing.Optional[str] = None, **kwargs) -> Trace: C = TraceConfig.G() @@ -200,7 +200,7 @@ def checkpoint_run(self, for cb in callbacks: # deep copy params to avoid conflict status params = {k: v for k, v in out.params.items()} - print("Apply Trace: {:25} Transformer: {}".format( + print("Apply Trace: {:25} SymbolTransformer: {}".format( tr_name, cb.__name__)) if cb.__name__ in C.log_before_tr_or_cbs: diff --git a/python/mrt/frontend/expr.py b/python/mrt/frontend/expr.py index 0341c58..ff8d089 100644 --- a/python/mrt/frontend/expr.py +++ b/python/mrt/frontend/expr.py @@ -90,7 +90,7 @@ def _cast_expr(node: RelayExpr): elif isinstance(node, relay.expr.If): args = [ node.cond, node.true_branch, node.false_branch ] args = [symbol_map[i] for i in args] - symbol_map[node] = opclass.extern_op_func(IF)(*args, **attrs) + symbol_map[node] = opclass.extern_opfunc(IF)(*args, **attrs) elif isinstance(node, relay.expr.Call): op_name = node.op.name if op_name in [CONCAT, ADV_INDEX]: @@ -109,14 +109,14 @@ def _cast_expr(node: RelayExpr): attrs.pop("dtype") elif op_name == GET_VALID_COUNT: attrs.pop("score_threshold") - symbol_map[node] = opclass.extern_op_func(op_name)(*args, **attrs) + symbol_map[node] = opclass.extern_opfunc(op_name)(*args, **attrs) elif isinstance(node, relay.TupleGetItem): args = [ symbol_map[node.tuple_value], ] attrs['index'] = node.index - symbol_map[node] = opclass.extern_op_func(TUPLE_GET_ITEM)(*args, **attrs) + symbol_map[node] = opclass.extern_opfunc(TUPLE_GET_ITEM)(*args, **attrs) elif isinstance(node, relay.Tuple): args = [ symbol_map[f] for f in node.fields ] - symbol_map[node] = opclass.extern_op_func(TUPLE)(*args, **attrs) + symbol_map[node] = opclass.extern_opfunc(TUPLE)(*args, **attrs) else: raise RuntimeError( "MRT not support expr type:{}".format(type(node))) diff --git a/python/mrt/frontend/pytorch/converter.py b/python/mrt/frontend/pytorch/converter.py index 0aa744d..8c5ddc8 100644 --- a/python/mrt/frontend/pytorch/converter.py +++ b/python/mrt/frontend/pytorch/converter.py @@ -242,10 +242,8 @@ def _retrieve_args(node): if mapper.op_name == TUPLE_GET_ITEM and args[0].op_name == BATCH_NORM: out = args[0] else: - out = Symbol(*args, - name=node.name, op_name=mapper.op_name, - extra_attrs={ "shape": shape, "dtype": dtype }, - **attrs) + out = opclass.extern_opfunc(mapper.op_name)(*args, name=node.name, + extra_attrs={"shape": shape, "dtype": dtype}, **attrs) env[node] = out else: raise ValueError(f"Unsupported op {node.op}") diff --git a/python/mrt/mir/op.py b/python/mrt/mir/op.py index 84f9498..ef2cab3 100644 --- a/python/mrt/mir/op.py +++ b/python/mrt/mir/op.py @@ -62,7 +62,6 @@ def retrieve_operator(symbol: Symbol) -> Symbol: # # class Conv2D(Symbol): # # strides: # -# # TODO: define op function # # def conv2d(X, weight, bias, strides=(1,1)...): # # return Symbol(args=[X, weight, bias], # # attrs={ "strides": strides }) diff --git a/python/mrt/quantization/transform.py b/python/mrt/mir/symbol_pass.py similarity index 91% rename from python/mrt/quantization/transform.py rename to python/mrt/mir/symbol_pass.py index 8cd51b4..e92de3a 100644 --- a/python/mrt/quantization/transform.py +++ b/python/mrt/mir/symbol_pass.py @@ -47,7 +47,7 @@ def is_near(self, *names, check_args: bool = True) -> bool: def to_dict(self): return self.graph.to_dict() @classmethod - def from_dict(cls, d: dict, **kwargs) -> WithParameters: + def from_dict(cls, d: dict, **kwargs) -> SymbolParameters: return cls(Symbol.from_dict(d, **kwargs), {}) @property def args(self): @@ -76,7 +76,7 @@ def set_extra_attrs(self, **kwargs): """ @dataclass(repr=False) -class WithParameters(SymbolBridge): # SymbolManipulator / Pass +class SymbolParameters(SymbolBridge): graph: Symbol params: ParametersT = field(repr=False) """ Parameters should not be changed in transformer, @@ -126,17 +126,17 @@ def _f(data, dtype): def from_const_data(self, data: typing.Union[int, float]) -> Symbol: return self.from_np_data(data) - def from_symbol(self, sym: Symbol) -> typing.Type[WithParameters]: #TODO + def from_symbol(self, sym: Symbol) -> typing.Type[SymbolParameters]: return type(self)(sym, self.params) def from_np_data(self, data: np.ndarray | typing.Union[int, float], prefix="%") -> Symbol: """ out = Return Symbol out = op.add(out, B) - self: WithParameter + self: SymbolParameter self.graph: Symbol self.from_symbol(out).from_np_data() - out = Return WithParameter + out = Return Symbol out.from_np_data() op.add(out.graph, B) @@ -159,14 +159,14 @@ def is_variable(self) -> bool: def is_operator(self) -> bool: return op.is_operator(self.graph, self.params) -TransformerT = typing.Callable[[Graph], Graph] -""" Transformer Callback Function Type, - inherited from WithParameters. +SymTransformerT = typing.Callable[[Graph], Graph] +""" Symbol-Transformer Callback Function Type, + inherited from SymbolParameters. """ @dataclass(repr=False) -class Transformer(WithParameters): - """ Symbol Transformer """ +class SymbolTransformer(SymbolParameters): + """ Symbol Transformer(Manipulator) """ RUN_ONCE: typing.ClassVar[bool] =False @@ -181,7 +181,7 @@ def _run(sym: Symbol): # use current cls to apply transform, this # may loss some information from origin # symbol, so record as `origin` in call. - out = cls.base(sym, params) # Type as Transformer + out = cls.base(sym, params) # Type as SymbolTransformer out = out(origin=sym, **kwargs) or sym # Type as Symbol assert isinstance(out, Symbol), ( "transform output type should be {}," @@ -212,7 +212,7 @@ def _run(sym: Symbol): # _tfm.__name__ = cls.__name__ # return _tfm - def __call__(self, *args, **kw) -> typing.Optional[Transformer]: + def __call__(self, *args, **kw) -> typing.Optional[SymbolTransformer]: """ Parameters: origin: original symbol passed from last transformer. @@ -220,7 +220,7 @@ def __call__(self, *args, **kw) -> typing.Optional[Transformer]: raise NotImplementedError() @dataclass(repr=False) -class RunOnce(Transformer): +class RunOnce(SymbolTransformer): RUN_ONCE: typing.ClassVar[bool] = True def __init__(self, *args): # symbol: Symbol, params: ParametersT):#, parsed: _BaseAttrs=None): diff --git a/python/mrt/quantization/calibrate.py b/python/mrt/quantization/calibrate.py index b1cc081..97b0cca 100644 --- a/python/mrt/quantization/calibrate.py +++ b/python/mrt/quantization/calibrate.py @@ -10,13 +10,13 @@ from mrt.mir.symbol import * from mrt.runtime import inference -from .transform import Transformer +from mrt.mir.symbol_pass import SymbolTransformer SamplingFuncT = typing.Callable[ [typing.Union[OpNumpyT, float]], typing.Any] @dataclass(repr=False) -class Calibrator(Transformer): +class Calibrator(SymbolTransformer): @property def raw_data(self) -> OpOutputT | None: return self.extra_attrs.get("raw_data", None) @@ -100,7 +100,7 @@ def _assert(self, val, expect): @dataclass(repr=False) -class Sampling(Transformer): +class Sampling(SymbolTransformer): @property def data(self) -> typing.Any: return self.extra_attrs.get("data", None) diff --git a/python/mrt/quantization/discrete.py b/python/mrt/quantization/discrete.py index ba9a528..811d56e 100644 --- a/python/mrt/quantization/discrete.py +++ b/python/mrt/quantization/discrete.py @@ -17,7 +17,6 @@ from .scaler import * from .calibrate import Sampling -from .transform import Transformer from .precision import WithPrecision __ALL__ = [ diff --git a/python/mrt/quantization/fixed_point.py b/python/mrt/quantization/fixed_point.py index 61450a0..07cf4f6 100644 --- a/python/mrt/quantization/fixed_point.py +++ b/python/mrt/quantization/fixed_point.py @@ -17,7 +17,6 @@ from mrt.common.config import _BaseConfig from mrt.common.utils import number_to_bits -from .transform import Transformer logger = logging.getLogger("exporter") diff --git a/python/mrt/quantization/fuse.py b/python/mrt/quantization/fuse.py index d6f110d..b664522 100644 --- a/python/mrt/quantization/fuse.py +++ b/python/mrt/quantization/fuse.py @@ -11,28 +11,28 @@ from mrt.runtime import inference from mrt.common.utils import N, product -from .transform import Transformer +from mrt.mir.symbol_pass import SymbolTransformer -# TODO: add op pass register map. -class FuseDropout(Transformer): +class FuseDropout(SymbolTransformer): #out = filter_operators(DROP_OUT)(__call__) # def out(): @filter_operators(DROP_OUT) def __call__(self, **kwargs): return self.args[0] -class FuseIdentity(Transformer): + +class FuseIdentity(SymbolTransformer): @filter_operators(IDENTITY) def __call__(self, **kwargs): return self.args[0] -class FuseConstant(Transformer): +class FuseConstant(SymbolTransformer): threshold: typing.ClassVar[float] = 1e-5 def np_is_zero(self, data) -> float: return np.abs(data).max() < self.threshold - def __call__(self: Transformer, **kw): + def __call__(self: SymbolTransformer, **kw): if self.is_operator() and all([self.from_symbol(c).is_param() for c in self.args]): data = inference.run_single( self.graph, [self.from_symbol(a).numpy() for a in self.args]) @@ -62,7 +62,7 @@ def __call__(self: Transformer, **kw): return self.as_parameter(data) -class FuseBatchNorm(Transformer): +class FuseBatchNorm(SymbolTransformer): @filter_operators(BATCH_NORM) def __call__(self, **kw): X, gamma, beta, mean, var = self.args @@ -120,7 +120,7 @@ def __call__(self, **kw): out = optype.infer_single(out) return out.like(self.graph) -class FuseTupleGetItem(Transformer): +class FuseTupleGetItem(SymbolTransformer): @filter_operators(TUPLE_GET_ITEM) def __call__(self, **kw): X: Symbol = self.args[0] @@ -130,7 +130,7 @@ def __call__(self, **kw): # assert self.parsed.index == 0 # return X -class FuseAvgPool2D(Transformer): +class FuseAvgPool2D(SymbolTransformer): def __call__(self, **kw): out = self._fuse_adaptive_avg_pool2d() out = out or self._fuse_avg_pool2d() @@ -201,7 +201,7 @@ def _fuse_adaptive_avg_pool2d(self): out = optype.infer_single(opclass.conv2d(X, W, **attrs)) return out.like(self.graph) -class FuseNaiveSoftmax(Transformer): +class FuseNaiveSoftmax(SymbolTransformer): def __call__(self, **kw): return self.graph # not fuse pass @@ -210,7 +210,7 @@ def __call__(self, **kw): assert self.is_variable() or not self.from_symbol(self.args[0]).is_op(SOFTMAX, LOG_SOFTMAX) return self.graph -class FuseMean(Transformer): +class FuseMean(SymbolTransformer): @filter_operators(MEAN) def __call__(self, **kw): X: Symbol = self.args[0] @@ -229,7 +229,7 @@ def __call__(self, **kw): out = optype.infer_single(opclass.mul(out, scale)) return out.like(self.graph) -class FuseLeakyReLU(Transformer): +class FuseLeakyReLU(SymbolTransformer): @filter_operators(LEAKY_RELU) def __call__(self, **kw): """ Customized rewrite pass Introduction. @@ -248,7 +248,7 @@ def __call__(self, **kw): return out.like(self.graph) -class FuseDivide(Transformer): +class FuseDivide(SymbolTransformer): @filter_operators(DIV) def __call__(self, **kw): """ Transform div to mul if possible. """ @@ -260,7 +260,7 @@ def __call__(self, **kw): return out.like(self.graph) # move to fuse constant -# class FuseNaiveMathmatic(Transformer): +# class FuseNaiveMathmatic(SymbolTransformer): # def __call__(self): # if self.is_op(BIAS_ADD): # X, B = self.args diff --git a/python/mrt/quantization/precision.py b/python/mrt/quantization/precision.py index e479223..7645ac1 100644 --- a/python/mrt/quantization/precision.py +++ b/python/mrt/quantization/precision.py @@ -14,16 +14,14 @@ number_to_bits, count_to_bits, bits_to_number from mrt.common.types import ParametersT -from .transform import SymbolBridge, Transformer +from mrt.mir.symbol_pass import SymbolBridge, SymbolTransformer __ALL__ = [ "WithPrecision", "InferPrecision", "QuantizedInfo", ] @dataclass(repr=False) -#class WithPrecision(Symbol): class WithPrecision(SymbolBridge): - #class WithPrecision(Transformer): MAX_BIT: typing.ClassVar[int] = 32 @classmethod @@ -176,7 +174,7 @@ def _infer_attr_prec(s: WithPrecision): return s.parsed.precision @dataclass(repr=False) -class PrecisionRevisor(WithPrecision, Transformer): +class PrecisionRevisor(WithPrecision, SymbolTransformer): def __call__(self, **kw): out = self if out.is_input(): diff --git a/python/mrt/quantization/scaler.py b/python/mrt/quantization/scaler.py index 0ad2f76..bcf847b 100644 --- a/python/mrt/quantization/scaler.py +++ b/python/mrt/quantization/scaler.py @@ -7,7 +7,7 @@ from mrt.mir.opns import * from mrt.mir.symbol import * -from .transform import SymbolBridge +from mrt.mir.symbol_pass import SymbolBridge @dataclass(repr=False) #class WithScale(Symbol): diff --git a/python/mrt/quantization/segement.py b/python/mrt/quantization/segement.py index a4635ca..20f7f68 100644 --- a/python/mrt/quantization/segement.py +++ b/python/mrt/quantization/segement.py @@ -3,10 +3,10 @@ from dataclasses import dataclass, field from mrt.mir.symbol import * -from mrt.mir import op, opns, helper, opclass +from mrt.mir import op, opns, helper, optype, opclass from .scaler import WithScale -from .transform import RunOnce +from mrt.mir.symbol_pass import RunOnce _SCALE_CONSTANT_OPS = [ opns.VAR, @@ -118,7 +118,7 @@ def _update_params(sym: Symbol): kwargs['pointer']["head"] = self.head kwargs['pointer']["head_params"] = self.head_params - return opclass.MRT_OP_MAP[opns.TUPLE](*outs).like(self.graph) + return optype.infer_single(opclass.MRT_OP_MAP[opns.TUPLE](*outs)).like(self.graph) @dataclass(repr=False) class Merger(WithScale, RunOnce): diff --git a/python/mrt/runtime/inference.py b/python/mrt/runtime/inference.py index 3b48afe..e2c76fd 100644 --- a/python/mrt/runtime/inference.py +++ b/python/mrt/runtime/inference.py @@ -6,11 +6,11 @@ from mrt.mir.opns import * from mrt import frontend as ft -from mrt.quantization.transform import WithParameters +from mrt.mir.symbol_pass import SymbolParameters from mrt.mir import op def run_single( - sym: WithParameters, + sym: SymbolParameters, args_data: typing.List[OpNumpyT], **kwargs) -> OpNumpyT: assert op.is_operator(sym), sym From 101d4143d060690cb33f4985db209e010e093c75 Mon Sep 17 00:00:00 2001 From: corlfj Date: Fri, 28 Nov 2025 11:58:53 +0800 Subject: [PATCH 11/12] [mir]: spliter/merger trace load --- python/mrt/api.py | 6 ++++-- python/mrt/quantization/segement.py | 19 ++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/python/mrt/api.py b/python/mrt/api.py index 7e55898..f095be0 100644 --- a/python/mrt/api.py +++ b/python/mrt/api.py @@ -227,8 +227,10 @@ def discrete(self) -> Trace: """Must pass params inside a dict, Cause it will be unfolded separately """ - kwargs_seg = {"pointer": {"head": {}, "head_params": {}, "seg_names": []}} - seg_tr = fuse_tr.checkpoint_run(seg.Spliter.get_transformer(), **kwargs_seg) + seg_tr = fuse_tr.checkpoint_run(seg.Spliter.get_transformer()) + kwargs_seg = {"ptr": {"head": seg_tr.symbol.extra_attrs.get("head"), + "head_params": seg_tr.symbol.extra_attrs.get("head_params"), + "seg_names": seg_tr.symbol.extra_attrs.get("seg_names")}} C = TraceConfig.G() calib_tr = seg_tr.calibrate( diff --git a/python/mrt/quantization/segement.py b/python/mrt/quantization/segement.py index 20f7f68..46332ee 100644 --- a/python/mrt/quantization/segement.py +++ b/python/mrt/quantization/segement.py @@ -114,20 +114,21 @@ def _update_params(sym: Symbol): # helper.format_print(head, self.head_params) - kwargs['pointer']["seg_names"] = self.seg_names - kwargs['pointer']["head"] = self.head - kwargs['pointer']["head_params"] = self.head_params - - return optype.infer_single(opclass.MRT_OP_MAP[opns.TUPLE](*outs)).like(self.graph) + # export to symbol_op Spliter_%N + out = optype.infer_single(opclass.MRT_OP_MAP[opns.TUPLE](*outs)).like(self.graph) + out.set_extra_attrs(seg_names=self.seg_names) + out.set_extra_attrs(head=self.head) + out.set_extra_attrs(head_params=self.head_params) + return out @dataclass(repr=False) class Merger(WithScale, RunOnce): - def __call__(self, spliter: Symbol, **kw): + def __call__(self, spliter: Symbol, **kwargs): assert self.op_name == opns.TUPLE - head = kw['pointer']["head"] - head_params = kw['pointer']["head_params"] - seg_names = kw['pointer']["seg_names"] + head = kwargs['ptr']["head"] + head_params = kwargs['ptr']["head_params"] + seg_names = kwargs['ptr']["seg_names"] tail_outs = dict(zip(seg_names, self.args)) From 74aa4f070905e649c02c686892e27daae8887de5 Mon Sep 17 00:00:00 2001 From: corlfj Date: Mon, 1 Dec 2025 20:13:16 +0800 Subject: [PATCH 12/12] [mir]: remove dataclass of SymbolBridge --- python/mrt/mir/mhsymbol.py | 1 - python/mrt/mir/opclass.py | 23 +++++++++++++++++++-- python/mrt/mir/symbol_pass.py | 28 ++++++++++++-------------- python/mrt/quantization/calibrate.py | 17 ++++++++++++---- python/mrt/quantization/discrete.py | 13 +++++++++--- python/mrt/quantization/fixed_point.py | 17 +++++++++++++--- python/mrt/quantization/precision.py | 11 +++++++--- python/mrt/quantization/scaler.py | 6 ++++-- python/mrt/quantization/segement.py | 10 +++++++-- 9 files changed, 91 insertions(+), 35 deletions(-) diff --git a/python/mrt/mir/mhsymbol.py b/python/mrt/mir/mhsymbol.py index bff35de..f2b5ada 100644 --- a/python/mrt/mir/mhsymbol.py +++ b/python/mrt/mir/mhsymbol.py @@ -16,7 +16,6 @@ def from_symbol(cls, symbol: symbol.Symbol, name: str = "main"): return MultiHeadSymbol({ name: symbol }) def as_tuple(self) -> typing.Tuple[typing.List[str], symbol.Symbol]: - from . import op # args = list(self.values()) # sym_type = type(args[0]) if args else Symbol mhs = self.origin or optype.infer_single(opclass.MRT_OP_MAP[opns.TUPLE](*list(self.values()))) diff --git a/python/mrt/mir/opclass.py b/python/mrt/mir/opclass.py index eb8972d..67eaa04 100644 --- a/python/mrt/mir/opclass.py +++ b/python/mrt/mir/opclass.py @@ -432,8 +432,27 @@ def maximum(X, name=None, extra_attrs=None) -> symbol.Symbol: def minimum(X, name=None, extra_attrs=None) -> symbol.Symbol: return symbol.Symbol(X, name=name or N.n(), op_name=opns.MINIMUM, extra_attrs=extra_attrs or {}) -def repeat(X, name=None, extra_attrs=None) -> symbol.Symbol: - return symbol.Symbol(X, name=name or N.n(), op_name=opns.REPEAT, extra_attrs=extra_attrs or {}) +#def repeat(X, name=None, extra_attrs=None) -> symbol.Symbol: +# return symbol.Symbol(X, name=name or N.n(), op_name=opns.REPEAT, extra_attrs=extra_attrs or {}) +class Repeat(symbol.Symbol): + op_name = opns.REPEAT + + @property + def repeats(self) -> typing.Optional[int]: + return self.attrs['repeats'] + + @property + def axis(self) -> typing.Optional[int]: + return self.attrs['axis'] + + def __init__(self, X, name=None, repeats=None, axis=None, extra_attrs=None): + super().__init__(X, name=name or N.n(), op_name=opns.REPEAT, extra_attrs=extra_attrs or {}, **{'repeats': repeats, 'axis': axis}) + + @classmethod + def from_dict(cls, d: dict, **kwargs): + return _from_dict_attrs(cls, d, ['repeats', 'axis'], **kwargs) +def repeat(*args, **kwargs): + return Repeat(*args, **kwargs) class Squeeze(symbol.Symbol): op_name = opns.SQUEEZE diff --git a/python/mrt/mir/symbol_pass.py b/python/mrt/mir/symbol_pass.py index e92de3a..49b349d 100644 --- a/python/mrt/mir/symbol_pass.py +++ b/python/mrt/mir/symbol_pass.py @@ -2,7 +2,7 @@ import typing from functools import wraps -from dataclasses import dataclass, field +from dataclasses import field import numpy as np @@ -14,7 +14,7 @@ from mrt.common.utils import N -@dataclass(repr=False) + class SymbolBridge: # SymbolManipulator / Pass graph: Symbol @@ -44,38 +44,37 @@ def is_op(self, *op_names) -> bool: return self.graph.op_name in op_names def is_near(self, *names, check_args: bool = True) -> bool: return self.graph.is_near(*names, check_args) - def to_dict(self): + def to_dict(self) -> dict: return self.graph.to_dict() @classmethod def from_dict(cls, d: dict, **kwargs) -> SymbolParameters: return cls(Symbol.from_dict(d, **kwargs), {}) @property - def args(self): + def args(self) -> list: return self.graph.args @property - def op_name(self): + def op_name(self) -> str: return self.graph.op_name @property - def name(self): + def name(self) -> str: return self.graph.name @property - def shape(self): + def shape(self) -> typing.Optional[ShapeT]: return self.graph.shape @property - def dtype(self): + def dtype(self) -> str: return self.graph.dtype @property - def attrs(self): + def attrs(self) -> dict: return self.graph.attrs @property - def extra_attrs(self): + def extra_attrs(self) -> dict: return self.graph.extra_attrs def set_extra_attrs(self, **kwargs): return self.graph.extra_attrs.update(kwargs) """Member Symbol End """ -@dataclass(repr=False) class SymbolParameters(SymbolBridge): graph: Symbol params: ParametersT = field(repr=False) @@ -164,12 +163,12 @@ def is_operator(self) -> bool: inherited from SymbolParameters. """ -@dataclass(repr=False) class SymbolTransformer(SymbolParameters): """ Symbol Transformer(Manipulator) """ - RUN_ONCE: typing.ClassVar[bool] =False + RUN_ONCE: typing.ClassVar[bool] = False + # inherit SymbolParameters __init__ def __init__(self, *args): super().__init__(*args) @@ -219,10 +218,9 @@ def __call__(self, *args, **kw) -> typing.Optional[SymbolTransformer]: """ raise NotImplementedError() -@dataclass(repr=False) class RunOnce(SymbolTransformer): RUN_ONCE: typing.ClassVar[bool] = True def __init__(self, *args): # symbol: Symbol, params: ParametersT):#, parsed: _BaseAttrs=None): - super().__init__(*args) + super().__init__(*args) diff --git a/python/mrt/quantization/calibrate.py b/python/mrt/quantization/calibrate.py index 97b0cca..278e3c1 100644 --- a/python/mrt/quantization/calibrate.py +++ b/python/mrt/quantization/calibrate.py @@ -4,7 +4,7 @@ import numpy as np -from dataclasses import dataclass, field, InitVar +from dataclasses import field, InitVar from mrt.mir import op, opns from mrt.mir.symbol import * @@ -15,7 +15,6 @@ SamplingFuncT = typing.Callable[ [typing.Union[OpNumpyT, float]], typing.Any] -@dataclass(repr=False) class Calibrator(SymbolTransformer): @property def raw_data(self) -> OpOutputT | None: @@ -31,6 +30,10 @@ def data(self) -> typing.List[OpNumpyT]: def data(self, val): self.set_extra_attrs(data=val) + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + def _rand_data(self, enabled: bool = False, absmax: float | None = None, @@ -99,7 +102,6 @@ def _assert(self, val, expect): assert val == expect, "{} vs. {}".format(val, expect) -@dataclass(repr=False) class Sampling(SymbolTransformer): @property def data(self) -> typing.Any: @@ -108,6 +110,10 @@ def data(self) -> typing.Any: def data(self, val): self.set_extra_attrs(data=val) + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + @classmethod def sampling(cls, np_data: np.ndarray) -> typing.Any: raise NotImplementedError() @@ -122,10 +128,13 @@ def __call__(self, origin: Symbol, **kw): self.data = self.sampling(origin.extra_attrs.get("raw_data")) return self.graph -@dataclass(repr=False) class SymmetricMinMaxSampling(Sampling): threshold: typing.ClassVar[float] = 1e-5 + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + @classmethod def sampling(cls, data: typing.List[OpNumpyT]) -> float: if isinstance(data, list): diff --git a/python/mrt/quantization/discrete.py b/python/mrt/quantization/discrete.py index 811d56e..a2402b9 100644 --- a/python/mrt/quantization/discrete.py +++ b/python/mrt/quantization/discrete.py @@ -35,9 +35,13 @@ def undefined(self) -> bool: return self.scale is None and self.precision is None -@dataclass(repr=False) class QuantInfo(WithScale, WithPrecision, Sampling): - requant_ops: typing.Dict[DiscreteInfo, Symbol] = field(repr=False, default_factory=dict) + requant_ops: typing.Dict[DiscreteInfo, Symbol] = {} #field(default_factory=dict) + + # inherit SymbolParameters __init__ + def __init__(self, *args): + self.requant_ops = {} + super().__init__(*args) def from_symbol(self, sym: Symbol) -> typing.Self: return type(self)(sym, self.params) @@ -336,7 +340,6 @@ def op_softmax_rules(s: QuantInfo): scale_rule=softmax_scale_rules ) -@dataclass(repr=False) class Discretor(QuantInfo): """ does operation -> out @@ -363,6 +366,10 @@ class Discretor(QuantInfo): output precision <- precision(target) output scale <- scale """ + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + def __call__(self, **kw): if self.is_variable(): return self.graph diff --git a/python/mrt/quantization/fixed_point.py b/python/mrt/quantization/fixed_point.py index 07cf4f6..6dd07df 100644 --- a/python/mrt/quantization/fixed_point.py +++ b/python/mrt/quantization/fixed_point.py @@ -52,8 +52,11 @@ class ExporterConfig(_BaseConfig): use_int_requant=True, use_int_dtype=True) -@dataclass(repr=False) class Exporter(QuantInfo): + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + def map_int_requant(self): """ requant(X, rescale) = X * rescale @@ -166,8 +169,12 @@ def __call__(self, **kw): # assert absmax - 0.01 <= out.int_max() return out -@dataclass(repr=False) +# TODO: deprecated? class Simulator(QuantInfo): + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + def round(self, out: Symbol): # data_0_5 = self.from_const_data(0.5) # out = op.add(out, data_0_5) @@ -203,8 +210,12 @@ def __call__(self, with_clip=False, with_round=False, **kw): return out.like(self.graph) -@dataclass(repr=False) +# TODO: deprecated? class FixPoint(QuantInfo): + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + def map_requant(self) -> FixPoint: if (self.args[0]).is_input(): return self diff --git a/python/mrt/quantization/precision.py b/python/mrt/quantization/precision.py index 7645ac1..0cbfb47 100644 --- a/python/mrt/quantization/precision.py +++ b/python/mrt/quantization/precision.py @@ -1,7 +1,6 @@ from __future__ import annotations import typing -from dataclasses import dataclass import math import numpy as np @@ -20,10 +19,13 @@ "InferPrecision", "QuantizedInfo", ] -@dataclass(repr=False) class WithPrecision(SymbolBridge): MAX_BIT: typing.ClassVar[int] = 32 + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + @classmethod def _validate_precision(cls, prec, msg=None): assert isinstance(prec, int), self.precision @@ -173,8 +175,11 @@ def _infer_attr_prec(s: WithPrecision): assert s.parsed.precision == s.precision return s.parsed.precision -@dataclass(repr=False) class PrecisionRevisor(WithPrecision, SymbolTransformer): + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + def __call__(self, **kw): out = self if out.is_input(): diff --git a/python/mrt/quantization/scaler.py b/python/mrt/quantization/scaler.py index bcf847b..0dac575 100644 --- a/python/mrt/quantization/scaler.py +++ b/python/mrt/quantization/scaler.py @@ -9,9 +9,11 @@ from mrt.mir.symbol_pass import SymbolBridge -@dataclass(repr=False) -#class WithScale(Symbol): class WithScale(SymbolBridge): + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + @classmethod def _validate_scale(cls, scale, msg=None): if isinstance(scale, (list, tuple)): diff --git a/python/mrt/quantization/segement.py b/python/mrt/quantization/segement.py index 46332ee..f1b3c31 100644 --- a/python/mrt/quantization/segement.py +++ b/python/mrt/quantization/segement.py @@ -23,12 +23,15 @@ opns.CLIP, opns.AS_TYPE, ] -@dataclass(repr=False) class Spliter(RunOnce): head: typing.Optional[dict] = None head_params: typing.Optional[typing.Dict[str, OpNumpyT]] = None seg_names: typing.List[str] = field(default_factory=list) + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + def __call__(self, **kwargs): """ Auto split the model. """ refs = { self.name: 1 } # add refs for root symbol @@ -121,8 +124,11 @@ def _update_params(sym: Symbol): out.set_extra_attrs(head_params=self.head_params) return out -@dataclass(repr=False) class Merger(WithScale, RunOnce): + # inherit SymbolParameters __init__ + def __init__(self, *args): + super().__init__(*args) + def __call__(self, spliter: Symbol, **kwargs): assert self.op_name == opns.TUPLE