Skip to content

Commit 41820b0

Browse files
romancstubbiali
authored andcommitted
refactor[cartesian]: gtir from backend and type hints for dace backend (GridTools#2103)
## Description This PR brings another two cleanup commits: 1. `gtir` doesn't need to be passed around when calling `BackendCodegen` classes. That info can be derived from the backend (class), which is a member of `BackendCodegen`. 2. Use type hints in `dace_backend` (only for the functions that we aren't about to torch anyway with PR GridTools#2067) ## Requirements - [ ] All fixes and/or new features come with corresponding tests. Covered by existing tests. - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/README.md) folder. N/A
1 parent daeda79 commit 41820b0

File tree

5 files changed

+109
-112
lines changed

5 files changed

+109
-112
lines changed

src/gt4py/cartesian/backend/cuda_backend.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,10 @@
2020
bindings_main_template,
2121
pybuffer_to_sid,
2222
)
23-
from gt4py.cartesian.gtc import gtir
2423
from gt4py.cartesian.gtc.common import DataType
2524
from gt4py.cartesian.gtc.cuir import cuir, cuir_codegen, extent_analysis, kernel_fusion
2625
from gt4py.cartesian.gtc.cuir.oir_to_cuir import OIRToCUIR
2726
from gt4py.cartesian.gtc.gtir_to_oir import GTIRToOIR
28-
from gt4py.cartesian.gtc.passes.gtir_pipeline import GtirPipeline
2927
from gt4py.cartesian.gtc.passes.oir_optimizations.caches import FillFlushToLocalKCaches
3028
from gt4py.cartesian.gtc.passes.oir_optimizations.pruning import NoFieldAccessPruning
3129
from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline
@@ -42,9 +40,8 @@ def __init__(self, class_name: str, module_name: str, backend: CudaBackend) -> N
4240
self.module_name = module_name
4341
self.backend = backend
4442

45-
def __call__(self, stencil_ir: gtir.Stencil) -> dict[str, dict[str, str]]:
46-
stencil_ir = GtirPipeline(stencil_ir, self.backend.builder.stencil_id).full()
47-
base_oir = GTIRToOIR().visit(stencil_ir)
43+
def __call__(self) -> dict[str, dict[str, str]]:
44+
base_oir = GTIRToOIR().visit(self.backend.builder.gtir)
4845
oir_pipeline = self.backend.builder.options.backend_opts.get(
4946
"oir_pipeline",
5047
DefaultPipeline(skip=[NoFieldAccessPruning], add_steps=[FillFlushToLocalKCaches]),

0 commit comments

Comments
 (0)